Fixed Clippy warnings (#309)

* - Fixed Clippy warnings
- Updated `tch` dependency
- Updated README to avoid confusion with respect to the required `LIBTORCH` version for the repository and published package versions

* Fixed Clippy warnings (2)

* Fixed Clippy warnings (3)
This commit is contained in:
guillaume-be 2022-12-21 17:52:26 +00:00 committed by GitHub
parent dae899fea6
commit fdf5503163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 130 additions and 138 deletions

View File

@ -70,7 +70,7 @@ features = ["doc-only"]
[dependencies] [dependencies]
rust_tokenizers = "~7.0.2" rust_tokenizers = "~7.0.2"
tch = "~0.9.0" tch = "~0.10.1"
serde_json = "1.0.82" serde_json = "1.0.82"
serde = { version = "1.0.140", features = ["derive"] } serde = { version = "1.0.140", features = ["derive"] }
ordered-float = "3.0.0" ordered-float = "3.0.0"
@ -88,6 +88,6 @@ anyhow = "1.0.58"
csv = "1.1.6" csv = "1.1.6"
criterion = "0.3.6" criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] } tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.9.0" torch-sys = "0.10.0"
tempfile = "3.3.0" tempfile = "3.3.0"
itertools = "0.10.3" itertools = "0.10.3"

View File

@ -76,8 +76,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended) ### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.0`: if this version is no longer available on the "get started" page, 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11. the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.1%2Bcu117.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice 2. Extract the library to a location of your choice
3. Set the following environment variables 3. Set the following environment variables
##### Linux: ##### Linux:

View File

@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = DebertaConfig::from_file(config_path); let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(&vs.root(), &config); let model = DebertaForSequenceClassification::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input

View File

@ -1101,7 +1101,7 @@ impl BartGenerator {
generate_config.validate(); generate_config.validate();
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path); let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(&var_store.root(), &config); let model = BartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -1131,7 +1131,7 @@ impl BartGenerator {
} }
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64) let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos)) .filter(|pos| !token_ids.contains(pos))
.collect(); .collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
@ -1337,6 +1337,6 @@ mod test {
let vs = tch::nn::VarStore::new(device); let vs = tch::nn::VarStore::new(device);
let config = BartConfig::from_file(config_path); let config = BartConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(BartModel::new(&vs.root(), &config)); let _: Box<dyn Send> = Box::new(BartModel::new(vs.root(), &config));
} }
} }

View File

@ -24,7 +24,7 @@ use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Borrow; use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use tch::nn::Init; use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Kind, Tensor}; use tch::{nn, Kind, Tensor};
/// # BERT Pretrained model weight files /// # BERT Pretrained model weight files
@ -507,7 +507,7 @@ impl BertLMPredictionHead {
config.vocab_size, config.vocab_size,
Default::default(), Default::default(),
); );
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform); let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
BertLMPredictionHead { BertLMPredictionHead {
transform, transform,
@ -1301,9 +1301,9 @@ mod test {
// Set-up masked LM model // Set-up masked LM model
let device = Device::cuda_if_available(); let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let config = BertConfig::from_file(config_path); let config = BertConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(&vs.root(), &config)); let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(vs.root(), &config));
} }
} }

View File

@ -11,6 +11,7 @@
// limitations under the License. // limitations under the License.
use std::borrow::Borrow; use std::borrow::Borrow;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::nn::{Init, Module, Path}; use tch::nn::{Init, Module, Path};
use tch::Tensor; use tch::Tensor;
@ -22,7 +23,7 @@ pub struct LinearNoBiasConfig {
impl Default for LinearNoBiasConfig { impl Default for LinearNoBiasConfig {
fn default() -> Self { fn default() -> Self {
LinearNoBiasConfig { LinearNoBiasConfig {
ws_init: Init::KaimingUniform, ws_init: DEFAULT_KAIMING_UNIFORM,
} }
} }
} }

View File

@ -1070,6 +1070,6 @@ mod test {
let vs = tch::nn::VarStore::new(device); let vs = tch::nn::VarStore::new(device);
let config = FNetConfig::from_file(config_path); let config = FNetConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(FNetModel::new(&vs.root(), &config, true)); let _: Box<dyn Send> = Box::new(FNetModel::new(vs.root(), &config, true));
} }
} }

View File

@ -742,7 +742,7 @@ impl GPT2Generator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(&var_store.root(), &config); let model = GPT2LMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id(); let bos_token_id = tokenizer.get_bos_id();

View File

