Update to tch=0.9.0 (#293)

* Fixed short sentence input and added documentation

* Fixed Clippy warnings

* Updated CI Python version

* cleaner dim specification
This commit is contained in:
guillaume-be 2022-11-07 17:45:52 +00:00 committed by GitHub
parent 340be36ed9
commit c6771d3992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 103 additions and 65 deletions

View File

@ -128,7 +128,7 @@ jobs:
override: true
- uses: actions/setup-python@v2
with:
python-version: '3.7'
python-version: '3.10'
- run: |
pip install -r requirements.txt --progress-bar off
python ./utils/download-dependencies_distilbert.py

View File

@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines
- Upgraded to `torch` 1.13 (via `tch` 0.9.0)
## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.

View File

@ -64,7 +64,7 @@ features = ["doc-only"]
[dependencies]
rust_tokenizers = "~7.0.2"
tch = "~0.8.0"
tch = "~0.9.0"
serde_json = "1.0.82"
serde = { version = "1.0.140", features = ["derive"] }
ordered-float = "3.0.0"
@ -81,6 +81,6 @@ anyhow = "1.0.58"
csv = "1.1.6"
criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "~0.8.0"
torch-sys = "~0.9.0"
tempfile = "3.3.0"
itertools = "0.10.3"

View File

@ -75,8 +75,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.12.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11.
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11.
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
@ -94,7 +94,7 @@ $Env:Path += ";X:\path\to\libtorch\lib"
### Automatic installation
Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu113`.
The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`.
Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
## Ready-to-use pipelines

View File

@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> {
];
// Run model
let output = sentiment_classifier.predict(&input);
let output = sentiment_classifier.predict(input);
for sentiment in output {
println!("{:?}", sentiment);
}

View File

@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> {
];
// Run model
let output = sentiment_classifier.predict(&input);
let output = sentiment_classifier.predict(input);
for sentiment in output {
println!("{:?}", sentiment);
}

View File

@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> {
];
// Run model
let output = sequence_classification_model.predict(&input);
let output = sequence_classification_model.predict(input);
for label in output {
println!("{:?}", label);
}

View File

@ -23,7 +23,7 @@ fn main() -> anyhow::Result<()> {
let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2],
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)

View File

@ -1,2 +1,3 @@
torch == 1.8.1
requests == 2.25.1
torch == 1.13.0
requests == 2.25.1
numpy == 1.23.4

View File

