From c6771d39929d6490db38848d0761b57996992130 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Mon, 7 Nov 2022 17:45:52 +0000 Subject: [PATCH] Update to `tch=0.9.0` (#293) * Fixed short sentence input and added documentation * Fixed Clippy warnings * Updated CI Python version * cleaner dim specification --- .github/workflows/continuous-integration.yml | 2 +- CHANGELOG.md | 1 + Cargo.toml | 4 +-- README.md | 6 ++-- examples/sentiment_analysis.rs | 2 +- examples/sentiment_analysis_fnet.rs | 2 +- examples/sequence_classification.rs | 2 +- examples/zero_shot_classification.rs | 2 +- requirements.txt | 5 +-- src/albert/attention.rs | 2 +- src/bart/bart_model.rs | 11 +++--- src/common/summary.rs | 5 ++- src/deberta/deberta_model.rs | 4 +-- src/deberta_v2/encoder.rs | 2 +- src/lib.rs | 6 ++-- src/longformer/attention.rs | 6 ++-- src/mbart/mbart_model.rs | 4 +-- src/pipelines/generation_utils.rs | 7 ++-- src/pipelines/sentence_embeddings/layers.rs | 10 +++--- src/pipelines/token_classification.rs | 5 ++- src/prophetnet/attention.rs | 2 ++ src/reformer/attention.rs | 19 ++++++---- src/t5/attention.rs | 2 +- src/t5/layer_norm.rs | 9 ++--- src/xlnet/attention.rs | 38 +++++++++++++------- src/xlnet/xlnet_model.rs | 2 +- tests/bart.rs | 4 +-- tests/distilbert.rs | 2 +- tests/fnet.rs | 2 +- 29 files changed, 103 insertions(+), 65 deletions(-) diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 1afa042..b942996 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -128,7 +128,7 @@ jobs: override: true - uses: actions/setup-python@v2 with: - python-version: '3.7' + python-version: '3.10' - run: | pip install -r requirements.txt --progress-bar off python ./utils/download-dependencies_distilbert.py diff --git a/CHANGELOG.md b/CHANGELOG.md index aa43e7f..b68bdcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`) - (BREAKING) `merges_resource` now optional for all pipelines - Allow mixing local and remote resources in pipelines +- Upgraded to `torch` 1.13 (via `tch` 0.9.0) ## Fixed - Fixed configuration check for RoBERTa models for sentence classification. diff --git a/Cargo.toml b/Cargo.toml index b94031b..6e3ae14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "~7.0.2" -tch = "~0.8.0" +tch = "~0.9.0" serde_json = "1.0.82" serde = { version = "1.0.140", features = ["derive"] } ordered-float = "3.0.0" @@ -81,6 +81,6 @@ anyhow = "1.0.58" csv = "1.1.6" criterion = "0.3.6" tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] } -torch-sys = "~0.8.0" +torch-sys = "~0.9.0" tempfile = "3.3.0" itertools = "0.10.3" diff --git a/README.md b/README.md index 9839cbe..f91c0ce 100644 --- a/README.md +++ b/README.md @@ -75,8 +75,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett ### Manual installation (recommended) -1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.12.0`: if this version is no longer available on the "get started" page, -the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11. +1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.0`: if this version is no longer available on the "get started" page, +the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11. 2. Extract the library to a location of your choice 3. Set the following environment variables ##### Linux: @@ -94,7 +94,7 @@ $Env:Path += ";X:\path\to\libtorch\lib" ### Automatic installation Alternatively, you can let the `build` script automatically download the `libtorch` library for you. -The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu113`. +The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`. Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. ## Ready-to-use pipelines diff --git a/examples/sentiment_analysis.rs b/examples/sentiment_analysis.rs index 5a5b910..f339bbf 100644 --- a/examples/sentiment_analysis.rs +++ b/examples/sentiment_analysis.rs @@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> { ]; // Run model - let output = sentiment_classifier.predict(&input); + let output = sentiment_classifier.predict(input); for sentiment in output { println!("{:?}", sentiment); } diff --git a/examples/sentiment_analysis_fnet.rs b/examples/sentiment_analysis_fnet.rs index 03f00f8..7d942ce 100644 --- a/examples/sentiment_analysis_fnet.rs +++ b/examples/sentiment_analysis_fnet.rs @@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> { ]; // Run model - let output = sentiment_classifier.predict(&input); + let output = sentiment_classifier.predict(input); for sentiment in output { println!("{:?}", sentiment); } diff --git a/examples/sequence_classification.rs b/examples/sequence_classification.rs index fd705dd..200d4e5 100644 --- a/examples/sequence_classification.rs +++ b/examples/sequence_classification.rs @@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> { ]; // Run model - let output = sequence_classification_model.predict(&input); + let output = sequence_classification_model.predict(input); for label in output { println!("{:?}", label); } diff --git a/examples/zero_shot_classification.rs b/examples/zero_shot_classification.rs index 631d185..ec08178 100644 --- a/examples/zero_shot_classification.rs +++ b/examples/zero_shot_classification.rs @@ -23,7 +23,7 @@ fn main() -> anyhow::Result<()> { let candidate_labels = &["politics", "public health", "economy", "sports"]; let output = sequence_classification_model.predict_multilabel( - &[input_sentence, input_sequence_2], + [input_sentence, input_sequence_2], candidate_labels, Some(Box::new(|label: &str| { format!("This example is about {}.", label) diff --git a/requirements.txt b/requirements.txt index d5d812b..ffd2dc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -torch == 1.8.1 -requests == 2.25.1 \ No newline at end of file +torch == 1.13.0 +requests == 2.25.1 +numpy == 1.23.4 \ No newline at end of file diff --git a/src/albert/attention.rs b/src/albert/attention.rs index 5d6fbee..97fa539 100644 --- a/src/albert/attention.rs +++ b/src/albert/attention.rs @@ -131,7 +131,7 @@ impl AlbertSelfAttention { )); let context: Tensor = - Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap(); + Tensor::einsum("bfnd,ndh->bfh", &[context, w], None) + self.dense.bs.as_ref().unwrap(); let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm); if !self.output_attentions { diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index fc2eb33..eed0cf4 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -351,10 +351,11 @@ pub(crate) fn _prepare_decoder_attention_mask( } fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { - let index_eos: Tensor = input_ids - .ne(pad_token_id) - .sum_dim_intlist(&[-1], true, Kind::Int64) - - 1; + let index_eos: Tensor = + input_ids + .ne(pad_token_id) + .sum_dim_intlist([-1].as_slice(), true, Kind::Int64) + - 1; let output = input_ids.empty_like().to_kind(Kind::Int64); output .select(1, 0) @@ -857,7 +858,7 @@ impl BartForSequenceClassification { train, ); let eos_mask = input_ids.eq(self.eos_token_id); - let reshape = eos_mask.sum_dim_intlist(&[1], true, input_ids.kind()); + let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind()); let sentence_representation = base_model_output .decoder_output .permute(&[2, 0, 1]) diff --git a/src/common/summary.rs b/src/common/summary.rs index a937bf7..d06d773 100644 --- a/src/common/summary.rs +++ b/src/common/summary.rs @@ -132,7 +132,10 @@ impl SequenceSummary { let mut output = match self.summary_type { SummaryType::last => hidden_states.select(1, -1), SummaryType::first => hidden_states.select(1, 0), - SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()), + + SummaryType::mean => { + hidden_states.mean_dim([1].as_slice(), false, hidden_states.kind()) + } SummaryType::cls_index => { let cls_index = if let Some(cls_index_value) = cls_index { let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1]; diff --git a/src/deberta/deberta_model.rs b/src/deberta/deberta_model.rs index 257d9f2..ee2086e 100644 --- a/src/deberta/deberta_model.rs +++ b/src/deberta/deberta_model.rs @@ -303,9 +303,9 @@ impl Module for DebertaLayerNorm { fn forward(&self, hidden_states: &Tensor) -> Tensor { let input_type = hidden_states.kind(); let hidden_states = hidden_states.to_kind(Kind::Float); - let mean = hidden_states.mean_dim(&[-1], true, hidden_states.kind()); + let mean = hidden_states.mean_dim([-1].as_slice(), true, hidden_states.kind()); let variance = (&hidden_states - &mean).pow_tensor_scalar(2.0).mean_dim( - &[-1], + [-1].as_slice(), true, hidden_states.kind(), ); diff --git a/src/deberta_v2/encoder.rs b/src/deberta_v2/encoder.rs index 9c3a2b8..4318627 100644 --- a/src/deberta_v2/encoder.rs +++ b/src/deberta_v2/encoder.rs @@ -291,7 +291,7 @@ impl DebertaV2Encoder { attention_mask.shallow_clone() } else { attention_mask - .sum_dim_intlist(&[-2], false, attention_mask.kind()) + .sum_dim_intlist([-2].as_slice(), false, attention_mask.kind()) .gt(0) .to_kind(Kind::Uint8) }; diff --git a/src/lib.rs b/src/lib.rs index 1654dd4..fc88a44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,8 +85,8 @@ //! //! ### Manual installation (recommended) //! -//! 1. Download `libtorch` from . This package requires `v1.12.0`: if this version is no longer available on the "get started" page, -//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11. +//! 1. Download `libtorch` from . This package requires `v1.13.0`: if this version is no longer available on the "get started" page, +//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11. //! 2. Extract the library to a location of your choice //! 3. Set the following environment variables //! ##### Linux: @@ -104,7 +104,7 @@ //! ### Automatic installation //! //! Alternatively, you can let the `build` script automatically download the `libtorch` library for you. -//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu116`. +//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`. //! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. //! //! # Ready-to-use pipelines diff --git a/src/longformer/attention.rs b/src/longformer/attention.rs index abba66b..adcd2f7 100644 --- a/src/longformer/attention.rs +++ b/src/longformer/attention.rs @@ -223,7 +223,7 @@ impl LongformerSelfAttention { let key = self.chunk(&key, window_overlap); let diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims( - &Tensor::einsum("bcxd,bcyd->bcxy", &[query, key]), + &Tensor::einsum("bcxd,bcyd->bcxy", &[query, key], None), &[0, 0, 0, 1], ); @@ -353,6 +353,7 @@ impl LongformerSelfAttention { Tensor::einsum( "bcwd,bcdh->bcwh", &[chunked_attention_probas, chunked_value], + None, ) .view([batch_size, num_heads, sequence_length, head_dim]) .transpose(1, 2) @@ -363,7 +364,7 @@ impl LongformerSelfAttention { is_index_global_attn: &Tensor, ) -> GlobalAttentionIndices { let num_global_attention_indices = - is_index_global_attn.sum_dim_intlist(&[1], false, Kind::Int64); + is_index_global_attn.sum_dim_intlist([1].as_slice(), false, Kind::Int64); let max_num_global_attention_indices = i64::from(num_global_attention_indices.max()); let is_index_global_attn_nonzero = is_index_global_attn .nonzero_numpy() @@ -428,6 +429,7 @@ impl LongformerSelfAttention { let attention_probas_from_global_key = Tensor::einsum( "blhd,bshd->blhs", &[query_vectors, &key_vectors_only_global], + None, ); let _ = attention_probas_from_global_key diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index f1f2247..9a1bb1c 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -160,7 +160,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id); let index_eos: Tensor = input_ids .ne(pad_token_id) - .sum_dim_intlist(&[1], true, Int64) + .sum_dim_intlist([1].as_slice(), true, Int64) - 1; output .select(1, 0) @@ -682,7 +682,7 @@ impl MBartForSequenceClassification { train, ); let eos_mask = input_ids.eq(self.eos_token_id); - let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64); + let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, Int64); let sentence_representation = base_model_output .decoder_output .permute(&[2, 0, 1]) diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 6fd21a7..cf5d853 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -932,8 +932,11 @@ pub(crate) mod private_generation_utils { current_length += 1; } let scores_output = token_scores_output.as_ref().map(|scores_tensor| { - (Tensor::stack(scores_tensor, 1).sum_dim_intlist(&[1], false, Kind::Float) - / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty)) + (Tensor::stack(scores_tensor, 1).sum_dim_intlist( + [1].as_slice(), + false, + Kind::Float, + ) / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty)) .iter::() .unwrap() .collect::>() diff --git a/src/pipelines/sentence_embeddings/layers.rs b/src/pipelines/sentence_embeddings/layers.rs index 31b6a99..fffdc7a 100644 --- a/src/pipelines/sentence_embeddings/layers.rs +++ b/src/pipelines/sentence_embeddings/layers.rs @@ -54,10 +54,12 @@ impl Pooling { if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens { let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings); - let sum_embeddings = - (token_embeddings * &input_mask_expanded).sum_dim_intlist(&[1], false, Kind::Float); - - let sum_mask = input_mask_expanded.sum_dim_intlist(&[1], false, Kind::Float); + let sum_embeddings = (token_embeddings * &input_mask_expanded).sum_dim_intlist( + [1].as_slice(), + false, + Kind::Float, + ); + let sum_mask = input_mask_expanded.sum_dim_intlist([1].as_slice(), false, Kind::Float); let sum_mask = sum_mask.clamp_min(10e-9); if self.conf.pooling_mode_mean_tokens { diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 6850383..59cd1e3 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -891,7 +891,10 @@ impl TokenClassificationModel { None, false, ); - let score = output.exp() / output.exp().sum_dim_intlist(&[-1], true, Kind::Float); + let score = output.exp() + / output + .exp() + .sum_dim_intlist([-1].as_slice(), true, Kind::Float); let label_indices = score.argmax(-1, true); for sentence_idx in 0..label_indices.size()[0] { let labels = label_indices.get(sentence_idx); diff --git a/src/prophetnet/attention.rs b/src/prophetnet/attention.rs index f464010..21d67bb 100644 --- a/src/prophetnet/attention.rs +++ b/src/prophetnet/attention.rs @@ -531,6 +531,7 @@ impl ProphetNetNgramAttention { let predict_attention_weights = Tensor::einsum( "nbtc,nbsc->nbts", &[predict_query_states, predict_key_states], + None, ); let predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( @@ -554,6 +555,7 @@ impl ProphetNetNgramAttention { let predict_attention_output = Tensor::einsum( "nbts,nbsc->nbtc", &[&predict_attention_probas, &predict_value_states], + None, ) .transpose(1, 2) .contiguous() diff --git a/src/reformer/attention.rs b/src/reformer/attention.rs index b9b6035..b3ca9df 100644 --- a/src/reformer/attention.rs +++ b/src/reformer/attention.rs @@ -253,7 +253,11 @@ impl LSHSelfAttention { self.hidden_size, ]) .transpose(-2, -1); - Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_query_key]) + Tensor::einsum( + "balh,ahr->balr", + &[hidden_states, &per_head_query_key], + None, + ) } fn value_per_attention_head(&self, hidden_states: &Tensor) -> Tensor { @@ -266,7 +270,7 @@ impl LSHSelfAttention { self.hidden_size, ]) .transpose(-2, -1); - Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_value]) + Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_value], None) } fn hash_vectors( @@ -304,7 +308,8 @@ impl LSHSelfAttention { rotation_size / 2, ]; let random_rotations = Tensor::randn(&rotations_shape, (vectors.kind(), vectors.device())); - let rotated_vectors = Tensor::einsum("bmtd,mdhr->bmhtr", &[vectors, random_rotations]); + let rotated_vectors = + Tensor::einsum("bmtd,mdhr->bmhtr", &[vectors, random_rotations], None); let mut buckets = match &self.num_buckets { NumBuckets::Integer(_) => { @@ -647,7 +652,8 @@ impl LSHSelfAttention { } fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor { - let variance = (input_tensor * input_tensor).mean_dim(&[-1], true, input_tensor.kind()); + let variance = + (input_tensor * input_tensor).mean_dim([-1].as_slice(), true, input_tensor.kind()); input_tensor * (variance + epsilon).rsqrt() } @@ -903,9 +909,10 @@ impl LSHSelfAttention { Some(self.attention_head_size), )? .unsqueeze(-1); - let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp(); + let probs_vectors = (&logits - &logits.logsumexp([2].as_slice(), true)).exp(); let out_kind = out_vectors.kind(); - out_vectors = (out_vectors * probs_vectors).sum_dim_intlist(&[2], false, out_kind); + out_vectors = + (out_vectors * probs_vectors).sum_dim_intlist([2].as_slice(), false, out_kind); } out_vectors = merge_hidden_size_dim( diff --git a/src/t5/attention.rs b/src/t5/attention.rs index dc0feaa..b5be649 100644 --- a/src/t5/attention.rs +++ b/src/t5/attention.rs @@ -192,7 +192,7 @@ impl T5Attention { None }; - let mut scores = Tensor::einsum("bnqd,bnkd->bnqk", &[q, k]); + let mut scores = Tensor::einsum("bnqd,bnkd->bnqk", &[q, k], None); let calculated_position_bias = if position_bias.is_none() { let mut temp_value = if self.has_relative_attention_bias { diff --git a/src/t5/layer_norm.rs b/src/t5/layer_norm.rs index 7f82f9b..a493567 100644 --- a/src/t5/layer_norm.rs +++ b/src/t5/layer_norm.rs @@ -33,10 +33,11 @@ impl T5LayerNorm { impl Module for T5LayerNorm { fn forward(&self, x: &Tensor) -> Tensor { let input_type = x.kind(); - let variance = - x.to_kind(Kind::Float) - .pow_tensor_scalar(2.0_f64) - .mean_dim(&[-1], true, Kind::Float); + let variance = x.to_kind(Kind::Float).pow_tensor_scalar(2.0_f64).mean_dim( + [-1].as_slice(), + true, + Kind::Float, + ); let x = x * (variance + self.epsilon).rsqrt(); if input_type != Kind::Float { (&self.weight * x).to_kind(input_type) diff --git a/src/xlnet/attention.rs b/src/xlnet/attention.rs index 88dddac..f5e4cba 100644 --- a/src/xlnet/attention.rs +++ b/src/xlnet/attention.rs @@ -166,9 +166,17 @@ impl XLNetRelativeAttention { attention_mask: Option<&Tensor>, train: bool, ) -> (Tensor, Option) { - let ac = Tensor::einsum("ibnd,jbnd->bnij", &[&(q_head + &self.r_w_bias), k_head_h]); + let ac = Tensor::einsum( + "ibnd,jbnd->bnij", + &[&(q_head + &self.r_w_bias), k_head_h], + None, + ); let bd = self.rel_shift_bnij( - &Tensor::einsum("ibnd,jbnd->bnij", &[&(q_head + &self.r_r_bias), k_head_r]), + &Tensor::einsum( + "ibnd,jbnd->bnij", + &[&(q_head + &self.r_r_bias), k_head_r], + None, + ), ac.size()[3], ); @@ -177,8 +185,9 @@ impl XLNetRelativeAttention { let ef = Tensor::einsum( "ibnd,snd->ibns", &[&(q_head + &self.r_s_bias), &self.seg_embed], + None, ); - Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef]) + Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef], None) } None => Tensor::zeros(&[1], (ac.kind(), ac.device())), }; @@ -193,7 +202,8 @@ impl XLNetRelativeAttention { .softmax(3, attention_score.kind()) .apply_t(&self.dropout, train); - let attention_vector = Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h]); + let attention_vector = + Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h], None); if self.output_attentions { ( @@ -212,8 +222,9 @@ impl XLNetRelativeAttention { residual: bool, train: bool, ) -> Tensor { - let mut attention_out = Tensor::einsum("ibnd,hnd->ibh", &[attention_vector, &self.output]) - .apply_t(&self.dropout, train); + let mut attention_out = + Tensor::einsum("ibnd,hnd->ibh", &[attention_vector, &self.output], None) + .apply_t(&self.dropout, train); if residual { attention_out = attention_out + h; }; @@ -245,10 +256,10 @@ impl XLNetRelativeAttention { Some(value) => value, None => h, }; - let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]); - let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]); - let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]); - let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]); + let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query], None); + let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key], None); + let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value], None); + let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos], None); let (attention_vec_h, attention_probas_h) = self.rel_attention_core( &q_head_h, @@ -262,11 +273,12 @@ impl XLNetRelativeAttention { let output_h = self.post_attention(h, &attention_vec_h, true, train); let (output_g, attention_probas_g) = if let Some(g) = g { - let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query]); + let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query], None); let (attention_vec_g, attention_probas_g) = match target_mapping { Some(target_mapping) => { - let q_head_g = Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping]); + let q_head_g = + Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping], None); let (attention_vec_g, attention_probas_g) = self.rel_attention_core( &q_head_g, &k_head_h, @@ -277,7 +289,7 @@ impl XLNetRelativeAttention { train, ); let attention_vec_g = - Tensor::einsum("lbnd,mlb->mbnd", &[&attention_vec_g, target_mapping]); + Tensor::einsum("lbnd,mlb->mbnd", &[&attention_vec_g, target_mapping], None); (attention_vec_g, attention_probas_g) } None => self.rel_attention_core( diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index f0f5a0e..cf5b727 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -311,7 +311,7 @@ impl XLNetModel { inverse_frequency: &Tensor, batch_size: Option, ) -> Tensor { - let sinusoid = Tensor::einsum("i,d->id", &[position_sequence, inverse_frequency]); + let sinusoid = Tensor::einsum("i,d->id", &[position_sequence, inverse_frequency], None); let mut positional_embeddings = Tensor::cat(&[sinusoid.sin(), sinusoid.cos()], -1).unsqueeze(1); diff --git a/tests/bart.rs b/tests/bart.rs index fc00914..2ceee53 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -212,7 +212,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> { let candidate_labels = &["politics", "public health", "economy", "sports"]; let output = sequence_classification_model.predict( - &[input_sentence, input_sequence_2], + [input_sentence, input_sequence_2], candidate_labels, Some(Box::new(|label: &str| { format!("This example is about {}.", label) @@ -245,7 +245,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> { let candidate_labels = &["politics", "public health", "economy", "sports"]; let output = sequence_classification_model.predict_multilabel( - &[input_sentence, input_sequence_2], + [input_sentence, input_sequence_2], candidate_labels, Some(Box::new(|label: &str| { format!("This example is about {}.", label) diff --git a/tests/distilbert.rs b/tests/distilbert.rs index b206ef2..ae9c563 100644 --- a/tests/distilbert.rs +++ b/tests/distilbert.rs @@ -26,7 +26,7 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> { "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", ]; - let output = sentiment_classifier.predict(&input); + let output = sentiment_classifier.predict(input); assert_eq!(output.len(), 3usize); assert_eq!(output[0].polarity, SentimentPolarity::Positive); diff --git a/tests/fnet.rs b/tests/fnet.rs index 874cd85..64d36b3 100644 --- a/tests/fnet.rs +++ b/tests/fnet.rs @@ -109,7 +109,7 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> { "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", ]; - let output = sentiment_classifier.predict(&input); + let output = sentiment_classifier.predict(input); assert_eq!(output.len(), 3usize); assert_eq!(output[0].polarity, SentimentPolarity::Negative);