mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-03 23:57:15 +03:00
Merge remote-tracking branch 'origin/main' into onnxruntime_update
# Conflicts: # CHANGELOG.md
This commit is contained in:
commit
f77b38072d
@ -5,11 +5,18 @@ All notable changes to this project will be documented in this file. The format
|
||||
## Added
|
||||
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
|
||||
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
|
||||
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.
|
||||
|
||||
## Fixed
|
||||
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
|
||||
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
|
||||
- Updated ONNX runtime backend version to 1.15.x
|
||||
- Issue with incorrect results for QA models with a tokenizer not using segment ids
|
||||
- Issue with GPT-J that was incorrectly tracking the gradients for the attention bias
|
||||
|
||||
## Changed
|
||||
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).
|
||||
- Updated ONNX runtime backend version to 1.15.x
|
||||
|
||||
## [0.21.0] - 2023-06-03
|
||||
## Added
|
||||
|
@ -75,8 +75,8 @@ hf-tokenizers = ["tokenizers"]
|
||||
features = ["doc-only"]
|
||||
|
||||
[dependencies]
|
||||
rust_tokenizers = "8.1"
|
||||
tch = "0.13.0"
|
||||
rust_tokenizers = "8.1.1"
|
||||
tch = "0.14.0"
|
||||
serde_json = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
ordered-float = "3"
|
||||
@ -97,7 +97,7 @@ anyhow = "1"
|
||||
csv = "1"
|
||||
criterion = "0.4"
|
||||
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
|
||||
torch-sys = "0.13.0"
|
||||
torch-sys = "0.14.0"
|
||||
tempfile = "3"
|
||||
itertools = "0.10"
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
|
||||
|
@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
|
||||
|
||||
### Manual installation (recommended)
|
||||
|
||||
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.0.0`: if this version is no longer available on the "get started" page,
|
||||
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
|
||||
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page,
|
||||
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
|
||||
2. Extract the library to a location of your choice
|
||||
3. Set the following environment variables
|
||||
##### Linux:
|
||||
|
@ -37,6 +37,7 @@ fn create_text_generation_model() -> TextGenerationModel {
|
||||
diversity_penalty: None,
|
||||
num_return_sequences: 5,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
};
|
||||
TextGenerationModel::new(config).unwrap()
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
let config = DebertaConfig::from_file(config_path);
|
||||
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
|
||||
load_weights(&model_resource, &mut vs)?;
|
||||
load_weights(&model_resource, &mut vs, None, device)?;
|
||||
|
||||
// Define input
|
||||
let input = [("I love you.", "I like you.")];
|
||||
|
@ -30,6 +30,7 @@ use std::ops::DerefMut;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::RwLockWriteGuard;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{Device, Kind};
|
||||
|
||||
pub enum Resource<'a> {
|
||||
PathBuf(PathBuf),
|
||||
@ -84,17 +85,19 @@ impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
|
||||
pub fn load_weights(
|
||||
rp: &(impl ResourceProvider + ?Sized),
|
||||
vs: &mut VarStore,
|
||||
kind: Option<Kind>,
|
||||
device: Device,
|
||||
) -> Result<(), RustBertError> {
|
||||
match rp.get_resource()? {
|
||||
Resource::Buffer(mut data) => {
|
||||
vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?;
|
||||
Ok(())
|
||||
}
|
||||
Resource::PathBuf(path) => Ok(vs.load(path)?),
|
||||
}
|
||||
Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())),
|
||||
Resource::PathBuf(path) => vs.load(path),
|
||||
}?;
|
||||
cast_var_store(vs, kind, device);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
mod remote;
|
||||
use crate::pipelines::common::cast_var_store;
|
||||
#[cfg(feature = "remote")]
|
||||
pub use remote::RemoteResource;
|
||||
|
@ -90,8 +90,8 @@
|
||||
//!
|
||||
//! ### Manual installation (recommended)
|
||||
//!
|
||||
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.0`: if this version is no longer available on the "get started" page,
|
||||
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11.
|
||||
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.1`: if this version is no longer available on the "get started" page,
|
||||
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11.
|
||||
//! 2. Extract the library to a location of your choice
|
||||
//! 3. Set the following environment variables
|
||||
//! ##### Linux:
|
||||
|
@ -1004,7 +1004,12 @@ impl BartGenerator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let model = BartForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -396,9 +396,11 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
train,
|
||||
)?;
|
||||
|
||||
let extended_attention_mask: Tensor =
|
||||
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
|
||||
.to_kind(embedding_output.kind());
|
||||
let extended_attention_mask: Tensor = ((extended_attention_mask
|
||||
.ones_like()
|
||||
.bitwise_xor_tensor(&extended_attention_mask))
|
||||
* -10000.0)
|
||||
.to_kind(embedding_output.kind());
|
||||
|
||||
let encoder_extended_attention_mask: Option<Tensor> =
|
||||
if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
|
@ -235,7 +235,7 @@ impl DebertaV2Encoder {
|
||||
.unsqueeze(-1)
|
||||
.to_kind(Kind::Uint8)
|
||||
}
|
||||
value if value == 3 => attention_mask.unsqueeze(1),
|
||||
3 => attention_mask.unsqueeze(1),
|
||||
_ => attention_mask.shallow_clone(),
|
||||
}
|
||||
}
|
||||
|
@ -652,7 +652,12 @@ impl GPT2Generator {
|
||||
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let model = GPT2LMHeadModel::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -68,11 +68,16 @@ impl GptJAttention {
|
||||
let p = p.borrow();
|
||||
|
||||
let max_positions = config.n_positions;
|
||||
let bias = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
|
||||
let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
|
||||
.tril(0)
|
||||
.view([1, 1, max_positions, max_positions])
|
||||
.requires_grad_(false);
|
||||
let bias = p.var_copy("bias", &bias);
|
||||
let mut bias = p
|
||||
.f_ones_no_train("bias", &[1, 1, max_positions, max_positions])
|
||||
.unwrap()
|
||||
.to_kind(Kind::Uint8)
|
||||
.to_device(p.device());
|
||||
bias.copy_(&bias_value);
|
||||
|
||||
let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
|
||||
let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
|
||||
@ -95,21 +100,9 @@ impl GptJAttention {
|
||||
..Default::default()
|
||||
};
|
||||
let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config);
|
||||
if config.use_float16 {
|
||||
(p / "k_proj").half();
|
||||
}
|
||||
let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config);
|
||||
if config.use_float16 {
|
||||
(p / "v_proj").half();
|
||||
}
|
||||
let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config);
|
||||
if config.use_float16 {
|
||||
(p / "q_proj").half();
|
||||
}
|
||||
let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config);
|
||||
if config.use_float16 {
|
||||
(p / "out_proj").half();
|
||||
}
|
||||
|
||||
GptJAttention {
|
||||
bias,
|
||||
|
@ -131,8 +131,6 @@ pub struct GptJConfig {
|
||||
pub rotary_dim: Option<i64>,
|
||||
pub vocab_size: i64,
|
||||
pub scale_attn_weights: Option<bool>,
|
||||
#[serde(default = "default_use_float16")]
|
||||
pub use_float16: bool,
|
||||
#[serde(default = "default_preload_on_cpu")]
|
||||
pub preload_on_cpu: bool,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
@ -164,7 +162,6 @@ impl Default for GptJConfig {
|
||||
rotary_dim: Some(64),
|
||||
vocab_size: 50400,
|
||||
scale_attn_weights: Some(true),
|
||||
use_float16: default_use_float16(),
|
||||
preload_on_cpu: default_preload_on_cpu(),
|
||||
decoder_start_token_id: None,
|
||||
forced_bos_token_id: None,
|
||||
@ -173,10 +170,6 @@ impl Default for GptJConfig {
|
||||
}
|
||||
}
|
||||
|
||||
fn default_use_float16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_preload_on_cpu() -> bool {
|
||||
true
|
||||
}
|
||||
@ -233,9 +226,6 @@ impl GptJModel {
|
||||
config.n_embd,
|
||||
Default::default(),
|
||||
);
|
||||
if config.use_float16 {
|
||||
(&(&p / "wte") / "weight").half()
|
||||
};
|
||||
|
||||
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
|
||||
let drop = Dropout::new(embd_pdrop);
|
||||
@ -245,9 +235,6 @@ impl GptJModel {
|
||||
..Default::default()
|
||||
};
|
||||
let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
|
||||
if config.use_float16 {
|
||||
(&p / "ln_f").half()
|
||||
};
|
||||
|
||||
let mut h: Vec<GptJBlock> = vec![];
|
||||
let h_path = &p / "h";
|
||||
@ -475,9 +462,6 @@ impl GptJLMHeadModel {
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
if config.use_float16 {
|
||||
(p / "lm_head").half();
|
||||
}
|
||||
|
||||
GptJLMHeadModel {
|
||||
transformer,
|
||||
@ -625,7 +609,12 @@ impl GptJGenerator {
|
||||
if config.preload_on_cpu && device != Device::Cpu {
|
||||
var_store.set_device(Device::Cpu);
|
||||
}
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
if device != Device::Cpu {
|
||||
var_store.set_device(device);
|
||||
}
|
||||
|
@ -43,18 +43,12 @@ impl GptJMLP {
|
||||
intermediate_size,
|
||||
Default::default(),
|
||||
);
|
||||
if config.use_float16 {
|
||||
(p / "fc_in").half()
|
||||
};
|
||||
let fc_out = nn::linear(
|
||||
p / "fc_out",
|
||||
intermediate_size,
|
||||
config.n_embd,
|
||||
Default::default(),
|
||||
);
|
||||
if config.use_float16 {
|
||||
(p / "fc_out").half()
|
||||
};
|
||||
|
||||
let activation = match &config.afn {
|
||||
Some(activation_enum) => match activation_enum {
|
||||
@ -100,9 +94,6 @@ impl GptJBlock {
|
||||
..Default::default()
|
||||
};
|
||||
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
||||
if config.use_float16 {
|
||||
(p / "ln_1").half()
|
||||
};
|
||||
let attn = GptJAttention::new(p / "attn", config);
|
||||
let mlp = GptJMLP::new(p / "mlp", config);
|
||||
|
||||
|
@ -672,7 +672,12 @@ impl GptNeoGenerator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = GptNeoConfig::from_file(config_path);
|
||||
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -288,8 +288,8 @@ impl LongT5Stack {
|
||||
|
||||
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
|
||||
|
||||
let mask_seq_length = if old_layer_states.is_some() {
|
||||
if old_layer_states.as_ref().unwrap()[0].0.is_some() {
|
||||
let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states {
|
||||
if old_layer_states_value[0].0.is_some() {
|
||||
old_layer_states.as_ref().unwrap()[0]
|
||||
.0
|
||||
.as_ref()
|
||||
|
@ -595,7 +595,12 @@ impl LongT5Generator {
|
||||
|
||||
let config = LongT5Config::from_file(config_path);
|
||||
let model = LongT5ForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = config.bos_token_id;
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -544,7 +544,12 @@ impl M2M100Generator {
|
||||
|
||||
let config = M2M100Config::from_file(config_path);
|
||||
let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -761,7 +761,12 @@ impl MarianGenerator {
|
||||
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let model = MarianForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -450,7 +450,7 @@ impl MBartForConditionalGeneration {
|
||||
{
|
||||
let p = p.borrow();
|
||||
|
||||
let base_model = MBartModel::new(p.borrow() / "model", config);
|
||||
let base_model = MBartModel::new(p / "model", config);
|
||||
let final_logits_bias = p.var(
|
||||
"final_logits_bias",
|
||||
&[1, config.vocab_size],
|
||||
@ -650,7 +650,7 @@ impl MBartForSequenceClassification {
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = MBartConfig::from_file(config_path);
|
||||
/// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();;
|
||||
/// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -800,7 +800,12 @@ impl MBartGenerator {
|
||||
|
||||
let config = MBartConfig::from_file(config_path);
|
||||
let model = MBartForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -498,7 +498,12 @@ impl OpenAIGenerator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -505,7 +505,12 @@ impl PegasusConditionalGenerator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = PegasusConfig::from_file(config_path);
|
||||
let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = config
|
||||
|
@ -919,7 +919,12 @@ impl ProphetNetConditionalGenerator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = ProphetNetConfig::from_file(config_path);
|
||||
let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id);
|
||||
let eos_token_ids = Some(vec![config.eos_token_id]);
|
||||
|
@ -1368,8 +1368,8 @@ impl ReformerAttention {
|
||||
let new_layer_state = if self.use_past {
|
||||
let prev_buckets = if let Some(buckets_value) = &buckets {
|
||||
if layer_state.is_none() | {
|
||||
if layer_state.is_some() {
|
||||
layer_state.as_ref().unwrap().prev_buckets.is_none()
|
||||
if let Some(layer_state_value) = &layer_state {
|
||||
layer_state_value.prev_buckets.is_none()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
@ -1056,7 +1056,12 @@ impl ReformerGenerator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let config = ReformerConfig::from_file(config_path);
|
||||
let model = ReformerModelWithLMHead::new(var_store.root(), &config)?;
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = tokenizer.get_bos_id();
|
||||
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
|
||||
|
@ -191,15 +191,15 @@ impl T5Attention {
|
||||
|
||||
let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.query), bs);
|
||||
|
||||
let (mut k, mut v) = if key_value_states.is_none() {
|
||||
let (mut k, mut v) = if let Some(key_value_states_value) = key_value_states {
|
||||
(
|
||||
self.shape(hidden_states.apply(&self.key), bs),
|
||||
self.shape(hidden_states.apply(&self.value), bs),
|
||||
self.shape(key_value_states_value.apply(&self.key), bs),
|
||||
self.shape(key_value_states_value.apply(&self.value), bs),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
self.shape(key_value_states.as_ref().unwrap().apply(&self.key), bs),
|
||||
self.shape(key_value_states.as_ref().unwrap().apply(&self.value), bs),
|
||||
self.shape(hidden_states.apply(&self.key), bs),
|
||||
self.shape(hidden_states.apply(&self.value), bs),
|
||||
)
|
||||
};
|
||||
|
||||
|
@ -383,9 +383,9 @@ impl T5Stack {
|
||||
|
||||
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
|
||||
|
||||
let mask_seq_length = if old_layer_states.is_some() {
|
||||
if old_layer_states.as_ref().unwrap()[0].0.is_some() {
|
||||
old_layer_states.as_ref().unwrap()[0]
|
||||
let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states {
|
||||
if old_layer_states_value[0].0.is_some() {
|
||||
old_layer_states_value[0]
|
||||
.0
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
|
@ -763,7 +763,12 @@ impl T5Generator {
|
||||
|
||||
let config = T5Config::from_file(config_path);
|
||||
let model = T5ForConditionalGeneration::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
|
@ -1560,7 +1560,12 @@ impl XLNetGenerator {
|
||||
|
||||
let config = XLNetConfig::from_file(config_path);
|
||||
let model = XLNetLMHeadModel::new(var_store.root(), &config);
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&generate_config.model_resource,
|
||||
&mut var_store,
|
||||
generate_config.kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id);
|
||||
let eos_token_ids = Some(vec![config.eos_token_id]);
|
||||
|
@ -60,6 +60,7 @@ use std::convert::TryFrom;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use tch::nn::VarStore;
|
||||
use tch::{Device, Kind, Tensor};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
@ -2348,3 +2349,11 @@ impl TokenizerOption {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cast_var_store(varstore: &mut VarStore, kind: Option<Kind>, device: Device) {
|
||||
match (kind, device) {
|
||||
(Some(kind), _) => varstore.set_kind(kind),
|
||||
(None, Device::Cpu) => varstore.set_kind(Kind::Float),
|
||||
(None, _) => {}
|
||||
}
|
||||
}
|
||||
|
@ -115,6 +115,8 @@ pub struct ConversationConfig {
|
||||
pub diversity_penalty: Option<f64>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@ -150,6 +152,7 @@ impl Default for ConversationConfig {
|
||||
num_beam_groups: None,
|
||||
diversity_penalty: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -177,6 +180,7 @@ impl From<ConversationConfig> for GenerateConfig {
|
||||
num_beam_groups: config.num_beam_groups,
|
||||
diversity_penalty: config.diversity_penalty,
|
||||
device: config.device,
|
||||
kind: config.kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@
|
||||
//! ```
|
||||
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::{no_grad, Device, Tensor};
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
use crate::bart::LayerState as BartLayerState;
|
||||
use crate::common::resources::ResourceProvider;
|
||||
@ -136,6 +136,8 @@ pub struct GenerateConfig {
|
||||
pub diversity_penalty: Option<f64>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@ -166,6 +168,7 @@ impl Default for GenerateConfig {
|
||||
num_beam_groups: None,
|
||||
diversity_penalty: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ use crate::deberta::DebertaForMaskedLM;
|
||||
use crate::deberta_v2::DebertaV2ForMaskedLM;
|
||||
use crate::fnet::FNetForMaskedLM;
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForMaskedLM;
|
||||
@ -67,7 +67,7 @@ use crate::{
|
||||
resources::RemoteResource,
|
||||
};
|
||||
use tch::nn::VarStore;
|
||||
use tch::{no_grad, Device, Tensor};
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Output container for masked language model pipeline.
|
||||
@ -103,6 +103,8 @@ pub struct MaskedLanguageConfig {
|
||||
pub mask_token: Option<String>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl MaskedLanguageConfig {
|
||||
@ -143,6 +145,7 @@ impl MaskedLanguageConfig {
|
||||
add_prefix_space: add_prefix_space.into(),
|
||||
mask_token: mask_token.into(),
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -285,6 +288,7 @@ impl MaskedLanguageOption {
|
||||
))),
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
cast_var_store(&mut var_store, config.kind, device);
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
|
@ -138,6 +138,7 @@ impl Default for POSConfig {
|
||||
strip_accents: Some(true),
|
||||
add_prefix_space: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
label_aggregation_function: LabelAggregationOption::First,
|
||||
batch_size: 64,
|
||||
},
|
||||
|
@ -52,7 +52,7 @@ use crate::fnet::FNetForQuestionAnswering;
|
||||
use crate::longformer::LongformerForQuestionAnswering;
|
||||
use crate::mobilebert::MobileBertForQuestionAnswering;
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::reformer::ReformerForQuestionAnswering;
|
||||
use crate::resources::ResourceProvider;
|
||||
@ -64,7 +64,6 @@ use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
@ -72,6 +71,7 @@ use crate::deberta_v2::DebertaV2ForQuestionAnswering;
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
|
||||
|
||||
use crate::common::kind::get_min;
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
|
||||
@ -158,6 +158,8 @@ pub struct QuestionAnsweringConfig {
|
||||
pub max_query_length: usize,
|
||||
/// Maximum length for the answer
|
||||
pub max_answer_length: usize,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl QuestionAnsweringConfig {
|
||||
@ -199,6 +201,7 @@ impl QuestionAnsweringConfig {
|
||||
doc_stride: 128,
|
||||
max_query_length: 64,
|
||||
max_answer_length: 15,
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -248,6 +251,7 @@ impl QuestionAnsweringConfig {
|
||||
doc_stride: doc_stride.into().unwrap_or(128),
|
||||
max_query_length: max_query_length.into().unwrap_or(64),
|
||||
max_answer_length: max_answer_length.into().unwrap_or(15),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -267,6 +271,7 @@ impl Default for QuestionAnsweringConfig {
|
||||
)),
|
||||
merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
model_type: ModelType::DistilBert,
|
||||
lower_case: false,
|
||||
add_prefix_space: None,
|
||||
@ -474,6 +479,7 @@ impl QuestionAnsweringOption {
|
||||
))),
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
cast_var_store(&mut var_store, config.kind, device);
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
@ -830,11 +836,15 @@ impl QuestionAnsweringModel {
|
||||
.to_device(start_logits.device())
|
||||
.eq(0);
|
||||
|
||||
let start = start_logits.get(feature_idx).masked_fill(&p_mask, -10000);
|
||||
let end = end_logits.get(feature_idx).masked_fill(&p_mask, -10000);
|
||||
let start = start_logits
|
||||
.get(feature_idx)
|
||||
.masked_fill(&p_mask, get_min(start_logits.kind()).unwrap());
|
||||
let end = end_logits
|
||||
.get(feature_idx)
|
||||
.masked_fill(&p_mask, get_min(start_logits.kind()).unwrap());
|
||||
|
||||
let start = start.exp() / start.exp().sum(Float);
|
||||
let end = end.exp() / end.exp().sum(Float);
|
||||
let start = start.softmax(0, start.kind());
|
||||
let end = end.softmax(0, end.kind());
|
||||
|
||||
let (starts, ends, scores) = self.decode(&start, &end, top_k);
|
||||
|
||||
@ -861,9 +871,7 @@ impl QuestionAnsweringModel {
|
||||
}
|
||||
}
|
||||
feature_id_start = max_feature_id;
|
||||
let example_answers = example_top_k_answers_map
|
||||
.entry(example_id)
|
||||
.or_insert_with(Vec::new);
|
||||
let example_answers = example_top_k_answers_map.entry(example_id).or_default();
|
||||
example_answers.extend(answers);
|
||||
}
|
||||
});
|
||||
@ -928,6 +936,20 @@ impl QuestionAnsweringModel {
|
||||
masks: encoded_query.masks,
|
||||
};
|
||||
|
||||
let sequence_added_tokens = self
|
||||
.tokenizer
|
||||
.build_input_with_special_tokens(
|
||||
TokenIdsWithOffsets {
|
||||
ids: vec![],
|
||||
offsets: vec![],
|
||||
reference_offsets: vec![],
|
||||
masks: vec![],
|
||||
},
|
||||
None,
|
||||
)
|
||||
.token_ids
|
||||
.len();
|
||||
|
||||
let sequence_pair_added_tokens = self
|
||||
.tokenizer
|
||||
.build_input_with_special_tokens(
|
||||
@ -975,7 +997,10 @@ impl QuestionAnsweringModel {
|
||||
let encoded_span = self
|
||||
.tokenizer
|
||||
.build_input_with_special_tokens(encoded_query.clone(), Some(sub_encoded_context));
|
||||
let p_mask = self.get_mask(&encoded_span);
|
||||
let p_mask = self.get_mask(
|
||||
&encoded_span,
|
||||
encoded_query.ids.len() + sequence_added_tokens,
|
||||
);
|
||||
let qa_feature = QaFeature {
|
||||
input_ids: encoded_span.token_ids,
|
||||
offsets: encoded_span.token_offsets,
|
||||
@ -1038,7 +1063,7 @@ impl QuestionAnsweringModel {
|
||||
(input_ids, attention_masks, token_type_ids)
|
||||
}
|
||||
|
||||
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
|
||||
fn get_mask(&self, encoded_span: &TokenizedInput, question_length: usize) -> Vec<i8> {
|
||||
let sep_indices: Vec<usize> = encoded_span
|
||||
.token_ids
|
||||
.iter()
|
||||
@ -1047,12 +1072,9 @@ impl QuestionAnsweringModel {
|
||||
.map(|(position, _)| position)
|
||||
.collect();
|
||||
|
||||
let mut p_mask: Vec<i8> = encoded_span
|
||||
.segment_ids
|
||||
.iter()
|
||||
.map(|v| min(v, &1i8))
|
||||
.map(|&v| 1i8 - v)
|
||||
.collect();
|
||||
let mut p_mask: Vec<i8> = Vec::with_capacity(encoded_span.token_ids.len());
|
||||
p_mask.extend(vec![1; question_length]);
|
||||
p_mask.extend(vec![0; encoded_span.token_ids.len() - question_length]);
|
||||
for sep_position in sep_indices {
|
||||
p_mask[sep_position] = 1;
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use tch::Device;
|
||||
use tch::{Device, Kind};
|
||||
|
||||
use crate::pipelines::common::ModelType;
|
||||
use crate::pipelines::sentence_embeddings::{
|
||||
@ -21,6 +21,7 @@ use crate::{
|
||||
/// (configuration and weights).
|
||||
pub struct SentenceEmbeddingsBuilder<T> {
|
||||
device: Device,
|
||||
kind: Option<Kind>,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
@ -29,6 +30,11 @@ impl<T> SentenceEmbeddingsBuilder<T> {
|
||||
self.device = device;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_kind(mut self, kind: Kind) -> Self {
|
||||
self.kind = Some(kind);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Local {
|
||||
@ -46,6 +52,7 @@ impl SentenceEmbeddingsBuilder<Local> {
|
||||
pub fn local<P: Into<PathBuf>>(model_dir: P) -> Self {
|
||||
Self {
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
inner: Local {
|
||||
model_dir: model_dir.into(),
|
||||
},
|
||||
@ -106,6 +113,7 @@ impl SentenceEmbeddingsBuilder<Local> {
|
||||
tokenizer_vocab_resource: tokenizer_vocab.into(),
|
||||
tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()),
|
||||
device: self.device,
|
||||
kind: self.kind,
|
||||
};
|
||||
|
||||
SentenceEmbeddingsModel::new(config)
|
||||
@ -122,6 +130,7 @@ impl SentenceEmbeddingsBuilder<Remote> {
|
||||
pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self {
|
||||
Self {
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
inner: Remote {
|
||||
config: SentenceEmbeddingsConfig::from(model_type),
|
||||
},
|
||||
|
@ -1,5 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tch::Device;
|
||||
use tch::{Device, Kind};
|
||||
|
||||
use crate::pipelines::common::ModelType;
|
||||
use crate::resources::ResourceProvider;
|
||||
@ -55,6 +55,8 @@ pub struct SentenceEmbeddingsConfig {
|
||||
pub tokenizer_merges_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
/// Device to place the transformer model on
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
@ -92,6 +94,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::BertBaseNliMeanTokens => SentenceEmbeddingsConfig {
|
||||
@ -121,6 +124,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::AllMiniLmL12V2 => SentenceEmbeddingsConfig {
|
||||
@ -149,7 +153,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
BertVocabResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
device: Device::cuda_if_available(), kind: None,
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::AllMiniLmL6V2 => SentenceEmbeddingsConfig {
|
||||
@ -178,7 +182,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
BertVocabResources::ALL_MINI_LM_L6_V2,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
device: Device::cuda_if_available(), kind: None,
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig {
|
||||
@ -209,7 +213,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
tokenizer_merges_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ALL_DISTILROBERTA_V1,
|
||||
))),
|
||||
device: Device::cuda_if_available(),
|
||||
device: Device::cuda_if_available(), kind: None,
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2 => SentenceEmbeddingsConfig {
|
||||
@ -238,7 +242,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
AlbertVocabResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
device: Device::cuda_if_available(), kind: None,
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::SentenceT5Base => SentenceEmbeddingsConfig {
|
||||
@ -271,7 +275,7 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
T5VocabResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
device: Device::cuda_if_available(), kind: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -236,6 +236,7 @@ impl SentenceEmbeddingsModel {
|
||||
dense_config_resource,
|
||||
dense_weights_resource,
|
||||
device,
|
||||
kind,
|
||||
} = config;
|
||||
|
||||
let modules =
|
||||
@ -254,7 +255,12 @@ impl SentenceEmbeddingsModel {
|
||||
);
|
||||
let transformer =
|
||||
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
|
||||
crate::resources::load_weights(&transformer_weights_resource, &mut var_store)?;
|
||||
crate::resources::load_weights(
|
||||
&transformer_weights_resource,
|
||||
&mut var_store,
|
||||
kind,
|
||||
device,
|
||||
)?;
|
||||
|
||||
// Setup pooling layer
|
||||
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
|
||||
|
@ -68,7 +68,7 @@ use crate::fnet::FNetForSequenceClassification;
|
||||
use crate::longformer::LongformerForSequenceClassification;
|
||||
use crate::mobilebert::MobileBertForSequenceClassification;
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::reformer::ReformerForSequenceClassification;
|
||||
use crate::resources::ResourceProvider;
|
||||
@ -123,6 +123,8 @@ pub struct SequenceClassificationConfig {
|
||||
pub add_prefix_space: Option<bool>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl SequenceClassificationConfig {
|
||||
@ -160,6 +162,7 @@ impl SequenceClassificationConfig {
|
||||
strip_accents: strip_accents.into(),
|
||||
add_prefix_space: add_prefix_space.into(),
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -392,6 +395,7 @@ impl SequenceClassificationOption {
|
||||
))),
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
cast_var_store(&mut var_store, config.kind, device);
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use tch::Device;
|
||||
use tch::{Device, Kind};
|
||||
|
||||
use crate::bart::BartGenerator;
|
||||
use crate::common::error::RustBertError;
|
||||
@ -126,6 +126,8 @@ pub struct SummarizationConfig {
|
||||
pub diversity_penalty: Option<f64>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl SummarizationConfig {
|
||||
@ -170,6 +172,7 @@ impl SummarizationConfig {
|
||||
num_beam_groups: None,
|
||||
diversity_penalty: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -214,6 +217,7 @@ impl From<SummarizationConfig> for GenerateConfig {
|
||||
num_beam_groups: config.num_beam_groups,
|
||||
diversity_penalty: config.diversity_penalty,
|
||||
device: config.device,
|
||||
kind: config.kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@
|
||||
//!
|
||||
//! Customized text generation models models can be loaded by overwriting the resources in the configuration.
|
||||
//! The dependencies will be downloaded to the user's home directory, e.g. under ~/.cache/.rustbert/gpt2
|
||||
use tch::Device;
|
||||
use tch::{Device, Kind};
|
||||
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::gpt2::GPT2Generator;
|
||||
@ -97,6 +97,8 @@ pub struct TextGenerationConfig {
|
||||
pub diversity_penalty: Option<f64>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl TextGenerationConfig {
|
||||
@ -141,6 +143,7 @@ impl TextGenerationConfig {
|
||||
num_beam_groups: None,
|
||||
diversity_penalty: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -185,6 +188,7 @@ impl From<TextGenerationConfig> for GenerateConfig {
|
||||
num_beam_groups: config.num_beam_groups,
|
||||
diversity_penalty: config.diversity_penalty,
|
||||
device: config.device,
|
||||
kind: config.kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ use crate::fnet::FNetForTokenClassification;
|
||||
use crate::longformer::LongformerForTokenClassification;
|
||||
use crate::mobilebert::MobileBertForTokenClassification;
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
cast_var_store, get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForTokenClassification;
|
||||
@ -242,6 +242,8 @@ pub struct TokenClassificationConfig {
|
||||
pub add_prefix_space: Option<bool>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
/// Sub-tokens aggregation method (default: `LabelAggregationOption::First`)
|
||||
pub label_aggregation_function: LabelAggregationOption,
|
||||
/// Batch size for predictions
|
||||
@ -284,6 +286,7 @@ impl TokenClassificationConfig {
|
||||
strip_accents: strip_accents.into(),
|
||||
add_prefix_space: add_prefix_space.into(),
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
label_aggregation_function,
|
||||
batch_size: 64,
|
||||
}
|
||||
@ -506,6 +509,7 @@ impl TokenClassificationOption {
|
||||
))),
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
cast_var_store(&mut var_store, config.kind, device);
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
|
@ -336,12 +336,10 @@ impl TranslationModelBuilder {
|
||||
) {
|
||||
(Some(ModelType::M2M100), source_languages, target_languages) => {
|
||||
match self.model_size {
|
||||
Some(value) if value == ModelSize::XLarge => {
|
||||
model_fetchers::get_m2m100_xlarge_resources(
|
||||
source_languages.as_ref(),
|
||||
target_languages.as_ref(),
|
||||
)?
|
||||
}
|
||||
Some(ModelSize::XLarge) => model_fetchers::get_m2m100_xlarge_resources(
|
||||
source_languages.as_ref(),
|
||||
target_languages.as_ref(),
|
||||
)?,
|
||||
_ => model_fetchers::get_m2m100_large_resources(
|
||||
source_languages.as_ref(),
|
||||
target_languages.as_ref(),
|
||||
@ -447,7 +445,7 @@ mod model_fetchers {
|
||||
Ok(match get_marian_model(source_languages, target_languages) {
|
||||
Ok(marian_resources) => marian_resources,
|
||||
Err(_) => match model_size {
|
||||
Some(value) if value == &ModelSize::XLarge => {
|
||||
Some(ModelSize::XLarge) => {
|
||||
get_m2m100_xlarge_resources(source_languages, target_languages)?
|
||||
}
|
||||
_ => get_m2m100_large_resources(source_languages, target_languages)?,
|
||||
|
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::Device;
|
||||
use tch::{Device, Kind};
|
||||
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::m2m_100::M2M100Generator;
|
||||
@ -978,6 +978,8 @@ pub struct TranslationConfig {
|
||||
pub num_beam_groups: Option<i64>,
|
||||
/// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
|
||||
pub diversity_penalty: Option<f64>,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl TranslationConfig {
|
||||
@ -1065,6 +1067,7 @@ impl TranslationConfig {
|
||||
num_return_sequences: 1,
|
||||
num_beam_groups: None,
|
||||
diversity_penalty: None,
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1092,6 +1095,7 @@ impl From<TranslationConfig> for GenerateConfig {
|
||||
num_beam_groups: config.num_beam_groups,
|
||||
diversity_penalty: config.diversity_penalty,
|
||||
device: config.device,
|
||||
kind: config.kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,10 +102,13 @@ use crate::albert::AlbertForSequenceClassification;
|
||||
use crate::bart::BartForSequenceClassification;
|
||||
use crate::bert::BertForSequenceClassification;
|
||||
use crate::deberta::DebertaForSequenceClassification;
|
||||
use crate::deberta_v2::DebertaV2ForSequenceClassification;
|
||||
use crate::distilbert::DistilBertModelClassifier;
|
||||
use crate::longformer::LongformerForSequenceClassification;
|
||||
use crate::mobilebert::MobileBertForSequenceClassification;
|
||||
use crate::pipelines::common::{ConfigOption, ModelResource, ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{
|
||||
cast_var_store, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::pipelines::sequence_classification::Label;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForSequenceClassification;
|
||||
@ -146,6 +149,8 @@ pub struct ZeroShotClassificationConfig {
|
||||
pub add_prefix_space: Option<bool>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl ZeroShotClassificationConfig {
|
||||
@ -183,6 +188,7 @@ impl ZeroShotClassificationConfig {
|
||||
strip_accents: strip_accents.into(),
|
||||
add_prefix_space: add_prefix_space.into(),
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -209,6 +215,7 @@ impl Default for ZeroShotClassificationConfig {
|
||||
strip_accents: None,
|
||||
add_prefix_space: None,
|
||||
device: Device::cuda_if_available(),
|
||||
kind: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -217,11 +224,14 @@ impl Default for ZeroShotClassificationConfig {
|
||||
/// The models are using a classification architecture that should be trained on Natural Language Inference.
|
||||
/// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding
|
||||
/// to contradiction and the last logit corresponding to entailment.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum ZeroShotClassificationOption {
|
||||
/// Bart for Sequence Classification
|
||||
Bart(BartForSequenceClassification),
|
||||
/// DeBERTa for Sequence Classification
|
||||
Deberta(DebertaForSequenceClassification),
|
||||
/// DeBERTaV2 for Sequence Classification
|
||||
DebertaV2(DebertaV2ForSequenceClassification),
|
||||
/// Bert for Sequence Classification
|
||||
Bert(BertForSequenceClassification),
|
||||
/// DistilBert for Sequence Classification
|
||||
@ -288,6 +298,17 @@ impl ZeroShotClassificationOption {
|
||||
))
|
||||
}
|
||||
}
|
||||
ModelType::DebertaV2 => {
|
||||
if let ConfigOption::DebertaV2(config) = model_config {
|
||||
Ok(Self::DebertaV2(
|
||||
DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a DebertaConfig for DeBERTaV2!".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(Self::Bert(
|
||||
@ -322,13 +343,13 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(Self::Roberta(
|
||||
RobertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a BertConfig for Roberta!".to_string(),
|
||||
"You can only supply a RobertaConfig for Roberta!".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -385,6 +406,7 @@ impl ZeroShotClassificationOption {
|
||||
))),
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
cast_var_store(&mut var_store, config.kind, device);
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
@ -413,6 +435,7 @@ impl ZeroShotClassificationOption {
|
||||
match *self {
|
||||
Self::Bart(_) => ModelType::Bart,
|
||||
Self::Deberta(_) => ModelType::Deberta,
|
||||
Self::DebertaV2(_) => ModelType::DebertaV2,
|
||||
Self::Bert(_) => ModelType::Bert,
|
||||
Self::Roberta(_) => ModelType::Roberta,
|
||||
Self::XLMRoberta(_) => ModelType::Roberta,
|
||||
@ -474,6 +497,19 @@ impl ZeroShotClassificationOption {
|
||||
.expect("Error in DeBERTa forward_t")
|
||||
.logits
|
||||
}
|
||||
Self::DebertaV2(ref model) => {
|
||||
model
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.expect("Error in DeBERTaV2 forward_t")
|
||||
.logits
|
||||
}
|
||||
Self::DistilBert(ref model) => {
|
||||
model
|
||||
.forward_t(input_ids, mask, input_embeds, train)
|
||||
@ -643,7 +679,7 @@ impl ZeroShotClassificationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn new_with_tokenizer(
|
||||
pub fn new_with_tokenizer(
|
||||
config: ZeroShotClassificationConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<ZeroShotClassificationModel, RustBertError> {
|
||||
|
@ -35,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
|
||||
let config = AlbertConfig::from_file(config_path);
|
||||
let albert_model = AlbertForMaskedLM::new(vs.root(), &config);
|
||||
load_weights(&weights_resource, &mut vs)?;
|
||||
load_weights(&weights_resource, &mut vs, None, device)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
|
@ -2,7 +2,7 @@ use rust_bert::bart::{
|
||||
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelResource;
|
||||
use rust_bert::pipelines::common::{cast_var_store, ModelResource};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::pipelines::zero_shot_classification::{
|
||||
ZeroShotClassificationConfig, ZeroShotClassificationModel,
|
||||
@ -44,6 +44,7 @@ fn bart_lm_model() -> anyhow::Result<()> {
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let bart_model = BartModel::new(&vs.root() / "model", &config);
|
||||
vs.load(weights_path)?;
|
||||
cast_var_store(&mut vs, None, device);
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four"];
|
||||
|
@ -3,12 +3,12 @@ use rust_bert::gpt_j::{
|
||||
GptJVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::Cache;
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use std::convert::TryFrom;
|
||||
use tch::{nn, Device, Tensor};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
/// Equivalent Python code:
|
||||
///
|
||||
@ -67,14 +67,15 @@ fn gpt_j_correctness() -> anyhow::Result<()> {
|
||||
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let weights_path = model_resource.get_local_path()?;
|
||||
let mut config = GptJConfig::from_file(config_path);
|
||||
config.use_float16 = matches!(device, Device::Cuda(_));
|
||||
let config = GptJConfig::from_file(config_path);
|
||||
let model = GptJLMHeadModel::new(vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
let kind = match device {
|
||||
Device::Cpu => None,
|
||||
_ => Some(Kind::Half),
|
||||
};
|
||||
load_weights(&model_resource, &mut vs, kind, device)?;
|
||||
|
||||
// Tokenize prompts
|
||||
|
||||
let prompts = [
|
||||
"It was a very nice and sunny",
|
||||
"It was a gloom winter night, and",
|
||||
|
@ -234,15 +234,15 @@ mod tests {
|
||||
ModelType::M2M100,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/encoder_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
decoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/e775f50e63b178d82b8d736fc43fcf5ef15d2f6c/decoder_with_past_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
}),
|
||||
|
Loading…
Reference in New Issue
Block a user