diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0143edf..c9d5766 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,9 @@ All notable changes to this project will be documented in this file. The format
## [Unreleased]
+## Changed
+- (BREAKING) Upgraded to `torch` 2.2 (via `tch` 0.15.0).
+
## [0.22.0] - 2024-01-20
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
@@ -10,7 +13,7 @@ All notable changes to this project will be documented in this file. The format
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.
## Fixed
-- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
+- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering n-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
- Updated ONNX runtime backend version to 1.15.x
- Issue with incorrect results for QA models with a tokenizer not using segment ids
@@ -449,4 +452,4 @@ All notable changes to this project will be documented in this file. The format
- Tensor conversion tools from Pytorch to Libtorch format
- DistilBERT model architecture
-- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
\ No newline at end of file
+- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
diff --git a/Cargo.toml b/Cargo.toml
index 4bf0baf..273ce3a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -76,7 +76,7 @@ features = ["doc-only"]
[dependencies]
rust_tokenizers = "8.1.1"
-tch = "0.14.0"
+tch = "0.15.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
@@ -97,7 +97,6 @@ anyhow = "1"
csv = "1"
criterion = "0.5"
tokio = { version = "1.35", features = ["sync", "rt-multi-thread", "macros"] }
-torch-sys = "0.14.0"
tempfile = "3"
itertools = "0.12"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
diff --git a/README.md b/README.md
index 2fb596c..0d28157 100644
--- a/README.md
+++ b/README.md
@@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended)
-1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page,
-the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
+1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.2`: if this version is no longer available on the "get started" page,
+the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
diff --git a/benches/generation_benchmark.rs b/benches/generation_benchmark.rs
index e54a971..f6e1f81 100644
--- a/benches/generation_benchmark.rs
+++ b/benches/generation_benchmark.rs
@@ -53,10 +53,6 @@ fn generation_forward_pass(iters: u64, model: &TextGenerationModel, data: &[&str
}
fn bench_generation(c: &mut Criterion) {
- // Set-up summarization model
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
let model = create_text_generation_model();
// Define input
diff --git a/benches/squad_benchmark.rs b/benches/squad_benchmark.rs
index 5be9f20..5d8b847 100644
--- a/benches/squad_benchmark.rs
+++ b/benches/squad_benchmark.rs
@@ -73,9 +73,7 @@ fn qa_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up QA model
let model = create_qa_model();
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
+
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
diff --git a/benches/sst2_benchmark.rs b/benches/sst2_benchmark.rs
index 465618a..37e982d 100644
--- a/benches/sst2_benchmark.rs
+++ b/benches/sst2_benchmark.rs
@@ -79,9 +79,7 @@ fn sst2_load_model(iters: u64) -> Duration {
fn bench_sst2(c: &mut Criterion) {
// Set-up classifier
let model = create_sentiment_model();
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
+
// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH").expect(
"Please set the \"SST2_PATH\" environment variable pointing to the SST2 dataset folder",
diff --git a/benches/summarization_benchmark.rs b/benches/summarization_benchmark.rs
index e80f880..fc2509e 100644
--- a/benches/summarization_benchmark.rs
+++ b/benches/summarization_benchmark.rs
@@ -40,9 +40,6 @@ fn summarization_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up summarization model
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
let model = create_summarization_model();
// Define input
diff --git a/benches/tensor_operations_benchmark.rs b/benches/tensor_operations_benchmark.rs
index 6c47acb..979852f 100644
--- a/benches/tensor_operations_benchmark.rs
+++ b/benches/tensor_operations_benchmark.rs
@@ -17,10 +17,6 @@ fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration {
}
fn bench_tensor_ops(c: &mut Criterion) {
- // Set-up summarization model
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
let input = Tensor::rand([32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand([512, 512], (Kind::Float, Device::cuda_if_available()));
diff --git a/benches/token_classification_benchmark.rs b/benches/token_classification_benchmark.rs
index 955a73e..2e3729c 100644
--- a/benches/token_classification_benchmark.rs
+++ b/benches/token_classification_benchmark.rs
@@ -14,9 +14,6 @@ fn create_model() -> TokenClassificationModel {
fn bench_token_classification_predict(c: &mut Criterion) {
// Set-up model
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
let model = create_model();
// Define input
diff --git a/benches/translation_benchmark.rs b/benches/translation_benchmark.rs
index b2c4d08..3ce2f4f 100644
--- a/benches/translation_benchmark.rs
+++ b/benches/translation_benchmark.rs
@@ -73,9 +73,6 @@ fn translation_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up translation model
- unsafe {
- torch_sys::dummy_cuda_dependency();
- }
let model = create_translation_model();
// Define input
diff --git a/src/lib.rs b/src/lib.rs
index e73b16b..45f7ef0 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -90,8 +90,8 @@
//!
//! ### Manual installation (recommended)
//!
-//! 1. Download `libtorch` from . This package requires `v2.1`: if this version is no longer available on the "get started" page,
-//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11.
+//! 1. Download `libtorch` from . This package requires `v2.2`: if this version is no longer available on the "get started" page,
+//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux: