diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 0a8f13d..1afa042 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -100,6 +100,22 @@ jobs: --test pegasus --test gpt_neo + test-batch-2: + name: Integration tests (batch 2) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - uses: actions-rs/cargo@v1 + with: + command: test + args: --package rust-bert + --test sentence_embeddings + convert-model: name: Model conversion test runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 662cafe..2581198 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## Added +- Support for sentence embeddings models and pipelines, based on [SentenceTransformers](https://www.sbert.net). ## [0.18.0] - 2022-05-29 ## Added diff --git a/examples/sentence_embeddings.rs b/examples/sentence_embeddings.rs new file mode 100644 index 0000000..9bf8304 --- /dev/null +++ b/examples/sentence_embeddings.rs @@ -0,0 +1,17 @@ +use rust_bert::pipelines::sentence_embeddings::{ + SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType, +}; + +fn main() -> anyhow::Result<()> { + // Set-up sentence embeddings model + let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL12V2) + .create_model()?; + + // Define input + let sentences = ["this is an example sentence", "each sentence is converted"]; + + // Generate Embeddings + let embeddings = model.encode(&sentences)?; + println!("{:?}", embeddings); + Ok(()) +} diff --git a/examples/sentence_embeddings_local.rs b/examples/sentence_embeddings_local.rs new file mode 100644 index 0000000..c54f378 --- /dev/null +++ b/examples/sentence_embeddings_local.rs @@ -0,0 +1,25 @@ +use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder; + +/// Download model: +/// ```sh +/// git lfs install +/// git -C resources clone https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 +/// ``` +/// Prepare model: +/// ```sh +/// python ./utils/convert_model.py resources/all-MiniLM-L12-v2/pytorch_model.bin +/// ``` +fn main() -> anyhow::Result<()> { + // Set-up sentence embeddings model + let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2") + .with_device(tch::Device::cuda_if_available()) + .create_model()?; + + // Define input + let sentences = ["this is an example sentence", "each sentence is converted"]; + + // Generate Embeddings + let embeddings = model.encode(&sentences)?; + println!("{:?}", embeddings); + Ok(()) +} diff --git a/src/albert/albert_model.rs b/src/albert/albert_model.rs index 3de93a6..26ce193 100644 --- a/src/albert/albert_model.rs +++ b/src/albert/albert_model.rs @@ -37,6 +37,11 @@ impl AlbertModelResources { "albert-base-v2/model", "https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/model", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/rust_model.ot", + ); } impl AlbertConfigResources { @@ -45,6 +50,11 @@ impl AlbertConfigResources { "albert-base-v2/config", "https://huggingface.co/albert-base-v2/resolve/main/config.json", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/config", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/config.json", + ); } impl AlbertVocabResources { @@ -53,6 +63,11 @@ impl AlbertVocabResources { "albert-base-v2/spiece", "https://huggingface.co/albert-base-v2/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/spiece", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/spiece.model", + ); } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -1048,6 +1063,10 @@ impl AlbertForMultipleChoice { } } +/// # ALBERT for sentence embeddings +/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel). +pub type AlbertForSentenceEmbeddings = AlbertModel; + /// Container for the ALBERT model output. pub struct AlbertOutput { /// Last hidden states from the model diff --git a/src/albert/mod.rs b/src/albert/mod.rs index 5c11ddf..a2a48c5 100644 --- a/src/albert/mod.rs +++ b/src/albert/mod.rs @@ -59,8 +59,8 @@ mod encoder; pub use albert_model::{ AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice, - AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification, - AlbertMaskedLMOutput, AlbertModel, AlbertModelResources, AlbertOutput, - AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput, + AlbertForQuestionAnswering, AlbertForSentenceEmbeddings, AlbertForSequenceClassification, + AlbertForTokenClassification, AlbertMaskedLMOutput, AlbertModel, AlbertModelResources, + AlbertOutput, AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput, AlbertTokenClassificationOutput, AlbertVocabResources, }; diff --git a/src/bert/bert_model.rs b/src/bert/bert_model.rs index 39ef5fc..e225855 100644 --- a/src/bert/bert_model.rs +++ b/src/bert/bert_model.rs @@ -52,6 +52,16 @@ impl BertModelResources { "bert-qa/model", "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/model", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/rust_model.ot", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/model", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot", + ); } impl BertConfigResources { @@ -70,6 +80,16 @@ impl BertConfigResources { "bert-qa/config", "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/config", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/config", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json", + ); } impl BertVocabResources { @@ -88,6 +108,16 @@ impl BertVocabResources { "bert-qa/vocab", "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/vocab", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/vocab.txt", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/vocab", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt", + ); } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -1179,6 +1209,10 @@ impl BertForQuestionAnswering { } } +/// # BERT for sentence embeddings +/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel). +pub type BertForSentenceEmbeddings = BertModel; + /// Container for the BERT model output. pub struct BertModelOutput { /// Last hidden states from the model diff --git a/src/bert/mod.rs b/src/bert/mod.rs index 743e3af..7a6672b 100644 --- a/src/bert/mod.rs +++ b/src/bert/mod.rs @@ -59,8 +59,8 @@ pub(crate) mod encoder; pub use bert_model::{ BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice, - BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, - BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources, + BertForQuestionAnswering, BertForSentenceEmbeddings, BertForSequenceClassification, + BertForTokenClassification, BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources, BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput, BertVocabResources, }; diff --git a/src/common/activations.rs b/src/common/activations.rs index 3deedf6..7b3dd27 100644 --- a/src/common/activations.rs +++ b/src/common/activations.rs @@ -26,6 +26,10 @@ pub fn _tanh(x: &Tensor) -> Tensor { x.tanh() } +pub fn _identity(x: &Tensor) -> Tensor { + x.shallow_clone() +} + pub struct TensorFunction(Box Tensor>); impl TensorFunction { @@ -58,6 +62,8 @@ pub enum Activation { gelu_new, /// Tanh tanh, + /// Identity + identity, } impl Activation { @@ -69,6 +75,7 @@ impl Activation { Activation::gelu_new => _gelu_new, Activation::mish => _mish, Activation::tanh => _tanh, + Activation::identity => _identity, })) } } diff --git a/src/common/resources/local.rs b/src/common/resources/local.rs index 9f44a3d..3cefeb0 100644 --- a/src/common/resources/local.rs +++ b/src/common/resources/local.rs @@ -30,3 +30,15 @@ impl ResourceProvider for LocalResource { Ok(self.local_path.clone()) } } + +impl From for LocalResource { + fn from(local_path: PathBuf) -> Self { + Self { local_path } + } +} + +impl From for Box { + fn from(local_path: PathBuf) -> Self { + Box::new(LocalResource { local_path }) + } +} diff --git a/src/distilbert/distilbert_model.rs b/src/distilbert/distilbert_model.rs index dd115ca..a0d949d 100644 --- a/src/distilbert/distilbert_model.rs +++ b/src/distilbert/distilbert_model.rs @@ -46,6 +46,11 @@ impl DistilBertModelResources { "distilbert-qa/model", "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/model", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/rust_model.ot", + ); } impl DistilBertConfigResources { @@ -64,6 +69,11 @@ impl DistilBertConfigResources { "distilbert-qa/config", "https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/config", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/config.json", + ); } impl DistilBertVocabResources { @@ -82,6 +92,11 @@ impl DistilBertVocabResources { "distilbert-qa/vocab", "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/vocab", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/vocab.txt", + ); } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -755,6 +770,10 @@ impl DistilBertForTokenClassification { } } +/// # DistilBERT for sentence embeddings +/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel). +pub type DistilBertForSentenceEmbeddings = DistilBertModel; + /// Container for the DistilBERT masked LM model output. pub struct DistilBertMaskedLMOutput { /// Logits for the vocabulary items at each sequence position diff --git a/src/distilbert/mod.rs b/src/distilbert/mod.rs index d0d420e..9f38909 100644 --- a/src/distilbert/mod.rs +++ b/src/distilbert/mod.rs @@ -60,8 +60,8 @@ mod transformer; pub use distilbert_model::{ DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering, - DistilBertForTokenClassification, DistilBertMaskedLMOutput, DistilBertModel, - DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources, + DistilBertForSentenceEmbeddings, DistilBertForTokenClassification, DistilBertMaskedLMOutput, + DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources, DistilBertQuestionAnsweringOutput, DistilBertSequenceClassificationOutput, DistilBertTokenClassificationOutput, DistilBertVocabResources, }; diff --git a/src/lib.rs b/src/lib.rs index 88c6321..f7b1938 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,7 +146,7 @@ //! ``` //! //! -//!   +//!   //!
//! 2. Translation //! @@ -198,7 +198,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 3. Summarization //! @@ -249,7 +249,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 4. Dialogue Model //! @@ -281,7 +281,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 5. Natural Language Generation //! @@ -319,7 +319,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 6. Zero-shot classification //! @@ -402,7 +402,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 7. Sentiment analysis //! @@ -445,7 +445,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 8. Named Entity Recognition //! @@ -502,7 +502,7 @@ //! ``` //! //!
-//!   +//!   //!
//! 9. Part of Speech tagging //! diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index c0b2c36..a886c3b 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -54,22 +54,28 @@ use rust_tokenizers::vocab::{ use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::convert::TryFrom; use std::path::Path; #[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq)] /// # Identifies the type of model pub enum ModelType { Bart, + #[serde(alias = "bert")] Bert, + #[serde(alias = "distilbert")] DistilBert, Deberta, DebertaV2, + #[serde(alias = "roberta")] Roberta, XLMRoberta, Electra, Marian, MobileBert, + #[serde(alias = "t5")] T5, + #[serde(alias = "albert")] Albert, XLNet, GPT2, @@ -309,6 +315,62 @@ impl ConfigOption { } } +impl TryFrom<&ConfigOption> for BertConfig { + type Error = RustBertError; + + fn try_from(config: &ConfigOption) -> Result { + match config { + ConfigOption::Bert(config) | ConfigOption::Roberta(config) => Ok(config.clone()), + _ => Err(RustBertError::InvalidConfigurationError( + "You can only supply a BertConfig for Bert or a RobertaConfig for Roberta!" + .to_string(), + )), + } + } +} + +impl TryFrom<&ConfigOption> for DistilBertConfig { + type Error = RustBertError; + + fn try_from(config: &ConfigOption) -> Result { + if let ConfigOption::DistilBert(config) = config { + Ok(config.clone()) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a DistilBertConfig for DistilBert!".to_string(), + )) + } + } +} + +impl TryFrom<&ConfigOption> for AlbertConfig { + type Error = RustBertError; + + fn try_from(config: &ConfigOption) -> Result { + if let ConfigOption::Albert(config) = config { + Ok(config.clone()) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply an AlbertConfig for Albert!".to_string(), + )) + } + } +} + +impl TryFrom<&ConfigOption> for T5Config { + type Error = RustBertError; + + fn try_from(config: &ConfigOption) -> Result { + if let ConfigOption::T5(config) = config { + Ok(config.clone()) + } else { + Err(RustBertError::InvalidConfigurationError( + "You can only supply a T5Config for T5!".to_string(), + )) + } + } +} + impl TokenizerOption { /// Interface method to load a tokenizer from file pub fn from_file( diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index ee84bdc..002453f 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -416,6 +416,7 @@ pub mod generation_utils; pub mod ner; pub mod pos_tagging; pub mod question_answering; +pub mod sentence_embeddings; pub mod sentiment; pub mod sequence_classification; pub mod summarization; diff --git a/src/pipelines/sentence_embeddings/builder.rs b/src/pipelines/sentence_embeddings/builder.rs new file mode 100644 index 0000000..a35947a --- /dev/null +++ b/src/pipelines/sentence_embeddings/builder.rs @@ -0,0 +1,185 @@ +use std::path::PathBuf; + +use serde::Deserialize; +use tch::Device; + +use crate::pipelines::common::ModelType; +use crate::pipelines::sentence_embeddings::{ + SentenceEmbeddingsConfig, SentenceEmbeddingsModel, SentenceEmbeddingsModulesConfig, +}; +use crate::{Config, RustBertError}; + +#[cfg(feature = "remote")] +use crate::{ + pipelines::sentence_embeddings::resources::SentenceEmbeddingsModelType, + resources::RemoteResource, +}; + +/// # SentenceEmbeddings Model Builder +/// +/// Allows the user to build a model from standard Sentence-Transformer files +/// (configuration and weights). +pub struct SentenceEmbeddingsBuilder { + device: Device, + inner: T, +} + +impl SentenceEmbeddingsBuilder { + pub fn with_device(mut self, device: Device) -> Self { + self.device = device; + self + } +} + +pub struct Local { + model_dir: PathBuf, +} + +#[derive(Debug, Deserialize)] +struct ModelConfig { + model_type: ModelType, +} + +impl Config for ModelConfig {} + +impl SentenceEmbeddingsBuilder { + pub fn local>(model_dir: P) -> Self { + Self { + device: Device::cuda_if_available(), + inner: Local { + model_dir: model_dir.into(), + }, + } + } + + pub fn create_model(self) -> Result { + let model_dir = self.inner.model_dir; + + let modules_config = model_dir.join("modules.json"); + let modules = SentenceEmbeddingsModulesConfig::from_file(&modules_config).validate()?; + + let transformer_config = model_dir.join("config.json"); + let transformer_type = ModelConfig::from_file(&transformer_config).model_type; + let transformer_weights = model_dir.join("rust_model.ot"); + + let pooling_config = model_dir + .join(&modules.pooling_module().path) + .join("config.json"); + + let (dense_config, dense_weights) = modules + .dense_module() + .map(|m| { + ( + Some(model_dir.join(&m.path).join("config.json")), + Some(model_dir.join(&m.path).join("rust_model.ot")), + ) + }) + .unwrap_or((None, None)); + + let tokenizer_config = model_dir.join("tokenizer_config.json"); + let sentence_bert_config = model_dir.join("sentence_bert_config.json"); + let (tokenizer_vocab, tokenizer_merges) = match transformer_type { + ModelType::Bert | ModelType::DistilBert => (model_dir.join("vocab.txt"), None), + ModelType::Roberta => ( + model_dir.join("vocab.json"), + Some(model_dir.join("merges.txt")), + ), + ModelType::Albert => (model_dir.join("spiece.model"), None), + ModelType::T5 => (model_dir.join("spiece.model"), None), + _ => { + return Err(RustBertError::InvalidConfigurationError(format!( + "Unsupported transformer model {:?} for Sentence Embeddings", + transformer_type + ))); + } + }; + + let config = SentenceEmbeddingsConfig { + modules_config_resource: modules_config.into(), + transformer_type, + transformer_config_resource: transformer_config.into(), + transformer_weights_resource: transformer_weights.into(), + pooling_config_resource: pooling_config.into(), + dense_config_resource: dense_config.map(|r| r.into()), + dense_weights_resource: dense_weights.map(|r| r.into()), + sentence_bert_config_resource: sentence_bert_config.into(), + tokenizer_config_resource: tokenizer_config.into(), + tokenizer_vocab_resource: tokenizer_vocab.into(), + tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()), + device: self.device, + }; + + SentenceEmbeddingsModel::new(config) + } +} + +#[cfg(feature = "remote")] +pub struct Remote { + config: SentenceEmbeddingsConfig, +} + +#[cfg(feature = "remote")] +impl SentenceEmbeddingsBuilder { + pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self { + Self { + device: Device::cuda_if_available(), + inner: Remote { + config: SentenceEmbeddingsConfig::from(model_type), + }, + } + } + + pub fn modules_config(mut self, resource: RemoteResource) -> Self { + self.inner.config.modules_config_resource = Box::new(resource); + self + } + + pub fn transformer_config(mut self, resource: RemoteResource) -> Self { + self.inner.config.transformer_config_resource = Box::new(resource); + self + } + + pub fn transformer_weights(mut self, resource: RemoteResource) -> Self { + self.inner.config.transformer_weights_resource = Box::new(resource); + self + } + + pub fn pooling_config(mut self, resource: RemoteResource) -> Self { + self.inner.config.pooling_config_resource = Box::new(resource); + self + } + + pub fn dense_config(mut self, resource: RemoteResource) -> Self { + self.inner.config.dense_config_resource = Some(Box::new(resource)); + self + } + + pub fn dense_weights(mut self, resource: RemoteResource) -> Self { + self.inner.config.dense_weights_resource = Some(Box::new(resource)); + self + } + + pub fn sentence_bert_config(mut self, resource: RemoteResource) -> Self { + self.inner.config.sentence_bert_config_resource = Box::new(resource); + self + } + + pub fn tokenizer_config(mut self, resource: RemoteResource) -> Self { + self.inner.config.tokenizer_config_resource = Box::new(resource); + self + } + + pub fn tokenizer_vocab(mut self, resource: RemoteResource) -> Self { + self.inner.config.tokenizer_vocab_resource = Box::new(resource); + self + } + + pub fn tokenizer_merges(mut self, resource: RemoteResource) -> Self { + self.inner.config.tokenizer_merges_resource = Some(Box::new(resource)); + self + } + + pub fn create_model(self) -> Result { + SentenceEmbeddingsModel::new(self.inner.config) + } +} diff --git a/src/pipelines/sentence_embeddings/config.rs b/src/pipelines/sentence_embeddings/config.rs new file mode 100644 index 0000000..cbeecd6 --- /dev/null +++ b/src/pipelines/sentence_embeddings/config.rs @@ -0,0 +1,429 @@ +use serde::{Deserialize, Serialize}; +use tch::Device; + +use crate::pipelines::common::ModelType; +use crate::resources::ResourceProvider; +use crate::{Config, RustBertError}; + +#[cfg(feature = "remote")] +use crate::{ + albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources}, + bert::{BertConfigResources, BertModelResources, BertVocabResources}, + distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources}, + pipelines::sentence_embeddings::resources::{ + SentenceEmbeddingsConfigResources, SentenceEmbeddingsModelType, + SentenceEmbeddingsModulesConfigResources, SentenceEmbeddingsPoolingConfigResources, + SentenceEmbeddingsTokenizerConfigResources, + }, + pipelines::sentence_embeddings::{ + SentenceEmbeddingsDenseConfigResources, SentenceEmbeddingsDenseResources, + }, + resources::RemoteResource, + roberta::{ + RobertaConfigResources, RobertaMergesResources, RobertaModelResources, + RobertaVocabResources, + }, + t5::{T5ConfigResources, T5ModelResources, T5VocabResources}, +}; + +/// # Configuration for sentence embeddings +/// +/// Contains information regarding the transformer model to load, the optional extra +/// layers, and device to place the model on. +pub struct SentenceEmbeddingsConfig { + /// Modules configuration resource, contains layers definition + pub modules_config_resource: Box, + /// Transformer model type + pub transformer_type: ModelType, + /// Transformer model configuration resource + pub transformer_config_resource: Box, + /// Transformer weights resource + pub transformer_weights_resource: Box, + /// Pooling layer configuration resource + pub pooling_config_resource: Box, + /// Optional dense layer configuration resource + pub dense_config_resource: Option>, + /// Optional dense layer weights resource + pub dense_weights_resource: Option>, + /// Sentence BERT specific configuration resource + pub sentence_bert_config_resource: Box, + /// Transformer's tokenizer configuration resource + pub tokenizer_config_resource: Box, + /// Transformer's tokenizer vocab resource + pub tokenizer_vocab_resource: Box, + /// Optional transformer's tokenizer merges resource + pub tokenizer_merges_resource: Option>, + /// Device to place the transformer model on + pub device: Device, +} + +#[cfg(feature = "remote")] +impl From for SentenceEmbeddingsConfig { + fn from(model_type: SentenceEmbeddingsModelType) -> Self { + match model_type { + SentenceEmbeddingsModelType::DistiluseBaseMultilingualCased => SentenceEmbeddingsConfig { + modules_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsModulesConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + transformer_type: ModelType::DistilBert, + transformer_config_resource: Box::new(RemoteResource::from_pretrained( + DistilBertConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + transformer_weights_resource: Box::new(RemoteResource::from_pretrained( + DistilBertModelResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + pooling_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsPoolingConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + dense_config_resource: Some(Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsDenseConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + ))), + dense_weights_resource: Some(Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsDenseResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + ))), + sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + tokenizer_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsTokenizerConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained( + DistilBertVocabResources::DISTILUSE_BASE_MULTILINGUAL_CASED, + )), + tokenizer_merges_resource: None, + device: Device::cuda_if_available(), + }, + + SentenceEmbeddingsModelType::BertBaseNliMeanTokens => SentenceEmbeddingsConfig { + modules_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsModulesConfigResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + transformer_type: ModelType::Bert, + transformer_config_resource: Box::new(RemoteResource::from_pretrained( + BertConfigResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + transformer_weights_resource: Box::new(RemoteResource::from_pretrained( + BertModelResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + pooling_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsPoolingConfigResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + dense_config_resource: None, + dense_weights_resource: None, + sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsConfigResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + tokenizer_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsTokenizerConfigResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained( + BertVocabResources::BERT_BASE_NLI_MEAN_TOKENS, + )), + tokenizer_merges_resource: None, + device: Device::cuda_if_available(), + }, + + SentenceEmbeddingsModelType::AllMiniLmL12V2 => SentenceEmbeddingsConfig { + modules_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsModulesConfigResources::ALL_MINI_LM_L12_V2, + )), + transformer_type: ModelType::Bert, + transformer_config_resource: Box::new(RemoteResource::from_pretrained( + BertConfigResources::ALL_MINI_LM_L12_V2, + )), + transformer_weights_resource: Box::new(RemoteResource::from_pretrained( + BertModelResources::ALL_MINI_LM_L12_V2, + )), + pooling_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsPoolingConfigResources::ALL_MINI_LM_L12_V2, + )), + dense_config_resource: None, + dense_weights_resource: None, + sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsConfigResources::ALL_MINI_LM_L12_V2, + )), + tokenizer_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsTokenizerConfigResources::ALL_MINI_LM_L12_V2, + )), + tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained( + BertVocabResources::ALL_MINI_LM_L12_V2, + )), + tokenizer_merges_resource: None, + device: Device::cuda_if_available(), + }, + + SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig { + modules_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsModulesConfigResources::ALL_DISTILROBERTA_V1, + )), + transformer_type: ModelType::Roberta, + transformer_config_resource: Box::new(RemoteResource::from_pretrained( + RobertaConfigResources::ALL_DISTILROBERTA_V1, + )), + transformer_weights_resource: Box::new(RemoteResource::from_pretrained( + RobertaModelResources::ALL_DISTILROBERTA_V1, + )), + pooling_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsPoolingConfigResources::ALL_DISTILROBERTA_V1, + )), + dense_config_resource: None, + dense_weights_resource: None, + sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsConfigResources::ALL_DISTILROBERTA_V1, + )), + tokenizer_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsTokenizerConfigResources::ALL_DISTILROBERTA_V1, + )), + tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained( + RobertaVocabResources::ALL_DISTILROBERTA_V1, + )), + tokenizer_merges_resource: Some(Box::new(RemoteResource::from_pretrained( + RobertaMergesResources::ALL_DISTILROBERTA_V1, + ))), + device: Device::cuda_if_available(), + }, + + SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2 => SentenceEmbeddingsConfig { + modules_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsModulesConfigResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + transformer_type: ModelType::Albert, + transformer_config_resource: Box::new(RemoteResource::from_pretrained( + AlbertConfigResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + transformer_weights_resource: Box::new(RemoteResource::from_pretrained( + AlbertModelResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + pooling_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsPoolingConfigResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + dense_config_resource: None, + dense_weights_resource: None, + sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsConfigResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + tokenizer_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsTokenizerConfigResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained( + AlbertVocabResources::PARAPHRASE_ALBERT_SMALL_V2, + )), + tokenizer_merges_resource: None, + device: Device::cuda_if_available(), + }, + + SentenceEmbeddingsModelType::SentenceT5Base => SentenceEmbeddingsConfig { + modules_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsModulesConfigResources::SENTENCE_T5_BASE, + )), + transformer_type: ModelType::T5, + transformer_config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::SENTENCE_T5_BASE, + )), + transformer_weights_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::SENTENCE_T5_BASE, + )), + pooling_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsPoolingConfigResources::SENTENCE_T5_BASE, + )), + dense_config_resource: Some(Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsDenseConfigResources::SENTENCE_T5_BASE, + ))), + dense_weights_resource: Some(Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsDenseResources::SENTENCE_T5_BASE, + ))), + sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsConfigResources::SENTENCE_T5_BASE, + )), + tokenizer_config_resource: Box::new(RemoteResource::from_pretrained( + SentenceEmbeddingsTokenizerConfigResources::SENTENCE_T5_BASE, + )), + tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::SENTENCE_T5_BASE, + )), + tokenizer_merges_resource: None, + device: Device::cuda_if_available(), + }, + } + } +} + +/// Configuration for the modules that define the model's layers +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SentenceEmbeddingsModulesConfig(pub Vec); + +impl std::ops::Deref for SentenceEmbeddingsModulesConfig { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for SentenceEmbeddingsModulesConfig { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From> for SentenceEmbeddingsModulesConfig { + fn from(source: Vec) -> Self { + Self(source) + } +} + +impl Config for SentenceEmbeddingsModulesConfig {} + +impl SentenceEmbeddingsModulesConfig { + pub fn validate(self) -> Result { + match self.get(0) { + Some(SentenceEmbeddingsModuleConfig { + module_type: SentenceEmbeddingsModuleType::Transformer, + .. + }) => (), + Some(_) => { + return Err(RustBertError::InvalidConfigurationError( + "First module defined in modules.json must be a Transformer".to_string(), + )); + } + None => { + return Err(RustBertError::InvalidConfigurationError( + "No modules found in modules.json".to_string(), + )); + } + } + + match self.get(1) { + Some(SentenceEmbeddingsModuleConfig { + module_type: SentenceEmbeddingsModuleType::Pooling, + .. + }) => (), + Some(_) => { + return Err(RustBertError::InvalidConfigurationError( + "Second module defined in modules.json must be a Pooling".to_string(), + )); + } + None => { + return Err(RustBertError::InvalidConfigurationError( + "Pooling module not found in second position in modules.json".to_string(), + )); + } + } + + Ok(self) + } + + pub fn transformer_module(&self) -> &SentenceEmbeddingsModuleConfig { + self.get(0).as_ref().unwrap() + } + + pub fn pooling_module(&self) -> &SentenceEmbeddingsModuleConfig { + self.get(1).as_ref().unwrap() + } + + pub fn dense_module(&self) -> Option<&SentenceEmbeddingsModuleConfig> { + for i in 2..=3 { + if let Some(SentenceEmbeddingsModuleConfig { + module_type: SentenceEmbeddingsModuleType::Dense, + .. + }) = self.get(i) + { + return self.get(i); + } + } + None + } + + pub fn has_normalization(&self) -> bool { + for i in 2..=3 { + if let Some(SentenceEmbeddingsModuleConfig { + module_type: SentenceEmbeddingsModuleType::Normalize, + .. + }) = self.get(i) + { + return true; + } + } + false + } +} + +/// Configuration defining a single module (model's layer) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SentenceEmbeddingsModuleConfig { + pub idx: usize, + pub name: String, + pub path: String, + #[serde(rename = "type")] + #[serde(with = "serde_sentence_embeddings_module_type")] + pub module_type: SentenceEmbeddingsModuleType, +} + +/// Available module types, based on Sentence-Transformers +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SentenceEmbeddingsModuleType { + Transformer, + Pooling, + Dense, + Normalize, +} + +mod serde_sentence_embeddings_module_type { + use super::SentenceEmbeddingsModuleType; + use serde::{de, Deserializer, Serializer}; + + pub fn serialize( + module_type: &SentenceEmbeddingsModuleType, + serializer: S, + ) -> Result + where + S: Serializer, + { + serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type)) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct SentenceEmbeddingsModuleTypeVisitor; + + impl de::Visitor<'_> for SentenceEmbeddingsModuleTypeVisitor { + type Value = SentenceEmbeddingsModuleType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sentence embeddings module type") + } + + fn visit_str(self, s: &str) -> Result { + s.split('.') + .last() + .map(|s| serde_json::from_value(serde_json::Value::String(s.to_string()))) + .transpose() + .map_err(de::Error::custom)? + .ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s)) + .map_err(de::Error::custom) + } + } + + deserializer.deserialize_str(SentenceEmbeddingsModuleTypeVisitor) + } +} + +/// Configuration for Sentence-Transformers specific parameters +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SentenceEmbeddingsSentenceBertConfig { + pub max_seq_length: usize, + pub do_lower_case: bool, +} + +impl Config for SentenceEmbeddingsSentenceBertConfig {} + +/// Configuration for transformer's tokenizer +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SentenceEmbeddingsTokenizerConfig { + pub add_prefix_space: Option, + pub strip_accents: Option, +} + +impl Config for SentenceEmbeddingsTokenizerConfig {} diff --git a/src/pipelines/sentence_embeddings/layers.rs b/src/pipelines/sentence_embeddings/layers.rs new file mode 100644 index 0000000..31b6a99 --- /dev/null +++ b/src/pipelines/sentence_embeddings/layers.rs @@ -0,0 +1,151 @@ +use std::path::Path; + +use serde::{de, Deserialize, Deserializer}; +use tch::{nn, Device, Kind, Tensor}; + +use crate::common::activations::{Activation, TensorFunction}; +use crate::{Config, RustBertError}; + +/// Configuration for [`Pooling`](Pooling) layer. +#[derive(Debug, Deserialize)] +pub struct PoolingConfig { + /// Dimensions for the word embeddings + pub word_embedding_dimension: i64, + /// Use the first token (CLS token) as text representations + pub pooling_mode_cls_token: bool, + /// Use max in each dimension over all tokens + pub pooling_mode_max_tokens: bool, + /// Perform mean-pooling + pub pooling_mode_mean_tokens: bool, + /// Perform mean-pooling, but devide by sqrt(input_length) + pub pooling_mode_mean_sqrt_len_tokens: bool, +} + +impl Config for PoolingConfig {} + +/// Performs pooling (max or mean) on the token embeddings. +/// +/// Using pooling, it generates from a variable sized sentence a fixed sized sentence +/// embedding. You can concatenate multiple poolings together. +pub struct Pooling { + conf: PoolingConfig, +} + +impl Pooling { + pub fn new(conf: PoolingConfig) -> Pooling { + Pooling { conf } + } + + pub fn forward(&self, mut token_embeddings: Tensor, attention_mask: &Tensor) -> Tensor { + let mut output_vectors = Vec::new(); + + if self.conf.pooling_mode_cls_token { + let cls_token = token_embeddings.select(1, 0); // Take first token by default + output_vectors.push(cls_token); + } + + if self.conf.pooling_mode_max_tokens { + let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings); + // Set padding tokens to large negative value + token_embeddings = token_embeddings.masked_fill_(&input_mask_expanded.eq(0), -1e9); + let max_over_time = token_embeddings.max_dim(1, true).0; + output_vectors.push(max_over_time); + } + + if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens { + let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings); + let sum_embeddings = + (token_embeddings * &input_mask_expanded).sum_dim_intlist(&[1], false, Kind::Float); + + let sum_mask = input_mask_expanded.sum_dim_intlist(&[1], false, Kind::Float); + let sum_mask = sum_mask.clamp_min(10e-9); + + if self.conf.pooling_mode_mean_tokens { + output_vectors.push(&sum_embeddings / &sum_mask); + } + if self.conf.pooling_mode_mean_sqrt_len_tokens { + output_vectors.push(sum_embeddings / sum_mask.sqrt()); + } + } + + Tensor::cat(&output_vectors, 1) + } +} + +/// Configuration for [`Dense`](Dense) layer. +#[derive(Debug, Deserialize)] +pub struct DenseConfig { + /// Size of the input dimension + pub in_features: i64, + /// Output size + pub out_features: i64, + /// Add a bias vector + pub bias: bool, + /// Activation function applied on output + #[serde(deserialize_with = "last_part")] + pub activation_function: Activation, +} + +impl Config for DenseConfig {} + +/// Split the given string on `.` and try to construct an `Activation` from the last part +fn last_part<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let activation = String::deserialize(deserializer)?; + activation + .split('.') + .last() + .map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase()))) + .transpose() + .map_err(de::Error::custom)? + .ok_or_else(|| format!("Invalid Activation: {}", activation)) + .map_err(de::Error::custom) +} + +/// Feed-forward function with activiation function. +/// +/// This layer takes a fixed-sized sentence embedding and passes it through a +/// feed-forward layer. Can be used to generate deep averaging networs (DAN). +pub struct Dense { + linear: nn::Linear, + activation: TensorFunction, + _var_store: nn::VarStore, +} + +impl Dense { + pub fn new>( + dense_conf: DenseConfig, + dense_weights: P, + device: Device, + ) -> Result { + let mut vs_dense = nn::VarStore::new(device); + + let linear_conf = nn::LinearConfig { + ws_init: nn::Init::Const(0.), + bs_init: Some(nn::Init::Const(0.)), + bias: dense_conf.bias, + }; + let linear = nn::linear( + &vs_dense.root(), + dense_conf.in_features, + dense_conf.out_features, + linear_conf, + ); + + let activation = dense_conf.activation_function.get_function(); + + vs_dense.load(dense_weights)?; + + Ok(Dense { + linear, + activation, + _var_store: vs_dense, + }) + } + + pub fn forward(&self, x: &Tensor) -> Tensor { + self.activation.get_fn()(&x.apply(&self.linear)) + } +} diff --git a/src/pipelines/sentence_embeddings/mod.rs b/src/pipelines/sentence_embeddings/mod.rs new file mode 100644 index 0000000..3ac4bcf --- /dev/null +++ b/src/pipelines/sentence_embeddings/mod.rs @@ -0,0 +1,64 @@ +//! # Sentence Embeddings pipeline +//! +//! Compute sentence/text embeddings that can be compared (e.g. with +//! cosine-similarity) to find sentences with a similar meaning. This can be useful for +//! semantic textual similar, semantic search, or paraphrase mining. +//! +//! The implementation is based on [Sentence-Transformers][sbert] and pretrained models +//! available on [Hugging Face Hub][sbert-hub] can be used. It's however necessary to +//! convert them using the script `utils/convert_model.py` beforehand, see +//! `tests/sentence_embeddings.rs` for such examples. +//! +//! [sbert]: https://sbert.net/ +//! [sbert-hub]: https://huggingface.co/sentence-transformers/ +//! +//! Basic usage is as follows: +//! +//! ```no_run +//! use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder; +//! +//! # fn main() -> anyhow::Result<()> { +//! let model = SentenceEmbeddingsBuilder::local("local/path/to/distiluse-base-multilingual-cased") +//! .with_device(tch::Device::cuda_if_available()) +//! .create_model()?; +//! +//! let sentences = ["This is an example sentence", "Each sentence is converted"]; +//! let embeddings = model.encode(&sentences)?; +//! # Ok(()) +//! # } +//! ``` + +pub mod builder; +mod config; +pub mod layers; +mod pipeline; +mod resources; + +pub use builder::SentenceEmbeddingsBuilder; +pub use config::{ + SentenceEmbeddingsConfig, SentenceEmbeddingsModuleConfig, SentenceEmbeddingsModuleType, + SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig, + SentenceEmbeddingsTokenizerConfig, +}; +pub use pipeline::{ + SentenceEmbeddingsModel, SentenceEmbeddingsModelOuput, SentenceEmbeddingsOption, + SentenceEmbeddingsTokenizerOuput, +}; + +pub use resources::{ + SentenceEmbeddingsConfigResources, SentenceEmbeddingsDenseConfigResources, + SentenceEmbeddingsDenseResources, SentenceEmbeddingsModelType, + SentenceEmbeddingsModulesConfigResources, SentenceEmbeddingsPoolingConfigResources, + SentenceEmbeddingsTokenizerConfigResources, +}; + +/// Length = sequence length +pub type Attention = Vec; +/// Length = sequence length +pub type AttentionHead = Vec; +/// Length = number of heads per attention layer +pub type AttentionLayer = Vec; +/// Length = number of attention layers +pub type AttentionOutput = Vec; + +pub type Embedding = Vec; diff --git a/src/pipelines/sentence_embeddings/pipeline.rs b/src/pipelines/sentence_embeddings/pipeline.rs new file mode 100644 index 0000000..f492a10 --- /dev/null +++ b/src/pipelines/sentence_embeddings/pipeline.rs @@ -0,0 +1,461 @@ +use std::borrow::Borrow; +use std::convert::TryInto; + +use rust_tokenizers::tokenizer::TruncationStrategy; +use tch::{nn, Tensor}; + +use crate::albert::AlbertForSentenceEmbeddings; +use crate::bert::BertForSentenceEmbeddings; +use crate::distilbert::DistilBertForSentenceEmbeddings; +use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption}; +use crate::pipelines::sentence_embeddings::layers::{Dense, DenseConfig, Pooling, PoolingConfig}; +use crate::pipelines::sentence_embeddings::{ + AttentionHead, AttentionLayer, AttentionOutput, Embedding, SentenceEmbeddingsConfig, + SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig, + SentenceEmbeddingsTokenizerConfig, +}; +use crate::roberta::RobertaForSentenceEmbeddings; +use crate::t5::T5ForSentenceEmbeddings; +use crate::{Config, RustBertError}; + +/// # Abstraction that holds one particular sentence embeddings model, for any of the supported models +pub enum SentenceEmbeddingsOption { + /// Bert for Sentence Embeddings + Bert(BertForSentenceEmbeddings), + /// DistilBert for Sentence Embeddings + DistilBert(DistilBertForSentenceEmbeddings), + /// Roberta for Sentence Embeddings + Roberta(RobertaForSentenceEmbeddings), + /// Albert for Sentence Embeddings + Albert(AlbertForSentenceEmbeddings), + /// T5 for Sentence Embeddings + T5(T5ForSentenceEmbeddings), +} + +impl SentenceEmbeddingsOption { + /// Instantiate a new sentence embeddings transformer of the supplied type. + /// + /// # Arguments + /// + /// * `transformer_type` - `ModelType` indicating the transformer model type to load (must match with the actual data to be loaded) + /// * `p` - `tch::nn::Path` path to the model file to load (e.g. rust_model.ot) + /// * `config` - A configuration (the transformer model type of the configuration must be compatible with the value for `transformer_type`) + pub fn new<'p, P>( + transformer_type: ModelType, + p: P, + config: &ConfigOption, + ) -> Result + where + P: Borrow>, + { + use SentenceEmbeddingsOption::*; + + let option = match transformer_type { + ModelType::Bert => Bert(BertForSentenceEmbeddings::new(p, &(config.try_into()?))), + ModelType::DistilBert => DistilBert(DistilBertForSentenceEmbeddings::new( + p, + &(config.try_into()?), + )), + ModelType::Roberta => Roberta(RobertaForSentenceEmbeddings::new_with_optional_pooler( + p, + &(config.try_into()?), + false, + )), + ModelType::Albert => Albert(AlbertForSentenceEmbeddings::new(p, &(config.try_into()?))), + ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))), + _ => { + return Err(RustBertError::InvalidConfigurationError(format!( + "Unsupported transformer model {:?} for Sentence Embeddings", + transformer_type + ))); + } + }; + + Ok(option) + } + + /// Interface method to forward() of the particular transformer models. + pub fn forward( + &self, + tokens_ids: &Tensor, + tokens_masks: &Tensor, + ) -> Result<(Tensor, Option>), RustBertError> { + match self { + Self::Bert(transformer) => transformer + .forward_t( + Some(tokens_ids), + Some(tokens_masks), + None, + None, + None, + None, + None, + false, + ) + .map(|transformer_output| { + ( + transformer_output.hidden_state, + transformer_output.all_attentions, + ) + }), + Self::DistilBert(transformer) => transformer + .forward_t(Some(tokens_ids), Some(tokens_masks), None, false) + .map(|transformer_output| { + ( + transformer_output.hidden_state, + transformer_output.all_attentions, + ) + }), + Self::Roberta(transformer) => transformer + .forward_t( + Some(tokens_ids), + Some(tokens_masks), + None, + None, + None, + None, + None, + false, + ) + .map(|transformer_output| { + ( + transformer_output.hidden_state, + transformer_output.all_attentions, + ) + }), + Self::Albert(transformer) => transformer + .forward_t( + Some(tokens_ids), + Some(tokens_masks), + None, + None, + None, + false, + ) + .map(|transformer_output| { + ( + transformer_output.hidden_state, + transformer_output.all_attentions.map(|attentions| { + attentions + .into_iter() + .map(|tensors| { + let num_inner_groups = tensors.len() as f64; + tensors.into_iter().sum::() / num_inner_groups + }) + .collect() + }), + ) + }), + Self::T5(transformer) => transformer.forward(tokens_ids, tokens_masks), + } + } +} + +/// # SentenceEmbeddingsModel to perform sentence embeddings +/// +/// It is made of the following blocks: +/// - `transformer`: Base transformer model +/// - `pooling`: Pooling layer +/// - `dense` _(optional)_: Linear (feed forward) layer +/// - `normalization` _(optional)_: Embeddings normalization +pub struct SentenceEmbeddingsModel { + sentence_bert_config: SentenceEmbeddingsSentenceBertConfig, + tokenizer: TokenizerOption, + tokenizer_truncation_strategy: TruncationStrategy, + var_store: nn::VarStore, + transformer: SentenceEmbeddingsOption, + transformer_config: ConfigOption, + pooling_layer: Pooling, + dense_layer: Option, + normalize_embeddings: bool, +} + +impl SentenceEmbeddingsModel { + /// Build a new `SentenceEmbeddingsModel` + /// + /// # Arguments + /// + /// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU) + pub fn new(config: SentenceEmbeddingsConfig) -> Result { + let SentenceEmbeddingsConfig { + modules_config_resource, + sentence_bert_config_resource, + tokenizer_config_resource, + tokenizer_vocab_resource, + tokenizer_merges_resource, + transformer_type, + transformer_config_resource, + transformer_weights_resource, + pooling_config_resource, + dense_config_resource, + dense_weights_resource, + device, + } = config; + + let modules = + SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?) + .validate()?; + + // Setup tokenizer + + let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file( + tokenizer_config_resource.get_local_path()?, + ); + let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file( + sentence_bert_config_resource.get_local_path()?, + ); + let tokenizer = TokenizerOption::from_file( + transformer_type, + tokenizer_vocab_resource + .get_local_path()? + .to_string_lossy() + .as_ref(), + tokenizer_merges_resource + .as_ref() + .map(|resource| resource.get_local_path()) + .transpose()? + .map(|path| path.to_string_lossy().into_owned()) + .as_deref(), + sentence_bert_config.do_lower_case, + tokenizer_config.strip_accents, + tokenizer_config.add_prefix_space, + )?; + + // Setup transformer + + let mut var_store = nn::VarStore::new(device); + let transformer_config = ConfigOption::from_file( + transformer_type, + transformer_config_resource.get_local_path()?, + ); + let transformer = SentenceEmbeddingsOption::new( + transformer_type, + &var_store.root(), + &transformer_config, + )?; + var_store.load(transformer_weights_resource.get_local_path()?)?; + + // Setup pooling layer + + let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?); + let pooling_layer = Pooling::new(pooling_config); + + // Setup dense layer + + let dense_layer = if modules.dense_module().is_some() { + let dense_config = + DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?); + Some(Dense::new( + dense_config, + dense_weights_resource.unwrap().get_local_path()?, + device, + )?) + } else { + None + }; + + let normalize_embeddings = modules.has_normalization(); + + Ok(Self { + tokenizer, + sentence_bert_config, + tokenizer_truncation_strategy: TruncationStrategy::LongestFirst, + var_store, + transformer, + transformer_config, + pooling_layer, + dense_layer, + normalize_embeddings, + }) + } + + /// Sets the tokenizer's truncation strategy + pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) { + self.tokenizer_truncation_strategy = truncation_strategy; + } + + /// Tokenizes the inputs + pub fn tokenize(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOuput + where + S: AsRef + Sync, + { + let tokenized_input = self.tokenizer.encode_list( + inputs, + self.sentence_bert_config.max_seq_length, + &self.tokenizer_truncation_strategy, + 0, + ); + + let max_len = tokenized_input + .iter() + .map(|input| input.token_ids.len()) + .max() + .unwrap_or(0); + + let pad_token_id = self.tokenizer.get_pad_id().unwrap_or(0); + let tokens_ids = tokenized_input + .into_iter() + .map(|input| { + let mut token_ids = input.token_ids; + token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]); + token_ids + }) + .collect::>(); + + let tokens_masks = tokens_ids + .iter() + .map(|input| { + Tensor::of_slice( + &input + .iter() + .map(|&e| if e == pad_token_id { 0_i64 } else { 1_i64 }) + .collect::>(), + ) + }) + .collect::>(); + + let tokens_ids = tokens_ids + .into_iter() + .map(|input| Tensor::of_slice(&(input))) + .collect::>(); + + SentenceEmbeddingsTokenizerOuput { + tokens_ids, + tokens_masks, + } + } + + /// Computes sentence embeddings, outputs `Tensor`. + pub fn encode_as_tensor( + &self, + inputs: &[S], + ) -> Result + where + S: AsRef + Sync, + { + let SentenceEmbeddingsTokenizerOuput { + tokens_ids, + tokens_masks, + } = self.tokenize(inputs); + let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device()); + let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device()); + + let (tokens_embeddings, all_attentions) = + tch::no_grad(|| self.transformer.forward(&tokens_ids, &tokens_masks))?; + + let mean_pool = + tch::no_grad(|| self.pooling_layer.forward(tokens_embeddings, &tokens_masks)); + let maybe_linear = if let Some(dense_layer) = &self.dense_layer { + tch::no_grad(|| dense_layer.forward(&mean_pool)) + } else { + mean_pool + }; + let maybe_normalized = if self.normalize_embeddings { + let norm = &maybe_linear + .norm_scalaropt_dim(2, &[1], true) + .clamp_min(1e-12) + .expand_as(&maybe_linear); + maybe_linear / norm + } else { + maybe_linear + }; + + Ok(SentenceEmbeddingsModelOuput { + embeddings: maybe_normalized, + all_attentions, + }) + } + + /// Computes sentence embeddings. + pub fn encode(&self, inputs: &[S]) -> Result, RustBertError> + where + S: AsRef + Sync, + { + let SentenceEmbeddingsModelOuput { embeddings, .. } = self.encode_as_tensor(inputs)?; + Ok(Vec::from(embeddings)) + } + + fn nb_layers(&self) -> usize { + use SentenceEmbeddingsOption::*; + match (&self.transformer, &self.transformer_config) { + (Bert(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize, + (Bert(_), _) => unreachable!(), + (DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_layers as usize, + (DistilBert(_), _) => unreachable!(), + (Roberta(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize, + (Roberta(_), _) => unreachable!(), + (Albert(_), ConfigOption::Albert(conf)) => conf.num_hidden_layers as usize, + (Albert(_), _) => unreachable!(), + (T5(_), ConfigOption::T5(conf)) => conf.num_layers as usize, + (T5(_), _) => unreachable!(), + } + } + + fn nb_heads(&self) -> usize { + use SentenceEmbeddingsOption::*; + match (&self.transformer, &self.transformer_config) { + (Bert(_), ConfigOption::Bert(conf)) => conf.num_attention_heads as usize, + (Bert(_), _) => unreachable!(), + (DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_heads as usize, + (DistilBert(_), _) => unreachable!(), + (Roberta(_), ConfigOption::Roberta(conf)) => conf.num_attention_heads as usize, + (Roberta(_), _) => unreachable!(), + (Albert(_), ConfigOption::Albert(conf)) => conf.num_attention_heads as usize, + (Albert(_), _) => unreachable!(), + (T5(_), ConfigOption::T5(conf)) => conf.num_heads as usize, + (T5(_), _) => unreachable!(), + } + } + + /// Computes sentence embeddings, also outputs `AttentionOutput`s. + pub fn encode_with_attention( + &self, + inputs: &[S], + ) -> Result<(Vec, Vec), RustBertError> + where + S: AsRef + Sync, + { + let SentenceEmbeddingsModelOuput { + embeddings, + all_attentions, + } = self.encode_as_tensor(inputs)?; + + let embeddings = Vec::from(embeddings); + let all_attentions = all_attentions.ok_or_else(|| { + RustBertError::InvalidConfigurationError("No attention outputted".into()) + })?; + + let attention_outputs = (0..inputs.len() as i64) + .map(|i| { + let mut attention_output = AttentionOutput::with_capacity(self.nb_layers()); + for layer in all_attentions.iter() { + let mut attention_layer = AttentionLayer::with_capacity(self.nb_heads()); + for head in 0..self.nb_heads() { + let attention_slice = layer + .slice(0, i, i + 1, 1) + .slice(1, head as i64, head as i64 + 1, 1) + .squeeze(); + let attention_head = AttentionHead::from(attention_slice); + attention_layer.push(attention_head); + } + attention_output.push(attention_layer); + } + attention_output + }) + .collect::>(); + + Ok((embeddings, attention_outputs)) + } +} + +/// Container for the SentenceEmbeddings tokenizer output. +pub struct SentenceEmbeddingsTokenizerOuput { + pub tokens_ids: Vec, + pub tokens_masks: Vec, +} + +/// Container for the SentenceEmbeddings model output. +pub struct SentenceEmbeddingsModelOuput { + pub embeddings: Tensor, + pub all_attentions: Option>, +} diff --git a/src/pipelines/sentence_embeddings/resources.rs b/src/pipelines/sentence_embeddings/resources.rs new file mode 100644 index 0000000..d284b20 --- /dev/null +++ b/src/pipelines/sentence_embeddings/resources.rs @@ -0,0 +1,184 @@ +/// # Pretrained config files for sentence embeddings +pub struct SentenceEmbeddingsModulesConfigResources; + +/// # Pretrained dense weights files for sentence embeddings +pub struct SentenceEmbeddingsDenseResources; + +/// # Pretrained dense config files for sentence embeddings +pub struct SentenceEmbeddingsDenseConfigResources; + +/// # Pretrained pooling config files for sentence embeddings +pub struct SentenceEmbeddingsPoolingConfigResources; + +/// # Pretrained config files for sentence embeddings +pub struct SentenceEmbeddingsConfigResources; + +/// # Pretrained tokenizer config files for sentence embeddings +pub struct SentenceEmbeddingsTokenizerConfigResources; + +pub enum SentenceEmbeddingsModelType { + DistiluseBaseMultilingualCased, + BertBaseNliMeanTokens, + AllMiniLmL12V2, + AllDistilrobertaV1, + ParaphraseAlbertSmallV2, + SentenceT5Base, +} + +impl SentenceEmbeddingsModulesConfigResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/sbert-config", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/modules.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/sbert-config", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/modules.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/sbert-config", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/modules.json", + ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/sbert-config", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/modules.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/sbert-config", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/modules.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/sbert-config", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/modules.json", + ); +} + +impl SentenceEmbeddingsDenseResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/sbert-dense", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/2_Dense/rust_model.ot", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/sbert-dense", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/2_Dense/rust_model.ot", + ); +} + +impl SentenceEmbeddingsDenseConfigResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/sbert-dense-config", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/2_Dense/config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/sbert-dense-config", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/2_Dense/config.json", + ); +} + +impl SentenceEmbeddingsPoolingConfigResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/sbert-pooling-config", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/1_Pooling/config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/sbert-pooling-config", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/1_Pooling/config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/sbert-pooling-config", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/1_Pooling/config.json", + ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/sbert-pooling-config", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/1_Pooling/config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/sbert-pooling-config", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/1_Pooling/config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/sbert-pooling-config", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/1_Pooling/config.json", + ); +} + +impl SentenceEmbeddingsConfigResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/sbert-config", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/sentence_bert_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/sbert-config", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/sentence_bert_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/sbert-config", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/sentence_bert_config.json", + ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/sbert-config", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/sentence_bert_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/sbert-config", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/sentence_bert_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/sbert-config", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/sentence_bert_config.json", + ); +} + +impl SentenceEmbeddingsTokenizerConfigResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = ( + "distiluse-base-multilingual-cased/tokenizer-config", + "https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/tokenizer_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = ( + "bert-base-nli-mean-tokens/tokenizer-config", + "https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/tokenizer_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = ( + "all-mini-lm-l12-v2/tokenizer-config", + "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer_config.json", + ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/tokenizer-config", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/tokenizer_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = ( + "paraphrase-albert-small-v2/tokenizer-config", + "https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/tokenizer_config.json", + ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/tokenizer-config", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/tokenizer_config.json", + ); +} diff --git a/src/roberta/mod.rs b/src/roberta/mod.rs index ad86149..1471f33 100644 --- a/src/roberta/mod.rs +++ b/src/roberta/mod.rs @@ -67,6 +67,7 @@ mod roberta_model; pub use embeddings::RobertaEmbeddings; pub use roberta_model::{ RobertaConfig, RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice, - RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification, - RobertaMergesResources, RobertaModelResources, RobertaVocabResources, + RobertaForQuestionAnswering, RobertaForSentenceEmbeddings, RobertaForSequenceClassification, + RobertaForTokenClassification, RobertaMergesResources, RobertaModelResources, + RobertaVocabResources, }; diff --git a/src/roberta/roberta_model.rs b/src/roberta/roberta_model.rs index 03ab68e..b384665 100644 --- a/src/roberta/roberta_model.rs +++ b/src/roberta/roberta_model.rs @@ -68,6 +68,11 @@ impl RobertaModelResources { "xlm-roberta-ner-es/model", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/model", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot", + ); } impl RobertaConfigResources { @@ -106,6 +111,11 @@ impl RobertaConfigResources { "xlm-roberta-ner-es/config", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json", ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/config", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json", + ); } impl RobertaVocabResources { @@ -144,6 +154,11 @@ impl RobertaVocabResources { "xlm-roberta-ner-es/spiece", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model", ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/vocab", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json", + ); } impl RobertaMergesResources { @@ -162,6 +177,11 @@ impl RobertaMergesResources { "roberta-qa/merges", "https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt", ); + /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( + "all-distilroberta-v1/merges", + "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt", + ); } pub struct RobertaLMHead { @@ -972,6 +992,10 @@ impl RobertaForQuestionAnswering { } } +/// # RoBERTa for sentence embeddings +/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel). +pub type RobertaForSentenceEmbeddings = BertModel; + /// Container for the RoBERTa masked LM model output. pub struct RobertaMaskedLMOutput { /// Logits for the vocabulary items at each sequence position diff --git a/src/t5/mod.rs b/src/t5/mod.rs index 6e8c521..3c47908 100644 --- a/src/t5/mod.rs +++ b/src/t5/mod.rs @@ -41,7 +41,7 @@ //! let mut vs = nn::VarStore::new(device); //! let tokenizer = T5Tokenizer::from_file(spiece_path.to_str().unwrap(), true); //! let config = T5Config::from_file(config_path); -//! let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config, false, false); +//! let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config); //! vs.load(weights_path)?; //! //! # Ok(()) @@ -55,6 +55,7 @@ mod t5_model; pub use attention::LayerState; pub use t5_model::{ - T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Generator, T5Model, T5ModelOutput, - T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, T5VocabResources, + T5Config, T5ConfigResources, T5ForConditionalGeneration, T5ForSentenceEmbeddings, T5Generator, + T5Model, T5ModelOutput, T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, + T5VocabResources, }; diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 1d53b3d..4adf843 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -59,6 +59,11 @@ impl T5ModelResources { "t5-base/model", "https://huggingface.co/t5-base/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/model", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot", + ); } impl T5ConfigResources { @@ -72,6 +77,11 @@ impl T5ConfigResources { "t5-base/config", "https://huggingface.co/t5-base/resolve/main/config.json", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/config", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json", + ); } impl T5VocabResources { @@ -85,6 +95,11 @@ impl T5VocabResources { "t5-base/spiece", "https://huggingface.co/t5-base/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( + "sentence-t5-base/spiece", + "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model", + ); } const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German]; @@ -133,6 +148,8 @@ pub struct T5Config { pub feed_forward_proj: Option, pub tie_word_embeddings: Option, task_specific_params: Option, + pub output_attentions: Option, + pub output_hidden_states: Option, } /// # T5 task-specific configurations @@ -209,6 +226,8 @@ impl Default for T5Config { feed_forward_proj: Some(FeedForwardProj::Relu), tie_word_embeddings: None, task_specific_params: None, + output_attentions: None, + output_hidden_states: None, } } } @@ -233,8 +252,6 @@ impl T5Model { /// /// * `p` - Variable store path for the root of the BART model /// * `config` - `T5Config` object defining the model architecture - /// * `output_attention` - flag indicating if the model should output the attention weights of intermediate layers - /// * `output_hidden_states` - flag indicating if the model should output the hidden states weights of intermediate layers /// /// # Example /// @@ -248,21 +265,9 @@ impl T5Model { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = T5Config::from_file(config_path); - /// let output_attentions = true; - /// let output_hidden_states = true; - /// let t5: T5Model = T5Model::new( - /// &p.root() / "t5", - /// &config, - /// output_attentions, - /// output_hidden_states, - /// ); + /// let t5: T5Model = T5Model::new(&p.root() / "t5", &config); /// ``` - pub fn new<'p, P>( - p: P, - config: &T5Config, - output_attentions: bool, - output_hidden_states: bool, - ) -> T5Model + pub fn new<'p, P>(p: P, config: &T5Config) -> T5Model where P: Borrow>, { @@ -280,16 +285,16 @@ impl T5Model { config, false, false, - output_attentions, - output_hidden_states, + config.output_attentions.unwrap_or(false), + config.output_hidden_states.unwrap_or(false), ); let decoder = T5Stack::new( p / "decoder", config, true, true, - output_attentions, - output_hidden_states, + config.output_attentions.unwrap_or(false), + config.output_hidden_states.unwrap_or(false), ); T5Model { @@ -338,7 +343,7 @@ impl T5Model { /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = T5Config::from_file(config_path); - /// # let t5_model: T5Model = T5Model::new(&vs.root(), &config, false, false); + /// # let t5_model: T5Model = T5Model::new(&vs.root(), &config); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -450,8 +455,6 @@ impl T5ForConditionalGeneration { /// /// * `p` - Variable store path for the root of the BART model /// * `config` - `T5Config` object defining the model architecture - /// * `output_attention` - flag indicating if the model should output the attention weights of intermediate layers - /// * `output_hidden_states` - flag indicating if the model should output the hidden states weights of intermediate layers /// /// # Example /// @@ -465,27 +468,15 @@ impl T5ForConditionalGeneration { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = T5Config::from_file(config_path); - /// let output_attentions = true; - /// let output_hidden_states = true; - /// let t5 = T5ForConditionalGeneration::new( - /// &p.root() / "t5", - /// &config, - /// output_attentions, - /// output_hidden_states, - /// ); + /// let t5 = T5ForConditionalGeneration::new(&p.root() / "t5", &config); /// ``` - pub fn new<'p, P>( - p: P, - config: &T5Config, - output_attentions: bool, - output_hidden_states: bool, - ) -> T5ForConditionalGeneration + pub fn new<'p, P>(p: P, config: &T5Config) -> T5ForConditionalGeneration where P: Borrow>, { let p = p.borrow(); - let base_model = T5Model::new(p, config, output_attentions, output_hidden_states); + let base_model = T5Model::new(p, config); let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true); let lm_head = if !tie_word_embeddings { @@ -549,7 +540,7 @@ impl T5ForConditionalGeneration { /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = T5Config::from_file(config_path); - /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config, false, false); + /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -666,7 +657,7 @@ impl LMHeadModel for T5ForConditionalGeneration { /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = T5Config::from_file(config_path); - /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config, false, false); + /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -749,6 +740,84 @@ impl LMHeadModel for T5ForConditionalGeneration { } } +/// # T5 for sentence embeddings +/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel). +pub struct T5ForSentenceEmbeddings { + embeddings: nn::Embedding, + encoder: T5Stack, +} + +impl T5ForSentenceEmbeddings { + /// Build a new `T5ForSentenceEmbeddings` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the BART model + /// * `config` - `T5Config` object defining the model architecture + /// + /// It consists of only an encoder (there is no decoder). + pub fn new<'p, P>(p: P, config: &T5Config) -> Self + where + P: Borrow>, + { + let p = p.borrow(); + + let embeddings: nn::Embedding = embedding( + p / "shared", + config.vocab_size, + config.d_model, + Default::default(), + ); + + let encoder = T5Stack::new( + p / "encoder", + config, + false, + false, + config.output_attentions.unwrap_or(false), + config.output_hidden_states.unwrap_or(false), + ); + + Self { + embeddings, + encoder, + } + } + + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Input of shape (*batch size*, *source_sequence_length*). + /// * `mask` - Attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. + /// + /// # Returns + /// + /// * Tuple containing: + /// - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state + /// - `Option>` of length *num_encoder_layers* of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing attention weights for all layers of the encoder + pub fn forward( + &self, + input_ids: &Tensor, + mask: &Tensor, + ) -> Result<(Tensor, Option>), RustBertError> { + let transformer_output = self.encoder.forward_t( + Some(input_ids), + Some(mask), + None, + None, + None, + &self.embeddings, + None, + false, + )?; + Ok(( + transformer_output.hidden_state, + transformer_output.all_attentions, + )) + } +} + /// Container holding a T5 model output. The decoder output may hold the hidden state of /// the last layer of the decoder, or may hold logits for a custom head module after the /// decoder (e.g. for language modeling tasks) @@ -812,7 +881,7 @@ impl T5Generator { let mut var_store = nn::VarStore::new(device); let config = T5Config::from_file(config_path); - let model = T5ForConditionalGeneration::new(&var_store.root(), &config, false, false); + let model = T5ForConditionalGeneration::new(&var_store.root(), &config); var_store.load(weights_path)?; let bos_token_id = Some(config.bos_token_id.unwrap_or(-1)); diff --git a/tests/sentence_embeddings.rs b/tests/sentence_embeddings.rs new file mode 100644 index 0000000..b333047 --- /dev/null +++ b/tests/sentence_embeddings.rs @@ -0,0 +1,157 @@ +use rust_bert::pipelines::sentence_embeddings::{ + SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType, +}; + +#[test] +fn sbert_distilbert() -> anyhow::Result<()> { + let model = SentenceEmbeddingsBuilder::remote( + SentenceEmbeddingsModelType::DistiluseBaseMultilingualCased, + ) + .create_model()?; + + let sentences = ["This is an example sentence", "Each sentence is converted"]; + let embeddings = model.encode(&sentences)?; + + assert!((embeddings[0][0] as f64 - -0.03479306).abs() < 1e-4); + assert!((embeddings[0][1] as f64 - 0.02635195).abs() < 1e-4); + assert!((embeddings[0][2] as f64 - -0.04427199).abs() < 1e-4); + assert!((embeddings[0][509] as f64 - 0.01743882).abs() < 1e-4); + assert!((embeddings[0][510] as f64 - -0.01952395).abs() < 1e-4); + assert!((embeddings[0][511] as f64 - -0.00118101).abs() < 1e-4); + + assert!((embeddings[1][0] as f64 - 0.02096637).abs() < 1e-4); + assert!((embeddings[1][1] as f64 - -0.00401743).abs() < 1e-4); + assert!((embeddings[1][2] as f64 - -0.05093712).abs() < 1e-4); + assert!((embeddings[1][509] as f64 - 0.03618195).abs() < 1e-4); + assert!((embeddings[1][510] as f64 - 0.0294408).abs() < 1e-4); + assert!((embeddings[1][511] as f64 - -0.04497765).abs() < 1e-4); + + Ok(()) +} + +#[test] +fn sbert_bert() -> anyhow::Result<()> { + let model = + SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::BertBaseNliMeanTokens) + .create_model()?; + + let sentences = ["this is an example sentence", "each sentence is converted"]; + let embeddings = model.encode(&sentences)?; + + assert!((embeddings[0][0] as f64 - -0.393099815).abs() < 1e-4); + assert!((embeddings[0][1] as f64 - 0.0388629436).abs() < 1e-4); + assert!((embeddings[0][2] as f64 - 1.98742473).abs() < 1e-4); + assert!((embeddings[0][765] as f64 - -0.609367728).abs() < 1e-4); + assert!((embeddings[0][766] as f64 - -1.09462142).abs() < 1e-4); + assert!((embeddings[0][767] as f64 - 0.326490253).abs() < 1e-4); + + assert!((embeddings[1][0] as f64 - 0.0615336187).abs() < 1e-4); + assert!((embeddings[1][1] as f64 - 0.32736221).abs() < 1e-4); + assert!((embeddings[1][2] as f64 - 1.8332324).abs() < 1e-4); + assert!((embeddings[1][765] as f64 - -0.129853949).abs() < 1e-4); + assert!((embeddings[1][766] as f64 - 0.460893631).abs() < 1e-4); + assert!((embeddings[1][767] as f64 - 0.240354523).abs() < 1e-4); + + Ok(()) +} + +#[test] +fn sbert_bert_small() -> anyhow::Result<()> { + let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL12V2) + .create_model()?; + + let sentences = ["this is an example sentence", "each sentence is converted"]; + let embeddings = model.encode(&sentences)?; + + assert!((embeddings[0][0] as f64 - -2.02682902e-04).abs() < 1e-4); + assert!((embeddings[0][1] as f64 - 8.14802647e-02).abs() < 1e-4); + assert!((embeddings[0][2] as f64 - 3.13617811e-02).abs() < 1e-4); + assert!((embeddings[0][381] as f64 - 6.20930083e-02).abs() < 1e-4); + assert!((embeddings[0][382] as f64 - 4.91031967e-02).abs() < 1e-4); + assert!((embeddings[0][383] as f64 - -2.90199649e-04).abs() < 1e-4); + + assert!((embeddings[1][0] as f64 - 6.47571534e-02).abs() < 1e-4); + assert!((embeddings[1][1] as f64 - 4.85198125e-02).abs() < 1e-4); + assert!((embeddings[1][2] as f64 - -1.78603437e-02).abs() < 1e-4); + assert!((embeddings[1][381] as f64 - 3.37569155e-02).abs() < 1e-4); + assert!((embeddings[1][382] as f64 - 8.43371451e-03).abs() < 1e-4); + assert!((embeddings[1][383] as f64 - -6.00359812e-02).abs() < 1e-4); + + Ok(()) +} + +#[test] +fn sbert_distilroberta() -> anyhow::Result<()> { + let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllDistilrobertaV1) + .create_model()?; + + let sentences = ["This is an example sentence", "Each sentence is converted"]; + let embeddings = model.encode(&sentences)?; + + assert!((embeddings[0][0] as f64 - -0.03375624).abs() < 1e-4); + assert!((embeddings[0][1] as f64 - -0.06316338).abs() < 1e-4); + assert!((embeddings[0][2] as f64 - -0.0316612).abs() < 1e-4); + assert!((embeddings[0][765] as f64 - 0.03684864).abs() < 1e-4); + assert!((embeddings[0][766] as f64 - -0.02036646).abs() < 1e-4); + assert!((embeddings[0][767] as f64 - -0.01574).abs() < 1e-4); + + assert!((embeddings[1][0] as f64 - -0.01409588).abs() < 1e-4); + assert!((embeddings[1][1] as f64 - 0.00091114).abs() < 1e-4); + assert!((embeddings[1][2] as f64 - -0.00096315).abs() < 1e-4); + assert!((embeddings[1][765] as f64 - -0.02571585).abs() < 1e-4); + assert!((embeddings[1][766] as f64 - -0.00289072).abs() < 1e-4); + assert!((embeddings[1][767] as f64 - -0.00579975).abs() < 1e-4); + + Ok(()) +} + +#[test] +fn sbert_albert() -> anyhow::Result<()> { + let model = + SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2) + .create_model()?; + + let sentences = ["this is an example sentence", "each sentence is converted"]; + let embeddings = model.encode(&sentences)?; + + assert!((embeddings[0][0] as f64 - 0.20412037).abs() < 1e-4); + assert!((embeddings[0][1] as f64 - 0.48823047).abs() < 1e-4); + assert!((embeddings[0][2] as f64 - 0.5664698).abs() < 1e-4); + assert!((embeddings[0][765] as f64 - -0.37474486).abs() < 1e-4); + assert!((embeddings[0][766] as f64 - 0.0254627).abs() < 1e-4); + assert!((embeddings[0][767] as f64 - -0.6846024).abs() < 1e-4); + + assert!((embeddings[1][0] as f64 - 0.25720373).abs() < 1e-4); + assert!((embeddings[1][1] as f64 - 0.24648172).abs() < 1e-4); + assert!((embeddings[1][2] as f64 - -0.2521183).abs() < 1e-4); + assert!((embeddings[1][765] as f64 - 0.4667896).abs() < 1e-4); + assert!((embeddings[1][766] as f64 - 0.14219822).abs() < 1e-4); + assert!((embeddings[1][767] as f64 - 0.3986863).abs() < 1e-4); + + Ok(()) +} + +#[test] +fn sbert_t5() -> anyhow::Result<()> { + let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::SentenceT5Base) + .create_model()?; + + let sentences = ["This is an example sentence", "Each sentence is converted"]; + let embeddings = model.encode(&sentences)?; + + assert!((embeddings[0][0] as f64 - -0.00904849).abs() < 1e-4); + assert!((embeddings[0][1] as f64 - 0.0191336).abs() < 1e-4); + assert!((embeddings[0][2] as f64 - 0.02657794).abs() < 1e-4); + assert!((embeddings[0][765] as f64 - -0.00876413).abs() < 1e-4); + assert!((embeddings[0][766] as f64 - -0.05602207).abs() < 1e-4); + assert!((embeddings[0][767] as f64 - -0.02163094).abs() < 1e-4); + + assert!((embeddings[1][0] as f64 - -0.00785422).abs() < 1e-4); + assert!((embeddings[1][1] as f64 - 0.03018173).abs() < 1e-4); + assert!((embeddings[1][2] as f64 - 0.03129675).abs() < 1e-4); + assert!((embeddings[1][765] as f64 - -0.01246878).abs() < 1e-4); + assert!((embeddings[1][766] as f64 - -0.06240674).abs() < 1e-4); + assert!((embeddings[1][767] as f64 - -0.00590969).abs() < 1e-4); + + Ok(()) +} diff --git a/utils/convert_model.py b/utils/convert_model.py index fa490cd..dedc9f2 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -11,6 +11,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert") parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head") + parser.add_argument("--prefix", help="Add a prefix on weight names") + parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part") args = parser.parse_args() source_file = Path(args.source_file) @@ -24,6 +26,10 @@ if __name__ == "__main__": if args.skip_embeddings: if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}: continue + if args.prefix: + k = args.prefix + k + if args.suffix: + k = k.split('.')[-1] if isinstance(v, Tensor): nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32)) print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')