Introduce in-memory resource abstraction (#375)

* Introduce in-memory resource abstraction

This follows from discussion in #366.

The goal of this change is to allow for weights to be loaded from a copy
of `rust_model.ot` that is already present in memory. There are two ways
in which that data might be present:

1. As a `HashMap<String, Tensor>` from previous interaction with `tch`
2. As a contiguous buffer of the file data

One or the other mechanism might be preferable depending on how user
code is using the model data. In some sense, implementing a provider
based on the second option is more of a convenience method for the user
to avoid the `tch::nn::VarStore::load_from_stream` interaction.

I've changed the definition of the `ResourceProvider` trait to require
that it be both `Send` and `Sync`. There are currently certain contexts
where `dyn ResourceProvider + Send` is required, but in theory before
this change an implementation might not be `Send` (or `Sync`). The
existing providers are both `Send` and `Sync`, and it seems reasonable
(if technically incorrect) for user code to assume this to be true. I
don't see a downside to making this explicit, but that part of this
change might be better suited for separate discussion. I am not trying
to sneak it in.

The `enum Resource` data type is used here as a means to abstract over
the possible ways a `ResourceProvider` might represent an underlying
resource. Without this, it would be necessary to either call different
trait methods until one succeeded or implement `as_any` and downcast in
order to implement `load_weights` similarly to how it is now. Those
options seemed less preferable to creating a wrapper.

While it would be possible to replace all calls to `get_local_path` with
the `get_resource` API, removal of the existing function would be a very
big breaking change. As such, this change also introduces
`RustBertError::UnsupportedError` to allow for the different methods to
coexist. An alternative would be for the new `ResourceProvider`s to
write their resources to a temporary disk location and return an
appropriate path, but that is counter to the purpose of the new
`ResourceProvider`s and so I chose not to implement that.

* - Add `impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T>`
- Remove `Resource::NamedTensors`
- Change `BufferResource` to contain a `&[u8]` rather than `Vec<u8>`

* Further rework proposal for resources

* Use mutable references and locks

* Make model resources mutable in tests/examples

* Remove unnecessary mutability and TensorResource references

* Add `BufferResource` example

---------

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
Matt Weber 2023-05-26 13:23:28 -04:00 committed by GitHub
parent f591dc30b9
commit ba57704c6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 286 additions and 50 deletions

View File

@ -0,0 +1,97 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use std::sync::{Arc, RwLock};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider};
use tch::Device;
fn main() -> anyhow::Result<()> {
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
let weights = Arc::new(RwLock::new(get_weights()?));
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{sentence}");
}
let summarization_model =
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{sentence}");
}
Ok(())
}
fn get_weights() -> anyhow::Result<Vec<u8>, anyhow::Error> {
let model_resource = RemoteResource::from_pretrained(BartModelResources::DISTILBART_CNN_6_6);
Ok(std::fs::read(model_resource.get_local_path()?)?)
}
fn config(device: Device, model_data: Arc<RwLock<Vec<u8>>>) -> SummarizationConfig {
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Box::new(BufferResource { data: model_data });
SummarizationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: Some(142),
device,
..Default::default()
}
}

View File

@ -4,7 +4,7 @@ use rust_bert::deberta::{
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Kind, Tensor};
@ -27,7 +27,6 @@ fn main() -> anyhow::Result<()> {
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -39,7 +38,7 @@ fn main() -> anyhow::Result<()> {
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
vs.load(weights_path)?;
load_weights(&model_resource, &mut vs)?;
// Define input
let input = [("I love you.", "I like you.")];

View File

@ -993,14 +993,13 @@ impl BartGenerator {
tokenizer: TokenizerOption,
) -> Result<BartGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -22,6 +22,9 @@ pub enum RustBertError {
#[error("Value error: {0}")]
ValueError(String),
#[error("Unsupported operation")]
UnsupportedError,
}
impl From<std::io::Error> for RustBertError {

View File

@ -0,0 +1,71 @@
use crate::common::error::RustBertError;
use crate::resources::{Resource, ResourceProvider};
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
/// # In-memory raw buffer resource
pub struct BufferResource {
/// The data representing the underlying resource
pub data: Arc<RwLock<Vec<u8>>>,
}
impl ResourceProvider for BufferResource {
/// Not implemented for this resource type
///
/// # Returns
///
/// * `RustBertError::UnsupportedError`
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
Err(RustBertError::UnsupportedError)
}
/// Gets a wrapper referring to the in-memory resource.
///
/// # Returns
///
/// * `Resource` referring to the resource data
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{BufferResource, ResourceProvider};
/// let data = std::fs::read("path/to/rust_model.ot").unwrap();
/// let weights_resource = BufferResource::from(data);
/// let weights = weights_resource.get_resource();
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::Buffer(self.data.write().unwrap()))
}
}
impl From<Vec<u8>> for BufferResource {
fn from(data: Vec<u8>) -> Self {
Self {
data: Arc::new(RwLock::new(data)),
}
}
}
impl From<Vec<u8>> for Box<dyn ResourceProvider> {
fn from(data: Vec<u8>) -> Self {
Box::new(BufferResource {
data: Arc::new(RwLock::new(data)),
})
}
}
impl From<RwLock<Vec<u8>>> for BufferResource {
fn from(lock: RwLock<Vec<u8>>) -> Self {
Self {
data: Arc::new(lock),
}
}
}
impl From<RwLock<Vec<u8>>> for Box<dyn ResourceProvider> {
fn from(lock: RwLock<Vec<u8>>) -> Self {
Box::new(BufferResource {
data: Arc::new(lock),
})
}
}

