diff --git a/Cargo.toml b/Cargo.toml index fa85db3..25d6bd9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "~7.0.2" -tch = "~0.9.0" +tch = "~0.10.1" serde_json = "1.0.82" serde = { version = "1.0.140", features = ["derive"] } ordered-float = "3.0.0" @@ -88,6 +88,6 @@ anyhow = "1.0.58" csv = "1.1.6" criterion = "0.3.6" tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] } -torch-sys = "0.9.0" +torch-sys = "0.10.0" tempfile = "3.3.0" itertools = "0.10.3" diff --git a/README.md b/README.md index a860c02..4b41100 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett ### Manual installation (recommended) -1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.0`: if this version is no longer available on the "get started" page, -the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11. +1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.1`: if this version is no longer available on the "get started" page, +the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.1%2Bcu117.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version). 2. Extract the library to a location of your choice 3. Set the following environment variables ##### Linux: diff --git a/examples/natural_language_inference_deberta.rs b/examples/natural_language_inference_deberta.rs index d63fa9e..e779a07 100644 --- a/examples/natural_language_inference_deberta.rs +++ b/examples/natural_language_inference_deberta.rs @@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> { false, )?; let config = DebertaConfig::from_file(config_path); - let model = DebertaForSequenceClassification::new(&vs.root(), &config); + let model = DebertaForSequenceClassification::new(vs.root(), &config); vs.load(weights_path)?; // Define input diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index edb4f69..2b9eb32 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -1101,7 +1101,7 @@ impl BartGenerator { generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = BartConfig::from_file(config_path); - let model = BartForConditionalGeneration::new(&var_store.root(), &config); + let model = BartForConditionalGeneration::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); @@ -1131,7 +1131,7 @@ impl BartGenerator { } fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { - let impossible_tokens: Vec = (0..self.get_vocab_size() as i64) + let impossible_tokens: Vec = (0..self.get_vocab_size()) .filter(|pos| !token_ids.contains(pos)) .collect(); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); @@ -1337,6 +1337,6 @@ mod test { let vs = tch::nn::VarStore::new(device); let config = BartConfig::from_file(config_path); - let _: Box = Box::new(BartModel::new(&vs.root(), &config)); + let _: Box = Box::new(BartModel::new(vs.root(), &config)); } } diff --git a/src/bert/bert_model.rs b/src/bert/bert_model.rs index 6868c1b..6b001bf 100644 --- a/src/bert/bert_model.rs +++ b/src/bert/bert_model.rs @@ -24,7 +24,7 @@ use crate::{Config, RustBertError}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; -use tch::nn::Init; +use tch::nn::init::DEFAULT_KAIMING_UNIFORM; use tch::{nn, Kind, Tensor}; /// # BERT Pretrained model weight files @@ -507,7 +507,7 @@ impl BertLMPredictionHead { config.vocab_size, Default::default(), ); - let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform); + let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM); BertLMPredictionHead { transform, @@ -1301,9 +1301,9 @@ mod test { // Set-up masked LM model let device = Device::cuda_if_available(); - let vs = tch::nn::VarStore::new(device); + let vs = nn::VarStore::new(device); let config = BertConfig::from_file(config_path); - let _: Box = Box::new(BertModel::::new(&vs.root(), &config)); + let _: Box = Box::new(BertModel::::new(vs.root(), &config)); } } diff --git a/src/common/linear.rs b/src/common/linear.rs index 48ad6c0..83cac42 100644 --- a/src/common/linear.rs +++ b/src/common/linear.rs @@ -11,6 +11,7 @@ // limitations under the License. use std::borrow::Borrow; +use tch::nn::init::DEFAULT_KAIMING_UNIFORM; use tch::nn::{Init, Module, Path}; use tch::Tensor; @@ -22,7 +23,7 @@ pub struct LinearNoBiasConfig { impl Default for LinearNoBiasConfig { fn default() -> Self { LinearNoBiasConfig { - ws_init: Init::KaimingUniform, + ws_init: DEFAULT_KAIMING_UNIFORM, } } } diff --git a/src/fnet/fnet_model.rs b/src/fnet/fnet_model.rs index 063c3d6..46c137b 100644 --- a/src/fnet/fnet_model.rs +++ b/src/fnet/fnet_model.rs @@ -1070,6 +1070,6 @@ mod test { let vs = tch::nn::VarStore::new(device); let config = FNetConfig::from_file(config_path); - let _: Box = Box::new(FNetModel::new(&vs.root(), &config, true)); + let _: Box = Box::new(FNetModel::new(vs.root(), &config, true)); } } diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index cebdf09..cbf442d 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -742,7 +742,7 @@ impl GPT2Generator { let mut var_store = nn::VarStore::new(device); let config = Gpt2Config::from_file(config_path); - let model = GPT2LMHeadModel::new(&var_store.root(), &config); + let model = GPT2LMHeadModel::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = tokenizer.get_bos_id(); diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index 795e3bb..97f50fe 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -716,7 +716,7 @@ impl GptNeoGenerator { generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = GptNeoConfig::from_file(config_path); - let model = GptNeoForCausalLM::new(&var_store.root(), &config)?; + let model = GptNeoForCausalLM::new(var_store.root(), &config)?; var_store.load(weights_path)?; let bos_token_id = tokenizer.get_bos_id(); diff --git a/src/m2m_100/embeddings.rs b/src/m2m_100/embeddings.rs index 8531999..be46de6 100644 --- a/src/m2m_100/embeddings.rs +++ b/src/m2m_100/embeddings.rs @@ -71,7 +71,7 @@ impl SinusoidalPositionalEmbedding { ) -> Tensor { let half_dim = embedding_dim / 2; - let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64); + let emb = -(10000f64.ln()) / ((half_dim - 1) as f64); let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp(); let emb = Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0); diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index f9d8e24..0bffbe8 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -651,7 +651,7 @@ impl M2M100Generator { let mut var_store = nn::VarStore::new(device); let config = M2M100Config::from_file(config_path); - let model = M2M100ForConditionalGeneration::new(&var_store.root(), &config); + let model = M2M100ForConditionalGeneration::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); @@ -681,7 +681,7 @@ impl M2M100Generator { } fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { - let impossible_tokens: Vec = (0..self.get_vocab_size() as i64) + let impossible_tokens: Vec = (0..self.get_vocab_size()) .filter(|pos| !token_ids.contains(pos)) .collect(); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); @@ -887,6 +887,6 @@ mod test { let vs = tch::nn::VarStore::new(device); let config = M2M100Config::from_file(config_path); - let _: Box = Box::new(M2M100Model::new(&vs.root(), &config)); + let _: Box = Box::new(M2M100Model::new(vs.root(), &config)); } } diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 6368d01..a2ea85c 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -872,7 +872,7 @@ impl MarianGenerator { let mut var_store = nn::VarStore::new(device); let config = BartConfig::from_file(config_path); - let model = MarianForConditionalGeneration::new(&var_store.root(), &config); + let model = MarianForConditionalGeneration::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); @@ -904,7 +904,7 @@ impl MarianGenerator { } fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { - let impossible_tokens: Vec = (0..self.get_vocab_size() as i64) + let impossible_tokens: Vec = (0..self.get_vocab_size()) .filter(|pos| !token_ids.contains(pos)) .collect(); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index 5a7cc72..f09c706 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -900,7 +900,7 @@ impl MBartGenerator { let mut var_store = nn::VarStore::new(device); let config = MBartConfig::from_file(config_path); - let model = MBartForConditionalGeneration::new(&var_store.root(), &config); + let model = MBartForConditionalGeneration::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); @@ -930,7 +930,7 @@ impl MBartGenerator { } fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { - let impossible_tokens: Vec = (0..self.get_vocab_size() as i64) + let impossible_tokens: Vec = (0..self.get_vocab_size()) .filter(|pos| !token_ids.contains(pos)) .collect(); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); @@ -1136,6 +1136,6 @@ mod test { let vs = tch::nn::VarStore::new(device); let config = MBartConfig::from_file(config_path); - let _: Box = Box::new(MBartModel::new(&vs.root(), &config)); + let _: Box = Box::new(MBartModel::new(vs.root(), &config)); } } diff --git a/src/mobilebert/mobilebert_model.rs b/src/mobilebert/mobilebert_model.rs index 23ea8d7..5cd836c 100644 --- a/src/mobilebert/mobilebert_model.rs +++ b/src/mobilebert/mobilebert_model.rs @@ -19,6 +19,7 @@ use crate::{Config, RustBertError}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; +use tch::nn::init::DEFAULT_KAIMING_UNIFORM; use tch::nn::{Init, LayerNormConfig, Module}; use tch::{nn, Kind, Tensor}; @@ -292,7 +293,7 @@ impl MobileBertLMPredictionHead { config.hidden_size - config.embedding_size, config.vocab_size, ], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0)); diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index 038809c..52adde7 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -504,7 +504,7 @@ impl OpenAIGenerator { let mut var_store = nn::VarStore::new(device); let config = Gpt2Config::from_file(config_path); - let model = OpenAIGPTLMHeadModel::new(&var_store.root(), &config); + let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = tokenizer.get_bos_id(); diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index df1bbba..bed3108 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -624,7 +624,7 @@ impl PegasusConditionalGenerator { generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = PegasusConfig::from_file(config_path); - let model = PegasusForConditionalGeneration::new(&var_store.root(), &config); + let model = PegasusForConditionalGeneration::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(0)); @@ -654,7 +654,7 @@ impl PegasusConditionalGenerator { } fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) { - let impossible_tokens: Vec = (0..self.get_vocab_size() as i64) + let impossible_tokens: Vec = (0..self.get_vocab_size()) .filter(|pos| !token_ids.contains(pos)) .collect(); let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device()); diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index d9a4ee3..710365c 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -403,7 +403,7 @@ pub(crate) mod private_generation_utils { prev_output_tokens: &Tensor, repetition_penalty: f64, ) { - for i in 0..(batch_size * num_beams as i64) { + for i in 0..(batch_size * num_beams) { for token_position in 0..prev_output_tokens.get(i).size()[0] { let token = prev_output_tokens.get(i).int64_value(&[token_position]); let updated_value = &next_token_logits.double_value(&[i, token]); @@ -826,8 +826,8 @@ pub(crate) mod private_generation_utils { if gen_opt.no_repeat_ngram_size > 0 { let banned_tokens = self.get_banned_tokens( &input_ids, - gen_opt.no_repeat_ngram_size as i64, - current_length as i64, + gen_opt.no_repeat_ngram_size, + current_length, ); for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) @@ -875,7 +875,7 @@ pub(crate) mod private_generation_utils { } self.top_k_top_p_filtering( &mut next_token_logits, - gen_opt.top_k as i64, + gen_opt.top_k, gen_opt.top_p, 1, ); @@ -915,7 +915,7 @@ pub(crate) mod private_generation_utils { &sentence_with_eos .to_kind(Kind::Bool) .to_device(sentence_lengths.device()), - current_length as i64 + 1, + current_length + 1, ); unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1); } @@ -943,7 +943,7 @@ pub(crate) mod private_generation_utils { &unfinished_sentences .to_kind(Kind::Bool) .to_device(sentence_lengths.device()), - current_length as i64, + current_length, ); break; } @@ -1927,10 +1927,7 @@ pub trait LanguageGenerator>: let batch_size = *input_ids.size().first().unwrap(); let (effective_batch_size, effective_batch_mult) = match do_sample { - true => ( - batch_size * num_return_sequences as i64, - num_return_sequences as i64, - ), + true => (batch_size * num_return_sequences, num_return_sequences), false => (batch_size, 1), }; @@ -1946,7 +1943,7 @@ pub trait LanguageGenerator>: let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap(); let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device())) .view((-1, 1)) - .repeat(&[1, num_beams as i64 * effective_batch_mult]) + .repeat(&[1, num_beams * effective_batch_mult]) .view(-1); Some(encoder_outputs.index_select(0, &expanded_batch_indices)) } else { @@ -1959,19 +1956,19 @@ pub trait LanguageGenerator>: input_ids .unsqueeze(1) .expand( - &[batch_size, effective_batch_mult * num_beams as i64, cur_len], + &[batch_size, effective_batch_mult * num_beams, cur_len], true, ) .contiguous() - .view((effective_batch_size * num_beams as i64, cur_len)), + .view((effective_batch_size * num_beams, cur_len)), attention_mask .unsqueeze(1) .expand( - &[batch_size, effective_batch_mult * num_beams as i64, cur_len], + &[batch_size, effective_batch_mult * num_beams, cur_len], true, ) .contiguous() - .view((effective_batch_size * num_beams as i64, cur_len)), + .view((effective_batch_size * num_beams, cur_len)), ) } else { (input_ids, attention_mask) @@ -1982,7 +1979,7 @@ pub trait LanguageGenerator>: .expect("decoder start id must be specified for encoder decoders") }); let input_ids = Tensor::full( - &[effective_batch_size * num_beams as i64, 1], + &[effective_batch_size * num_beams, 1], decoder_start_token_id, (Int64, input_ids.device()), ); @@ -1990,15 +1987,11 @@ pub trait LanguageGenerator>: attention_mask .unsqueeze(1) .expand( - &[ - batch_size, - effective_batch_mult * num_beams as i64, - input_ids_len, - ], + &[batch_size, effective_batch_mult * num_beams, input_ids_len], true, ) .contiguous() - .view((effective_batch_size * num_beams as i64, input_ids_len)) + .view((effective_batch_size * num_beams, input_ids_len)) } else { attention_mask }; diff --git a/src/pipelines/keywords_extraction/pipeline.rs b/src/pipelines/keywords_extraction/pipeline.rs index 51ccd86..3f74825 100644 --- a/src/pipelines/keywords_extraction/pipeline.rs +++ b/src/pipelines/keywords_extraction/pipeline.rs @@ -144,13 +144,13 @@ impl<'a> KeywordExtractionModel<'a> { config: KeywordExtractionConfig<'a>, ) -> Result, RustBertError> { let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file( - &config + config .sentence_embeddings_config .tokenizer_config_resource .get_local_path()?, ); let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file( - &config + config .sentence_embeddings_config .sentence_bert_config_resource .get_local_path()?, diff --git a/src/pipelines/masked_language.rs b/src/pipelines/masked_language.rs index 2ffad11..600446d 100644 --- a/src/pipelines/masked_language.rs +++ b/src/pipelines/masked_language.rs @@ -407,7 +407,7 @@ impl MaskedLanguageModel { .unwrap_or(usize::MAX); let language_encode = - MaskedLanguageOption::new(config.model_type, &var_store.root(), &model_config)?; + MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?; var_store.load(weights_path)?; let mask_token = config.mask_token; Ok(MaskedLanguageModel { diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index 742501e..d6b6c30 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -620,7 +620,7 @@ impl QuestionAnsweringModel { let qa_model = QuestionAnsweringOption::new( question_answering_config.model_type, - &var_store.root(), + var_store.root(), &model_config, )?; @@ -878,7 +878,7 @@ impl QuestionAnsweringModel { max_seq_length - sequence_pair_added_tokens - encoded_query.ids.len(); let mut start_token = 0_usize; - while (spans.len() * doc_stride as usize) < encoded_context.ids.len() { + while (spans.len() * doc_stride) < encoded_context.ids.len() { let end_token = min(start_token + max_context_length, encoded_context.ids.len()); let sub_encoded_context = TokenIdsWithOffsets { ids: encoded_context.ids[start_token..end_token].to_vec(), diff --git a/src/pipelines/sentence_embeddings/layers.rs b/src/pipelines/sentence_embeddings/layers.rs index fffdc7a..fa63282 100644 --- a/src/pipelines/sentence_embeddings/layers.rs +++ b/src/pipelines/sentence_embeddings/layers.rs @@ -130,7 +130,7 @@ impl Dense { bias: dense_conf.bias, }; let linear = nn::linear( - &vs_dense.root(), + vs_dense.root(), dense_conf.in_features, dense_conf.out_features, linear_conf, diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs index 29e2e52..fa4d35b 100644 --- a/src/pipelines/sentence_embeddings/pipeline.rs +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -230,11 +230,8 @@ impl SentenceEmbeddingsModel { transformer_type, transformer_config_resource.get_local_path()?, ); - let transformer = SentenceEmbeddingsOption::new( - transformer_type, - &var_store.root(), - &transformer_config, - )?; + let transformer = + SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?; var_store.load(transformer_weights_resource.get_local_path()?)?; // Setup pooling layer @@ -310,7 +307,7 @@ impl SentenceEmbeddingsModel { Tensor::of_slice( &input .iter() - .map(|&e| if e == pad_token_id { 0_i64 } else { 1_i64 }) + .map(|&e| i64::from(e != pad_token_id)) .collect::>(), ) }) diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 700fb27..db241f3 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -593,7 +593,7 @@ impl SequenceClassificationModel { .map(|v| v as usize) .unwrap_or(usize::MAX); let sequence_classifier = - SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config)?; + SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?; let label_mapping = model_config.get_label_mapping().clone(); var_store.load(weights_path)?; Ok(SequenceClassificationModel { diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 59cd1e3..7766ceb 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -699,7 +699,7 @@ impl TokenClassificationModel { .map(|v| v as usize) .unwrap_or(usize::MAX); let token_sequence_classifier = - TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config)?; + TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?; let label_mapping = model_config.get_label_mapping().clone(); let batch_size = config.batch_size; var_store.load(weights_path)?; @@ -749,7 +749,7 @@ impl TokenClassificationModel { let mut start_token = 0_usize; let total_length = encoded_input.ids.len(); - while (spans.len() * doc_stride as usize) < encoded_input.ids.len() { + while (spans.len() * doc_stride) < encoded_input.ids.len() { let end_token = min(start_token + max_content_length, total_length); let sub_encoded_input = TokenIdsWithOffsets { ids: encoded_input.ids[start_token..end_token].to_vec(), @@ -994,8 +994,8 @@ impl TokenClassificationModel { position_idx: i64, word_index: u16, ) -> Token { - let label_id = labels.int64_value(&[position_idx as i64]); - let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]); + let label_id = labels.int64_value(&[position_idx]); + let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]); let offsets = &sentence_tokens.offsets[position_idx as usize]; diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index c538324..0a69f40 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -575,7 +575,7 @@ impl ZeroShotClassificationModel { let mut var_store = VarStore::new(device); let model_config = ConfigOption::from_file(config.model_type, config_path); let zero_shot_classifier = - ZeroShotClassificationOption::new(config.model_type, &var_store.root(), &model_config)?; + ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?; var_store.load(weights_path)?; Ok(ZeroShotClassificationModel { tokenizer, diff --git a/src/prophetnet/decoder.rs b/src/prophetnet/decoder.rs index 434d4a0..6e83c12 100644 --- a/src/prophetnet/decoder.rs +++ b/src/prophetnet/decoder.rs @@ -21,7 +21,7 @@ use crate::prophetnet::embeddings::ProphetNetPositionalEmbeddings; use crate::prophetnet::ProphetNetConfig; use crate::RustBertError; use std::borrow::{Borrow, BorrowMut}; -use tch::nn::Init; +use tch::nn::init::DEFAULT_KAIMING_UNIFORM; use tch::{nn, Device, Kind, Tensor}; fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor { @@ -210,7 +210,7 @@ impl ProphetNetDecoder { let ngram_embeddings = p_ngram_embedding.var( "weight", &[config.ngram, config.hidden_size], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let output_attentions = config.output_attentions.unwrap_or(false); diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index c035c64..cb3efd2 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -965,7 +965,7 @@ impl ProphetNetConditionalGenerator { generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = ProphetNetConfig::from_file(config_path); - let model = ProphetNetForConditionalGeneration::new(&var_store.root(), &config)?; + let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?; var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id); diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index fb0b304..c367254 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -354,7 +354,7 @@ impl ReformerModel { let must_pad_to_match_chunk_length = (input_shape.last().unwrap() % self.least_common_mult_chunk_length != 0) - & (*input_shape.last().unwrap() as i64 > self.min_chunk_length) + & (*input_shape.last().unwrap() > self.min_chunk_length) & old_layer_states.is_none(); let start_idx_pos_encodings = if let Some(layer_states) = &old_layer_states { @@ -1091,7 +1091,7 @@ impl ReformerGenerator { generate_config.validate(); let mut var_store = nn::VarStore::new(device); let config = ReformerConfig::from_file(config_path); - let model = ReformerModelWithLMHead::new(&var_store.root(), &config)?; + let model = ReformerModelWithLMHead::new(var_store.root(), &config)?; var_store.load(weights_path)?; let bos_token_id = tokenizer.get_bos_id(); diff --git a/src/roberta/roberta_model.rs b/src/roberta/roberta_model.rs index b384665..a96d72e 100644 --- a/src/roberta/roberta_model.rs +++ b/src/roberta/roberta_model.rs @@ -17,7 +17,7 @@ use crate::common::dropout::Dropout; use crate::common::linear::{linear_no_bias, LinearNoBias}; use crate::roberta::embeddings::RobertaEmbeddings; use std::borrow::Borrow; -use tch::nn::Init; +use tch::nn::init::DEFAULT_KAIMING_UNIFORM; use tch::{nn, Tensor}; /// # RoBERTa Pretrained model weight files @@ -218,7 +218,7 @@ impl RobertaLMHead { config.vocab_size, Default::default(), ); - let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform); + let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM); RobertaLMHead { dense, diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index d33b3b9..3e40581 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -881,7 +881,7 @@ impl T5Generator { let mut var_store = nn::VarStore::new(device); let config = T5Config::from_file(config_path); - let model = T5ForConditionalGeneration::new(&var_store.root(), &config); + let model = T5ForConditionalGeneration::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(-1)); diff --git a/src/xlnet/attention.rs b/src/xlnet/attention.rs index f5e4cba..4492a58 100644 --- a/src/xlnet/attention.rs +++ b/src/xlnet/attention.rs @@ -15,7 +15,7 @@ use crate::common::dropout::Dropout; use crate::xlnet::XLNetConfig; use std::borrow::Borrow; -use tch::nn::Init; +use tch::nn::init::DEFAULT_KAIMING_UNIFORM; use tch::{nn, Kind, Tensor}; #[derive(Debug)] @@ -72,52 +72,52 @@ impl XLNetRelativeAttention { let query = p.var( "q", &[config.d_model, config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let key = p.var( "k", &[config.d_model, config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let value = p.var( "v", &[config.d_model, config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let output = p.var( "o", &[config.d_model, config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let pos = p.var( "r", &[config.d_model, config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let r_r_bias = p.var( "r_r_bias", &[config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let r_s_bias = p.var( "r_s_bias", &[config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let r_w_bias = p.var( "r_w_bias", &[config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let seg_embed = p.var( "seg_embed", &[2, config.n_head, config.d_head], - Init::KaimingUniform, + DEFAULT_KAIMING_UNIFORM, ); let dropout = Dropout::new(config.dropout); diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 38ab301..60a0d7b 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -1648,7 +1648,7 @@ impl XLNetGenerator { let mut var_store = nn::VarStore::new(device); let config = XLNetConfig::from_file(config_path); - let model = XLNetLMHeadModel::new(&var_store.root(), &config); + let model = XLNetLMHeadModel::new(var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id); diff --git a/tests/albert.rs b/tests/albert.rs index 9d787ec..b95b040 100644 --- a/tests/albert.rs +++ b/tests/albert.rs @@ -35,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> { let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?; let config = AlbertConfig::from_file(config_path); - let albert_model = AlbertForMaskedLM::new(&vs.root(), &config); + let albert_model = AlbertForMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -109,7 +109,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config); + let albert_model = AlbertForSequenceClassification::new(vs.root(), &config); // Define input let input = [ @@ -170,7 +170,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> { let mut config = AlbertConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config); + let albert_model = AlbertForMultipleChoice::new(vs.root(), &config); // Define input let input = [ @@ -242,7 +242,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let albert_model = AlbertForTokenClassification::new(&vs.root(), &config); + let albert_model = AlbertForTokenClassification::new(vs.root(), &config); // Define input let input = [ @@ -303,7 +303,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> { let mut config = AlbertConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config); + let albert_model = AlbertForQuestionAnswering::new(vs.root(), &config); // Define input let input = [ diff --git a/tests/bert.rs b/tests/bert.rs index d790922..318d4bb 100644 --- a/tests/bert.rs +++ b/tests/bert.rs @@ -35,7 +35,7 @@ fn bert_masked_lm() -> anyhow::Result<()> { let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = BertConfig::from_file(config_path); - let bert_model = BertForMaskedLM::new(&vs.root(), &config); + let bert_model = BertForMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -162,7 +162,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let bert_model = BertForSequenceClassification::new(&vs.root(), &config); + let bert_model = BertForSequenceClassification::new(vs.root(), &config); // Define input let input = [ @@ -219,7 +219,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> { let mut config = BertConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let bert_model = BertForMultipleChoice::new(&vs.root(), &config); + let bert_model = BertForMultipleChoice::new(vs.root(), &config); // Define input let input = [ @@ -283,7 +283,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let bert_model = BertForTokenClassification::new(&vs.root(), &config); + let bert_model = BertForTokenClassification::new(vs.root(), &config); // Define input let input = [ @@ -340,7 +340,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> { let mut config = BertConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let bert_model = BertForQuestionAnswering::new(&vs.root(), &config); + let bert_model = BertForQuestionAnswering::new(vs.root(), &config); // Define input let input = [ diff --git a/tests/deberta.rs b/tests/deberta.rs index 9cd8537..f696c14 100644 --- a/tests/deberta.rs +++ b/tests/deberta.rs @@ -41,7 +41,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> { false, )?; let config = DebertaConfig::from_file(config_path); - let model = DebertaForSequenceClassification::new(&vs.root(), &config); + let model = DebertaForSequenceClassification::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -96,7 +96,7 @@ fn deberta_masked_lm() -> anyhow::Result<()> { let mut config = DebertaConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let deberta_model = DebertaForMaskedLM::new(&vs.root(), &config); + let deberta_model = DebertaForMaskedLM::new(vs.root(), &config); // Generate random input let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device)); @@ -170,7 +170,7 @@ fn deberta_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); - let model = DebertaForTokenClassification::new(&vs.root(), &config); + let model = DebertaForTokenClassification::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; @@ -225,7 +225,7 @@ fn deberta_for_question_answering() -> anyhow::Result<()> { false, )?; let config = DebertaConfig::from_file(config_path); - let model = DebertaForQuestionAnswering::new(&vs.root(), &config); + let model = DebertaForQuestionAnswering::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; diff --git a/tests/deberta_v2.rs b/tests/deberta_v2.rs index 12ceedf..0a46ee9 100644 --- a/tests/deberta_v2.rs +++ b/tests/deberta_v2.rs @@ -22,7 +22,7 @@ fn deberta_v2_masked_lm() -> anyhow::Result<()> { let mut config = DebertaV2Config::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let deberta_model = DebertaV2ForMaskedLM::new(&vs.root(), &config); + let deberta_model = DebertaV2ForMaskedLM::new(vs.root(), &config); // Generate random input let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device)); @@ -88,7 +88,7 @@ fn deberta_v2_for_sequence_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(1, String::from("Neutral")); dummy_label_mapping.insert(2, String::from("Negative")); config.id2label = Some(dummy_label_mapping); - let model = DebertaV2ForSequenceClassification::new(&vs.root(), &config); + let model = DebertaV2ForSequenceClassification::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; @@ -142,7 +142,7 @@ fn deberta_v2_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); - let model = DebertaV2ForTokenClassification::new(&vs.root(), &config); + let model = DebertaV2ForTokenClassification::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; @@ -190,7 +190,7 @@ fn deberta_v2_for_question_answering() -> anyhow::Result<()> { let tokenizer = DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?; let config = DebertaV2Config::from_file(config_path); - let model = DebertaV2ForQuestionAnswering::new(&vs.root(), &config); + let model = DebertaV2ForQuestionAnswering::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; diff --git a/tests/distilbert.rs b/tests/distilbert.rs index ae9c563..52d14bb 100644 --- a/tests/distilbert.rs +++ b/tests/distilbert.rs @@ -61,7 +61,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> { let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = DistilBertConfig::from_file(config_path); - let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); + let distil_bert_model = DistilBertModelMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -140,7 +140,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> { let mut config = DistilBertConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config); + let distil_bert_model = DistilBertForQuestionAnswering::new(vs.root(), &config); // Define input let input = [ @@ -211,7 +211,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); - let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config); + let distil_bert_model = DistilBertForTokenClassification::new(vs.root(), &config); // Define input let input = [ diff --git a/tests/distilgpt2.rs b/tests/distilgpt2.rs index a20bf91..1ca7096 100644 --- a/tests/distilgpt2.rs +++ b/tests/distilgpt2.rs @@ -37,7 +37,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { false, )?; let config = Gpt2Config::from_file(config_path); - let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); + let gpt2_model = GPT2LMHeadModel::new(vs.root(), &config); vs.load(weights_path)?; // Define input diff --git a/tests/electra.rs b/tests/electra.rs index 0f25b34..be50622 100644 --- a/tests/electra.rs +++ b/tests/electra.rs @@ -32,7 +32,7 @@ fn electra_masked_lm() -> anyhow::Result<()> { let mut config = ElectraConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let electra_model = ElectraForMaskedLM::new(&vs.root(), &config); + let electra_model = ElectraForMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -114,7 +114,7 @@ fn electra_discriminator() -> anyhow::Result<()> { let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = ElectraConfig::from_file(config_path); - let electra_model = ElectraDiscriminator::new(&vs.root(), &config); + let electra_model = ElectraDiscriminator::new(vs.root(), &config); vs.load(weights_path)?; // Define input diff --git a/tests/fnet.rs b/tests/fnet.rs index 64d36b3..d1323a3 100644 --- a/tests/fnet.rs +++ b/tests/fnet.rs @@ -30,7 +30,7 @@ fn fnet_masked_lm() -> anyhow::Result<()> { let tokenizer: FNetTokenizer = FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?; let config = FNetConfig::from_file(config_path); - let fnet_model = FNetForMaskedLM::new(&vs.root(), &config); + let fnet_model = FNetForMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -138,7 +138,7 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> { let mut config = FNetConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let fnet_model = FNetForMultipleChoice::new(&vs.root(), &config); + let fnet_model = FNetForMultipleChoice::new(vs.root(), &config); // Define input let input = [ @@ -201,7 +201,7 @@ fn fnet_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); config.output_hidden_states = Some(true); - let fnet_model = FNetForTokenClassification::new(&vs.root(), &config); + let fnet_model = FNetForTokenClassification::new(vs.root(), &config); // Define input let input = [ @@ -256,7 +256,7 @@ fn fnet_for_question_answering() -> anyhow::Result<()> { FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?; let mut config = FNetConfig::from_file(config_path); config.output_hidden_states = Some(true); - let fnet_model = FNetForQuestionAnswering::new(&vs.root(), &config); + let fnet_model = FNetForQuestionAnswering::new(vs.root(), &config); // Define input let input = [ diff --git a/tests/gpt2.rs b/tests/gpt2.rs index 07beb02..143ccd2 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -35,7 +35,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> { false, )?; let config = Gpt2Config::from_file(config_path); - let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); + let gpt2_model = GPT2LMHeadModel::new(vs.root(), &config); vs.load(weights_path)?; // Define input diff --git a/tests/gpt_neo.rs b/tests/gpt_neo.rs index 37a8fce..3618e00 100644 --- a/tests/gpt_neo.rs +++ b/tests/gpt_neo.rs @@ -40,7 +40,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> { let mut config = GptNeoConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let gpt_neo_model = GptNeoForCausalLM::new(&vs.root(), &config)?; + let gpt_neo_model = GptNeoForCausalLM::new(vs.root(), &config)?; vs.load(weights_path)?; // Define input diff --git a/tests/longformer.rs b/tests/longformer.rs index dc235a1..b7295c7 100644 --- a/tests/longformer.rs +++ b/tests/longformer.rs @@ -197,7 +197,7 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(1, String::from("Negative")); dummy_label_mapping.insert(3, String::from("Neutral")); config.id2label = Some(dummy_label_mapping); - let model = LongformerForSequenceClassification::new(&vs.root(), &config); + let model = LongformerForSequenceClassification::new(vs.root(), &config); // Define input let input = ["Very positive sentence", "Second sentence input"]; @@ -258,7 +258,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> { false, )?; let config = LongformerConfig::from_file(config_path); - let model = LongformerForMultipleChoice::new(&vs.root(), &config); + let model = LongformerForMultipleChoice::new(vs.root(), &config); // Define input let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; @@ -337,7 +337,7 @@ fn longformer_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); - let model = LongformerForTokenClassification::new(&vs.root(), &config); + let model = LongformerForTokenClassification::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; diff --git a/tests/mobilebert.rs b/tests/mobilebert.rs index 50c3d04..09a4b86 100644 --- a/tests/mobilebert.rs +++ b/tests/mobilebert.rs @@ -35,7 +35,7 @@ fn mobilebert_masked_model() -> anyhow::Result<()> { let mut config = MobileBertConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let mobilebert_model = MobileBertForMaskedLM::new(&vs.root(), &config); + let mobilebert_model = MobileBertForMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -130,7 +130,7 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(1, String::from("Negative")); dummy_label_mapping.insert(3, String::from("Neutral")); config.id2label = Some(dummy_label_mapping); - let model = MobileBertForSequenceClassification::new(&vs.root(), &config); + let model = MobileBertForSequenceClassification::new(vs.root(), &config); // Define input let input = ["Very positive sentence", "Second sentence input"]; @@ -176,7 +176,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> { let vs = nn::VarStore::new(device); let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = MobileBertConfig::from_file(config_path); - let model = MobileBertForMultipleChoice::new(&vs.root(), &config); + let model = MobileBertForMultipleChoice::new(vs.root(), &config); // Define input let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; @@ -240,7 +240,7 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); - let model = MobileBertForTokenClassification::new(&vs.root(), &config); + let model = MobileBertForTokenClassification::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; @@ -287,7 +287,7 @@ fn mobilebert_for_question_answering() -> anyhow::Result<()> { let vs = nn::VarStore::new(device); let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = MobileBertConfig::from_file(config_path); - let model = MobileBertForQuestionAnswering::new(&vs.root(), &config); + let model = MobileBertForQuestionAnswering::new(vs.root(), &config); // Define input let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"]; diff --git a/tests/openai_gpt.rs b/tests/openai_gpt.rs index 7680cca..0e12f5d 100644 --- a/tests/openai_gpt.rs +++ b/tests/openai_gpt.rs @@ -39,7 +39,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> { true, )?; let config = OpenAiGptConfig::from_file(config_path); - let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config); + let openai_gpt = OpenAIGPTLMHeadModel::new(vs.root(), &config); vs.load(weights_path)?; // Define input diff --git a/tests/reformer.rs b/tests/reformer.rs index 5ac52a8..50b0e1d 100644 --- a/tests/reformer.rs +++ b/tests/reformer.rs @@ -98,7 +98,7 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let reformer_model = ReformerForSequenceClassification::new(&vs.root(), &config)?; + let reformer_model = ReformerForSequenceClassification::new(vs.root(), &config)?; // Define input let input = [ @@ -159,7 +159,7 @@ fn reformer_for_question_answering() -> anyhow::Result<()> { let mut config = ReformerConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let reformer_model = ReformerForQuestionAnswering::new(&vs.root(), &config)?; + let reformer_model = ReformerForQuestionAnswering::new(vs.root(), &config)?; // Define input let input = [ diff --git a/tests/roberta.rs b/tests/roberta.rs index dbabeec..eb8976c 100644 --- a/tests/roberta.rs +++ b/tests/roberta.rs @@ -41,7 +41,7 @@ fn roberta_masked_lm() -> anyhow::Result<()> { false, )?; let config = RobertaConfig::from_file(config_path); - let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config); + let roberta_model = RobertaForMaskedLM::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -136,7 +136,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config); + let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config); // Define input let input = [ @@ -201,7 +201,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> { let mut config = RobertaConfig::from_file(config_path); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config); + let roberta_model = RobertaForMultipleChoice::new(vs.root(), &config); // Define input let input = [ @@ -273,7 +273,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config); + let roberta_model = RobertaForTokenClassification::new(vs.root(), &config); // Define input let input = [ diff --git a/tests/xlnet.rs b/tests/xlnet.rs index 2561c29..ba0d6b5 100644 --- a/tests/xlnet.rs +++ b/tests/xlnet.rs @@ -141,7 +141,7 @@ fn xlnet_lm_model() -> anyhow::Result<()> { let tokenizer: XLNetTokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?; let config = XLNetConfig::from_file(config_path); - let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config); + let xlnet_model = XLNetLMHeadModel::new(vs.root(), &config); vs.load(weights_path)?; // Define input @@ -257,7 +257,7 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> { config.id2label = Some(dummy_label_mapping); config.output_attentions = Some(true); config.output_hidden_states = Some(true); - let xlnet_model = XLNetForSequenceClassification::new(&vs.root(), &config)?; + let xlnet_model = XLNetForSequenceClassification::new(vs.root(), &config)?; // Define input let input = ["Very positive sentence", "Second sentence input"]; @@ -322,7 +322,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> { let vs = nn::VarStore::new(device); let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = XLNetConfig::from_file(config_path); - let xlnet_model = XLNetForMultipleChoice::new(&vs.root(), &config)?; + let xlnet_model = XLNetForMultipleChoice::new(vs.root(), &config)?; // Define input let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; @@ -396,7 +396,7 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> { dummy_label_mapping.insert(2, String::from("PER")); dummy_label_mapping.insert(3, String::from("ORG")); config.id2label = Some(dummy_label_mapping); - let xlnet_model = XLNetForTokenClassification::new(&vs.root(), &config)?; + let xlnet_model = XLNetForTokenClassification::new(vs.root(), &config)?; // Define input let inputs = ["Where's Paris?", "In Kentucky, United States"]; @@ -453,7 +453,7 @@ fn xlnet_for_question_answering() -> anyhow::Result<()> { let vs = nn::VarStore::new(device); let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; let config = XLNetConfig::from_file(config_path); - let xlnet_model = XLNetForQuestionAnswering::new(&vs.root(), &config)?; + let xlnet_model = XLNetForQuestionAnswering::new(vs.root(), &config)?; // Define input let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];