Update to tch=0.9.0 (#293)

* Fixed short sentence input and added documentation

* Fixed Clippy warnings

* Updated CI Python version

* cleaner dim specification
This commit is contained in:
guillaume-be 2022-11-07 17:45:52 +00:00 committed by GitHub
parent 340be36ed9
commit c6771d3992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 103 additions and 65 deletions

View File

@ -128,7 +128,7 @@ jobs:
override: true override: true
- uses: actions/setup-python@v2 - uses: actions/setup-python@v2
with: with:
python-version: '3.7' python-version: '3.10'
- run: | - run: |
pip install -r requirements.txt --progress-bar off pip install -r requirements.txt --progress-bar off
python ./utils/download-dependencies_distilbert.py python ./utils/download-dependencies_distilbert.py

View File

@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`) - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines - (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines - Allow mixing local and remote resources in pipelines
- Upgraded to `torch` 1.13 (via `tch` 0.9.0)
## Fixed ## Fixed
- Fixed configuration check for RoBERTa models for sentence classification. - Fixed configuration check for RoBERTa models for sentence classification.

View File

@ -64,7 +64,7 @@ features = ["doc-only"]
[dependencies] [dependencies]
rust_tokenizers = "~7.0.2" rust_tokenizers = "~7.0.2"
tch = "~0.8.0" tch = "~0.9.0"
serde_json = "1.0.82" serde_json = "1.0.82"
serde = { version = "1.0.140", features = ["derive"] } serde = { version = "1.0.140", features = ["derive"] }
ordered-float = "3.0.0" ordered-float = "3.0.0"
@ -81,6 +81,6 @@ anyhow = "1.0.58"
csv = "1.1.6" csv = "1.1.6"
criterion = "0.3.6" criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] } tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "~0.8.0" torch-sys = "~0.9.0"
tempfile = "3.3.0" tempfile = "3.3.0"
itertools = "0.10.3" itertools = "0.10.3"

View File

@ -75,8 +75,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended) ### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.12.0`: if this version is no longer available on the "get started" page, 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11. the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11.
2. Extract the library to a location of your choice 2. Extract the library to a location of your choice
3. Set the following environment variables 3. Set the following environment variables
##### Linux: ##### Linux:
@ -94,7 +94,7 @@ $Env:Path += ";X:\path\to\libtorch\lib"
### Automatic installation ### Automatic installation
Alternatively, you can let the `build` script automatically download the `libtorch` library for you. Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu113`. The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`.
Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
## Ready-to-use pipelines ## Ready-to-use pipelines

View File

@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> {
]; ];
// Run model // Run model
let output = sentiment_classifier.predict(&input); let output = sentiment_classifier.predict(input);
for sentiment in output { for sentiment in output {
println!("{:?}", sentiment); println!("{:?}", sentiment);
} }

View File

@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> {
]; ];
// Run model // Run model
let output = sentiment_classifier.predict(&input); let output = sentiment_classifier.predict(input);
for sentiment in output { for sentiment in output {
println!("{:?}", sentiment); println!("{:?}", sentiment);
} }

View File

@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> {
]; ];
// Run model // Run model
let output = sequence_classification_model.predict(&input); let output = sequence_classification_model.predict(input);
for label in output { for label in output {
println!("{:?}", label); println!("{:?}", label);
} }

View File