@ -716,7 +716,7 @@ impl GptNeoGenerator {
generate_config.validate(); generate_config.validate();
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path); let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(&var_store.root(), &config)?; let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id(); let bos_token_id = tokenizer.get_bos_id();

View File

@ -71,7 +71,7 @@ impl SinusoidalPositionalEmbedding {
) -> Tensor { ) -> Tensor {
let half_dim = embedding_dim / 2; let half_dim = embedding_dim / 2;
let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64); let emb = -(10000f64.ln()) / ((half_dim - 1) as f64);
let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp(); let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp();
let emb = let emb =
Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0); Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0);

View File

@ -651,7 +651,7 @@ impl M2M100Generator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = M2M100Config::from_file(config_path); let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(&var_store.root(), &config); let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -681,7 +681,7 @@ impl M2M100Generator {
} }
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64) let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos)) .filter(|pos| !token_ids.contains(pos))
.collect(); .collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
@ -887,6 +887,6 @@ mod test {
let vs = tch::nn::VarStore::new(device); let vs = tch::nn::VarStore::new(device);
let config = M2M100Config::from_file(config_path); let config = M2M100Config::from_file(config_path);
let _: Box<dyn Send> = Box::new(M2M100Model::new(&vs.root(), &config)); let _: Box<dyn Send> = Box::new(M2M100Model::new(vs.root(), &config));
} }
} }

View File

@ -872,7 +872,7 @@ impl MarianGenerator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path); let config = BartConfig::from_file(config_path);
let model = MarianForConditionalGeneration::new(&var_store.root(), &config); let model = MarianForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -904,7 +904,7 @@ impl MarianGenerator {
} }
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64) let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos)) .filter(|pos| !token_ids.contains(pos))
.collect(); .collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());

View File

@ -900,7 +900,7 @@ impl MBartGenerator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = MBartConfig::from_file(config_path); let config = MBartConfig::from_file(config_path);
let model = MBartForConditionalGeneration::new(&var_store.root(), &config); let model = MBartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -930,7 +930,7 @@ impl MBartGenerator {
} }
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64) let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos)) .filter(|pos| !token_ids.contains(pos))
.collect(); .collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
@ -1136,6 +1136,6 @@ mod test {
let vs = tch::nn::VarStore::new(device); let vs = tch::nn::VarStore::new(device);
let config = MBartConfig::from_file(config_path); let config = MBartConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(MBartModel::new(&vs.root(), &config)); let _: Box<dyn Send> = Box::new(MBartModel::new(vs.root(), &config));
} }
} }

View File

@ -19,6 +19,7 @@ use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Borrow; use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::nn::{Init, LayerNormConfig, Module}; use tch::nn::{Init, LayerNormConfig, Module};
use tch::{nn, Kind, Tensor}; use tch::{nn, Kind, Tensor};
@ -292,7 +293,7 @@ impl MobileBertLMPredictionHead {
config.hidden_size - config.embedding_size, config.hidden_size - config.embedding_size,
config.vocab_size, config.vocab_size,
], ],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0)); let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0));

View File

@ -504,7 +504,7 @@ impl OpenAIGenerator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(&var_store.root(), &config); let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id(); let bos_token_id = tokenizer.get_bos_id();

View File

@ -624,7 +624,7 @@ impl PegasusConditionalGenerator {
generate_config.validate(); generate_config.validate();
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = PegasusConfig::from_file(config_path); let config = PegasusConfig::from_file(config_path);
let model = PegasusForConditionalGeneration::new(&var_store.root(), &config); let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -654,7 +654,7 @@ impl PegasusConditionalGenerator {
} }
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64) let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos)) .filter(|pos| !token_ids.contains(pos))
.collect(); .collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());

View File

