mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Introduce in-memory resource abstraction (#375)
* Introduce in-memory resource abstraction This follows from discussion in #366. The goal of this change is to allow for weights to be loaded from a copy of `rust_model.ot` that is already present in memory. There are two ways in which that data might be present: 1. As a `HashMap<String, Tensor>` from previous interaction with `tch` 2. As a contiguous buffer of the file data One or the other mechanism might be preferable depending on how user code is using the model data. In some sense, implementing a provider based on the second option is more of a convenience method for the user to avoid the `tch::nn::VarStore::load_from_stream` interaction. I've changed the definition of the `ResourceProvider` trait to require that it be both `Send` and `Sync`. There are currently certain contexts where `dyn ResourceProvider + Send` is required, but in theory before this change an implementation might not be `Send` (or `Sync`). The existing providers are both `Send` and `Sync`, and it seems reasonable (if technically incorrect) for user code to assume this to be true. I don't see a downside to making this explicit, but that part of this change might be better suited for separate discussion. I am not trying to sneak it in. The `enum Resource` data type is used here as a means to abstract over the possible ways a `ResourceProvider` might represent an underlying resource. Without this, it would be necessary to either call different trait methods until one succeeded or implement `as_any` and downcast in order to implement `load_weights` similarly to how it is now. Those options seemed less preferable to creating a wrapper. While it would be possible to replace all calls to `get_local_path` with the `get_resource` API, removal of the existing function would be a very big breaking change. As such, this change also introduces `RustBertError::UnsupportedError` to allow for the different methods to coexist. An alternative would be for the new `ResourceProvider`s to write their resources to a temporary disk location and return an appropriate path, but that is counter to the purpose of the new `ResourceProvider`s and so I chose not to implement that. * - Add `impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T>` - Remove `Resource::NamedTensors` - Change `BufferResource` to contain a `&[u8]` rather than `Vec<u8>` * Further rework proposal for resources * Use mutable references and locks * Make model resources mutable in tests/examples * Remove unnecessary mutability and TensorResource references * Add `BufferResource` example --------- Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
parent
f591dc30b9
commit
ba57704c6f
97
examples/buffer_resource.rs
Normal file
97
examples/buffer_resource.rs
Normal file
@ -0,0 +1,97 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use rust_bert::bart::{
|
||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider};
|
||||
use tch::Device;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
|
||||
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
|
||||
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
|
||||
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
|
||||
habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
|
||||
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
|
||||
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
|
||||
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
|
||||
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
|
||||
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
|
||||
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
|
||||
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
|
||||
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
|
||||
said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors. \
|
||||
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
|
||||
a potentially habitable planet, but further observations will be required to say for sure. \" \
|
||||
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
|
||||
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
|
||||
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
|
||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||
about exoplanets like K2-18b."];
|
||||
|
||||
let weights = Arc::new(RwLock::new(get_weights()?));
|
||||
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = summarization_model.summarize(&input);
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
||||
let summarization_model =
|
||||
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
let output = summarization_model.summarize(&input);
|
||||
for sentence in output {
|
||||
println!("{sentence}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_weights() -> anyhow::Result<Vec<u8>, anyhow::Error> {
|
||||
let model_resource = RemoteResource::from_pretrained(BartModelResources::DISTILBART_CNN_6_6);
|
||||
Ok(std::fs::read(model_resource.get_local_path()?)?)
|
||||
}
|
||||
|
||||
fn config(device: Device, model_data: Arc<RwLock<Vec<u8>>>) -> SummarizationConfig {
|
||||
let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||
BartConfigResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||
BartVocabResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||
BartMergesResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
let model_resource = Box::new(BufferResource { data: model_data });
|
||||
|
||||
SummarizationConfig {
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
num_beams: 1,
|
||||
length_penalty: 1.0,
|
||||
min_length: 56,
|
||||
max_length: Some(142),
|
||||
device,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
@ -4,7 +4,7 @@ use rust_bert::deberta::{
|
||||
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
|
||||
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||
@ -27,7 +27,6 @@ fn main() -> anyhow::Result<()> {
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = model_resource.get_local_path()?;
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
@ -39,7 +38,7 @@ fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
let config = DebertaConfig::from_file(config_path);
|
||||
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
|
||||
vs.load(weights_path)?;
|
||||
load_weights(&model_resource, &mut vs)?;
|
||||
|
||||
// Define input
|
||||
let input = [("I love you.", "I like you.")];
|
||||
|
@ -993,14 +993,13 @@ impl BartGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<BartGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let model = BartForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -22,6 +22,9 @@ pub enum RustBertError {
|
||||
|
||||
#[error("Value error: {0}")]
|
||||
ValueError(String),
|
||||
|
||||
#[error("Unsupported operation")]
|
||||
UnsupportedError,
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for RustBertError {
|
||||
|
71
src/common/resources/buffer.rs
Normal file
71
src/common/resources/buffer.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::resources::{Resource, ResourceProvider};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// # In-memory raw buffer resource
|
||||
pub struct BufferResource {
|
||||
/// The data representing the underlying resource
|
||||
pub data: Arc<RwLock<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl ResourceProvider for BufferResource {
|
||||
/// Not implemented for this resource type
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `RustBertError::UnsupportedError`
|
||||
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||
Err(RustBertError::UnsupportedError)
|
||||
}
|
||||
|
||||
/// Gets a wrapper referring to the in-memory resource.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Resource` referring to the resource data
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{BufferResource, ResourceProvider};
|
||||
/// let data = std::fs::read("path/to/rust_model.ot").unwrap();
|
||||
/// let weights_resource = BufferResource::from(data);
|
||||
/// let weights = weights_resource.get_resource();
|
||||
/// ```
|
||||
fn get_resource(&self) -> Result<Resource, RustBertError> {
|
||||
Ok(Resource::Buffer(self.data.write().unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for BufferResource {
|
||||
fn from(data: Vec<u8>) -> Self {
|
||||
Self {
|
||||
data: Arc::new(RwLock::new(data)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for Box<dyn ResourceProvider> {
|
||||
fn from(data: Vec<u8>) -> Self {
|
||||
Box::new(BufferResource {
|
||||
data: Arc::new(RwLock::new(data)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RwLock<Vec<u8>>> for BufferResource {
|
||||
fn from(lock: RwLock<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
data: Arc::new(lock),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RwLock<Vec<u8>>> for Box<dyn ResourceProvider> {
|
||||
fn from(lock: RwLock<Vec<u8>>) -> Self {
|
||||
Box::new(BufferResource {
|
||||
data: Arc::new(lock),
|
||||
})
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::resources::{Resource, ResourceProvider};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// # Local resource
|
||||
@ -29,6 +29,26 @@ impl ResourceProvider for LocalResource {
|
||||
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||
Ok(self.local_path.clone())
|
||||
}
|
||||
|
||||
/// Gets a wrapper around the path for a local resource.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Resource` wrapping a `PathBuf` pointing to the resource file
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{LocalResource, ResourceProvider};
|
||||
/// use std::path::PathBuf;
|
||||
/// let config_resource = LocalResource {
|
||||
/// local_path: PathBuf::from("path/to/config.json"),
|
||||
/// };
|
||||
/// let config_path = config_resource.get_resource();
|
||||
/// ```
|
||||
fn get_resource(&self) -> Result<Resource, RustBertError> {
|
||||
Ok(Resource::PathBuf(self.local_path.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LocalResource {
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! # Resource definitions for model weights, vocabularies and configuration files
|
||||
//!
|
||||
//! This crate relies on the concept of Resources to access the files used by the models.
|
||||
//! This crate relies on the concept of Resources to access the data used by the models.
|
||||
//! This includes:
|
||||
//! - model weights
|
||||
//! - configuration files
|
||||
@ -11,20 +11,33 @@
|
||||
//! resource location. Two types of resources are pre-defined:
|
||||
//! - LocalResource: points to a local file
|
||||
//! - RemoteResource: points to a remote file via a URL
|
||||
//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only
|
||||
//! usable for weights)
|
||||
//!
|
||||
//! For both types of resources, the local location of the file can be retrieved using
|
||||
//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using
|
||||
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
|
||||
//! or local resource. Default implementations for a number of `RemoteResources` are available as
|
||||
//! pre-trained models in each model module.
|
||||
|
||||
mod buffer;
|
||||
mod local;
|
||||
|
||||
use crate::common::error::RustBertError;
|
||||
pub use buffer::BufferResource;
|
||||
pub use local::LocalResource;
|
||||
use std::ops::DerefMut;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::RwLockWriteGuard;
|
||||
use tch::nn::VarStore;
|
||||
|
||||
/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources
|
||||
pub trait ResourceProvider {
|
||||
pub enum Resource<'a> {
|
||||
PathBuf(PathBuf),
|
||||
Buffer(RwLockWriteGuard<'a, Vec<u8>>),
|
||||
}
|
||||
|
||||
/// # Resource Trait that can provide the location or data for the model, and location of
|
||||
/// configuration or vocabulary resources
|
||||
pub trait ResourceProvider: Send + Sync {
|
||||
/// Provides the local path for a resource.
|
||||
///
|
||||
/// # Returns
|
||||
@ -42,6 +55,42 @@ pub trait ResourceProvider {
|
||||
/// let config_path = config_resource.get_local_path();
|
||||
/// ```
|
||||
fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
|
||||
|
||||
/// Provides access to an underlying resource.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Resource` wrapping a representation of a resource.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{BufferResource, LocalResource, ResourceProvider};
|
||||
/// ```
|
||||
fn get_resource(&self) -> Result<Resource, RustBertError>;
|
||||
}
|
||||
|
||||
impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
|
||||
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||
T::get_local_path(self)
|
||||
}
|
||||
fn get_resource(&self) -> Result<Resource, RustBertError> {
|
||||
T::get_resource(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the provided `VarStore` with model weights from the provided `ResourceProvider`
|
||||
pub fn load_weights(
|
||||
rp: &(impl ResourceProvider + ?Sized),
|
||||
vs: &mut VarStore,
|
||||
) -> Result<(), RustBertError> {
|
||||
match rp.get_resource()? {
|
||||
Resource::Buffer(mut data) => {
|
||||
vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?;
|
||||
Ok(())
|
||||
}
|
||||
Resource::PathBuf(path) => Ok(vs.load(path)?),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
|
@ -93,6 +93,23 @@ impl ResourceProvider for RemoteResource {
|
||||
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
|
||||
Ok(cached_path)
|
||||
}
|
||||
|
||||
/// Gets a wrapper around the local path for a remote resource.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Resource` wrapping a `PathBuf` pointing to the resource file
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
/// let config_resource = RemoteResource::new("http://config_json_location", "configs");
|
||||
/// let config_path = config_resource.get_resource();
|
||||
/// ```
|
||||
fn get_resource(&self) -> Result<Resource, RustBertError> {
|
||||
Ok(Resource::PathBuf(self.get_local_path()?))
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
|
@ -639,7 +639,6 @@ impl GPT2Generator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<GPT2Generator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -647,7 +646,7 @@ impl GPT2Generator {
|
||||
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let model = GPT2LMHeadModel::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -609,7 +609,6 @@ impl GptJGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<GptJGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -620,7 +619,7 @@ impl GptJGenerator {
|
||||
if config.preload_on_cpu && device != Device::Cpu {
|
||||
var_store.set_device(Device::Cpu);
|
||||
}
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
if device != Device::Cpu {
|
||||
var_store.set_device(device);
|
||||
}
|
||||
|
@ -660,14 +660,13 @@ impl GptNeoGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<GptNeoGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = GptNeoConfig::from_file(config_path);
|
||||
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -583,7 +583,6 @@ impl LongT5Generator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<LongT5Generator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -591,7 +590,7 @@ impl LongT5Generator {
|
||||
|
||||
let config = LongT5Config::from_file(config_path);
|
||||
let model = LongT5ForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = config.bos_token_id;
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -538,7 +538,6 @@ impl M2M100Generator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<M2M100Generator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -546,7 +545,7 @@ impl M2M100Generator {
|
||||
|
||||
let config = M2M100Config::from_file(config_path);
|
||||
let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -755,7 +755,6 @@ impl MarianGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<MarianGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -763,7 +762,7 @@ impl MarianGenerator {
|
||||
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let model = MarianForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -791,7 +791,6 @@ impl MBartGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<MBartGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -799,7 +798,7 @@ impl MBartGenerator {
|
||||
|
||||
let config = MBartConfig::from_file(config_path);
|
||||
let model = MBartForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -493,13 +493,12 @@ impl OpenAIGenerator {
|
||||
generate_config.validate();
|
||||
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -501,14 +501,13 @@ impl PegasusConditionalGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<PegasusConditionalGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = PegasusConfig::from_file(config_path);
|
||||
let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -429,7 +429,6 @@ impl MaskedLanguageModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<MaskedLanguageModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let weights_path = config.model_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
@ -441,7 +440,7 @@ impl MaskedLanguageModel {
|
||||
|
||||
let language_encode =
|
||||
MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
let mask_token = config.mask_token;
|
||||
Ok(MaskedLanguageModel {
|
||||
tokenizer,
|
||||
|
@ -632,7 +632,6 @@ impl QuestionAnsweringModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<QuestionAnsweringModel, RustBertError> {
|
||||
let config_path = question_answering_config.config_resource.get_local_path()?;
|
||||
let weights_path = question_answering_config.model_resource.get_local_path()?;
|
||||
let device = question_answering_config.device;
|
||||
|
||||
let pad_idx = tokenizer
|
||||
@ -670,7 +669,7 @@ impl QuestionAnsweringModel {
|
||||
)));
|
||||
}
|
||||
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&question_answering_config.model_resource, &mut var_store)?;
|
||||
Ok(QuestionAnsweringModel {
|
||||
tokenizer,
|
||||
pad_idx,
|
||||
|
@ -230,7 +230,7 @@ impl SentenceEmbeddingsModel {
|
||||
);
|
||||
let transformer =
|
||||
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
|
||||
var_store.load(transformer_weights_resource.get_local_path()?)?;
|
||||
crate::resources::load_weights(&transformer_weights_resource, &mut var_store)?;
|
||||
|
||||
// Setup pooling layer
|
||||
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
|
||||
|
@ -615,7 +615,6 @@ impl SequenceClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<SequenceClassificationModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let weights_path = config.model_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
@ -627,7 +626,7 @@ impl SequenceClassificationModel {
|
||||
let sequence_classifier =
|
||||
SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
let label_mapping = model_config.get_label_mapping().clone();
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
Ok(SequenceClassificationModel {
|
||||
tokenizer,
|
||||
sequence_classifier,
|
||||
|
@ -720,7 +720,6 @@ impl TokenClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<TokenClassificationModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let weights_path = config.model_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
let label_aggregation_function = config.label_aggregation_function;
|
||||
|
||||
@ -734,7 +733,7 @@ impl TokenClassificationModel {
|
||||
TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
let label_mapping = model_config.get_label_mapping().clone();
|
||||
let batch_size = config.batch_size;
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
Ok(TokenClassificationModel {
|
||||
tokenizer,
|
||||
token_sequence_classifier,
|
||||
|
@ -601,14 +601,13 @@ impl ZeroShotClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<ZeroShotClassificationModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let weights_path = config.model_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let zero_shot_classifier =
|
||||
ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
Ok(ZeroShotClassificationModel {
|
||||
tokenizer,
|
||||
zero_shot_classifier,
|
||||
|
@ -906,14 +906,13 @@ impl ProphetNetConditionalGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<ProphetNetConditionalGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = ProphetNetConfig::from_file(config_path);
|
||||
let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id);
|
||||
let eos_token_ids = Some(vec![config.eos_token_id]);
|
||||
|
@ -1044,14 +1044,13 @@ impl ReformerGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<ReformerGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = ReformerConfig::from_file(config_path);
|
||||
let model = ReformerModelWithLMHead::new(var_store.root(), &config)?;
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -753,7 +753,6 @@ impl T5Generator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<T5Generator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -761,7 +760,7 @@ impl T5Generator {
|
||||
|
||||
let config = T5Config::from_file(config_path);
|
||||
let model = T5ForConditionalGeneration::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -1553,7 +1553,6 @@ impl XLNetGenerator {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<XLNetGenerator, RustBertError> {
|
||||
let config_path = generate_config.config_resource.get_local_path()?;
|
||||
let weights_path = generate_config.model_resource.get_local_path()?;
|
||||
let device = generate_config.device;
|
||||
|
||||
generate_config.validate();
|
||||
@ -1561,7 +1560,7 @@ impl XLNetGenerator {
|
||||
|
||||
let config = XLNetConfig::from_file(config_path);
|
||||
let model = XLNetLMHeadModel::new(var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id);
|
||||
let eos_token_ids = Some(vec![config.eos_token_id]);
|
||||
|
@ -6,7 +6,7 @@ use rust_bert::albert::{
|
||||
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||
AlbertModelResources, AlbertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{AlbertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
@ -27,7 +27,6 @@ fn albert_masked_lm() -> anyhow::Result<()> {
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
@ -36,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
|
||||
let config = AlbertConfig::from_file(config_path);
|
||||
let albert_model = AlbertForMaskedLM::new(vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
load_weights(&weights_resource, &mut vs)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
|
Loading…
Reference in New Issue
Block a user