@ -23,7 +23,7 @@ fn main() -> anyhow::Result<()> {
let candidate_labels = &["politics", "public health", "economy", "sports"]; let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict_multilabel( let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2], [input_sentence, input_sequence_2],
candidate_labels, candidate_labels,
Some(Box::new(|label: &str| { Some(Box::new(|label: &str| {
format!("This example is about {}.", label) format!("This example is about {}.", label)

View File

@ -1,2 +1,3 @@
torch == 1.8.1 torch == 1.13.0
requests == 2.25.1 requests == 2.25.1
numpy == 1.23.4

View File

@ -131,7 +131,7 @@ impl AlbertSelfAttention {
)); ));
let context: Tensor = let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap(); Tensor::einsum("bfnd,ndh->bfh", &[context, w], None) + self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm); let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
if !self.output_attentions { if !self.output_attentions {

View File

@ -351,10 +351,11 @@ pub(crate) fn _prepare_decoder_attention_mask(
} }
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids let index_eos: Tensor =
.ne(pad_token_id) input_ids
.sum_dim_intlist(&[-1], true, Kind::Int64) .ne(pad_token_id)
- 1; .sum_dim_intlist([-1].as_slice(), true, Kind::Int64)
- 1;
let output = input_ids.empty_like().to_kind(Kind::Int64); let output = input_ids.empty_like().to_kind(Kind::Int64);
output output
.select(1, 0) .select(1, 0)
@ -857,7 +858,7 @@ impl BartForSequenceClassification {
train, train,
); );
let eos_mask = input_ids.eq(self.eos_token_id); let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum_dim_intlist(&[1], true, input_ids.kind()); let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind());
let sentence_representation = base_model_output let sentence_representation = base_model_output
.decoder_output .decoder_output
.permute(&[2, 0, 1]) .permute(&[2, 0, 1])

View File

@ -132,7 +132,10 @@ impl SequenceSummary {
let mut output = match self.summary_type { let mut output = match self.summary_type {
SummaryType::last => hidden_states.select(1, -1), SummaryType::last => hidden_states.select(1, -1),
SummaryType::first => hidden_states.select(1, 0), SummaryType::first => hidden_states.select(1, 0),
SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()),
SummaryType::mean => {
hidden_states.mean_dim([1].as_slice(), false, hidden_states.kind())
}
SummaryType::cls_index => { SummaryType::cls_index => {
let cls_index = if let Some(cls_index_value) = cls_index { let cls_index = if let Some(cls_index_value) = cls_index {
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1]; let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];

View File

@ -303,9 +303,9 @@ impl Module for DebertaLayerNorm {
fn forward(&self, hidden_states: &Tensor) -> Tensor { fn forward(&self, hidden_states: &Tensor) -> Tensor {
let input_type = hidden_states.kind(); let input_type = hidden_states.kind();
let hidden_states = hidden_states.to_kind(Kind::Float); let hidden_states = hidden_states.to_kind(Kind::Float);
let mean = hidden_states.mean_dim(&[-1], true, hidden_states.kind()); let mean = hidden_states.mean_dim([-1].as_slice(), true, hidden_states.kind());
let variance = (&hidden_states - &mean).pow_tensor_scalar(2.0).mean_dim( let variance = (&hidden_states - &mean).pow_tensor_scalar(2.0).mean_dim(
&[-1], [-1].as_slice(),
true, true,
hidden_states.kind(), hidden_states.kind(),
); );

View File

@ -291,7 +291,7 @@ impl DebertaV2Encoder {
attention_mask.shallow_clone() attention_mask.shallow_clone()
} else { } else {
attention_mask attention_mask
.sum_dim_intlist(&[-2], false, attention_mask.kind()) .sum_dim_intlist([-2].as_slice(), false, attention_mask.kind())
.gt(0) .gt(0)
.to_kind(Kind::Uint8) .to_kind(Kind::Uint8)
}; };

View File

@ -85,8 +85,8 @@
//! //!
//! ### Manual installation (recommended) //! ### Manual installation (recommended)
//! //!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.12.0`: if this version is no longer available on the "get started" page, //! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.13.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11. //! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice //! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables //! 3. Set the following environment variables
//! ##### Linux: //! ##### Linux:
@ -104,7 +104,7 @@
//! ### Automatic installation //! ### Automatic installation
//! //!
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you. //! Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu116`. //! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. //! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//! //!
//! # Ready-to-use pipelines //! # Ready-to-use pipelines

View File

@ -223,7 +223,7 @@ impl LongformerSelfAttention {
let key = self.chunk(&key, window_overlap); let key = self.chunk(&key, window_overlap);
let diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims( let diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims(
&Tensor::einsum("bcxd,bcyd->bcxy", &[query, key]), &Tensor::einsum("bcxd,bcyd->bcxy", &[query, key], None),
&[0, 0, 0, 1], &[0, 0, 0, 1],
); );
@ -353,6 +353,7 @@ impl LongformerSelfAttention {
Tensor::einsum( Tensor::einsum(
"bcwd,bcdh->bcwh", "bcwd,bcdh->bcwh",
&[chunked_attention_probas, chunked_value], &[chunked_attention_probas, chunked_value],
None,
) )
.view([batch_size, num_heads, sequence_length, head_dim]) .view([batch_size, num_heads, sequence_length, head_dim])
.transpose(1, 2) .transpose(1, 2)
@ -363,7 +364,7 @@ impl LongformerSelfAttention {
is_index_global_attn: &Tensor, is_index_global_attn: &Tensor,
) -> GlobalAttentionIndices { ) -> GlobalAttentionIndices {
let num_global_attention_indices = let num_global_attention_indices =
is_index_global_attn.sum_dim_intlist(&[1], false, Kind::Int64); is_index_global_attn.sum_dim_intlist([1].as_slice(), false, Kind::Int64);
let max_num_global_attention_indices = i64::from(num_global_attention_indices.max()); let max_num_global_attention_indices = i64::from(num_global_attention_indices.max());
let is_index_global_attn_nonzero = is_index_global_attn let is_index_global_attn_nonzero = is_index_global_attn
.nonzero_numpy() .nonzero_numpy()
@ -428,6 +429,7 @@ impl LongformerSelfAttention {
let attention_probas_from_global_key = Tensor::einsum( let attention_probas_from_global_key = Tensor::einsum(
"blhd,bshd->blhs", "blhd,bshd->blhs",
&[query_vectors, &key_vectors_only_global], &[query_vectors, &key_vectors_only_global],
None,
); );
let _ = attention_probas_from_global_key let _ = attention_probas_from_global_key

View File

@ -160,7 +160,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id); let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id);
let index_eos: Tensor = input_ids let index_eos: Tensor = input_ids
.ne(pad_token_id) .ne(pad_token_id)
.sum_dim_intlist(&[1], true, Int64) .sum_dim_intlist([1].as_slice(), true, Int64)
- 1; - 1;
output output
.select(1, 0) .select(1, 0)
@ -682,7 +682,7 @@ impl MBartForSequenceClassification {
train, train,
); );
let eos_mask = input_ids.eq(self.eos_token_id); let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64); let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, Int64);
let sentence_representation = base_model_output let sentence_representation = base_model_output
.decoder_output .decoder_output
.permute(&[2, 0, 1]) .permute(&[2, 0, 1])

View File

@ -932,8 +932,11 @@ pub(crate) mod private_generation_utils {
current_length += 1; current_length += 1;
} }
let scores_output = token_scores_output.as_ref().map(|scores_tensor| { let scores_output = token_scores_output.as_ref().map(|scores_tensor| {
(Tensor::stack(scores_tensor, 1).sum_dim_intlist(&[1], false, Kind::Float) (Tensor::stack(scores_tensor, 1).sum_dim_intlist(
/ sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty)) [1].as_slice(),
false,
Kind::Float,
) / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
.iter::<f64>() .iter::<f64>()
.unwrap() .unwrap()
.collect::<Vec<f64>>() .collect::<Vec<f64>>()

View File

@ -54,10 +54,12 @@ impl Pooling {
if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens { if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens {
let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings); let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings);
let sum_embeddings = let sum_embeddings = (token_embeddings * &input_mask_expanded).sum_dim_intlist(
(token_embeddings * &input_mask_expanded).sum_dim_intlist(&[1], false, Kind::Float); [1].as_slice(),
false,
let sum_mask = input_mask_expanded.sum_dim_intlist(&[1], false, Kind::Float); Kind::Float,
);
let sum_mask = input_mask_expanded.sum_dim_intlist([1].as_slice(), false, Kind::Float);
let sum_mask = sum_mask.clamp_min(10e-9); let sum_mask = sum_mask.clamp_min(10e-9);
if self.conf.pooling_mode_mean_tokens { if self.conf.pooling_mode_mean_tokens {

View File

@ -891,7 +891,10 @@ impl TokenClassificationModel {
None, None,
false, false,
); );
let score = output.exp() / output.exp().sum_dim_intlist(&[-1], true, Kind::Float); let score = output.exp()
/ output
.exp()
.sum_dim_intlist([-1].as_slice(), true, Kind::Float);
let label_indices = score.argmax(-1, true); let label_indices = score.argmax(-1, true);
for sentence_idx in 0..label_indices.size()[0] { for sentence_idx in 0..label_indices.size()[0] {
let labels = label_indices.get(sentence_idx); let labels = label_indices.get(sentence_idx);

View File

@ -531,6 +531,7 @@ impl ProphetNetNgramAttention {
let predict_attention_weights = Tensor::einsum( let predict_attention_weights = Tensor::einsum(
"nbtc,nbsc->nbts", "nbtc,nbsc->nbts",
&[predict_query_states, predict_key_states], &[predict_query_states, predict_key_states],
None,
); );
let predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( let predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
@ -554,6 +555,7 @@ impl ProphetNetNgramAttention {
let predict_attention_output = Tensor::einsum( let predict_attention_output = Tensor::einsum(
"nbts,nbsc->nbtc", "nbts,nbsc->nbtc",
&[&predict_attention_probas, &predict_value_states], &[&predict_attention_probas, &predict_value_states],
None,
) )
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()

View File

@ -253,7 +253,11 @@ impl LSHSelfAttention {
self.hidden_size, self.hidden_size,
]) ])
.transpose(-2, -1); .transpose(-2, -1);
Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_query_key]) Tensor::einsum(
"balh,ahr->balr",
&[hidden_states, &per_head_query_key],
None,
)
} }
fn value_per_attention_head(&self, hidden_states: &Tensor) -> Tensor { fn value_per_attention_head(&self, hidden_states: &Tensor) -> Tensor {
@ -266,7 +270,7 @@ impl LSHSelfAttention {
self.hidden_size, self.hidden_size,
]) ])
.transpose(-2, -1); .transpose(-2, -1);
Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_value]) Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_value], None)
} }
fn hash_vectors( fn hash_vectors(
@ -304,7 +308,8 @@ impl LSHSelfAttention {
rotation_size / 2, rotation_size / 2,
]; ];
let random_rotations = Tensor::randn(&rotations_shape, (vectors.kind(), vectors.device())); let random_rotations = Tensor::randn(&rotations_shape, (vectors.kind(), vectors.device()));
let rotated_vectors = Tensor::einsum("bmtd,mdhr->bmhtr", &[vectors, random_rotations]); let rotated_vectors =
Tensor::einsum("bmtd,mdhr->bmhtr", &[vectors, random_rotations], None);
let mut buckets = match &self.num_buckets { let mut buckets = match &self.num_buckets {
NumBuckets::Integer(_) => { NumBuckets::Integer(_) => {
@ -647,7 +652,8 @@ impl LSHSelfAttention {
} }
fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor { fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor {
let variance = (input_tensor * input_tensor).mean_dim(&[-1], true, input_tensor.kind()); let variance =
(input_tensor * input_tensor).mean_dim([-1].as_slice(), true, input_tensor.kind());
input_tensor * (variance + epsilon).rsqrt() input_tensor * (variance + epsilon).rsqrt()
} }
@ -903,9 +909,10 @@ impl LSHSelfAttention {
Some(self.attention_head_size), Some(self.attention_head_size),
)? )?
.unsqueeze(-1); .unsqueeze(-1);
let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp(); let probs_vectors = (&logits - &logits.logsumexp([2].as_slice(), true)).exp();
let out_kind = out_vectors.kind(); let out_kind = out_vectors.kind();
out_vectors = (out_vectors * probs_vectors).sum_dim_intlist(&[2], false, out_kind); out_vectors =
(out_vectors * probs_vectors).sum_dim_intlist([2].as_slice(), false, out_kind);
} }
out_vectors = merge_hidden_size_dim( out_vectors = merge_hidden_size_dim(

View File

@ -192,7 +192,7 @@ impl T5Attention {
None None
}; };
let mut scores = Tensor::einsum("bnqd,bnkd->bnqk", &[q, k]); let mut scores = Tensor::einsum("bnqd,bnkd->bnqk", &[q, k], None);
let calculated_position_bias = if position_bias.is_none() { let calculated_position_bias = if position_bias.is_none() {
let mut temp_value = if self.has_relative_attention_bias { let mut temp_value = if self.has_relative_attention_bias {

View File

@ -33,10 +33,11 @@ impl T5LayerNorm {
impl Module for T5LayerNorm { impl Module for T5LayerNorm {
fn forward(&self, x: &Tensor) -> Tensor { fn forward(&self, x: &Tensor) -> Tensor {
let input_type = x.kind(); let input_type = x.kind();
let variance = let variance = x.to_kind(Kind::Float).pow_tensor_scalar(2.0_f64).mean_dim(
x.to_kind(Kind::Float) [-1].as_slice(),
.pow_tensor_scalar(2.0_f64) true,
.mean_dim(&[-1], true, Kind::Float); Kind::Float,
);
let x = x * (variance + self.epsilon).rsqrt(); let x = x * (variance + self.epsilon).rsqrt();
if input_type != Kind::Float { if input_type != Kind::Float {
(&self.weight * x).to_kind(input_type) (&self.weight * x).to_kind(input_type)

View File

@ -166,9 +166,17 @@ impl XLNetRelativeAttention {
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
train: bool, train: bool,
) -> (Tensor, Option<Tensor>) { ) -> (Tensor, Option<Tensor>) {
let ac = Tensor::einsum("ibnd,jbnd->bnij", &[&(q_head + &self.r_w_bias), k_head_h]); let ac = Tensor::einsum(
"ibnd,jbnd->bnij",
&[&(q_head + &self.r_w_bias), k_head_h],
None,
);
let bd = self.rel_shift_bnij( let bd = self.rel_shift_bnij(
&Tensor::einsum("ibnd,jbnd->bnij", &[&(q_head + &self.r_r_bias), k_head_r]), &Tensor::einsum(
"ibnd,jbnd->bnij",
&[&(q_head + &self.r_r_bias), k_head_r],
None,
),
ac.size()[3], ac.size()[3],
); );
@ -177,8 +185,9 @@ impl XLNetRelativeAttention {
let ef = Tensor::einsum( let ef = Tensor::einsum(
"ibnd,snd->ibns", "ibnd,snd->ibns",
&[&(q_head + &self.r_s_bias), &self.seg_embed], &[&(q_head + &self.r_s_bias), &self.seg_embed],
None,
); );
Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef]) Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef], None)
} }
None => Tensor::zeros(&[1], (ac.kind(), ac.device())), None => Tensor::zeros(&[1], (ac.kind(), ac.device())),
}; };
@ -193,7 +202,8 @@ impl XLNetRelativeAttention {
.softmax(3, attention_score.kind()) .softmax(3, attention_score.kind())
.apply_t(&self.dropout, train); .apply_t(&self.dropout, train);
let attention_vector = Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h]); let attention_vector =
Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h], None);
if self.output_attentions { if self.output_attentions {
( (
@ -212,8 +222,9 @@ impl XLNetRelativeAttention {
residual: bool, residual: bool,
train: bool, train: bool,
) -> Tensor { ) -> Tensor {
let mut attention_out = Tensor::einsum("ibnd,hnd->ibh", &[attention_vector, &self.output]) let mut attention_out =
.apply_t(&self.dropout, train); Tensor::einsum("ibnd,hnd->ibh", &[attention_vector, &self.output], None)
.apply_t(&self.dropout, train);
if residual { if residual {
attention_out = attention_out + h; attention_out = attention_out + h;
}; };
@ -245,10 +256,10 @@ impl XLNetRelativeAttention {
Some(value) => value, Some(value) => value,
None => h, None => h,
}; };
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]); let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query], None);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]); let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key], None);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]); let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value], None);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]); let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos], None);
let (attention_vec_h, attention_probas_h) = self.rel_attention_core( let (attention_vec_h, attention_probas_h) = self.rel_attention_core(
&q_head_h, &q_head_h,
@ -262,11 +273,12 @@ impl XLNetRelativeAttention {
let output_h = self.post_attention(h, &attention_vec_h, true, train); let output_h = self.post_attention(h, &attention_vec_h, true, train);
let (output_g, attention_probas_g) = if let Some(g) = g { let (output_g, attention_probas_g) = if let Some(g) = g {
let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query]); let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query], None);
let (attention_vec_g, attention_probas_g) = match target_mapping { let (attention_vec_g, attention_probas_g) = match target_mapping {
Some(target_mapping) => { Some(target_mapping) => {
let q_head_g = Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping]); let q_head_g =
Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping], None);
let (attention_vec_g, attention_probas_g) = self.rel_attention_core( let (attention_vec_g, attention_probas_g) = self.rel_attention_core(
&q_head_g, &q_head_g,
&k_head_h, &k_head_h,
@ -277,7 +289,7 @@ impl XLNetRelativeAttention {
train, train,
); );
let attention_vec_g = let attention_vec_g =
Tensor::einsum("lbnd,mlb->mbnd", &[&attention_vec_g, target_mapping]); Tensor::einsum("lbnd,mlb->mbnd", &[&attention_vec_g, target_mapping], None);
(attention_vec_g, attention_probas_g) (attention_vec_g, attention_probas_g)
} }
None => self.rel_attention_core( None => self.rel_attention_core(

View File

@ -311,7 +311,7 @@ impl XLNetModel {
inverse_frequency: &Tensor, inverse_frequency: &Tensor,
batch_size: Option<i64>, batch_size: Option<i64>,
) -> Tensor { ) -> Tensor {
let sinusoid = Tensor::einsum("i,d->id", &[position_sequence, inverse_frequency]); let sinusoid = Tensor::einsum("i,d->id", &[position_sequence, inverse_frequency], None);
let mut positional_embeddings = let mut positional_embeddings =
Tensor::cat(&[sinusoid.sin(), sinusoid.cos()], -1).unsqueeze(1); Tensor::cat(&[sinusoid.sin(), sinusoid.cos()], -1).unsqueeze(1);

View File

@ -212,7 +212,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
let candidate_labels = &["politics", "public health", "economy", "sports"]; let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict( let output = sequence_classification_model.predict(
&[input_sentence, input_sequence_2], [input_sentence, input_sequence_2],
candidate_labels, candidate_labels,
Some(Box::new(|label: &str| { Some(Box::new(|label: &str| {
format!("This example is about {}.", label) format!("This example is about {}.", label)
@ -245,7 +245,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
let candidate_labels = &["politics", "public health", "economy", "sports"]; let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict_multilabel( let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2], [input_sentence, input_sequence_2],
candidate_labels, candidate_labels,
Some(Box::new(|label: &str| { Some(Box::new(|label: &str| {
format!("This example is about {}.", label) format!("This example is about {}.", label)

View File

@ -26,7 +26,7 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
]; ];
let output = sentiment_classifier.predict(&input); let output = sentiment_classifier.predict(input);
assert_eq!(output.len(), 3usize); assert_eq!(output.len(), 3usize);
assert_eq!(output[0].polarity, SentimentPolarity::Positive); assert_eq!(output[0].polarity, SentimentPolarity::Positive);

View File

@ -109,7 +109,7 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
]; ];
let output = sentiment_classifier.predict(&input); let output = sentiment_classifier.predict(input);
assert_eq!(output.len(), 3usize); assert_eq!(output.len(), 3usize);
assert_eq!(output[0].polarity, SentimentPolarity::Negative); assert_eq!(output[0].polarity, SentimentPolarity::Negative);