updated tch-rs to 0.13.0 (#380)

* updated tch-rs to 0.13.0
find replaced of_slice to from_slice as per
008fff6cc0/CHANGELOG.md

* fixed formatting

* Add download feature and update CI

* add build script, update CI

* updated chanelog, readme, convert script

* fixed wrong position for build script

* added libtorch download to dependencies download test script

* args reordering

---------

Co-authored-by: josephhajduk <joseph@solidys.dev>
Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
Joseph Hajduk 2023-05-21 12:41:18 -05:00 committed by GitHub
parent 5f9500c54a
commit 2bff63b2ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 178 additions and 128 deletions

View File

@ -20,6 +20,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --features download-libtorch
build-no-defaults:
name: Build no defaults
@ -34,7 +35,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --no-default-features
args: --no-default-features --features download-libtorch
build-windows:
name: Build Windows
@ -49,6 +50,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --features download-libtorch
build-mac-os:
name: Build macOS
@ -63,6 +65,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --features download-libtorch
test-batch-0:
name: Integration tests (batch 0)
@ -89,6 +92,7 @@ jobs:
--test fnet
--test deberta
--test deberta_v2
--features download-libtorch
test-batch-1:
name: Integration tests (batch 1)
@ -114,6 +118,7 @@ jobs:
--test longformer
--test pegasus
--test gpt_neo
--features download-libtorch
test-batch-2:
name: Integration tests (batch 2)
@ -133,6 +138,7 @@ jobs:
--test longt5
--test gpt_j
--test nllb
--features download-libtorch
convert-model:
name: Model conversion test

View File

@ -11,7 +11,7 @@ All notable changes to this project will be documented in this file. The format
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator)
- Upgraded to `torch` 2.0 (via `tch` 0.12.0).
- (BREAKING) Upgraded to `torch` 2.0 (via `tch` 0.13.0). The process to automatically download the dependencies have changed, it must now be enabled via the `download-libtorch` feature flag.
## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)

View File

