Updated sentence embeddings example (#263)

* Added conversion information for Distil-based sentence embedding models

* Fix Clippy warnings
This commit is contained in:
guillaume-be 2022-07-03 08:48:31 +01:00 committed by GitHub
parent 4d8a298586
commit a1595e6dfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 2 deletions

View File

@ -9,6 +9,19 @@ use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
/// ```sh
/// python ./utils/convert_model.py resources/all-MiniLM-L12-v2/pytorch_model.bin
/// ```
///
/// For models missing the prefix in their saved weights (e.g. Distil-based models), the
/// conversion needs to be updated to include this prefix so that the weights can be found:
/// ```sh
/// python ./utils/convert_model.py resources/path/to/pytorch_model.bin --prefix distilbert.
/// ```
///
/// For models including a dense projection layer (e.g. Distil-based models), these weights
/// need to be converted as well:
/// ```sh
/// python ../utils/convert_model.py resources/path/to/2_Dense/pytorch_model.bin --suffix
/// ```
///
fn main() -> anyhow::Result<()> {
// Set-up sentence embeddings model
let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2")

View File

@ -119,7 +119,7 @@ fn _shift_tokens_right(
(Kind::Int64, input_ids.device()),
);
let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id);
let _ = shifted_input_ids
shifted_input_ids
.slice(1, 1, *shifted_input_ids.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *input_ids.size().last().unwrap() - 1, 1));
shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), pad_token_id)

View File

@ -77,7 +77,7 @@ fn _shift_tokens_right(
input_ids.size().as_slice(),
(input_ids.kind(), input_ids.device()),
);
let _ = shifted_input_ids
shifted_input_ids
.slice(1, 1, input_ids_length, 1)
.copy_(&input_ids.slice(1, 0, input_ids_length - 1, 1));