diff --git a/examples/sentence_embeddings_local.rs b/examples/sentence_embeddings_local.rs index c54f378..f148f94 100644 --- a/examples/sentence_embeddings_local.rs +++ b/examples/sentence_embeddings_local.rs @@ -9,6 +9,19 @@ use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder; /// ```sh /// python ./utils/convert_model.py resources/all-MiniLM-L12-v2/pytorch_model.bin /// ``` +/// +/// For models missing the prefix in their saved weights (e.g. Distil-based models), the +/// conversion needs to be updated to include this prefix so that the weights can be found: +/// ```sh +/// python ./utils/convert_model.py resources/path/to/pytorch_model.bin --prefix distilbert. +/// ``` +/// +/// For models including a dense projection layer (e.g. Distil-based models), these weights +/// need to be converted as well: +/// ```sh +/// python ../utils/convert_model.py resources/path/to/2_Dense/pytorch_model.bin --suffix +/// ``` +/// fn main() -> anyhow::Result<()> { // Set-up sentence embeddings model let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2") diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index 119d6f2..6ba9c4e 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -119,7 +119,7 @@ fn _shift_tokens_right( (Kind::Int64, input_ids.device()), ); let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id); - let _ = shifted_input_ids + shifted_input_ids .slice(1, 1, *shifted_input_ids.size().last().unwrap(), 1) .copy_(&input_ids.slice(1, 0, *input_ids.size().last().unwrap() - 1, 1)); shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), pad_token_id) diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index 746e4a8..7a3c372 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -77,7 +77,7 @@ fn _shift_tokens_right( input_ids.size().as_slice(), (input_ids.kind(), input_ids.device()), ); - let _ = shifted_input_ids + shifted_input_ids .slice(1, 1, input_ids_length, 1) .copy_(&input_ids.slice(1, 0, input_ids_length - 1, 1));