@ -403,7 +403,7 @@ pub(crate) mod private_generation_utils {
prev_output_tokens: &Tensor, prev_output_tokens: &Tensor,
repetition_penalty: f64, repetition_penalty: f64,
) { ) {
for i in 0..(batch_size * num_beams as i64) { for i in 0..(batch_size * num_beams) {
for token_position in 0..prev_output_tokens.get(i).size()[0] { for token_position in 0..prev_output_tokens.get(i).size()[0] {
let token = prev_output_tokens.get(i).int64_value(&[token_position]); let token = prev_output_tokens.get(i).int64_value(&[token_position]);
let updated_value = &next_token_logits.double_value(&[i, token]); let updated_value = &next_token_logits.double_value(&[i, token]);
@ -826,8 +826,8 @@ pub(crate) mod private_generation_utils {
if gen_opt.no_repeat_ngram_size > 0 { if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens( let banned_tokens = self.get_banned_tokens(
&input_ids, &input_ids,
gen_opt.no_repeat_ngram_size as i64, gen_opt.no_repeat_ngram_size,
current_length as i64, current_length,
); );
for (batch_index, index_banned_token) in for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens) (0..banned_tokens.len() as i64).zip(banned_tokens)
@ -875,7 +875,7 @@ pub(crate) mod private_generation_utils {
} }
self.top_k_top_p_filtering( self.top_k_top_p_filtering(
&mut next_token_logits, &mut next_token_logits,
gen_opt.top_k as i64, gen_opt.top_k,
gen_opt.top_p, gen_opt.top_p,
1, 1,
); );
@ -915,7 +915,7 @@ pub(crate) mod private_generation_utils {
&sentence_with_eos &sentence_with_eos
.to_kind(Kind::Bool) .to_kind(Kind::Bool)
.to_device(sentence_lengths.device()), .to_device(sentence_lengths.device()),
current_length as i64 + 1, current_length + 1,
); );
unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1); unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1);
} }
@ -943,7 +943,7 @@ pub(crate) mod private_generation_utils {
&unfinished_sentences &unfinished_sentences
.to_kind(Kind::Bool) .to_kind(Kind::Bool)
.to_device(sentence_lengths.device()), .to_device(sentence_lengths.device()),
current_length as i64, current_length,
); );
break; break;
} }
@ -1927,10 +1927,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let batch_size = *input_ids.size().first().unwrap(); let batch_size = *input_ids.size().first().unwrap();
let (effective_batch_size, effective_batch_mult) = match do_sample { let (effective_batch_size, effective_batch_mult) = match do_sample {
true => ( true => (batch_size * num_return_sequences, num_return_sequences),
batch_size * num_return_sequences as i64,
num_return_sequences as i64,
),
false => (batch_size, 1), false => (batch_size, 1),
}; };
@ -1946,7 +1943,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap(); let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap();
let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device())) let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
.view((-1, 1)) .view((-1, 1))
.repeat(&[1, num_beams as i64 * effective_batch_mult]) .repeat(&[1, num_beams * effective_batch_mult])
.view(-1); .view(-1);
Some(encoder_outputs.index_select(0, &expanded_batch_indices)) Some(encoder_outputs.index_select(0, &expanded_batch_indices))
} else { } else {
@ -1959,19 +1956,19 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
input_ids input_ids
.unsqueeze(1) .unsqueeze(1)
.expand( .expand(
&[batch_size, effective_batch_mult * num_beams as i64, cur_len], &[batch_size, effective_batch_mult * num_beams, cur_len],
true, true,
) )
.contiguous() .contiguous()
.view((effective_batch_size * num_beams as i64, cur_len)), .view((effective_batch_size * num_beams, cur_len)),
attention_mask attention_mask
.unsqueeze(1) .unsqueeze(1)
.expand( .expand(
&[batch_size, effective_batch_mult * num_beams as i64, cur_len], &[batch_size, effective_batch_mult * num_beams, cur_len],
true, true,
) )
.contiguous() .contiguous()
.view((effective_batch_size * num_beams as i64, cur_len)), .view((effective_batch_size * num_beams, cur_len)),
) )
} else { } else {
(input_ids, attention_mask) (input_ids, attention_mask)
@ -1982,7 +1979,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
.expect("decoder start id must be specified for encoder decoders") .expect("decoder start id must be specified for encoder decoders")
}); });
let input_ids = Tensor::full( let input_ids = Tensor::full(
&[effective_batch_size * num_beams as i64, 1], &[effective_batch_size * num_beams, 1],
decoder_start_token_id, decoder_start_token_id,
(Int64, input_ids.device()), (Int64, input_ids.device()),
); );
@ -1990,15 +1987,11 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
attention_mask attention_mask
.unsqueeze(1) .unsqueeze(1)
.expand( .expand(
&[ &[batch_size, effective_batch_mult * num_beams, input_ids_len],
batch_size,
effective_batch_mult * num_beams as i64,
input_ids_len,
],
true, true,
) )
.contiguous() .contiguous()
.view((effective_batch_size * num_beams as i64, input_ids_len)) .view((effective_batch_size * num_beams, input_ids_len))
} else { } else {
attention_mask attention_mask
}; };

View File

