Merge remote-tracking branch 'origin/main' into onnxruntime_update

# Conflicts:
#	CHANGELOG.md
This commit is contained in:
Guillaume Becquin 2023-12-02 13:27:36 +00:00
commit f77b38072d
No known key found for this signature in database
GPG Key ID: D23E3F3D92A4157D
49 changed files with 310 additions and 141 deletions

View File

@ -5,11 +5,18 @@ All notable changes to this project will be documented in this file. The format
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.
## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
- Updated ONNX runtime backend version to 1.15.x
- Issue with incorrect results for QA models with a tokenizer not using segment ids
- Issue with GPT-J that was incorrectly tracking the gradients for the attention bias
## Changed
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).
- Updated ONNX runtime backend version to 1.15.x
## [0.21.0] - 2023-06-03
## Added

View File

@ -75,8 +75,8 @@ hf-tokenizers = ["tokenizers"]
features = ["doc-only"]
[dependencies]
rust_tokenizers = "8.1"
tch = "0.13.0"
rust_tokenizers = "8.1.1"
tch = "0.14.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
@ -97,7 +97,7 @@ anyhow = "1"
csv = "1"
criterion = "0.4"
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.13.0"
torch-sys = "0.14.0"
tempfile = "3"
itertools = "0.10"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }

View File

@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.0.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:

View File

@ -37,6 +37,7 @@ fn create_text_generation_model() -> TextGenerationModel {
diversity_penalty: None,
num_return_sequences: 5,
device: Device::cuda_if_available(),
kind: None,
};
TextGenerationModel::new(config).unwrap()
}

View File

@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
load_weights(&model_resource, &mut vs)?;
load_weights(&model_resource, &mut vs, None, device)?;
// Define input
let input = [("I love you.", "I like you.")];

View File

@ -30,6 +30,7 @@ use std::ops::DerefMut;
use std::path::PathBuf;
use std::sync::RwLockWriteGuard;
use tch::nn::VarStore;
use tch::{Device, Kind};
pub enum Resource<'a> {
PathBuf(PathBuf),
@ -84,17 +85,19 @@ impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
pub fn load_weights(
rp: &(impl ResourceProvider + ?Sized),
vs: &mut VarStore,
kind: Option<Kind>,
device: Device,
) -> Result<(), RustBertError> {
match rp.get_resource()? {
Resource::Buffer(mut data) => {
vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?;
Ok(())
}
Resource::PathBuf(path) => Ok(vs.load(path)?),
}
Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())),
Resource::PathBuf(path) => vs.load(path),
}?;
cast_var_store(vs, kind, device);
Ok(())
}
#[cfg(feature = "remote")]
mod remote;
use crate::pipelines::common::cast_var_store;
#[cfg(feature = "remote")]
pub use remote::RemoteResource;

View File