@ -8,6 +8,7 @@ repository = "https://github.com/guillaume-be/rust-bert"
documentation = "https://docs.rs/rust-bert"
license = "Apache-2.0"
readme = "README.md"
build = "build.rs"
keywords = [
"nlp",
"deep-learning",
@ -64,13 +65,14 @@ default = ["remote"]
doc-only = ["tch/doc-only"]
all-tests = []
remote = ["cached-path", "dirs", "lazy_static"]
download-libtorch = ["torch-sys/download-libtorch"]
[package.metadata.docs.rs]
features = ["doc-only"]
[dependencies]
rust_tokenizers = "8.1"
tch = "0.12.0"
tch = "0.13.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
@ -88,6 +90,6 @@ anyhow = "1"
csv = "1"
criterion = "0.4"
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.12.0"
torch-sys = "0.13.0"
tempfile = "3"
itertools = "0.10"

View File

@ -95,7 +95,7 @@ $Env:Path += ";X:\path\to\libtorch\lib"
### Automatic installation
Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
Alternatively, you can let the `build` script automatically download the `libtorch` library for you. The `download-libtorch` feature flag needs to be enabled.
The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu118`.
Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.

29
build.rs Normal file
View File

@ -0,0 +1,29 @@
// Copyright 2023 Laurent Mazare
// https://github.com/LaurentMazare/diffusers-rs/blob/main/build.rs
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
fn main() {
let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
match os.as_str() {
"linux" | "windows" => {
if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") {
println!(
"cargo:rustc-link-arg=-Wl,-rpath={}",
lib_path.to_string_lossy()
);
}
println!("cargo:rustc-link-arg=-Wl,--no-as-needed");
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
println!("cargo:rustc-link-arg=-ltorch");
}
_ => {}
}
}

View File

@ -63,7 +63,7 @@ fn main() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -1032,7 +1032,7 @@ impl BartGenerator {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
@ -1207,7 +1207,7 @@ impl PrivateLanguageGenerator for BartGenerator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -42,7 +42,7 @@ where
);
}
}
let temp_vec = Tensor::of_slice(&temp_vec);
let temp_vec = Tensor::from_slice(&temp_vec);
sinusoidal_embedding.push(temp_vec);
}
let sinusoidal_embedding = Tensor::stack(&sinusoidal_embedding, 0)

View File

@ -173,7 +173,7 @@ impl GptNeoSelfAttention {
let mut attention_weights = attention_weights.where_self(
causal_mask,
&Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()),
&Tensor::from_slice(&[-1e9f32]).to_device(attention_weights.device()),
);
if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value;

View File

@ -105,7 +105,7 @@
//!
//! ### Automatic installation
//!
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you. The `download-libtorch` feature flag needs to be enabled.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu118`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//!

View File

@ -775,7 +775,7 @@ impl PrivateLanguageGenerator for LongT5Generator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -578,7 +578,7 @@ impl M2M100Generator {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
@ -750,7 +750,7 @@ impl PrivateLanguageGenerator for M2M100Generator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -799,7 +799,7 @@ impl MarianGenerator {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
@ -895,7 +895,7 @@ impl PrivateLanguageGenerator for MarianGenerator {
) {
let _ = scores.index_fill_(
1,
&Tensor::of_slice(&[self.get_pad_id().unwrap()])
&Tensor::from_slice(&[self.get_pad_id().unwrap()])
.to_kind(Kind::Int64)
.to_device(scores.device()),
f64::NEG_INFINITY,
@ -975,7 +975,7 @@ impl PrivateLanguageGenerator for MarianGenerator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -831,7 +831,7 @@ impl MBartGenerator {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
@ -1004,7 +1004,7 @@ impl PrivateLanguageGenerator for MBartGenerator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -66,7 +66,7 @@ impl SinusoidalPositionalEmbedding {
temp_vec.push(base_value.cos());
}
}
let temp_vec = Tensor::of_slice(&temp_vec);
let temp_vec = Tensor::from_slice(&temp_vec);
sinusoidal_embedding.push(temp_vec);
}

View File

@ -540,7 +540,7 @@ impl PegasusConditionalGenerator {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(
1,
&impossible_tokens,
@ -716,7 +716,7 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -1027,7 +1027,7 @@ impl ConversationModel {
padded_input.extend(input);
padded_input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.device))
.map(|tokens| Tensor::from_slice(&tokens).to(self.device))
.collect::<Vec<Tensor>>();
(Tensor::stack(&concatenated_inputs, 0), attention_mask)

View File

@ -404,7 +404,7 @@ pub(crate) mod private_generation_utils {
temp.extend(input);
temp
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
@ -424,7 +424,7 @@ pub(crate) mod private_generation_utils {
if updated_value < &0f64 {
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
&Tensor::from_slice(&[token])
.to_kind(Kind::Int64)
.to_device(next_token_logits.device()),
updated_value * repetition_penalty,
@ -432,7 +432,7 @@ pub(crate) mod private_generation_utils {
} else {
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
&Tensor::from_slice(&[token])
.to_kind(Kind::Int64)
.to_device(next_token_logits.device()),
updated_value / repetition_penalty,
@ -536,7 +536,7 @@ pub(crate) mod private_generation_utils {
);
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::of_slice(&[0])
&Tensor::from_slice(&[0])
.to_kind(Kind::Int64)
.to_device(sorted_indices_to_remove.device()),
0,
@ -600,7 +600,7 @@ pub(crate) mod private_generation_utils {
prefix_allowed_tokens_fn(batch_id, &input_ids.get(idx));
let _ = mask.get(idx).index_fill_(
0,
&Tensor::of_slice(allowed_tokens.as_slice()).to(scores.device()),
&Tensor::from_slice(allowed_tokens.as_slice()).to(scores.device()),
0,
);
}
@ -657,7 +657,7 @@ pub(crate) mod private_generation_utils {
Tensor::zeros([scores.size()[1]], (Kind::Int8, scores.device()));
let _ = static_bad_words_mask.index_fill_(
0,
&Tensor::of_slice(bad_words_id_length_1).to_device(scores.device()),
&Tensor::from_slice(bad_words_id_length_1).to_device(scores.device()),
1,
);
static_bad_words_mask.unsqueeze(0).totype(Kind::Bool)
@ -720,7 +720,7 @@ pub(crate) mod private_generation_utils {
if !sequence_ban_tokens.is_empty() {
let _ = dynamic_banned_mask.get(sequence_index as i64).index_fill_(
0,
&Tensor::of_slice(sequence_ban_tokens).to_device(scores.device()),
&Tensor::from_slice(sequence_ban_tokens).to_device(scores.device()),
1,
);
}
@ -847,7 +847,7 @@ pub(crate) mod private_generation_utils {
{
let _ = next_token_logits.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
&Tensor::from_slice(&index_banned_token)
.to_device(next_token_logits.device()),
f64::NEG_INFINITY,
);
@ -868,7 +868,7 @@ pub(crate) mod private_generation_utils {
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = next_token_logits.index_fill_(
1,
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
&Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
f64::NEG_INFINITY,
);
@ -1094,7 +1094,8 @@ pub(crate) mod private_generation_utils {
)
}
let batch_group_indices =
Tensor::of_slice(batch_group_indices.as_slice()).to(input_ids.device());
Tensor::from_slice(batch_group_indices.as_slice())
.to(input_ids.device());
(
Some(input_ids.index_select(0, &batch_group_indices)),
Some(batch_group_indices),
@ -1137,7 +1138,7 @@ pub(crate) mod private_generation_utils {
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = scores.index_fill_(
1,
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
&Tensor::from_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(scores.device()),
f64::NEG_INFINITY,
);
@ -1173,7 +1174,7 @@ pub(crate) mod private_generation_utils {
{
let _ = scores.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
&Tensor::from_slice(&index_banned_token)
.to_device(next_token_logits.device()),
f64::NEG_INFINITY,
);
@ -1438,7 +1439,7 @@ pub(crate) mod private_generation_utils {
sorted_hypotheses.beams.pop().unwrap();
let _ = sentence_lengths.index_fill_(
0,
&Tensor::of_slice(&[effective_batch_index]).to(sentence_lengths.device()),
&Tensor::from_slice(&[effective_batch_index]).to(sentence_lengths.device()),
*best_hyp.size().first().unwrap(),
);
best_ids.push(best_hyp);
@ -1497,7 +1498,7 @@ pub(crate) mod private_generation_utils {
if sentence_length < sentence_length_max {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
&Tensor::from_slice(&[sentence_length]).to_device(input_ids.device()),
gen_opt.eos_token_ids.as_ref().unwrap()[0],
);
}

View File

@ -103,14 +103,14 @@ fn maximal_margin_relevance_score(
let _ = candidate_indices.remove(keyword_indices[0] as usize);
for _ in 0..min(num_keywords - 1, word_embeddings.size()[0] as usize) {
let candidate_indices_tensor =
Tensor::of_slice(&candidate_indices).to(word_document_similarities.device());
Tensor::from_slice(&candidate_indices).to(word_document_similarities.device());
let candidate_similarities =
word_document_similarities.index_select(0, &candidate_indices_tensor);
let (target_similarities, _) = word_similarities
.index_select(0, &candidate_indices_tensor)
.index_select(
1,
&Tensor::of_slice(&keyword_indices).to(word_similarities.device()),
&Tensor::from_slice(&keyword_indices).to(word_similarities.device()),
)
.max_dim(1, false);
let mmr = candidate_similarities * (1.0 - diversity) - target_similarities * diversity;

View File

@ -504,7 +504,7 @@ impl MaskedLanguageModel {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}

View File

@ -782,7 +782,7 @@ impl QuestionAnsweringModel {
let example = &qa_inputs[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
let p_mask = (Tensor::from_slice(&feature.p_mask) - 1)
.abs()
.to_device(start_logits.device())
.eq(0);
@ -964,7 +964,7 @@ impl QuestionAnsweringModel {
attention_mask.resize(max_len, 0);
attention_mask
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
for feature in features.iter_mut() {
@ -975,7 +975,7 @@ impl QuestionAnsweringModel {
let padded_input_ids = features
.iter_mut()
.map(|input| Tensor::of_slice(input.input_ids.as_slice()))
.map(|input| Tensor::from_slice(input.input_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());

View File

@ -318,7 +318,7 @@ impl SentenceEmbeddingsModel {
let tokens_masks = tokens_ids
.iter()
.map(|input| {
Tensor::of_slice(
Tensor::from_slice(
&input
.iter()
.map(|&e| i64::from(e != pad_token_id))
@ -329,7 +329,7 @@ impl SentenceEmbeddingsModel {
let tokens_ids = tokens_ids
.into_iter()
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
SentenceEmbeddingsTokenizerOutput {

View File

@ -670,7 +670,7 @@ impl SequenceClassificationModel {
.into_iter()
.map(|mut input| {
input.token_ids.resize(max_len, pad_id);
Tensor::of_slice(&(input.token_ids))
Tensor::from_slice(&(input.token_ids))
})
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())

View File

@ -1002,7 +1002,7 @@ impl TokenClassificationModel {
attention_mask.resize(max_len, 0);
attention_mask
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let padding_index = self
@ -1017,7 +1017,7 @@ impl TokenClassificationModel {
let padded_input_ids = features
.iter()
.map(|input| Tensor::of_slice(input.input_ids.as_slice()))
.map(|input| Tensor::from_slice(input.input_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());

View File

@ -680,7 +680,7 @@ impl ZeroShotClassificationModel {
.into_iter()
.map(|mut input| {
input.token_ids.resize(max_len, pad_id);
Tensor::of_slice(&(input.token_ids))
Tensor::from_slice(&(input.token_ids))
})
.collect::<Vec<_>>();

View File

@ -1097,7 +1097,7 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -210,15 +210,19 @@ impl LSHSelfAttention {
let query_key = nn::linear(p / "query_key", hidden_size, all_head_size, linear_config);
let value = nn::linear(p / "value", hidden_size, all_head_size, linear_config);
let self_mask_value_fp32 = Tensor::of_slice(&[-1e5])
let self_mask_value_fp32 = Tensor::from_slice(&[-1e5])
.to_kind(Kind::Float)
.to(p.device());
let mask_value_fp32 = Tensor::of_slice(&[-1e9])
let mask_value_fp32 = Tensor::from_slice(&[-1e9])
.to_kind(Kind::Float)
.to(p.device());
let self_mask_value_fp16 = Tensor::of_slice(&[-1e3]).to_kind(Kind::Half).to(p.device());
let mask_value_fp16 = Tensor::of_slice(&[-1e4]).to_kind(Kind::Half).to(p.device());
let self_mask_value_fp16 = Tensor::from_slice(&[-1e3])
.to_kind(Kind::Half)
.to(p.device());
let mask_value_fp16 = Tensor::from_slice(&[-1e4])
.to_kind(Kind::Half)
.to(p.device());
Ok(LSHSelfAttention {
chunk_length,
@ -359,7 +363,7 @@ impl LSHSelfAttention {
.to_kind(Kind::Bool);
buckets = buckets.where_self(
&buckets_mask,
&Tensor::of_slice(&[num_buckets - 1])
&Tensor::from_slice(&[num_buckets - 1])
.to_kind(buckets.kind())
.to(buckets_mask.device()),
)
@ -667,7 +671,7 @@ impl LSHSelfAttention {
fn len_and_dim_norm(&self, input_tensor: &Tensor) -> Tensor {
self.len_norm(input_tensor, 1e-6)
* Tensor::of_slice(&[self.attention_head_size])
* Tensor::from_slice(&[self.attention_head_size])
.to_kind(input_tensor.kind())
.to_device(input_tensor.device())
.rsqrt()
@ -994,11 +998,13 @@ impl LocalSelfAttention {
let key = nn::linear(p / "key", hidden_size, all_head_size, linear_config);
let value = nn::linear(p / "value", hidden_size, all_head_size, linear_config);
let mask_value_fp32 = Tensor::of_slice(&[-1e9])
let mask_value_fp32 = Tensor::from_slice(&[-1e9])
.to_kind(Kind::Float)
.to(p.device());
let mask_value_fp16 = Tensor::of_slice(&[-1e4]).to_kind(Kind::Half).to(p.device());
let mask_value_fp16 = Tensor::from_slice(&[-1e4])
.to_kind(Kind::Half)
.to(p.device());
LocalSelfAttention {
chunk_length,
@ -1096,7 +1102,7 @@ impl LocalSelfAttention {
let key_kind_device = (key_vectors.kind(), key_vectors.device());
let mut key_vectors = key_vectors
/ Tensor::of_slice(&[self.attention_head_size])
/ Tensor::from_slice(&[self.attention_head_size])
.to_kind(key_kind_device.0)
.to(key_kind_device.1)
.sqrt();

View File

@ -943,7 +943,7 @@ impl PrivateLanguageGenerator for T5Generator {
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)

View File

@ -56,7 +56,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -129,7 +129,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -190,7 +190,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -262,7 +262,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -323,7 +323,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -59,7 +59,7 @@ fn bart_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -63,7 +63,7 @@ fn bert_masked_lm() -> anyhow::Result<()> {
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(input))
.map(|input| Tensor::from_slice(input))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -182,7 +182,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -239,7 +239,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -303,7 +303,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -360,7 +360,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -66,7 +66,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -187,7 +187,7 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -247,7 +247,7 @@ fn deberta_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -105,7 +105,7 @@ fn deberta_v2_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -159,7 +159,7 @@ fn deberta_v2_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -212,7 +212,7 @@ fn deberta_v2_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -89,7 +89,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(input))
.map(|input| Tensor::from_slice(input))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -160,7 +160,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -231,7 +231,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -55,7 +55,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -53,7 +53,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -132,7 +132,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);

View File

@ -52,7 +52,7 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
input.extend(vec![3; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -164,7 +164,7 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -227,7 +227,7 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -282,7 +282,7 @@ fn fnet_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -53,7 +53,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -102,7 +102,7 @@ fn gpt_j_correctness() -> anyhow::Result<()> {
let token_masks = token_ids
.iter()
.map(|input| {
Tensor::of_slice(
Tensor::from_slice(
&input
.iter()
.map(|&e| i64::from(e != pad_token))
@ -114,7 +114,7 @@ fn gpt_j_correctness() -> anyhow::Result<()> {
let token_ids = token_ids
.into_iter()
.map(|tokens| Tensor::of_slice(&tokens).to(device))
.map(|tokens| Tensor::from_slice(&tokens).to(device))
.collect::<Vec<Tensor>>();
let input_tensor = Tensor::stack(&token_ids, 0);

View File

@ -58,7 +58,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -75,12 +75,12 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
let mut global_attention_mask_vector = vec![0; max_len];
global_attention_mask_vector[0] = 1;
let global_attention_mask = Tensor::of_slice(global_attention_mask_vector.as_slice());
let global_attention_mask = Tensor::from_slice(global_attention_mask_vector.as_slice());
let global_attention_mask = Tensor::stack(
vec![&global_attention_mask; tokenized_input.len()].as_slice(),
0,
@ -217,7 +217,7 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -287,7 +287,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -357,7 +357,7 @@ fn longformer_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -48,7 +48,7 @@ fn m2m100_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -47,7 +47,7 @@ fn mbart_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -57,7 +57,7 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -148,7 +148,7 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -203,7 +203,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -258,7 +258,7 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -310,7 +310,7 @@ fn mobilebert_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -57,7 +57,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -118,7 +118,7 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -179,7 +179,7 @@ fn reformer_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -69,7 +69,7 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(input))
.map(|input| Tensor::from_slice(input))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -156,7 +156,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -221,7 +221,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -293,7 +293,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -54,7 +54,7 @@ fn xlnet_base_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input[..input.len() - 2])))
.map(|input| Tensor::from_slice(&(input[..input.len() - 2])))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -159,7 +159,7 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input[..input.len() - 2])))
.map(|input| Tensor::from_slice(&(input[..input.len() - 2])))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -274,7 +274,7 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -348,7 +348,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
@ -413,7 +413,7 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
@ -475,7 +475,7 @@ fn xlnet_for_question_answering() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -30,6 +30,11 @@ if __name__ == "__main__":
"--dtype",
help="Convert weights to a specific numpy DataType (float32, float16, ...)",
)
parser.add_argument(
"--download_libtorch",
action="store_true",
help="Use this flag to enable automatic download of the libtorch library.",
)
args = parser.parse_args()
nps = {}
@ -73,14 +78,15 @@ if __name__ == "__main__":
target = str(target_folder / "rust_model.ot")
toml_location = (Path(__file__).resolve() / ".." / ".." / "Cargo.toml").resolve()
subprocess.run(
[
"cargo",
"run",
"--bin=convert-tensor",
"--manifest-path=%s" % toml_location,
"--",
source,
target,
],
)
cargo_args = [
"cargo",
"run",
"--bin=convert-tensor",
"--manifest-path=%s" % toml_location,
"--",
source,
target,
]
if args.download_libtorch:
cargo_args += ["--features", "download-libtorch"]
subprocess.run(cargo_args)

View File

@ -27,7 +27,7 @@ if __name__ == "__main__":
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
['cargo', 'run', '--bin=convert-tensor', '--features', 'download-libtorch', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'pytorch_model.bin'))
os.remove(str(target_path / 'model.npz'))