@ -144,13 +144,13 @@ impl<'a> KeywordExtractionModel<'a> {
config: KeywordExtractionConfig<'a>, config: KeywordExtractionConfig<'a>,
) -> Result<KeywordExtractionModel<'a>, RustBertError> { ) -> Result<KeywordExtractionModel<'a>, RustBertError> {
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file( let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
&config config
.sentence_embeddings_config .sentence_embeddings_config
.tokenizer_config_resource .tokenizer_config_resource
.get_local_path()?, .get_local_path()?,
); );
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file( let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
&config config
.sentence_embeddings_config .sentence_embeddings_config
.sentence_bert_config_resource .sentence_bert_config_resource
.get_local_path()?, .get_local_path()?,

View File

@ -407,7 +407,7 @@ impl MaskedLanguageModel {
.unwrap_or(usize::MAX); .unwrap_or(usize::MAX);
let language_encode = let language_encode =
MaskedLanguageOption::new(config.model_type, &var_store.root(), &model_config)?; MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?;
var_store.load(weights_path)?; var_store.load(weights_path)?;
let mask_token = config.mask_token; let mask_token = config.mask_token;
Ok(MaskedLanguageModel { Ok(MaskedLanguageModel {

View File

@ -620,7 +620,7 @@ impl QuestionAnsweringModel {
let qa_model = QuestionAnsweringOption::new( let qa_model = QuestionAnsweringOption::new(
question_answering_config.model_type, question_answering_config.model_type,
&var_store.root(), var_store.root(),
&model_config, &model_config,
)?; )?;
@ -878,7 +878,7 @@ impl QuestionAnsweringModel {
max_seq_length - sequence_pair_added_tokens - encoded_query.ids.len(); max_seq_length - sequence_pair_added_tokens - encoded_query.ids.len();
let mut start_token = 0_usize; let mut start_token = 0_usize;
while (spans.len() * doc_stride as usize) < encoded_context.ids.len() { while (spans.len() * doc_stride) < encoded_context.ids.len() {
let end_token = min(start_token + max_context_length, encoded_context.ids.len()); let end_token = min(start_token + max_context_length, encoded_context.ids.len());
let sub_encoded_context = TokenIdsWithOffsets { let sub_encoded_context = TokenIdsWithOffsets {
ids: encoded_context.ids[start_token..end_token].to_vec(), ids: encoded_context.ids[start_token..end_token].to_vec(),

View File

@ -130,7 +130,7 @@ impl Dense {
bias: dense_conf.bias, bias: dense_conf.bias,
}; };
let linear = nn::linear( let linear = nn::linear(
&vs_dense.root(), vs_dense.root(),
dense_conf.in_features, dense_conf.in_features,
dense_conf.out_features, dense_conf.out_features,
linear_conf, linear_conf,

View File

@ -230,11 +230,8 @@ impl SentenceEmbeddingsModel {
transformer_type, transformer_type,
transformer_config_resource.get_local_path()?, transformer_config_resource.get_local_path()?,
); );
let transformer = SentenceEmbeddingsOption::new( let transformer =
transformer_type, SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
&var_store.root(),
&transformer_config,
)?;
var_store.load(transformer_weights_resource.get_local_path()?)?; var_store.load(transformer_weights_resource.get_local_path()?)?;
// Setup pooling layer // Setup pooling layer
@ -310,7 +307,7 @@ impl SentenceEmbeddingsModel {
Tensor::of_slice( Tensor::of_slice(
&input &input
.iter() .iter()
.map(|&e| if e == pad_token_id { 0_i64 } else { 1_i64 }) .map(|&e| i64::from(e != pad_token_id))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
) )
}) })

View File

@ -593,7 +593,7 @@ impl SequenceClassificationModel {
.map(|v| v as usize) .map(|v| v as usize)
.unwrap_or(usize::MAX); .unwrap_or(usize::MAX);
let sequence_classifier = let sequence_classifier =
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config)?; SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone(); let label_mapping = model_config.get_label_mapping().clone();
var_store.load(weights_path)?; var_store.load(weights_path)?;
Ok(SequenceClassificationModel { Ok(SequenceClassificationModel {

View File

@ -699,7 +699,7 @@ impl TokenClassificationModel {
.map(|v| v as usize) .map(|v| v as usize)
.unwrap_or(usize::MAX); .unwrap_or(usize::MAX);
let token_sequence_classifier = let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config)?; TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone(); let label_mapping = model_config.get_label_mapping().clone();
let batch_size = config.batch_size; let batch_size = config.batch_size;
var_store.load(weights_path)?; var_store.load(weights_path)?;
@ -749,7 +749,7 @@ impl TokenClassificationModel {
let mut start_token = 0_usize; let mut start_token = 0_usize;
let total_length = encoded_input.ids.len(); let total_length = encoded_input.ids.len();
while (spans.len() * doc_stride as usize) < encoded_input.ids.len() { while (spans.len() * doc_stride) < encoded_input.ids.len() {
let end_token = min(start_token + max_content_length, total_length); let end_token = min(start_token + max_content_length, total_length);
let sub_encoded_input = TokenIdsWithOffsets { let sub_encoded_input = TokenIdsWithOffsets {
ids: encoded_input.ids[start_token..end_token].to_vec(), ids: encoded_input.ids[start_token..end_token].to_vec(),
@ -994,8 +994,8 @@ impl TokenClassificationModel {
position_idx: i64, position_idx: i64,
word_index: u16, word_index: u16,
) -> Token { ) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]); let label_id = labels.int64_value(&[position_idx]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]); let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
let offsets = &sentence_tokens.offsets[position_idx as usize]; let offsets = &sentence_tokens.offsets[position_idx as usize];

View File

@ -575,7 +575,7 @@ impl ZeroShotClassificationModel {
let mut var_store = VarStore::new(device); let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path); let model_config = ConfigOption::from_file(config.model_type, config_path);
let zero_shot_classifier = let zero_shot_classifier =
ZeroShotClassificationOption::new(config.model_type, &var_store.root(), &model_config)?; ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
var_store.load(weights_path)?; var_store.load(weights_path)?;
Ok(ZeroShotClassificationModel { Ok(ZeroShotClassificationModel {
tokenizer, tokenizer,

View File

@ -21,7 +21,7 @@ use crate::prophetnet::embeddings::ProphetNetPositionalEmbeddings;
use crate::prophetnet::ProphetNetConfig; use crate::prophetnet::ProphetNetConfig;
use crate::RustBertError; use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut}; use std::borrow::{Borrow, BorrowMut};
use tch::nn::Init; use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Device, Kind, Tensor}; use tch::{nn, Device, Kind, Tensor};
fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor { fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor {
@ -210,7 +210,7 @@ impl ProphetNetDecoder {
let ngram_embeddings = p_ngram_embedding.var( let ngram_embeddings = p_ngram_embedding.var(
"weight", "weight",
&[config.ngram, config.hidden_size], &[config.ngram, config.hidden_size],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let output_attentions = config.output_attentions.unwrap_or(false); let output_attentions = config.output_attentions.unwrap_or(false);

View File

@ -965,7 +965,7 @@ impl ProphetNetConditionalGenerator {
generate_config.validate(); generate_config.validate();
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = ProphetNetConfig::from_file(config_path); let config = ProphetNetConfig::from_file(config_path);
let model = ProphetNetForConditionalGeneration::new(&var_store.root(), &config)?; let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id); let bos_token_id = Some(config.bos_token_id);

View File

@ -354,7 +354,7 @@ impl ReformerModel {
let must_pad_to_match_chunk_length = let must_pad_to_match_chunk_length =
(input_shape.last().unwrap() % self.least_common_mult_chunk_length != 0) (input_shape.last().unwrap() % self.least_common_mult_chunk_length != 0)
& (*input_shape.last().unwrap() as i64 > self.min_chunk_length) & (*input_shape.last().unwrap() > self.min_chunk_length)
& old_layer_states.is_none(); & old_layer_states.is_none();
let start_idx_pos_encodings = if let Some(layer_states) = &old_layer_states { let start_idx_pos_encodings = if let Some(layer_states) = &old_layer_states {
@ -1091,7 +1091,7 @@ impl ReformerGenerator {
generate_config.validate(); generate_config.validate();
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = ReformerConfig::from_file(config_path); let config = ReformerConfig::from_file(config_path);
let model = ReformerModelWithLMHead::new(&var_store.root(), &config)?; let model = ReformerModelWithLMHead::new(var_store.root(), &config)?;
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id(); let bos_token_id = tokenizer.get_bos_id();

View File

@ -17,7 +17,7 @@ use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias}; use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings; use crate::roberta::embeddings::RobertaEmbeddings;
use std::borrow::Borrow; use std::borrow::Borrow;
use tch::nn::Init; use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
/// # RoBERTa Pretrained model weight files /// # RoBERTa Pretrained model weight files
@ -218,7 +218,7 @@ impl RobertaLMHead {
config.vocab_size, config.vocab_size,
Default::default(), Default::default(),
); );
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform); let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
RobertaLMHead { RobertaLMHead {
dense, dense,

View File

@ -881,7 +881,7 @@ impl T5Generator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = T5Config::from_file(config_path); let config = T5Config::from_file(config_path);
let model = T5ForConditionalGeneration::new(&var_store.root(), &config); let model = T5ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1)); let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));

View File

@ -15,7 +15,7 @@
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::xlnet::XLNetConfig; use crate::xlnet::XLNetConfig;
use std::borrow::Borrow; use std::borrow::Borrow;
use tch::nn::Init; use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Kind, Tensor}; use tch::{nn, Kind, Tensor};
#[derive(Debug)] #[derive(Debug)]
@ -72,52 +72,52 @@ impl XLNetRelativeAttention {
let query = p.var( let query = p.var(
"q", "q",
&[config.d_model, config.n_head, config.d_head], &[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let key = p.var( let key = p.var(
"k", "k",
&[config.d_model, config.n_head, config.d_head], &[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let value = p.var( let value = p.var(
"v", "v",
&[config.d_model, config.n_head, config.d_head], &[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let output = p.var( let output = p.var(
"o", "o",
&[config.d_model, config.n_head, config.d_head], &[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let pos = p.var( let pos = p.var(
"r", "r",
&[config.d_model, config.n_head, config.d_head], &[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let r_r_bias = p.var( let r_r_bias = p.var(
"r_r_bias", "r_r_bias",
&[config.n_head, config.d_head], &[config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let r_s_bias = p.var( let r_s_bias = p.var(
"r_s_bias", "r_s_bias",
&[config.n_head, config.d_head], &[config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let r_w_bias = p.var( let r_w_bias = p.var(
"r_w_bias", "r_w_bias",
&[config.n_head, config.d_head], &[config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let seg_embed = p.var( let seg_embed = p.var(
"seg_embed", "seg_embed",
&[2, config.n_head, config.d_head], &[2, config.n_head, config.d_head],
Init::KaimingUniform, DEFAULT_KAIMING_UNIFORM,
); );
let dropout = Dropout::new(config.dropout); let dropout = Dropout::new(config.dropout);

View File

@ -1648,7 +1648,7 @@ impl XLNetGenerator {
let mut var_store = nn::VarStore::new(device); let mut var_store = nn::VarStore::new(device);
let config = XLNetConfig::from_file(config_path); let config = XLNetConfig::from_file(config_path);
let model = XLNetLMHeadModel::new(&var_store.root(), &config); let model = XLNetLMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id); let bos_token_id = Some(config.bos_token_id);

View File

@ -35,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
let tokenizer: AlbertTokenizer = let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?; AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
let config = AlbertConfig::from_file(config_path); let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config); let albert_model = AlbertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -109,7 +109,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config); let albert_model = AlbertForSequenceClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -170,7 +170,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
let mut config = AlbertConfig::from_file(config_path); let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config); let albert_model = AlbertForMultipleChoice::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -242,7 +242,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForTokenClassification::new(&vs.root(), &config); let albert_model = AlbertForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -303,7 +303,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
let mut config = AlbertConfig::from_file(config_path); let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config); let albert_model = AlbertForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [

View File

@ -35,7 +35,7 @@ fn bert_masked_lm() -> anyhow::Result<()> {
let tokenizer: BertTokenizer = let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = BertConfig::from_file(config_path); let config = BertConfig::from_file(config_path);
let bert_model = BertForMaskedLM::new(&vs.root(), &config); let bert_model = BertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -162,7 +162,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForSequenceClassification::new(&vs.root(), &config); let bert_model = BertForSequenceClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -219,7 +219,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
let mut config = BertConfig::from_file(config_path); let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForMultipleChoice::new(&vs.root(), &config); let bert_model = BertForMultipleChoice::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -283,7 +283,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForTokenClassification::new(&vs.root(), &config); let bert_model = BertForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -340,7 +340,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
let mut config = BertConfig::from_file(config_path); let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config); let bert_model = BertForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [

View File

@ -41,7 +41,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = DebertaConfig::from_file(config_path); let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(&vs.root(), &config); let model = DebertaForSequenceClassification::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -96,7 +96,7 @@ fn deberta_masked_lm() -> anyhow::Result<()> {
let mut config = DebertaConfig::from_file(config_path); let mut config = DebertaConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let deberta_model = DebertaForMaskedLM::new(&vs.root(), &config); let deberta_model = DebertaForMaskedLM::new(vs.root(), &config);
// Generate random input // Generate random input
let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device)); let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device));
@ -170,7 +170,7 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = DebertaForTokenClassification::new(&vs.root(), &config); let model = DebertaForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"]; let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -225,7 +225,7 @@ fn deberta_for_question_answering() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = DebertaConfig::from_file(config_path); let config = DebertaConfig::from_file(config_path);
let model = DebertaForQuestionAnswering::new(&vs.root(), &config); let model = DebertaForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];

View File

@ -22,7 +22,7 @@ fn deberta_v2_masked_lm() -> anyhow::Result<()> {
let mut config = DebertaV2Config::from_file(config_path); let mut config = DebertaV2Config::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let deberta_model = DebertaV2ForMaskedLM::new(&vs.root(), &config); let deberta_model = DebertaV2ForMaskedLM::new(vs.root(), &config);
// Generate random input // Generate random input
let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device)); let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device));
@ -88,7 +88,7 @@ fn deberta_v2_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Neutral")); dummy_label_mapping.insert(1, String::from("Neutral"));
dummy_label_mapping.insert(2, String::from("Negative")); dummy_label_mapping.insert(2, String::from("Negative"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = DebertaV2ForSequenceClassification::new(&vs.root(), &config); let model = DebertaV2ForSequenceClassification::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"]; let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -142,7 +142,7 @@ fn deberta_v2_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = DebertaV2ForTokenClassification::new(&vs.root(), &config); let model = DebertaV2ForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"]; let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -190,7 +190,7 @@ fn deberta_v2_for_question_answering() -> anyhow::Result<()> {
let tokenizer = let tokenizer =
DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?; DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?;
let config = DebertaV2Config::from_file(config_path); let config = DebertaV2Config::from_file(config_path);
let model = DebertaV2ForQuestionAnswering::new(&vs.root(), &config); let model = DebertaV2ForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];

View File

@ -61,7 +61,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
let tokenizer: BertTokenizer = let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = DistilBertConfig::from_file(config_path); let config = DistilBertConfig::from_file(config_path);
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); let distil_bert_model = DistilBertModelMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -140,7 +140,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
let mut config = DistilBertConfig::from_file(config_path); let mut config = DistilBertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config); let distil_bert_model = DistilBertForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -211,7 +211,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config); let distil_bert_model = DistilBertForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [

View File

@ -37,7 +37,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); let gpt2_model = GPT2LMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input

View File

@ -32,7 +32,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
let mut config = ElectraConfig::from_file(config_path); let mut config = ElectraConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config); let electra_model = ElectraForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -114,7 +114,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
let tokenizer: BertTokenizer = let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = ElectraConfig::from_file(config_path); let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraDiscriminator::new(&vs.root(), &config); let electra_model = ElectraDiscriminator::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input

View File

@ -30,7 +30,7 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
let tokenizer: FNetTokenizer = let tokenizer: FNetTokenizer =
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?; FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let config = FNetConfig::from_file(config_path); let config = FNetConfig::from_file(config_path);
let fnet_model = FNetForMaskedLM::new(&vs.root(), &config); let fnet_model = FNetForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -138,7 +138,7 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> {
let mut config = FNetConfig::from_file(config_path); let mut config = FNetConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let fnet_model = FNetForMultipleChoice::new(&vs.root(), &config); let fnet_model = FNetForMultipleChoice::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -201,7 +201,7 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let fnet_model = FNetForTokenClassification::new(&vs.root(), &config); let fnet_model = FNetForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -256,7 +256,7 @@ fn fnet_for_question_answering() -> anyhow::Result<()> {
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?; FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let mut config = FNetConfig::from_file(config_path); let mut config = FNetConfig::from_file(config_path);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let fnet_model = FNetForQuestionAnswering::new(&vs.root(), &config); let fnet_model = FNetForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [

View File

@ -35,7 +35,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); let gpt2_model = GPT2LMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input

View File

@ -40,7 +40,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
let mut config = GptNeoConfig::from_file(config_path); let mut config = GptNeoConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let gpt_neo_model = GptNeoForCausalLM::new(&vs.root(), &config)?; let gpt_neo_model = GptNeoForCausalLM::new(vs.root(), &config)?;
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input

View File

@ -197,7 +197,7 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Negative")); dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral")); dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = LongformerForSequenceClassification::new(&vs.root(), &config); let model = LongformerForSequenceClassification::new(vs.root(), &config);
// Define input // Define input
let input = ["Very positive sentence", "Second sentence input"]; let input = ["Very positive sentence", "Second sentence input"];
@ -258,7 +258,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = LongformerConfig::from_file(config_path); let config = LongformerConfig::from_file(config_path);
let model = LongformerForMultipleChoice::new(&vs.root(), &config); let model = LongformerForMultipleChoice::new(vs.root(), &config);
// Define input // Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
@ -337,7 +337,7 @@ fn longformer_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = LongformerForTokenClassification::new(&vs.root(), &config); let model = LongformerForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"]; let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -35,7 +35,7 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
let mut config = MobileBertConfig::from_file(config_path); let mut config = MobileBertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let mobilebert_model = MobileBertForMaskedLM::new(&vs.root(), &config); let mobilebert_model = MobileBertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -130,7 +130,7 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Negative")); dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral")); dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = MobileBertForSequenceClassification::new(&vs.root(), &config); let model = MobileBertForSequenceClassification::new(vs.root(), &config);
// Define input // Define input
let input = ["Very positive sentence", "Second sentence input"]; let input = ["Very positive sentence", "Second sentence input"];
@ -176,7 +176,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path); let config = MobileBertConfig::from_file(config_path);
let model = MobileBertForMultipleChoice::new(&vs.root(), &config); let model = MobileBertForMultipleChoice::new(vs.root(), &config);
// Define input // Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
@ -240,7 +240,7 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let model = MobileBertForTokenClassification::new(&vs.root(), &config); let model = MobileBertForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"]; let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -287,7 +287,7 @@ fn mobilebert_for_question_answering() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path); let config = MobileBertConfig::from_file(config_path);
let model = MobileBertForQuestionAnswering::new(&vs.root(), &config); let model = MobileBertForQuestionAnswering::new(vs.root(), &config);
// Define input // Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];

View File

@ -39,7 +39,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
true, true,
)?; )?;
let config = OpenAiGptConfig::from_file(config_path); let config = OpenAiGptConfig::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config); let openai_gpt = OpenAIGPTLMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input

View File

@ -98,7 +98,7 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let reformer_model = ReformerForSequenceClassification::new(&vs.root(), &config)?; let reformer_model = ReformerForSequenceClassification::new(vs.root(), &config)?;
// Define input // Define input
let input = [ let input = [
@ -159,7 +159,7 @@ fn reformer_for_question_answering() -> anyhow::Result<()> {
let mut config = ReformerConfig::from_file(config_path); let mut config = ReformerConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let reformer_model = ReformerForQuestionAnswering::new(&vs.root(), &config)?; let reformer_model = ReformerForQuestionAnswering::new(vs.root(), &config)?;
// Define input // Define input
let input = [ let input = [

View File

@ -41,7 +41,7 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
false, false,
)?; )?;
let config = RobertaConfig::from_file(config_path); let config = RobertaConfig::from_file(config_path);
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config); let roberta_model = RobertaForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -136,7 +136,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config); let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -201,7 +201,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
let mut config = RobertaConfig::from_file(config_path); let mut config = RobertaConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config); let roberta_model = RobertaForMultipleChoice::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [
@ -273,7 +273,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config); let roberta_model = RobertaForTokenClassification::new(vs.root(), &config);
// Define input // Define input
let input = [ let input = [

View File

@ -141,7 +141,7 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
let tokenizer: XLNetTokenizer = let tokenizer: XLNetTokenizer =
XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?; XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?;
let config = XLNetConfig::from_file(config_path); let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config); let xlnet_model = XLNetLMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
@ -257,7 +257,7 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let xlnet_model = XLNetForSequenceClassification::new(&vs.root(), &config)?; let xlnet_model = XLNetForSequenceClassification::new(vs.root(), &config)?;
// Define input // Define input
let input = ["Very positive sentence", "Second sentence input"]; let input = ["Very positive sentence", "Second sentence input"];
@ -322,7 +322,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = XLNetConfig::from_file(config_path); let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetForMultipleChoice::new(&vs.root(), &config)?; let xlnet_model = XLNetForMultipleChoice::new(vs.root(), &config)?;
// Define input // Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
@ -396,7 +396,7 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG")); dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let xlnet_model = XLNetForTokenClassification::new(&vs.root(), &config)?; let xlnet_model = XLNetForTokenClassification::new(vs.root(), &config)?;
// Define input // Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"]; let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -453,7 +453,7 @@ fn xlnet_for_question_answering() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = XLNetConfig::from_file(config_path); let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetForQuestionAnswering::new(&vs.root(), &config)?; let xlnet_model = XLNetForQuestionAnswering::new(vs.root(), &config)?;
// Define input // Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];