View File

@ -1,5 +1,5 @@
use crate::common::error::RustBertError;
use crate::resources::ResourceProvider;
use crate::resources::{Resource, ResourceProvider};
use std::path::PathBuf;
/// # Local resource
@ -29,6 +29,26 @@ impl ResourceProvider for LocalResource {
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
Ok(self.local_path.clone())
}
/// Gets a wrapper around the path for a local resource.
///
/// # Returns
///
/// * `Resource` wrapping a `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_resource();
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::PathBuf(self.local_path.clone()))
}
}
impl From<PathBuf> for LocalResource {

View File

@ -1,6 +1,6 @@
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This crate relies on the concept of Resources to access the data used by the models.
//! This includes:
//! - model weights
//! - configuration files
@ -11,20 +11,33 @@
//! resource location. Two types of resources are pre-defined:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL
//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only
//! usable for weights)
//!
//! For both types of resources, the local location of the file can be retrieved using
//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
mod buffer;
mod local;
use crate::common::error::RustBertError;
pub use buffer::BufferResource;
pub use local::LocalResource;
use std::ops::DerefMut;
use std::path::PathBuf;
use std::sync::RwLockWriteGuard;
use tch::nn::VarStore;
/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources
pub trait ResourceProvider {
pub enum Resource<'a> {
PathBuf(PathBuf),
Buffer(RwLockWriteGuard<'a, Vec<u8>>),
}
/// # Resource Trait that can provide the location or data for the model, and location of
/// configuration or vocabulary resources
pub trait ResourceProvider: Send + Sync {
/// Provides the local path for a resource.
///
/// # Returns
@ -42,6 +55,42 @@ pub trait ResourceProvider {
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
/// Provides access to an underlying resource.
///
/// # Returns
///
/// * `Resource` wrapping a representation of a resource.
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{BufferResource, LocalResource, ResourceProvider};
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError>;
}
impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
T::get_local_path(self)
}
fn get_resource(&self) -> Result<Resource, RustBertError> {
T::get_resource(self)
}
}
/// Load the provided `VarStore` with model weights from the provided `ResourceProvider`
pub fn load_weights(
rp: &(impl ResourceProvider + ?Sized),
vs: &mut VarStore,
) -> Result<(), RustBertError> {
match rp.get_resource()? {
Resource::Buffer(mut data) => {
vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?;
Ok(())
}
Resource::PathBuf(path) => Ok(vs.load(path)?),
}
}
#[cfg(feature = "remote")]

View File

@ -93,6 +93,23 @@ impl ResourceProvider for RemoteResource {
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
Ok(cached_path)
}
/// Gets a wrapper around the local path for a remote resource.
///
/// # Returns
///
/// * `Resource` wrapping a `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, ResourceProvider};
/// let config_resource = RemoteResource::new("http://config_json_location", "configs");
/// let config_path = config_resource.get_resource();
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::PathBuf(self.get_local_path()?))
}
}
lazy_static! {

View File

@ -639,7 +639,6 @@ impl GPT2Generator {
tokenizer: TokenizerOption,
) -> Result<GPT2Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -647,7 +646,7 @@ impl GPT2Generator {
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -609,7 +609,6 @@ impl GptJGenerator {
tokenizer: TokenizerOption,
) -> Result<GptJGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -620,7 +619,7 @@ impl GptJGenerator {
if config.preload_on_cpu && device != Device::Cpu {
var_store.set_device(Device::Cpu);
}
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
if device != Device::Cpu {
var_store.set_device(device);
}

View File

@ -660,14 +660,13 @@ impl GptNeoGenerator {
tokenizer: TokenizerOption,
) -> Result<GptNeoGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -583,7 +583,6 @@ impl LongT5Generator {
tokenizer: TokenizerOption,
) -> Result<LongT5Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -591,7 +590,7 @@ impl LongT5Generator {
let config = LongT5Config::from_file(config_path);
let model = LongT5ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = config.bos_token_id;
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -538,7 +538,6 @@ impl M2M100Generator {
tokenizer: TokenizerOption,
) -> Result<M2M100Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -546,7 +545,7 @@ impl M2M100Generator {
let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -755,7 +755,6 @@ impl MarianGenerator {
tokenizer: TokenizerOption,
) -> Result<MarianGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -763,7 +762,7 @@ impl MarianGenerator {
let config = BartConfig::from_file(config_path);
let model = MarianForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -791,7 +791,6 @@ impl MBartGenerator {
tokenizer: TokenizerOption,
) -> Result<MBartGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -799,7 +798,7 @@ impl MBartGenerator {
let config = MBartConfig::from_file(config_path);
let model = MBartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -493,13 +493,12 @@ impl OpenAIGenerator {
generate_config.validate();
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -501,14 +501,13 @@ impl PegasusConditionalGenerator {
tokenizer: TokenizerOption,
) -> Result<PegasusConditionalGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = PegasusConfig::from_file(config_path);
let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -429,7 +429,6 @@ impl MaskedLanguageModel {
tokenizer: TokenizerOption,
) -> Result<MaskedLanguageModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let device = config.device;
let mut var_store = VarStore::new(device);
@ -441,7 +440,7 @@ impl MaskedLanguageModel {
let language_encode =
MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?;
var_store.load(weights_path)?;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
let mask_token = config.mask_token;
Ok(MaskedLanguageModel {
tokenizer,

View File

@ -632,7 +632,6 @@ impl QuestionAnsweringModel {
tokenizer: TokenizerOption,
) -> Result<QuestionAnsweringModel, RustBertError> {
let config_path = question_answering_config.config_resource.get_local_path()?;
let weights_path = question_answering_config.model_resource.get_local_path()?;
let device = question_answering_config.device;
let pad_idx = tokenizer
@ -670,7 +669,7 @@ impl QuestionAnsweringModel {
)));
}
var_store.load(weights_path)?;
crate::resources::load_weights(&question_answering_config.model_resource, &mut var_store)?;
Ok(QuestionAnsweringModel {
tokenizer,
pad_idx,

View File

@ -230,7 +230,7 @@ impl SentenceEmbeddingsModel {
);
let transformer =
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
var_store.load(transformer_weights_resource.get_local_path()?)?;
crate::resources::load_weights(&transformer_weights_resource, &mut var_store)?;
// Setup pooling layer
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);

View File

@ -615,7 +615,6 @@ impl SequenceClassificationModel {
tokenizer: TokenizerOption,
) -> Result<SequenceClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let device = config.device;
let mut var_store = VarStore::new(device);
@ -627,7 +626,7 @@ impl SequenceClassificationModel {
let sequence_classifier =
SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
var_store.load(weights_path)?;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
Ok(SequenceClassificationModel {
tokenizer,
sequence_classifier,

View File

@ -720,7 +720,6 @@ impl TokenClassificationModel {
tokenizer: TokenizerOption,
) -> Result<TokenClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let device = config.device;
let label_aggregation_function = config.label_aggregation_function;
@ -734,7 +733,7 @@ impl TokenClassificationModel {
TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
let batch_size = config.batch_size;
var_store.load(weights_path)?;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
Ok(TokenClassificationModel {
tokenizer,
token_sequence_classifier,

View File

@ -601,14 +601,13 @@ impl ZeroShotClassificationModel {
tokenizer: TokenizerOption,
) -> Result<ZeroShotClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let device = config.device;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let zero_shot_classifier =
ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
var_store.load(weights_path)?;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
Ok(ZeroShotClassificationModel {
tokenizer,
zero_shot_classifier,

View File

@ -906,14 +906,13 @@ impl ProphetNetConditionalGenerator {
tokenizer: TokenizerOption,
) -> Result<ProphetNetConditionalGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = ProphetNetConfig::from_file(config_path);
let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id);
let eos_token_ids = Some(vec![config.eos_token_id]);

View File

@ -1044,14 +1044,13 @@ impl ReformerGenerator {
tokenizer: TokenizerOption,
) -> Result<ReformerGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = ReformerConfig::from_file(config_path);
let model = ReformerModelWithLMHead::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -753,7 +753,6 @@ impl T5Generator {
tokenizer: TokenizerOption,
) -> Result<T5Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -761,7 +760,7 @@ impl T5Generator {
let config = T5Config::from_file(config_path);
let model = T5ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -1553,7 +1553,6 @@ impl XLNetGenerator {
tokenizer: TokenizerOption,
) -> Result<XLNetGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -1561,7 +1560,7 @@ impl XLNetGenerator {
let config = XLNetConfig::from_file(config_path);
let model = XLNetLMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id);
let eos_token_ids = Some(vec![config.eos_token_id]);

View File

@ -6,7 +6,7 @@ use rust_bert::albert::{
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
@ -27,7 +27,6 @@ fn albert_masked_lm() -> anyhow::Result<()> {
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
@ -36,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
load_weights(&weights_resource, &mut vs)?;
// Define input
let input = [