@ -90,8 +90,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.1`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:

View File

@ -1004,7 +1004,12 @@ impl BartGenerator {
let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -396,9 +396,11 @@ impl<T: BertEmbedding> BertModel<T> {
train,
)?;
let extended_attention_mask: Tensor =
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let extended_attention_mask: Tensor = ((extended_attention_mask
.ones_like()
.bitwise_xor_tensor(&extended_attention_mask))
* -10000.0)
.to_kind(embedding_output.kind());
let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {

View File

@ -235,7 +235,7 @@ impl DebertaV2Encoder {
.unsqueeze(-1)
.to_kind(Kind::Uint8)
}
value if value == 3 => attention_mask.unsqueeze(1),
3 => attention_mask.unsqueeze(1),
_ => attention_mask.shallow_clone(),
}
}

View File

@ -652,7 +652,12 @@ impl GPT2Generator {
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -68,11 +68,16 @@ impl GptJAttention {
let p = p.borrow();
let max_positions = config.n_positions;
let bias = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
.tril(0)
.view([1, 1, max_positions, max_positions])
.requires_grad_(false);
let bias = p.var_copy("bias", &bias);
let mut bias = p
.f_ones_no_train("bias", &[1, 1, max_positions, max_positions])
.unwrap()
.to_kind(Kind::Uint8)
.to_device(p.device());
bias.copy_(&bias_value);
let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
@ -95,21 +100,9 @@ impl GptJAttention {
..Default::default()
};
let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "k_proj").half();
}
let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "v_proj").half();
}
let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "q_proj").half();
}
let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "out_proj").half();
}
GptJAttention {
bias,

View File

@ -131,8 +131,6 @@ pub struct GptJConfig {
pub rotary_dim: Option<i64>,
pub vocab_size: i64,
pub scale_attn_weights: Option<bool>,
#[serde(default = "default_use_float16")]
pub use_float16: bool,
#[serde(default = "default_preload_on_cpu")]
pub preload_on_cpu: bool,
pub decoder_start_token_id: Option<i64>,
@ -164,7 +162,6 @@ impl Default for GptJConfig {
rotary_dim: Some(64),
vocab_size: 50400,
scale_attn_weights: Some(true),
use_float16: default_use_float16(),
preload_on_cpu: default_preload_on_cpu(),
decoder_start_token_id: None,
forced_bos_token_id: None,
@ -173,10 +170,6 @@ impl Default for GptJConfig {
}
}
fn default_use_float16() -> bool {
true
}
fn default_preload_on_cpu() -> bool {
true
}
@ -233,9 +226,6 @@ impl GptJModel {
config.n_embd,
Default::default(),
);
if config.use_float16 {
(&(&p / "wte") / "weight").half()
};
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
@ -245,9 +235,6 @@ impl GptJModel {
..Default::default()
};
let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
if config.use_float16 {
(&p / "ln_f").half()
};
let mut h: Vec<GptJBlock> = vec![];
let h_path = &p / "h";
@ -475,9 +462,6 @@ impl GptJLMHeadModel {
config.vocab_size,
Default::default(),
);
if config.use_float16 {
(p / "lm_head").half();
}
GptJLMHeadModel {
transformer,
@ -625,7 +609,12 @@ impl GptJGenerator {
if config.preload_on_cpu && device != Device::Cpu {
var_store.set_device(Device::Cpu);
}
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
if device != Device::Cpu {
var_store.set_device(device);
}

View File

@ -43,18 +43,12 @@ impl GptJMLP {
intermediate_size,
Default::default(),
);
if config.use_float16 {
(p / "fc_in").half()
};
let fc_out = nn::linear(
p / "fc_out",
intermediate_size,
config.n_embd,
Default::default(),
);
if config.use_float16 {
(p / "fc_out").half()
};
let activation = match &config.afn {
Some(activation_enum) => match activation_enum {
@ -100,9 +94,6 @@ impl GptJBlock {
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
if config.use_float16 {
(p / "ln_1").half()
};
let attn = GptJAttention::new(p / "attn", config);
let mlp = GptJMLP::new(p / "mlp", config);

View File

@ -672,7 +672,12 @@ impl GptNeoGenerator {
let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -288,8 +288,8 @@ impl LongT5Stack {
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
let mask_seq_length = if old_layer_states.is_some() {
if old_layer_states.as_ref().unwrap()[0].0.is_some() {
let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states {
if old_layer_states_value[0].0.is_some() {
old_layer_states.as_ref().unwrap()[0]
.0
.as_ref()

View File

@ -595,7 +595,12 @@ impl LongT5Generator {
let config = LongT5Config::from_file(config_path);
let model = LongT5ForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = config.bos_token_id;
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -544,7 +544,12 @@ impl M2M100Generator {
let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -761,7 +761,12 @@ impl MarianGenerator {
let config = BartConfig::from_file(config_path);
let model = MarianForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -450,7 +450,7 @@ impl MBartForConditionalGeneration {
{
let p = p.borrow();
let base_model = MBartModel::new(p.borrow() / "model", config);
let base_model = MBartModel::new(p / "model", config);
let final_logits_bias = p.var(
"final_logits_bias",
&[1, config.vocab_size],
@ -650,7 +650,7 @@ impl MBartForSequenceClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = MBartConfig::from_file(config_path);
/// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();;
/// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -800,7 +800,12 @@ impl MBartGenerator {
let config = MBartConfig::from_file(config_path);
let model = MBartForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -498,7 +498,12 @@ impl OpenAIGenerator {
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -505,7 +505,12 @@ impl PegasusConditionalGenerator {
let mut var_store = nn::VarStore::new(device);
let config = PegasusConfig::from_file(config_path);
let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = config

View File

@ -919,7 +919,12 @@ impl ProphetNetConditionalGenerator {
let mut var_store = nn::VarStore::new(device);
let config = ProphetNetConfig::from_file(config_path);
let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id);
let eos_token_ids = Some(vec![config.eos_token_id]);

View File

@ -1368,8 +1368,8 @@ impl ReformerAttention {
let new_layer_state = if self.use_past {
let prev_buckets = if let Some(buckets_value) = &buckets {
if layer_state.is_none() | {
if layer_state.is_some() {
layer_state.as_ref().unwrap().prev_buckets.is_none()
if let Some(layer_state_value) = &layer_state {
layer_state_value.prev_buckets.is_none()
} else {
false
}

View File

@ -1056,7 +1056,12 @@ impl ReformerGenerator {
let mut var_store = nn::VarStore::new(device);
let config = ReformerConfig::from_file(config_path);
let model = ReformerModelWithLMHead::new(var_store.root(), &config)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);

View File

@ -191,15 +191,15 @@ impl T5Attention {
let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.query), bs);
let (mut k, mut v) = if key_value_states.is_none() {
let (mut k, mut v) = if let Some(key_value_states_value) = key_value_states {
(
self.shape(hidden_states.apply(&self.key), bs),
self.shape(hidden_states.apply(&self.value), bs),
self.shape(key_value_states_value.apply(&self.key), bs),
self.shape(key_value_states_value.apply(&self.value), bs),
)
} else {
(
self.shape(key_value_states.as_ref().unwrap().apply(&self.key), bs),
self.shape(key_value_states.as_ref().unwrap().apply(&self.value), bs),
self.shape(hidden_states.apply(&self.key), bs),
self.shape(hidden_states.apply(&self.value), bs),
)
};

View File

@ -383,9 +383,9 @@ impl T5Stack {
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
let mask_seq_length = if old_layer_states.is_some() {
if old_layer_states.as_ref().unwrap()[0].0.is_some() {
old_layer_states.as_ref().unwrap()[0]
let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states {
if old_layer_states_value[0].0.is_some() {
old_layer_states_value[0]
.0
.as_ref()
.unwrap()

View File

@ -763,7 +763,12 @@ impl T5Generator {
let config = T5Config::from_file(config_path);
let model = T5ForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
let eos_token_ids = Some(match config.eos_token_id {

View File

@ -1560,7 +1560,12 @@ impl XLNetGenerator {
let config = XLNetConfig::from_file(config_path);
let model = XLNetLMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id);
let eos_token_ids = Some(vec![config.eos_token_id]);

View File

@ -60,6 +60,7 @@ use std::convert::TryFrom;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
use tch::nn::VarStore;
use tch::{Device, Kind, Tensor};
#[cfg(feature = "onnx")]
@ -2348,3 +2349,11 @@ impl TokenizerOption {
}
}
}
pub fn cast_var_store(varstore: &mut VarStore, kind: Option<Kind>, device: Device) {
match (kind, device) {
(Some(kind), _) => varstore.set_kind(kind),
(None, Device::Cpu) => varstore.set_kind(Kind::Float),
(None, _) => {}
}
}

View File

@ -115,6 +115,8 @@ pub struct ConversationConfig {
pub diversity_penalty: Option<f64>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
#[cfg(feature = "remote")]
@ -150,6 +152,7 @@ impl Default for ConversationConfig {
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -177,6 +180,7 @@ impl From<ConversationConfig> for GenerateConfig {
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}

View File

@ -67,7 +67,7 @@
//! ```
use tch::kind::Kind::Int64;
use tch::{no_grad, Device, Tensor};
use tch::{no_grad, Device, Kind, Tensor};
use crate::bart::LayerState as BartLayerState;
use crate::common::resources::ResourceProvider;
@ -136,6 +136,8 @@ pub struct GenerateConfig {
pub diversity_penalty: Option<f64>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
#[cfg(feature = "remote")]
@ -166,6 +168,7 @@ impl Default for GenerateConfig {
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}

View File

@ -52,7 +52,7 @@ use crate::deberta::DebertaForMaskedLM;
use crate::deberta_v2::DebertaV2ForMaskedLM;
use crate::fnet::FNetForMaskedLM;
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForMaskedLM;
@ -67,7 +67,7 @@ use crate::{
resources::RemoteResource,
};
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
use tch::{no_grad, Device, Kind, Tensor};
#[derive(Debug, Clone)]
/// Output container for masked language model pipeline.
@ -103,6 +103,8 @@ pub struct MaskedLanguageConfig {
pub mask_token: Option<String>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl MaskedLanguageConfig {
@ -143,6 +145,7 @@ impl MaskedLanguageConfig {
add_prefix_space: add_prefix_space.into(),
mask_token: mask_token.into(),
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -285,6 +288,7 @@ impl MaskedLanguageOption {
))),
}?;
var_store.load(weights_path)?;
cast_var_store(&mut var_store, config.kind, device);
Ok(model)
}

