diff --git a/examples/buffer_resource.rs b/examples/buffer_resource.rs new file mode 100644 index 0000000..0f9e858 --- /dev/null +++ b/examples/buffer_resource.rs @@ -0,0 +1,97 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use std::sync::{Arc, RwLock}; + +use rust_bert::bart::{ + BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources, +}; +use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; +use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider}; +use tch::Device; + +fn main() -> anyhow::Result<()> { + let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ +from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \ +from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \ +a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \ +habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \ +used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \ +passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \ +weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \ +contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \ +and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \ +but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \ +\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \ +said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \ +said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors. \ +\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \ +a potentially habitable planet, but further observations will be required to say for sure. \" \ +K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \ +but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \ +on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \ +telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \ +about exoplanets like K2-18b."]; + + let weights = Arc::new(RwLock::new(get_weights()?)); + let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?; + + // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) + let output = summarization_model.summarize(&input); + for sentence in output { + println!("{sentence}"); + } + + let summarization_model = + SummarizationModel::new(config(Device::cuda_if_available(), weights))?; + + // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) + let output = summarization_model.summarize(&input); + for sentence in output { + println!("{sentence}"); + } + + Ok(()) +} + +fn get_weights() -> anyhow::Result, anyhow::Error> { + let model_resource = RemoteResource::from_pretrained(BartModelResources::DISTILBART_CNN_6_6); + Ok(std::fs::read(model_resource.get_local_path()?)?) +} + +fn config(device: Device, model_data: Arc>>) -> SummarizationConfig { + let config_resource = Box::new(RemoteResource::from_pretrained( + BartConfigResources::DISTILBART_CNN_6_6, + )); + let vocab_resource = Box::new(RemoteResource::from_pretrained( + BartVocabResources::DISTILBART_CNN_6_6, + )); + let merges_resource = Box::new(RemoteResource::from_pretrained( + BartMergesResources::DISTILBART_CNN_6_6, + )); + let model_resource = Box::new(BufferResource { data: model_data }); + + SummarizationConfig { + model_resource, + config_resource, + vocab_resource, + merges_resource: Some(merges_resource), + num_beams: 1, + length_penalty: 1.0, + min_length: 56, + max_length: Some(142), + device, + ..Default::default() + } +} diff --git a/examples/natural_language_inference_deberta.rs b/examples/natural_language_inference_deberta.rs index 1ce6585..debe53b 100644 --- a/examples/natural_language_inference_deberta.rs +++ b/examples/natural_language_inference_deberta.rs @@ -4,7 +4,7 @@ use rust_bert::deberta::{ DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification, DebertaMergesResources, DebertaModelResources, DebertaVocabResources, }; -use rust_bert::resources::{RemoteResource, ResourceProvider}; +use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider}; use rust_bert::Config; use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use tch::{nn, no_grad, Device, Kind, Tensor}; @@ -27,7 +27,6 @@ fn main() -> anyhow::Result<()> { let config_path = config_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?; let merges_path = merges_resource.get_local_path()?; - let weights_path = model_resource.get_local_path()?; // Set-up model let device = Device::Cpu; @@ -39,7 +38,7 @@ fn main() -> anyhow::Result<()> { )?; let config = DebertaConfig::from_file(config_path); let model = DebertaForSequenceClassification::new(vs.root(), &config)?; - vs.load(weights_path)?; + load_weights(&model_resource, &mut vs)?; // Define input let input = [("I love you.", "I like you.")]; diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index cb34c1e..60aa3cd 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -993,14 +993,13 @@ impl BartGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = BartConfig::from_file(config_path); let model = BartForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/common/error.rs b/src/common/error.rs index 0019567..5f4e115 100644 --- a/src/common/error.rs +++ b/src/common/error.rs @@ -22,6 +22,9 @@ pub enum RustBertError { #[error("Value error: {0}")] ValueError(String), + + #[error("Unsupported operation")] + UnsupportedError, } impl From for RustBertError { diff --git a/src/common/resources/buffer.rs b/src/common/resources/buffer.rs new file mode 100644 index 0000000..e0a2f76 --- /dev/null +++ b/src/common/resources/buffer.rs @@ -0,0 +1,71 @@ +use crate::common::error::RustBertError; +use crate::resources::{Resource, ResourceProvider}; +use std::path::PathBuf; +use std::sync::{Arc, RwLock}; + +/// # In-memory raw buffer resource +pub struct BufferResource { + /// The data representing the underlying resource + pub data: Arc>>, +} + +impl ResourceProvider for BufferResource { + /// Not implemented for this resource type + /// + /// # Returns + /// + /// * `RustBertError::UnsupportedError` + fn get_local_path(&self) -> Result { + Err(RustBertError::UnsupportedError) + } + + /// Gets a wrapper referring to the in-memory resource. + /// + /// # Returns + /// + /// * `Resource` referring to the resource data + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::resources::{BufferResource, ResourceProvider}; + /// let data = std::fs::read("path/to/rust_model.ot").unwrap(); + /// let weights_resource = BufferResource::from(data); + /// let weights = weights_resource.get_resource(); + /// ``` + fn get_resource(&self) -> Result { + Ok(Resource::Buffer(self.data.write().unwrap())) + } +} + +impl From> for BufferResource { + fn from(data: Vec) -> Self { + Self { + data: Arc::new(RwLock::new(data)), + } + } +} + +impl From> for Box { + fn from(data: Vec) -> Self { + Box::new(BufferResource { + data: Arc::new(RwLock::new(data)), + }) + } +} + +impl From>> for BufferResource { + fn from(lock: RwLock>) -> Self { + Self { + data: Arc::new(lock), + } + } +} + +impl From>> for Box { + fn from(lock: RwLock>) -> Self { + Box::new(BufferResource { + data: Arc::new(lock), + }) + } +} diff --git a/src/common/resources/local.rs b/src/common/resources/local.rs index d06dfbb..00ee2d9 100644 --- a/src/common/resources/local.rs +++ b/src/common/resources/local.rs @@ -1,5 +1,5 @@ use crate::common::error::RustBertError; -use crate::resources::ResourceProvider; +use crate::resources::{Resource, ResourceProvider}; use std::path::PathBuf; /// # Local resource @@ -29,6 +29,26 @@ impl ResourceProvider for LocalResource { fn get_local_path(&self) -> Result { Ok(self.local_path.clone()) } + + /// Gets a wrapper around the path for a local resource. + /// + /// # Returns + /// + /// * `Resource` wrapping a `PathBuf` pointing to the resource file + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::resources::{LocalResource, ResourceProvider}; + /// use std::path::PathBuf; + /// let config_resource = LocalResource { + /// local_path: PathBuf::from("path/to/config.json"), + /// }; + /// let config_path = config_resource.get_resource(); + /// ``` + fn get_resource(&self) -> Result { + Ok(Resource::PathBuf(self.local_path.clone())) + } } impl From for LocalResource { diff --git a/src/common/resources/mod.rs b/src/common/resources/mod.rs index 96e9e52..5f921fa 100644 --- a/src/common/resources/mod.rs +++ b/src/common/resources/mod.rs @@ -1,6 +1,6 @@ //! # Resource definitions for model weights, vocabularies and configuration files //! -//! This crate relies on the concept of Resources to access the files used by the models. +//! This crate relies on the concept of Resources to access the data used by the models. //! This includes: //! - model weights //! - configuration files @@ -11,20 +11,33 @@ //! resource location. Two types of resources are pre-defined: //! - LocalResource: points to a local file //! - RemoteResource: points to a remote file via a URL +//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only +//! usable for weights) //! -//! For both types of resources, the local location of the file can be retrieved using +//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using //! `get_local_path`, allowing to reference the resource file location regardless if it is a remote //! or local resource. Default implementations for a number of `RemoteResources` are available as //! pre-trained models in each model module. +mod buffer; mod local; use crate::common::error::RustBertError; +pub use buffer::BufferResource; pub use local::LocalResource; +use std::ops::DerefMut; use std::path::PathBuf; +use std::sync::RwLockWriteGuard; +use tch::nn::VarStore; -/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources -pub trait ResourceProvider { +pub enum Resource<'a> { + PathBuf(PathBuf), + Buffer(RwLockWriteGuard<'a, Vec>), +} + +/// # Resource Trait that can provide the location or data for the model, and location of +/// configuration or vocabulary resources +pub trait ResourceProvider: Send + Sync { /// Provides the local path for a resource. /// /// # Returns @@ -42,6 +55,42 @@ pub trait ResourceProvider { /// let config_path = config_resource.get_local_path(); /// ``` fn get_local_path(&self) -> Result; + + /// Provides access to an underlying resource. + /// + /// # Returns + /// + /// * `Resource` wrapping a representation of a resource. + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::resources::{BufferResource, LocalResource, ResourceProvider}; + /// ``` + fn get_resource(&self) -> Result; +} + +impl ResourceProvider for Box { + fn get_local_path(&self) -> Result { + T::get_local_path(self) + } + fn get_resource(&self) -> Result { + T::get_resource(self) + } +} + +/// Load the provided `VarStore` with model weights from the provided `ResourceProvider` +pub fn load_weights( + rp: &(impl ResourceProvider + ?Sized), + vs: &mut VarStore, +) -> Result<(), RustBertError> { + match rp.get_resource()? { + Resource::Buffer(mut data) => { + vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?; + Ok(()) + } + Resource::PathBuf(path) => Ok(vs.load(path)?), + } } #[cfg(feature = "remote")] diff --git a/src/common/resources/remote.rs b/src/common/resources/remote.rs index 096d691..83fa9af 100644 --- a/src/common/resources/remote.rs +++ b/src/common/resources/remote.rs @@ -93,6 +93,23 @@ impl ResourceProvider for RemoteResource { .cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?; Ok(cached_path) } + + /// Gets a wrapper around the local path for a remote resource. + /// + /// # Returns + /// + /// * `Resource` wrapping a `PathBuf` pointing to the resource file + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::resources::{RemoteResource, ResourceProvider}; + /// let config_resource = RemoteResource::new("http://config_json_location", "configs"); + /// let config_path = config_resource.get_resource(); + /// ``` + fn get_resource(&self) -> Result { + Ok(Resource::PathBuf(self.get_local_path()?)) + } } lazy_static! { diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index 415b486..31ff74f 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -639,7 +639,6 @@ impl GPT2Generator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -647,7 +646,7 @@ impl GPT2Generator { let config = Gpt2Config::from_file(config_path); let model = GPT2LMHeadModel::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/gpt_j/gpt_j_model.rs b/src/gpt_j/gpt_j_model.rs index b01e72d..cadc7bb 100644 --- a/src/gpt_j/gpt_j_model.rs +++ b/src/gpt_j/gpt_j_model.rs @@ -609,7 +609,6 @@ impl GptJGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -620,7 +619,7 @@ impl GptJGenerator { if config.preload_on_cpu && device != Device::Cpu { var_store.set_device(Device::Cpu); } - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; if device != Device::Cpu { var_store.set_device(device); } diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index 71c8f2d..a058cc9 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -660,14 +660,13 @@ impl GptNeoGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = GptNeoConfig::from_file(config_path); let model = GptNeoForCausalLM::new(var_store.root(), &config)?; - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/longt5/longt5_model.rs b/src/longt5/longt5_model.rs index 3e37f9a..4be1dc4 100644 --- a/src/longt5/longt5_model.rs +++ b/src/longt5/longt5_model.rs @@ -583,7 +583,6 @@ impl LongT5Generator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -591,7 +590,7 @@ impl LongT5Generator { let config = LongT5Config::from_file(config_path); let model = LongT5ForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = config.bos_token_id; let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index cce3ec8..0fc379f 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -538,7 +538,6 @@ impl M2M100Generator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -546,7 +545,7 @@ impl M2M100Generator { let config = M2M100Config::from_file(config_path); let model = M2M100ForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 40c21d9..e3d7984 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -755,7 +755,6 @@ impl MarianGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -763,7 +762,7 @@ impl MarianGenerator { let config = BartConfig::from_file(config_path); let model = MarianForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index cba9017..16f5648 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -791,7 +791,6 @@ impl MBartGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -799,7 +798,7 @@ impl MBartGenerator { let config = MBartConfig::from_file(config_path); let model = MBartForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index 58f831a..d43ab2d 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -493,13 +493,12 @@ impl OpenAIGenerator { generate_config.validate(); let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; let mut var_store = nn::VarStore::new(device); let config = Gpt2Config::from_file(config_path); let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index 6fe50ce..09c4277 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -501,14 +501,13 @@ impl PegasusConditionalGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = PegasusConfig::from_file(config_path); let model = PegasusForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs index a1b0318..c7075ca 100644 --- a/src/pipelines/masked_language.rs +++ b/src/pipelines/masked_language.rs @@ -429,7 +429,6 @@ impl MaskedLanguageModel { tokenizer: TokenizerOption, ) -> Result { let config_path = config.config_resource.get_local_path()?; - let weights_path = config.model_resource.get_local_path()?; let device = config.device; let mut var_store = VarStore::new(device); @@ -441,7 +440,7 @@ impl MaskedLanguageModel { let language_encode = MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?; - var_store.load(weights_path)?; + crate::resources::load_weights(&config.model_resource, &mut var_store)?; let mask_token = config.mask_token; Ok(MaskedLanguageModel { tokenizer, diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index bedc350..fb1a3a7 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -632,7 +632,6 @@ impl QuestionAnsweringModel { tokenizer: TokenizerOption, ) -> Result { let config_path = question_answering_config.config_resource.get_local_path()?; - let weights_path = question_answering_config.model_resource.get_local_path()?; let device = question_answering_config.device; let pad_idx = tokenizer @@ -670,7 +669,7 @@ impl QuestionAnsweringModel { ))); } - var_store.load(weights_path)?; + crate::resources::load_weights(&question_answering_config.model_resource, &mut var_store)?; Ok(QuestionAnsweringModel { tokenizer, pad_idx, diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs index 703c09c..5faa976 100644 --- a/src/pipelines/sentence_embeddings/pipeline.rs +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -230,7 +230,7 @@ impl SentenceEmbeddingsModel { ); let transformer = SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?; - var_store.load(transformer_weights_resource.get_local_path()?)?; + crate::resources::load_weights(&transformer_weights_resource, &mut var_store)?; // Setup pooling layer let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?); diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 1dc7498..d11021b 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -615,7 +615,6 @@ impl SequenceClassificationModel { tokenizer: TokenizerOption, ) -> Result { let config_path = config.config_resource.get_local_path()?; - let weights_path = config.model_resource.get_local_path()?; let device = config.device; let mut var_store = VarStore::new(device); @@ -627,7 +626,7 @@ impl SequenceClassificationModel { let sequence_classifier = SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?; let label_mapping = model_config.get_label_mapping().clone(); - var_store.load(weights_path)?; + crate::resources::load_weights(&config.model_resource, &mut var_store)?; Ok(SequenceClassificationModel { tokenizer, sequence_classifier, diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 593bc4c..05b6e54 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -720,7 +720,6 @@ impl TokenClassificationModel { tokenizer: TokenizerOption, ) -> Result { let config_path = config.config_resource.get_local_path()?; - let weights_path = config.model_resource.get_local_path()?; let device = config.device; let label_aggregation_function = config.label_aggregation_function; @@ -734,7 +733,7 @@ impl TokenClassificationModel { TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?; let label_mapping = model_config.get_label_mapping().clone(); let batch_size = config.batch_size; - var_store.load(weights_path)?; + crate::resources::load_weights(&config.model_resource, &mut var_store)?; Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index d26af3e..1f0d604 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -601,14 +601,13 @@ impl ZeroShotClassificationModel { tokenizer: TokenizerOption, ) -> Result { let config_path = config.config_resource.get_local_path()?; - let weights_path = config.model_resource.get_local_path()?; let device = config.device; let mut var_store = VarStore::new(device); let model_config = ConfigOption::from_file(config.model_type, config_path); let zero_shot_classifier = ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?; - var_store.load(weights_path)?; + crate::resources::load_weights(&config.model_resource, &mut var_store)?; Ok(ZeroShotClassificationModel { tokenizer, zero_shot_classifier, diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index ad9bf79..0c21ea6 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -906,14 +906,13 @@ impl ProphetNetConditionalGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = ProphetNetConfig::from_file(config_path); let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?; - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id); let eos_token_ids = Some(vec![config.eos_token_id]); diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index 130e28b..8576eb7 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -1044,14 +1044,13 @@ impl ReformerGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = ReformerConfig::from_file(config_path); let model = ReformerModelWithLMHead::new(var_store.root(), &config)?; - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = tokenizer.get_bos_id(); let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 4eba98e..7be33bb 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -753,7 +753,6 @@ impl T5Generator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -761,7 +760,7 @@ impl T5Generator { let config = T5Config::from_file(config_path); let model = T5ForConditionalGeneration::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(-1)); let eos_token_ids = Some(match config.eos_token_id { diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 28cff18..8edc735 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -1553,7 +1553,6 @@ impl XLNetGenerator { tokenizer: TokenizerOption, ) -> Result { let config_path = generate_config.config_resource.get_local_path()?; - let weights_path = generate_config.model_resource.get_local_path()?; let device = generate_config.device; generate_config.validate(); @@ -1561,7 +1560,7 @@ impl XLNetGenerator { let config = XLNetConfig::from_file(config_path); let model = XLNetLMHeadModel::new(var_store.root(), &config); - var_store.load(weights_path)?; + crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?; let bos_token_id = Some(config.bos_token_id); let eos_token_ids = Some(vec![config.eos_token_id]); diff --git a/tests/albert.rs b/tests/albert.rs index 40e6866..8990f4d 100644 --- a/tests/albert.rs +++ b/tests/albert.rs @@ -6,7 +6,7 @@ use rust_bert::albert::{ AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertModelResources, AlbertVocabResources, }; -use rust_bert::resources::{RemoteResource, ResourceProvider}; +use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider}; use rust_bert::Config; use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; use rust_tokenizers::vocab::Vocab; @@ -27,7 +27,6 @@ fn albert_masked_lm() -> anyhow::Result<()> { )); let config_path = config_resource.get_local_path()?; let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; // Set-up masked LM model let device = Device::Cpu; @@ -36,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> { AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?; let config = AlbertConfig::from_file(config_path); let albert_model = AlbertForMaskedLM::new(vs.root(), &config); - vs.load(weights_path)?; + load_weights(&weights_resource, &mut vs)?; // Define input let input = [