@ -131,7 +131,7 @@ impl AlbertSelfAttention {
));
let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap();
Tensor::einsum("bfnd,ndh->bfh", &[context, w], None) + self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
if !self.output_attentions {

View File

@ -351,10 +351,11 @@ pub(crate) fn _prepare_decoder_attention_mask(
}
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids
.ne(pad_token_id)
.sum_dim_intlist(&[-1], true, Kind::Int64)
- 1;
let index_eos: Tensor =
input_ids
.ne(pad_token_id)
.sum_dim_intlist([-1].as_slice(), true, Kind::Int64)
- 1;
let output = input_ids.empty_like().to_kind(Kind::Int64);
output
.select(1, 0)
@ -857,7 +858,7 @@ impl BartForSequenceClassification {
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum_dim_intlist(&[1], true, input_ids.kind());
let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind());
let sentence_representation = base_model_output
.decoder_output
.permute(&[2, 0, 1])

View File

@ -132,7 +132,10 @@ impl SequenceSummary {
let mut output = match self.summary_type {
SummaryType::last => hidden_states.select(1, -1),
SummaryType::first => hidden_states.select(1, 0),
SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()),
SummaryType::mean => {
hidden_states.mean_dim([1].as_slice(), false, hidden_states.kind())
}
SummaryType::cls_index => {
let cls_index = if let Some(cls_index_value) = cls_index {
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];

View File

@ -303,9 +303,9 @@ impl Module for DebertaLayerNorm {
fn forward(&self, hidden_states: &Tensor) -> Tensor {
let input_type = hidden_states.kind();
let hidden_states = hidden_states.to_kind(Kind::Float);
let mean = hidden_states.mean_dim(&[-1], true, hidden_states.kind());
let mean = hidden_states.mean_dim([-1].as_slice(), true, hidden_states.kind());
let variance = (&hidden_states - &mean).pow_tensor_scalar(2.0).mean_dim(
&[-1],
[-1].as_slice(),
true,
hidden_states.kind(),
);

View File

@ -291,7 +291,7 @@ impl DebertaV2Encoder {
attention_mask.shallow_clone()
} else {
attention_mask
.sum_dim_intlist(&[-2], false, attention_mask.kind())
.sum_dim_intlist([-2].as_slice(), false, attention_mask.kind())
.gt(0)
.to_kind(Kind::Uint8)
};

View File

@ -85,8 +85,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.12.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.13.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
@ -104,7 +104,7 @@
//! ### Automatic installation
//!
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu116`.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//!
//! # Ready-to-use pipelines

View File

@ -223,7 +223,7 @@ impl LongformerSelfAttention {
let key = self.chunk(&key, window_overlap);
let diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims(
&Tensor::einsum("bcxd,bcyd->bcxy", &[query, key]),
&Tensor::einsum("bcxd,bcyd->bcxy", &[query, key], None),
&[0, 0, 0, 1],
);
@ -353,6 +353,7 @@ impl LongformerSelfAttention {
Tensor::einsum(
"bcwd,bcdh->bcwh",
&[chunked_attention_probas, chunked_value],
None,
)
.view([batch_size, num_heads, sequence_length, head_dim])
.transpose(1, 2)
@ -363,7 +364,7 @@ impl LongformerSelfAttention {
is_index_global_attn: &Tensor,
) -> GlobalAttentionIndices {
let num_global_attention_indices =
is_index_global_attn.sum_dim_intlist(&[1], false, Kind::Int64);
is_index_global_attn.sum_dim_intlist([1].as_slice(), false, Kind::Int64);
let max_num_global_attention_indices = i64::from(num_global_attention_indices.max());
let is_index_global_attn_nonzero = is_index_global_attn
.nonzero_numpy()
@ -428,6 +429,7 @@ impl LongformerSelfAttention {
let attention_probas_from_global_key = Tensor::einsum(
"blhd,bshd->blhs",
&[query_vectors, &key_vectors_only_global],
None,
);
let _ = attention_probas_from_global_key

View File

@ -160,7 +160,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id);
let index_eos: Tensor = input_ids
.ne(pad_token_id)
.sum_dim_intlist(&[1], true, Int64)
.sum_dim_intlist([1].as_slice(), true, Int64)
- 1;
output
.select(1, 0)
@ -682,7 +682,7 @@ impl MBartForSequenceClassification {
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64);
let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, Int64);
let sentence_representation = base_model_output
.decoder_output
.permute(&[2, 0, 1])

View File

@ -932,8 +932,11 @@ pub(crate) mod private_generation_utils {
current_length += 1;
}
let scores_output = token_scores_output.as_ref().map(|scores_tensor| {
(Tensor::stack(scores_tensor, 1).sum_dim_intlist(&[1], false, Kind::Float)
/ sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
(Tensor::stack(scores_tensor, 1).sum_dim_intlist(
[1].as_slice(),
false,
Kind::Float,
) / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>()

View File

@ -54,10 +54,12 @@ impl Pooling {
if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens {
let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings);
let sum_embeddings =
(token_embeddings * &input_mask_expanded).sum_dim_intlist(&[1], false, Kind::Float);
let sum_mask = input_mask_expanded.sum_dim_intlist(&[1], false, Kind::Float);
let sum_embeddings = (token_embeddings * &input_mask_expanded).sum_dim_intlist(
[1].as_slice(),
false,
Kind::Float,
);
let sum_mask = input_mask_expanded.sum_dim_intlist([1].as_slice(), false, Kind::Float);
let sum_mask = sum_mask.clamp_min(10e-9);
if self.conf.pooling_mode_mean_tokens {

View File

@ -891,7 +891,10 @@ impl TokenClassificationModel {
None,
false,
);
let score = output.exp() / output.exp().sum_dim_intlist(&[-1], true, Kind::Float);
let score = output.exp()
/ output
.exp()
.sum_dim_intlist([-1].as_slice(), true, Kind::Float);
let label_indices = score.argmax(-1, true);
for sentence_idx in 0..label_indices.size()[0] {
let labels = label_indices.get(sentence_idx);

View File

@ -531,6 +531,7 @@ impl ProphetNetNgramAttention {
let predict_attention_weights = Tensor::einsum(
"nbtc,nbsc->nbts",
&[predict_query_states, predict_key_states],
None,
);
let predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
@ -554,6 +555,7 @@ impl ProphetNetNgramAttention {
let predict_attention_output = Tensor::einsum(
"nbts,nbsc->nbtc",
&[&predict_attention_probas, &predict_value_states],
None,
)
.transpose(1, 2)
.contiguous()

View File

@ -253,7 +253,11 @@ impl LSHSelfAttention {
self.hidden_size,
])
.transpose(-2, -1);
Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_query_key])
Tensor::einsum(
"balh,ahr->balr",
&[hidden_states, &per_head_query_key],
None,
)
}
fn value_per_attention_head(&self, hidden_states: &Tensor) -> Tensor {
@ -266,7 +270,7 @@ impl LSHSelfAttention {
self.hidden_size,
])
.transpose(-2, -1);
Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_value])
Tensor::einsum("balh,ahr->balr", &[hidden_states, &per_head_value], None)
}
fn hash_vectors(
@ -304,7 +308,8 @@ impl LSHSelfAttention {
rotation_size / 2,
];
let random_rotations = Tensor::randn(&rotations_shape, (vectors.kind(), vectors.device()));
let rotated_vectors = Tensor::einsum("bmtd,mdhr->bmhtr", &[vectors, random_rotations]);
let rotated_vectors =
Tensor::einsum("bmtd,mdhr->bmhtr", &[vectors, random_rotations], None);
let mut buckets = match &self.num_buckets {
NumBuckets::Integer(_) => {
@ -647,7 +652,8 @@ impl LSHSelfAttention {
}
fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor {
let variance = (input_tensor * input_tensor).mean_dim(&[-1], true, input_tensor.kind());
let variance =
(input_tensor * input_tensor).mean_dim([-1].as_slice(), true, input_tensor.kind());
input_tensor * (variance + epsilon).rsqrt()
}
@ -903,9 +909,10 @@ impl LSHSelfAttention {
Some(self.attention_head_size),
)?
.unsqueeze(-1);
let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp();
let probs_vectors = (&logits - &logits.logsumexp([2].as_slice(), true)).exp();
let out_kind = out_vectors.kind();
out_vectors = (out_vectors * probs_vectors).sum_dim_intlist(&[2], false, out_kind);
out_vectors =
(out_vectors * probs_vectors).sum_dim_intlist([2].as_slice(), false, out_kind);
}
out_vectors = merge_hidden_size_dim(

View File

@ -192,7 +192,7 @@ impl T5Attention {
None
};
let mut scores = Tensor::einsum("bnqd,bnkd->bnqk", &[q, k]);
let mut scores = Tensor::einsum("bnqd,bnkd->bnqk", &[q, k], None);
let calculated_position_bias = if position_bias.is_none() {
let mut temp_value = if self.has_relative_attention_bias {

View File

@ -33,10 +33,11 @@ impl T5LayerNorm {
impl Module for T5LayerNorm {
fn forward(&self, x: &Tensor) -> Tensor {
let input_type = x.kind();
let variance =
x.to_kind(Kind::Float)
.pow_tensor_scalar(2.0_f64)
.mean_dim(&[-1], true, Kind::Float);
let variance = x.to_kind(Kind::Float).pow_tensor_scalar(2.0_f64).mean_dim(
[-1].as_slice(),
true,
Kind::Float,
);
let x = x * (variance + self.epsilon).rsqrt();
if input_type != Kind::Float {
(&self.weight * x).to_kind(input_type)

View File

@ -166,9 +166,17 @@ impl XLNetRelativeAttention {
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let ac = Tensor::einsum("ibnd,jbnd->bnij", &[&(q_head + &self.r_w_bias), k_head_h]);
let ac = Tensor::einsum(
"ibnd,jbnd->bnij",
&[&(q_head + &self.r_w_bias), k_head_h],
None,
);
let bd = self.rel_shift_bnij(
&Tensor::einsum("ibnd,jbnd->bnij", &[&(q_head + &self.r_r_bias), k_head_r]),
&Tensor::einsum(
"ibnd,jbnd->bnij",
&[&(q_head + &self.r_r_bias), k_head_r],
None,
),
ac.size()[3],
);
@ -177,8 +185,9 @@ impl XLNetRelativeAttention {
let ef = Tensor::einsum(
"ibnd,snd->ibns",
&[&(q_head + &self.r_s_bias), &self.seg_embed],
None,
);
Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef])
Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef], None)
}
None => Tensor::zeros(&[1], (ac.kind(), ac.device())),
};
@ -193,7 +202,8 @@ impl XLNetRelativeAttention {
.softmax(3, attention_score.kind())
.apply_t(&self.dropout, train);
let attention_vector = Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h]);
let attention_vector =
Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h], None);
if self.output_attentions {
(
@ -212,8 +222,9 @@ impl XLNetRelativeAttention {
residual: bool,
train: bool,
) -> Tensor {
let mut attention_out = Tensor::einsum("ibnd,hnd->ibh", &[attention_vector, &self.output])
.apply_t(&self.dropout, train);
let mut attention_out =
Tensor::einsum("ibnd,hnd->ibh", &[attention_vector, &self.output], None)
.apply_t(&self.dropout, train);
if residual {
attention_out = attention_out + h;
};
@ -245,10 +256,10 @@ impl XLNetRelativeAttention {
Some(value) => value,
None => h,
};
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]);
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query], None);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key], None);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value], None);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos], None);
let (attention_vec_h, attention_probas_h) = self.rel_attention_core(
&q_head_h,
@ -262,11 +273,12 @@ impl XLNetRelativeAttention {
let output_h = self.post_attention(h, &attention_vec_h, true, train);
let (output_g, attention_probas_g) = if let Some(g) = g {
let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query]);
let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query], None);
let (attention_vec_g, attention_probas_g) = match target_mapping {
Some(target_mapping) => {
let q_head_g = Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping]);
let q_head_g =
Tensor::einsum("mbnd,mlb->lbnd", &[&q_head_g, target_mapping], None);
let (attention_vec_g, attention_probas_g) = self.rel_attention_core(
&q_head_g,
&k_head_h,
@ -277,7 +289,7 @@ impl XLNetRelativeAttention {
train,
);
let attention_vec_g =
Tensor::einsum("lbnd,mlb->mbnd", &[&attention_vec_g, target_mapping]);
Tensor::einsum("lbnd,mlb->mbnd", &[&attention_vec_g, target_mapping], None);
(attention_vec_g, attention_probas_g)
}
None => self.rel_attention_core(

View File

@ -311,7 +311,7 @@ impl XLNetModel {
inverse_frequency: &Tensor,
batch_size: Option<i64>,
) -> Tensor {
let sinusoid = Tensor::einsum("i,d->id", &[position_sequence, inverse_frequency]);
let sinusoid = Tensor::einsum("i,d->id", &[position_sequence, inverse_frequency], None);
let mut positional_embeddings =
Tensor::cat(&[sinusoid.sin(), sinusoid.cos()], -1).unsqueeze(1);

View File

@ -212,7 +212,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict(
&[input_sentence, input_sequence_2],
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
@ -245,7 +245,7 @@ fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
let candidate_labels = &["politics", "public health", "economy", "sports"];
let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2],
[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)

View File

@ -26,7 +26,7 @@ fn distilbert_sentiment_classifier() -> anyhow::Result<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = sentiment_classifier.predict(&input);
let output = sentiment_classifier.predict(input);
assert_eq!(output.len(), 3usize);
assert_eq!(output[0].polarity, SentimentPolarity::Positive);

View File

@ -109,7 +109,7 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = sentiment_classifier.predict(&input);
let output = sentiment_classifier.predict(input);
assert_eq!(output.len(), 3usize);
assert_eq!(output[0].polarity, SentimentPolarity::Negative);