View File

@ -138,6 +138,7 @@ impl Default for POSConfig {
strip_accents: Some(true),
add_prefix_space: None,
device: Device::cuda_if_available(),
kind: None,
label_aggregation_function: LabelAggregationOption::First,
batch_size: 64,
},

View File

@ -52,7 +52,7 @@ use crate::fnet::FNetForQuestionAnswering;
use crate::longformer::LongformerForQuestionAnswering;
use crate::mobilebert::MobileBertForQuestionAnswering;
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::reformer::ReformerForQuestionAnswering;
use crate::resources::ResourceProvider;
@ -64,7 +64,6 @@ use std::cmp::min;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
@ -72,6 +71,7 @@ use crate::deberta_v2::DebertaV2ForQuestionAnswering;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
use crate::common::kind::get_min;
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
@ -158,6 +158,8 @@ pub struct QuestionAnsweringConfig {
pub max_query_length: usize,
/// Maximum length for the answer
pub max_answer_length: usize,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl QuestionAnsweringConfig {
@ -199,6 +201,7 @@ impl QuestionAnsweringConfig {
doc_stride: 128,
max_query_length: 64,
max_answer_length: 15,
kind: None,
}
}
@ -248,6 +251,7 @@ impl QuestionAnsweringConfig {
doc_stride: doc_stride.into().unwrap_or(128),
max_query_length: max_query_length.into().unwrap_or(64),
max_answer_length: max_answer_length.into().unwrap_or(15),
kind: None,
}
}
}
@ -267,6 +271,7 @@ impl Default for QuestionAnsweringConfig {
)),
merges_resource: None,
device: Device::cuda_if_available(),
kind: None,
model_type: ModelType::DistilBert,
lower_case: false,
add_prefix_space: None,
@ -474,6 +479,7 @@ impl QuestionAnsweringOption {
))),
}?;
var_store.load(weights_path)?;
cast_var_store(&mut var_store, config.kind, device);
Ok(model)
}
@ -830,11 +836,15 @@ impl QuestionAnsweringModel {
.to_device(start_logits.device())
.eq(0);
let start = start_logits.get(feature_idx).masked_fill(&p_mask, -10000);
let end = end_logits.get(feature_idx).masked_fill(&p_mask, -10000);
let start = start_logits
.get(feature_idx)
.masked_fill(&p_mask, get_min(start_logits.kind()).unwrap());
let end = end_logits
.get(feature_idx)
.masked_fill(&p_mask, get_min(start_logits.kind()).unwrap());
let start = start.exp() / start.exp().sum(Float);
let end = end.exp() / end.exp().sum(Float);
let start = start.softmax(0, start.kind());
let end = end.softmax(0, end.kind());
let (starts, ends, scores) = self.decode(&start, &end, top_k);
@ -861,9 +871,7 @@ impl QuestionAnsweringModel {
}
}
feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert_with(Vec::new);
let example_answers = example_top_k_answers_map.entry(example_id).or_default();
example_answers.extend(answers);
}
});
@ -928,6 +936,20 @@ impl QuestionAnsweringModel {
masks: encoded_query.masks,
};
let sequence_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
TokenIdsWithOffsets {
ids: vec![],
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
)
.token_ids
.len();
let sequence_pair_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
@ -975,7 +997,10 @@ impl QuestionAnsweringModel {
let encoded_span = self
.tokenizer
.build_input_with_special_tokens(encoded_query.clone(), Some(sub_encoded_context));
let p_mask = self.get_mask(&encoded_span);
let p_mask = self.get_mask(
&encoded_span,
encoded_query.ids.len() + sequence_added_tokens,
);
let qa_feature = QaFeature {
input_ids: encoded_span.token_ids,
offsets: encoded_span.token_offsets,
@ -1038,7 +1063,7 @@ impl QuestionAnsweringModel {
(input_ids, attention_masks, token_type_ids)
}
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
fn get_mask(&self, encoded_span: &TokenizedInput, question_length: usize) -> Vec<i8> {
let sep_indices: Vec<usize> = encoded_span
.token_ids
.iter()
@ -1047,12 +1072,9 @@ impl QuestionAnsweringModel {
.map(|(position, _)| position)
.collect();
let mut p_mask: Vec<i8> = encoded_span
.segment_ids
.iter()
.map(|v| min(v, &1i8))
.map(|&v| 1i8 - v)
.collect();
let mut p_mask: Vec<i8> = Vec::with_capacity(encoded_span.token_ids.len());
p_mask.extend(vec![1; question_length]);
p_mask.extend(vec![0; encoded_span.token_ids.len() - question_length]);
for sep_position in sep_indices {
p_mask[sep_position] = 1;
}

View File

@ -1,7 +1,7 @@
use std::path::PathBuf;
use serde::Deserialize;
use tch::Device;
use tch::{Device, Kind};
use crate::pipelines::common::ModelType;
use crate::pipelines::sentence_embeddings::{
@ -21,6 +21,7 @@ use crate::{
/// (configuration and weights).
pub struct SentenceEmbeddingsBuilder<T> {
device: Device,
kind: Option<Kind>,
inner: T,
}
@ -29,6 +30,11 @@ impl<T> SentenceEmbeddingsBuilder<T> {
self.device = device;
self
}
pub fn with_kind(mut self, kind: Kind) -> Self {
self.kind = Some(kind);
self
}
}
pub struct Local {
@ -46,6 +52,7 @@ impl SentenceEmbeddingsBuilder<Local> {
pub fn local<P: Into<PathBuf>>(model_dir: P) -> Self {
Self {
device: Device::cuda_if_available(),
kind: None,
inner: Local {
model_dir: model_dir.into(),
},
@ -106,6 +113,7 @@ impl SentenceEmbeddingsBuilder<Local> {
tokenizer_vocab_resource: tokenizer_vocab.into(),
tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()),
device: self.device,
kind: self.kind,
};
SentenceEmbeddingsModel::new(config)
@ -122,6 +130,7 @@ impl SentenceEmbeddingsBuilder<Remote> {
pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self {
Self {
device: Device::cuda_if_available(),
kind: None,
inner: Remote {
config: SentenceEmbeddingsConfig::from(model_type),
},

View File

@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use tch::Device;
use tch::{Device, Kind};
use crate::pipelines::common::ModelType;
use crate::resources::ResourceProvider;
@ -55,6 +55,8 @@ pub struct SentenceEmbeddingsConfig {
pub tokenizer_merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Device to place the transformer model on
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
#[cfg(feature = "remote")]
@ -92,6 +94,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
kind: None,
},
SentenceEmbeddingsModelType::BertBaseNliMeanTokens => SentenceEmbeddingsConfig {
@ -121,6 +124,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
kind: None,
},
SentenceEmbeddingsModelType::AllMiniLmL12V2 => SentenceEmbeddingsConfig {
@ -149,7 +153,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
BertVocabResources::ALL_MINI_LM_L12_V2,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
device: Device::cuda_if_available(), kind: None,
},
SentenceEmbeddingsModelType::AllMiniLmL6V2 => SentenceEmbeddingsConfig {
@ -178,7 +182,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
BertVocabResources::ALL_MINI_LM_L6_V2,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
device: Device::cuda_if_available(), kind: None,
},
SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig {
@ -209,7 +213,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
tokenizer_merges_resource: Some(Box::new(RemoteResource::from_pretrained(
RobertaMergesResources::ALL_DISTILROBERTA_V1,
))),
device: Device::cuda_if_available(),
device: Device::cuda_if_available(), kind: None,
},
SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2 => SentenceEmbeddingsConfig {
@ -238,7 +242,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
AlbertVocabResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
device: Device::cuda_if_available(), kind: None,
},
SentenceEmbeddingsModelType::SentenceT5Base => SentenceEmbeddingsConfig {
@ -271,7 +275,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
T5VocabResources::SENTENCE_T5_BASE,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
device: Device::cuda_if_available(), kind: None,
},
}
}

View File

@ -236,6 +236,7 @@ impl SentenceEmbeddingsModel {
dense_config_resource,
dense_weights_resource,
device,
kind,
} = config;
let modules =
@ -254,7 +255,12 @@ impl SentenceEmbeddingsModel {
);
let transformer =
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
crate::resources::load_weights(&transformer_weights_resource, &mut var_store)?;
crate::resources::load_weights(
&transformer_weights_resource,
&mut var_store,
kind,
device,
)?;
// Setup pooling layer
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);

View File

@ -68,7 +68,7 @@ use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::reformer::ReformerForSequenceClassification;
use crate::resources::ResourceProvider;
@ -123,6 +123,8 @@ pub struct SequenceClassificationConfig {
pub add_prefix_space: Option<bool>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl SequenceClassificationConfig {
@ -160,6 +162,7 @@ impl SequenceClassificationConfig {
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -392,6 +395,7 @@ impl SequenceClassificationOption {
))),
}?;
var_store.load(weights_path)?;
cast_var_store(&mut var_store, config.kind, device);
Ok(model)
}

View File

@ -62,7 +62,7 @@
//! # ;
//! ```
use tch::Device;
use tch::{Device, Kind};
use crate::bart::BartGenerator;
use crate::common::error::RustBertError;
@ -126,6 +126,8 @@ pub struct SummarizationConfig {
pub diversity_penalty: Option<f64>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl SummarizationConfig {
@ -170,6 +172,7 @@ impl SummarizationConfig {
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -214,6 +217,7 @@ impl From<SummarizationConfig> for GenerateConfig {
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}

View File

@ -31,7 +31,7 @@
//!
//! Customized text generation models models can be loaded by overwriting the resources in the configuration.
//! The dependencies will be downloaded to the user's home directory, e.g. under ~/.cache/.rustbert/gpt2
use tch::Device;
use tch::{Device, Kind};
use crate::common::error::RustBertError;
use crate::gpt2::GPT2Generator;
@ -97,6 +97,8 @@ pub struct TextGenerationConfig {
pub diversity_penalty: Option<f64>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl TextGenerationConfig {
@ -141,6 +143,7 @@ impl TextGenerationConfig {
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -185,6 +188,7 @@ impl From<TextGenerationConfig> for GenerateConfig {
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}

View File

@ -122,7 +122,7 @@ use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForTokenClassification;
@ -242,6 +242,8 @@ pub struct TokenClassificationConfig {
pub add_prefix_space: Option<bool>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
/// Sub-tokens aggregation method (default: `LabelAggregationOption::First`)
pub label_aggregation_function: LabelAggregationOption,
/// Batch size for predictions
@ -284,6 +286,7 @@ impl TokenClassificationConfig {
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
device: Device::cuda_if_available(),
kind: None,
label_aggregation_function,
batch_size: 64,
}
@ -506,6 +509,7 @@ impl TokenClassificationOption {
))),
}?;
var_store.load(weights_path)?;
cast_var_store(&mut var_store, config.kind, device);
Ok(model)
}

View File

@ -336,12 +336,10 @@ impl TranslationModelBuilder {
) {
(Some(ModelType::M2M100), source_languages, target_languages) => {
match self.model_size {
Some(value) if value == ModelSize::XLarge => {
model_fetchers::get_m2m100_xlarge_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?
}
Some(ModelSize::XLarge) => model_fetchers::get_m2m100_xlarge_resources(
source_languages.as_ref(),
target_languages.as_ref(),
)?,
_ => model_fetchers::get_m2m100_large_resources(
source_languages.as_ref(),
target_languages.as_ref(),
@ -447,7 +445,7 @@ mod model_fetchers {
Ok(match get_marian_model(source_languages, target_languages) {
Ok(marian_resources) => marian_resources,
Err(_) => match model_size {
Some(value) if value == &ModelSize::XLarge => {
Some(ModelSize::XLarge) => {
get_m2m100_xlarge_resources(source_languages, target_languages)?
}
_ => get_m2m100_large_resources(source_languages, target_languages)?,

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::Device;
use tch::{Device, Kind};
use crate::common::error::RustBertError;
use crate::m2m_100::M2M100Generator;
@ -978,6 +978,8 @@ pub struct TranslationConfig {
pub num_beam_groups: Option<i64>,
/// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
pub diversity_penalty: Option<f64>,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl TranslationConfig {
@ -1065,6 +1067,7 @@ impl TranslationConfig {
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
kind: None,
}
}
}
@ -1092,6 +1095,7 @@ impl From<TranslationConfig> for GenerateConfig {
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}

View File

@ -102,10 +102,13 @@ use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::deberta::DebertaForSequenceClassification;
use crate::deberta_v2::DebertaV2ForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelResource, ModelType, TokenizerOption};
use crate::pipelines::common::{
cast_var_store, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::pipelines::sequence_classification::Label;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
@ -146,6 +149,8 @@ pub struct ZeroShotClassificationConfig {
pub add_prefix_space: Option<bool>,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
pub kind: Option<Kind>,
}
impl ZeroShotClassificationConfig {
@ -183,6 +188,7 @@ impl ZeroShotClassificationConfig {
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -209,6 +215,7 @@ impl Default for ZeroShotClassificationConfig {
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
@ -217,11 +224,14 @@ impl Default for ZeroShotClassificationConfig {
/// The models are using a classification architecture that should be trained on Natural Language Inference.
/// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding
/// to contradiction and the last logit corresponding to entailment.
#[allow(clippy::large_enum_variant)]
pub enum ZeroShotClassificationOption {
/// Bart for Sequence Classification
Bart(BartForSequenceClassification),
/// DeBERTa for Sequence Classification
Deberta(DebertaForSequenceClassification),
/// DeBERTaV2 for Sequence Classification
DebertaV2(DebertaV2ForSequenceClassification),
/// Bert for Sequence Classification
Bert(BertForSequenceClassification),
/// DistilBert for Sequence Classification
@ -288,6 +298,17 @@ impl ZeroShotClassificationOption {
))
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = model_config {
Ok(Self::DebertaV2(
DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTaV2!".to_string(),
))
}
}
ModelType::Bert => {
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Bert(
@ -322,13 +343,13 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = model_config {
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::Roberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
"You can only supply a RobertaConfig for Roberta!".to_string(),
))
}
}
@ -385,6 +406,7 @@ impl ZeroShotClassificationOption {
))),
}?;
var_store.load(weights_path)?;
cast_var_store(&mut var_store, config.kind, device);
Ok(model)
}
@ -413,6 +435,7 @@ impl ZeroShotClassificationOption {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::Deberta(_) => ModelType::Deberta,
Self::DebertaV2(_) => ModelType::DebertaV2,
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
@ -474,6 +497,19 @@ impl ZeroShotClassificationOption {
.expect("Error in DeBERTa forward_t")
.logits
}
Self::DebertaV2(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in DeBERTaV2 forward_t")
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
@ -643,7 +679,7 @@ impl ZeroShotClassificationModel {
/// # Ok(())
/// # }
/// ```
fn new_with_tokenizer(
pub fn new_with_tokenizer(
config: ZeroShotClassificationConfig,
tokenizer: TokenizerOption,
) -> Result<ZeroShotClassificationModel, RustBertError> {

View File

@ -35,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(vs.root(), &config);
load_weights(&weights_resource, &mut vs)?;
load_weights(&weights_resource, &mut vs, None, device)?;
// Define input
let input = [

View File

@ -2,7 +2,7 @@ use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::pipelines::common::ModelResource;
use rust_bert::pipelines::common::{cast_var_store, ModelResource};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
@ -44,6 +44,7 @@ fn bart_lm_model() -> anyhow::Result<()> {
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root() / "model", &config);
vs.load(weights_path)?;
cast_var_store(&mut vs, None, device);
// Define input
let input = ["One two three four"];

View File

@ -3,12 +3,12 @@ use rust_bert::gpt_j::{
GptJVocabResources,
};
use rust_bert::pipelines::generation_utils::Cache;
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer};
use rust_tokenizers::vocab::Vocab;
use std::convert::TryFrom;
use tch::{nn, Device, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// Equivalent Python code:
///
@ -67,14 +67,15 @@ fn gpt_j_correctness() -> anyhow::Result<()> {
let mut vs = nn::VarStore::new(device);
let config_path = config_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let mut config = GptJConfig::from_file(config_path);
config.use_float16 = matches!(device, Device::Cuda(_));
let config = GptJConfig::from_file(config_path);
let model = GptJLMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?;
let kind = match device {
Device::Cpu => None,
_ => Some(Kind::Half),
};
load_weights(&model_resource, &mut vs, kind, device)?;
// Tokenize prompts
let prompts = [
"It was a very nice and sunny",
"It was a gloom winter night, and",

View File

@ -234,15 +234,15 @@ mod tests {
ModelType::M2M100,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
"https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/encoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
"https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
"https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_with_past_model.onnx",
"onnx-m2m100_418M",
))),
}),