Add sbert implementation for inference (#250)

* Add sbert implementation for inference

* Fix clippy warnings

* Refactor sentence embeddings into a dedicated pipeline

* Add output_attentions and output_hidden_states to T5Config

* Add sbert implementation for inference

* Fix clippy warnings

* Refactor sentence embeddings into a dedicated pipeline

* Add output_attentions and output_hidden_states to T5Config

* Improve sentence_embeddings implementation

* Dedicated tokenizer config for strip_accents and add_prefix_space

* Rename forward to encode_as_tensor

* Remove _conf from Dense layer

* Add sentence embeddings docs

* Addition of remote resources and tests update

* Merge feature branch and fix doctests

* Add SentenceEmbeddingsBuilder<Remote> and improve remote resources

* Use tch::no_grad in sentence embeddings

* Updated changelog, registration of sentence embeddings integration tests

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
Romain Leroux 2022-06-21 21:24:09 +02:00 committed by GitHub
parent 6b20da41de
commit 4d8a298586
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 2007 additions and 61 deletions

View File

@ -100,6 +100,22 @@ jobs:
--test pegasus
--test gpt_neo
test-batch-2:
name: Integration tests (batch 2)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --package rust-bert
--test sentence_embeddings
convert-model:
name: Model conversion test
runs-on: ubuntu-latest

View File

@ -2,6 +2,8 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## Added
- Support for sentence embeddings models and pipelines, based on [SentenceTransformers](https://www.sbert.net).
## [0.18.0] - 2022-05-29
## Added

View File

@ -0,0 +1,17 @@
use rust_bert::pipelines::sentence_embeddings::{
SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType,
};
fn main() -> anyhow::Result<()> {
// Set-up sentence embeddings model
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL12V2)
.create_model()?;
// Define input
let sentences = ["this is an example sentence", "each sentence is converted"];
// Generate Embeddings
let embeddings = model.encode(&sentences)?;
println!("{:?}", embeddings);
Ok(())
}

View File

@ -0,0 +1,25 @@
use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
/// Download model:
/// ```sh
/// git lfs install
/// git -C resources clone https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2
/// ```
/// Prepare model:
/// ```sh
/// python ./utils/convert_model.py resources/all-MiniLM-L12-v2/pytorch_model.bin
/// ```
fn main() -> anyhow::Result<()> {
// Set-up sentence embeddings model
let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2")
.with_device(tch::Device::cuda_if_available())
.create_model()?;
// Define input
let sentences = ["this is an example sentence", "each sentence is converted"];
// Generate Embeddings
let embeddings = model.encode(&sentences)?;
println!("{:?}", embeddings);
Ok(())
}

View File

@ -37,6 +37,11 @@ impl AlbertModelResources {
"albert-base-v2/model",
"https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/model",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/rust_model.ot",
);
}
impl AlbertConfigResources {
@ -45,6 +50,11 @@ impl AlbertConfigResources {
"albert-base-v2/config",
"https://huggingface.co/albert-base-v2/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/config",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/config.json",
);
}
impl AlbertVocabResources {
@ -53,6 +63,11 @@ impl AlbertVocabResources {
"albert-base-v2/spiece",
"https://huggingface.co/albert-base-v2/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/spiece",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/spiece.model",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
@ -1048,6 +1063,10 @@ impl AlbertForMultipleChoice {
}
}
/// # ALBERT for sentence embeddings
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
pub type AlbertForSentenceEmbeddings = AlbertModel;
/// Container for the ALBERT model output.
pub struct AlbertOutput {
/// Last hidden states from the model

View File

@ -59,8 +59,8 @@ mod encoder;
pub use albert_model::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertMaskedLMOutput, AlbertModel, AlbertModelResources, AlbertOutput,
AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput,
AlbertForQuestionAnswering, AlbertForSentenceEmbeddings, AlbertForSequenceClassification,
AlbertForTokenClassification, AlbertMaskedLMOutput, AlbertModel, AlbertModelResources,
AlbertOutput, AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput,
AlbertTokenClassificationOutput, AlbertVocabResources,
};

View File

@ -52,6 +52,16 @@ impl BertModelResources {
"bert-qa/model",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/model",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/model",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot",
);
}
impl BertConfigResources {
@ -70,6 +80,16 @@ impl BertConfigResources {
"bert-qa/config",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json",
);
}
impl BertVocabResources {
@ -88,6 +108,16 @@ impl BertVocabResources {
"bert-qa/vocab",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/vocab",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/vocab",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
@ -1179,6 +1209,10 @@ impl BertForQuestionAnswering {
}
}
/// # BERT for sentence embeddings
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
pub type BertForSentenceEmbeddings = BertModel<BertEmbeddings>;
/// Container for the BERT model output.
pub struct BertModelOutput {
/// Last hidden states from the model

View File

@ -59,8 +59,8 @@ pub(crate) mod encoder;
pub use bert_model::{
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources,
BertForQuestionAnswering, BertForSentenceEmbeddings, BertForSequenceClassification,
BertForTokenClassification, BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources,
BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput,
BertVocabResources,
};

View File

@ -26,6 +26,10 @@ pub fn _tanh(x: &Tensor) -> Tensor {
x.tanh()
}
pub fn _identity(x: &Tensor) -> Tensor {
x.shallow_clone()
}
pub struct TensorFunction(Box<fn(&Tensor) -> Tensor>);
impl TensorFunction {
@ -58,6 +62,8 @@ pub enum Activation {
gelu_new,
/// Tanh
tanh,
/// Identity
identity,
}
impl Activation {
@ -69,6 +75,7 @@ impl Activation {
Activation::gelu_new => _gelu_new,
Activation::mish => _mish,
Activation::tanh => _tanh,
Activation::identity => _identity,
}))
}
}

View File

@ -30,3 +30,15 @@ impl ResourceProvider for LocalResource {
Ok(self.local_path.clone())
}
}
impl From<PathBuf> for LocalResource {
fn from(local_path: PathBuf) -> Self {
Self { local_path }
}
}
impl From<PathBuf> for Box<dyn ResourceProvider + Send> {
fn from(local_path: PathBuf) -> Self {
Box::new(LocalResource { local_path })
}
}

View File

@ -46,6 +46,11 @@ impl DistilBertModelResources {
"distilbert-qa/model",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/model",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/rust_model.ot",
);
}
impl DistilBertConfigResources {
@ -64,6 +69,11 @@ impl DistilBertConfigResources {
"distilbert-qa/config",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/config.json",
);
}
impl DistilBertVocabResources {
@ -82,6 +92,11 @@ impl DistilBertVocabResources {
"distilbert-qa/vocab",
"https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/vocab",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
@ -755,6 +770,10 @@ impl DistilBertForTokenClassification {
}
}
/// # DistilBERT for sentence embeddings
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
pub type DistilBertForSentenceEmbeddings = DistilBertModel;
/// Container for the DistilBERT masked LM model output.
pub struct DistilBertMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position

View File

@ -60,8 +60,8 @@ mod transformer;
pub use distilbert_model::{
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertMaskedLMOutput, DistilBertModel,
DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertForSentenceEmbeddings, DistilBertForTokenClassification, DistilBertMaskedLMOutput,
DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertQuestionAnsweringOutput, DistilBertSequenceClassificationOutput,
DistilBertTokenClassificationOutput, DistilBertVocabResources,
};

View File

@ -146,7 +146,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>2. Translation </b> </summary>
//!
@ -198,7 +198,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>3. Summarization </b> </summary>
//!
@ -249,7 +249,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>4. Dialogue Model </b> </summary>
//!
@ -281,7 +281,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>5. Natural Language Generation </b> </summary>
//!
@ -319,7 +319,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>6. Zero-shot classification </b> </summary>
//!
@ -402,7 +402,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>7. Sentiment analysis </b> </summary>
//!
@ -445,7 +445,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>8. Named Entity Recognition </b> </summary>
//!
@ -502,7 +502,7 @@
//! ```
//!
//! </details>
//! &nbsp;
//! &nbsp;
//! <details>
//! <summary> <b>9. Part of Speech tagging </b> </summary>
//!

View File

@ -54,22 +54,28 @@ use rust_tokenizers::vocab::{
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::Path;
#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq)]
/// # Identifies the type of model
pub enum ModelType {
Bart,
#[serde(alias = "bert")]
Bert,
#[serde(alias = "distilbert")]
DistilBert,
Deberta,
DebertaV2,
#[serde(alias = "roberta")]
Roberta,
XLMRoberta,
Electra,
Marian,
MobileBert,
#[serde(alias = "t5")]
T5,
#[serde(alias = "albert")]
Albert,
XLNet,
GPT2,
@ -309,6 +315,62 @@ impl ConfigOption {
}
}
impl TryFrom<&ConfigOption> for BertConfig {
type Error = RustBertError;
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
match config {
ConfigOption::Bert(config) | ConfigOption::Roberta(config) => Ok(config.clone()),
_ => Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Bert or a RobertaConfig for Roberta!"
.to_string(),
)),
}
}
}
impl TryFrom<&ConfigOption> for DistilBertConfig {
type Error = RustBertError;
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
if let ConfigOption::DistilBert(config) = config {
Ok(config.clone())
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DistilBertConfig for DistilBert!".to_string(),
))
}
}
}
impl TryFrom<&ConfigOption> for AlbertConfig {
type Error = RustBertError;
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
if let ConfigOption::Albert(config) = config {
Ok(config.clone())
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an AlbertConfig for Albert!".to_string(),
))
}
}
}
impl TryFrom<&ConfigOption> for T5Config {
type Error = RustBertError;
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
if let ConfigOption::T5(config) = config {
Ok(config.clone())
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a T5Config for T5!".to_string(),
))
}
}
}
impl TokenizerOption {
/// Interface method to load a tokenizer from file
pub fn from_file(

View File

@ -416,6 +416,7 @@ pub mod generation_utils;
pub mod ner;
pub mod pos_tagging;
pub mod question_answering;
pub mod sentence_embeddings;
pub mod sentiment;
pub mod sequence_classification;
pub mod summarization;

View File

@ -0,0 +1,185 @@
use std::path::PathBuf;
use serde::Deserialize;
use tch::Device;
use crate::pipelines::common::ModelType;
use crate::pipelines::sentence_embeddings::{
SentenceEmbeddingsConfig, SentenceEmbeddingsModel, SentenceEmbeddingsModulesConfig,
};
use crate::{Config, RustBertError};
#[cfg(feature = "remote")]
use crate::{
pipelines::sentence_embeddings::resources::SentenceEmbeddingsModelType,
resources::RemoteResource,
};
/// # SentenceEmbeddings Model Builder
///
/// Allows the user to build a model from standard Sentence-Transformer files
/// (configuration and weights).
pub struct SentenceEmbeddingsBuilder<T> {
device: Device,
inner: T,
}
impl<T> SentenceEmbeddingsBuilder<T> {
pub fn with_device(mut self, device: Device) -> Self {
self.device = device;
self
}
}
pub struct Local {
model_dir: PathBuf,
}
#[derive(Debug, Deserialize)]
struct ModelConfig {
model_type: ModelType,
}
impl Config for ModelConfig {}
impl SentenceEmbeddingsBuilder<Local> {
pub fn local<P: Into<PathBuf>>(model_dir: P) -> Self {
Self {
device: Device::cuda_if_available(),
inner: Local {
model_dir: model_dir.into(),
},
}
}
pub fn create_model(self) -> Result<SentenceEmbeddingsModel, RustBertError> {
let model_dir = self.inner.model_dir;
let modules_config = model_dir.join("modules.json");
let modules = SentenceEmbeddingsModulesConfig::from_file(&modules_config).validate()?;
let transformer_config = model_dir.join("config.json");
let transformer_type = ModelConfig::from_file(&transformer_config).model_type;
let transformer_weights = model_dir.join("rust_model.ot");
let pooling_config = model_dir
.join(&modules.pooling_module().path)
.join("config.json");
let (dense_config, dense_weights) = modules
.dense_module()
.map(|m| {
(
Some(model_dir.join(&m.path).join("config.json")),
Some(model_dir.join(&m.path).join("rust_model.ot")),
)
})
.unwrap_or((None, None));
let tokenizer_config = model_dir.join("tokenizer_config.json");
let sentence_bert_config = model_dir.join("sentence_bert_config.json");
let (tokenizer_vocab, tokenizer_merges) = match transformer_type {
ModelType::Bert | ModelType::DistilBert => (model_dir.join("vocab.txt"), None),
ModelType::Roberta => (
model_dir.join("vocab.json"),
Some(model_dir.join("merges.txt")),
),
ModelType::Albert => (model_dir.join("spiece.model"), None),
ModelType::T5 => (model_dir.join("spiece.model"), None),
_ => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Unsupported transformer model {:?} for Sentence Embeddings",
transformer_type
)));
}
};
let config = SentenceEmbeddingsConfig {
modules_config_resource: modules_config.into(),
transformer_type,
transformer_config_resource: transformer_config.into(),
transformer_weights_resource: transformer_weights.into(),
pooling_config_resource: pooling_config.into(),
dense_config_resource: dense_config.map(|r| r.into()),
dense_weights_resource: dense_weights.map(|r| r.into()),
sentence_bert_config_resource: sentence_bert_config.into(),
tokenizer_config_resource: tokenizer_config.into(),
tokenizer_vocab_resource: tokenizer_vocab.into(),
tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()),
device: self.device,
};
SentenceEmbeddingsModel::new(config)
}
}
#[cfg(feature = "remote")]
pub struct Remote {
config: SentenceEmbeddingsConfig,
}
#[cfg(feature = "remote")]
impl SentenceEmbeddingsBuilder<Remote> {
pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self {
Self {
device: Device::cuda_if_available(),
inner: Remote {
config: SentenceEmbeddingsConfig::from(model_type),
},
}
}
pub fn modules_config(mut self, resource: RemoteResource) -> Self {
self.inner.config.modules_config_resource = Box::new(resource);
self
}
pub fn transformer_config(mut self, resource: RemoteResource) -> Self {
self.inner.config.transformer_config_resource = Box::new(resource);
self
}
pub fn transformer_weights(mut self, resource: RemoteResource) -> Self {
self.inner.config.transformer_weights_resource = Box::new(resource);
self
}
pub fn pooling_config(mut self, resource: RemoteResource) -> Self {
self.inner.config.pooling_config_resource = Box::new(resource);
self
}
pub fn dense_config(mut self, resource: RemoteResource) -> Self {
self.inner.config.dense_config_resource = Some(Box::new(resource));
self
}
pub fn dense_weights(mut self, resource: RemoteResource) -> Self {
self.inner.config.dense_weights_resource = Some(Box::new(resource));
self
}
pub fn sentence_bert_config(mut self, resource: RemoteResource) -> Self {
self.inner.config.sentence_bert_config_resource = Box::new(resource);
self
}
pub fn tokenizer_config(mut self, resource: RemoteResource) -> Self {
self.inner.config.tokenizer_config_resource = Box::new(resource);
self
}
pub fn tokenizer_vocab(mut self, resource: RemoteResource) -> Self {
self.inner.config.tokenizer_vocab_resource = Box::new(resource);
self
}
pub fn tokenizer_merges(mut self, resource: RemoteResource) -> Self {
self.inner.config.tokenizer_merges_resource = Some(Box::new(resource));
self
}
pub fn create_model(self) -> Result<SentenceEmbeddingsModel, RustBertError> {
SentenceEmbeddingsModel::new(self.inner.config)
}
}

View File

@ -0,0 +1,429 @@
use serde::{Deserialize, Serialize};
use tch::Device;
use crate::pipelines::common::ModelType;
use crate::resources::ResourceProvider;
use crate::{Config, RustBertError};
#[cfg(feature = "remote")]
use crate::{
albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources},
bert::{BertConfigResources, BertModelResources, BertVocabResources},
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
pipelines::sentence_embeddings::resources::{
SentenceEmbeddingsConfigResources, SentenceEmbeddingsModelType,
SentenceEmbeddingsModulesConfigResources, SentenceEmbeddingsPoolingConfigResources,
SentenceEmbeddingsTokenizerConfigResources,
},
pipelines::sentence_embeddings::{
SentenceEmbeddingsDenseConfigResources, SentenceEmbeddingsDenseResources,
},
resources::RemoteResource,
roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
},
t5::{T5ConfigResources, T5ModelResources, T5VocabResources},
};
/// # Configuration for sentence embeddings
///
/// Contains information regarding the transformer model to load, the optional extra
/// layers, and device to place the model on.
pub struct SentenceEmbeddingsConfig {
/// Modules configuration resource, contains layers definition
pub modules_config_resource: Box<dyn ResourceProvider + Send>,
/// Transformer model type
pub transformer_type: ModelType,
/// Transformer model configuration resource
pub transformer_config_resource: Box<dyn ResourceProvider + Send>,
/// Transformer weights resource
pub transformer_weights_resource: Box<dyn ResourceProvider + Send>,
/// Pooling layer configuration resource
pub pooling_config_resource: Box<dyn ResourceProvider + Send>,
/// Optional dense layer configuration resource
pub dense_config_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Optional dense layer weights resource
pub dense_weights_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Sentence BERT specific configuration resource
pub sentence_bert_config_resource: Box<dyn ResourceProvider + Send>,
/// Transformer's tokenizer configuration resource
pub tokenizer_config_resource: Box<dyn ResourceProvider + Send>,
/// Transformer's tokenizer vocab resource
pub tokenizer_vocab_resource: Box<dyn ResourceProvider + Send>,
/// Optional transformer's tokenizer merges resource
pub tokenizer_merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Device to place the transformer model on
pub device: Device,
}
#[cfg(feature = "remote")]
impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
fn from(model_type: SentenceEmbeddingsModelType) -> Self {
match model_type {
SentenceEmbeddingsModelType::DistiluseBaseMultilingualCased => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
transformer_type: ModelType::DistilBert,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
dense_config_resource: Some(Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsDenseConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
))),
dense_weights_resource: Some(Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsDenseResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
))),
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
},
SentenceEmbeddingsModelType::BertBaseNliMeanTokens => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
transformer_type: ModelType::Bert,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
BertConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
dense_config_resource: None,
dense_weights_resource: None,
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
BertVocabResources::BERT_BASE_NLI_MEAN_TOKENS,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
},
SentenceEmbeddingsModelType::AllMiniLmL12V2 => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::ALL_MINI_LM_L12_V2,
)),
transformer_type: ModelType::Bert,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
BertConfigResources::ALL_MINI_LM_L12_V2,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
BertModelResources::ALL_MINI_LM_L12_V2,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::ALL_MINI_LM_L12_V2,
)),
dense_config_resource: None,
dense_weights_resource: None,
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::ALL_MINI_LM_L12_V2,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::ALL_MINI_LM_L12_V2,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
BertVocabResources::ALL_MINI_LM_L12_V2,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
},
SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::ALL_DISTILROBERTA_V1,
)),
transformer_type: ModelType::Roberta,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
RobertaConfigResources::ALL_DISTILROBERTA_V1,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
RobertaModelResources::ALL_DISTILROBERTA_V1,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::ALL_DISTILROBERTA_V1,
)),
dense_config_resource: None,
dense_weights_resource: None,
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::ALL_DISTILROBERTA_V1,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::ALL_DISTILROBERTA_V1,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
RobertaVocabResources::ALL_DISTILROBERTA_V1,
)),
tokenizer_merges_resource: Some(Box::new(RemoteResource::from_pretrained(
RobertaMergesResources::ALL_DISTILROBERTA_V1,
))),
device: Device::cuda_if_available(),
},
SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2 => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
transformer_type: ModelType::Albert,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
AlbertConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
AlbertModelResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
dense_config_resource: None,
dense_weights_resource: None,
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
AlbertVocabResources::PARAPHRASE_ALBERT_SMALL_V2,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
},
SentenceEmbeddingsModelType::SentenceT5Base => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::SENTENCE_T5_BASE,
)),
transformer_type: ModelType::T5,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
T5ConfigResources::SENTENCE_T5_BASE,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
T5ModelResources::SENTENCE_T5_BASE,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::SENTENCE_T5_BASE,
)),
dense_config_resource: Some(Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsDenseConfigResources::SENTENCE_T5_BASE,
))),
dense_weights_resource: Some(Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsDenseResources::SENTENCE_T5_BASE,
))),
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::SENTENCE_T5_BASE,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::SENTENCE_T5_BASE,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
T5VocabResources::SENTENCE_T5_BASE,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
},
}
}
}
/// Configuration for the modules that define the model's layers
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentenceEmbeddingsModulesConfig(pub Vec<SentenceEmbeddingsModuleConfig>);
impl std::ops::Deref for SentenceEmbeddingsModulesConfig {
type Target = Vec<SentenceEmbeddingsModuleConfig>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for SentenceEmbeddingsModulesConfig {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<Vec<SentenceEmbeddingsModuleConfig>> for SentenceEmbeddingsModulesConfig {
fn from(source: Vec<SentenceEmbeddingsModuleConfig>) -> Self {
Self(source)
}
}
impl Config for SentenceEmbeddingsModulesConfig {}
impl SentenceEmbeddingsModulesConfig {
pub fn validate(self) -> Result<Self, RustBertError> {
match self.get(0) {
Some(SentenceEmbeddingsModuleConfig {
module_type: SentenceEmbeddingsModuleType::Transformer,
..
}) => (),
Some(_) => {
return Err(RustBertError::InvalidConfigurationError(
"First module defined in modules.json must be a Transformer".to_string(),
));
}
None => {
return Err(RustBertError::InvalidConfigurationError(
"No modules found in modules.json".to_string(),
));
}
}
match self.get(1) {
Some(SentenceEmbeddingsModuleConfig {
module_type: SentenceEmbeddingsModuleType::Pooling,
..
}) => (),
Some(_) => {
return Err(RustBertError::InvalidConfigurationError(
"Second module defined in modules.json must be a Pooling".to_string(),
));
}
None => {
return Err(RustBertError::InvalidConfigurationError(
"Pooling module not found in second position in modules.json".to_string(),
));
}
}
Ok(self)
}
pub fn transformer_module(&self) -> &SentenceEmbeddingsModuleConfig {
self.get(0).as_ref().unwrap()
}
pub fn pooling_module(&self) -> &SentenceEmbeddingsModuleConfig {
self.get(1).as_ref().unwrap()
}
pub fn dense_module(&self) -> Option<&SentenceEmbeddingsModuleConfig> {
for i in 2..=3 {
if let Some(SentenceEmbeddingsModuleConfig {
module_type: SentenceEmbeddingsModuleType::Dense,
..
}) = self.get(i)
{
return self.get(i);
}
}
None
}
pub fn has_normalization(&self) -> bool {
for i in 2..=3 {
if let Some(SentenceEmbeddingsModuleConfig {
module_type: SentenceEmbeddingsModuleType::Normalize,
..
}) = self.get(i)
{
return true;
}
}
false
}
}
/// Configuration defining a single module (model's layer)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentenceEmbeddingsModuleConfig {
pub idx: usize,
pub name: String,
pub path: String,
#[serde(rename = "type")]
#[serde(with = "serde_sentence_embeddings_module_type")]
pub module_type: SentenceEmbeddingsModuleType,
}
/// Available module types, based on Sentence-Transformers
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SentenceEmbeddingsModuleType {
Transformer,
Pooling,
Dense,
Normalize,
}
mod serde_sentence_embeddings_module_type {
use super::SentenceEmbeddingsModuleType;
use serde::{de, Deserializer, Serializer};
pub fn serialize<S>(
module_type: &SentenceEmbeddingsModuleType,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type))
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<SentenceEmbeddingsModuleType, D::Error>
where
D: Deserializer<'de>,
{
struct SentenceEmbeddingsModuleTypeVisitor;
impl de::Visitor<'_> for SentenceEmbeddingsModuleTypeVisitor {
type Value = SentenceEmbeddingsModuleType;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sentence embeddings module type")
}
fn visit_str<E: de::Error>(self, s: &str) -> Result<Self::Value, E> {
s.split('.')
.last()
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_string())))
.transpose()
.map_err(de::Error::custom)?
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s))
.map_err(de::Error::custom)
}
}
deserializer.deserialize_str(SentenceEmbeddingsModuleTypeVisitor)
}
}
/// Configuration for Sentence-Transformers specific parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentenceEmbeddingsSentenceBertConfig {
pub max_seq_length: usize,
pub do_lower_case: bool,
}
impl Config for SentenceEmbeddingsSentenceBertConfig {}
/// Configuration for transformer's tokenizer
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentenceEmbeddingsTokenizerConfig {
pub add_prefix_space: Option<bool>,
pub strip_accents: Option<bool>,
}
impl Config for SentenceEmbeddingsTokenizerConfig {}

View File

@ -0,0 +1,151 @@
use std::path::Path;
use serde::{de, Deserialize, Deserializer};
use tch::{nn, Device, Kind, Tensor};
use crate::common::activations::{Activation, TensorFunction};
use crate::{Config, RustBertError};
/// Configuration for [`Pooling`](Pooling) layer.
#[derive(Debug, Deserialize)]
pub struct PoolingConfig {
/// Dimensions for the word embeddings
pub word_embedding_dimension: i64,
/// Use the first token (CLS token) as text representations
pub pooling_mode_cls_token: bool,
/// Use max in each dimension over all tokens
pub pooling_mode_max_tokens: bool,
/// Perform mean-pooling
pub pooling_mode_mean_tokens: bool,
/// Perform mean-pooling, but devide by sqrt(input_length)
pub pooling_mode_mean_sqrt_len_tokens: bool,
}
impl Config for PoolingConfig {}
/// Performs pooling (max or mean) on the token embeddings.
///
/// Using pooling, it generates from a variable sized sentence a fixed sized sentence
/// embedding. You can concatenate multiple poolings together.
pub struct Pooling {
conf: PoolingConfig,
}
impl Pooling {
pub fn new(conf: PoolingConfig) -> Pooling {
Pooling { conf }
}
pub fn forward(&self, mut token_embeddings: Tensor, attention_mask: &Tensor) -> Tensor {
let mut output_vectors = Vec::new();
if self.conf.pooling_mode_cls_token {
let cls_token = token_embeddings.select(1, 0); // Take first token by default
output_vectors.push(cls_token);
}
if self.conf.pooling_mode_max_tokens {
let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings);
// Set padding tokens to large negative value
token_embeddings = token_embeddings.masked_fill_(&input_mask_expanded.eq(0), -1e9);
let max_over_time = token_embeddings.max_dim(1, true).0;
output_vectors.push(max_over_time);
}
if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens {
let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings);
let sum_embeddings =
(token_embeddings * &input_mask_expanded).sum_dim_intlist(&[1], false, Kind::Float);
let sum_mask = input_mask_expanded.sum_dim_intlist(&[1], false, Kind::Float);
let sum_mask = sum_mask.clamp_min(10e-9);
if self.conf.pooling_mode_mean_tokens {
output_vectors.push(&sum_embeddings / &sum_mask);
}
if self.conf.pooling_mode_mean_sqrt_len_tokens {
output_vectors.push(sum_embeddings / sum_mask.sqrt());
}
}
Tensor::cat(&output_vectors, 1)
}
}
/// Configuration for [`Dense`](Dense) layer.
#[derive(Debug, Deserialize)]
pub struct DenseConfig {
/// Size of the input dimension
pub in_features: i64,
/// Output size
pub out_features: i64,
/// Add a bias vector
pub bias: bool,
/// Activation function applied on output
#[serde(deserialize_with = "last_part")]
pub activation_function: Activation,
}
impl Config for DenseConfig {}
/// Split the given string on `.` and try to construct an `Activation` from the last part
fn last_part<'de, D>(deserializer: D) -> Result<Activation, D::Error>
where
D: Deserializer<'de>,
{
let activation = String::deserialize(deserializer)?;
activation
.split('.')
.last()
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase())))
.transpose()
.map_err(de::Error::custom)?
.ok_or_else(|| format!("Invalid Activation: {}", activation))
.map_err(de::Error::custom)
}
/// Feed-forward function with activiation function.
///
/// This layer takes a fixed-sized sentence embedding and passes it through a
/// feed-forward layer. Can be used to generate deep averaging networs (DAN).
pub struct Dense {
linear: nn::Linear,
activation: TensorFunction,
_var_store: nn::VarStore,
}
impl Dense {
pub fn new<P: AsRef<Path>>(
dense_conf: DenseConfig,
dense_weights: P,
device: Device,
) -> Result<Dense, RustBertError> {
let mut vs_dense = nn::VarStore::new(device);
let linear_conf = nn::LinearConfig {
ws_init: nn::Init::Const(0.),
bs_init: Some(nn::Init::Const(0.)),
bias: dense_conf.bias,
};
let linear = nn::linear(
&vs_dense.root(),
dense_conf.in_features,
dense_conf.out_features,
linear_conf,
);
let activation = dense_conf.activation_function.get_function();
vs_dense.load(dense_weights)?;
Ok(Dense {
linear,
activation,
_var_store: vs_dense,
})
}
pub fn forward(&self, x: &Tensor) -> Tensor {
self.activation.get_fn()(&x.apply(&self.linear))
}
}

View File

@ -0,0 +1,64 @@
//! # Sentence Embeddings pipeline
//!
//! Compute sentence/text embeddings that can be compared (e.g. with
//! cosine-similarity) to find sentences with a similar meaning. This can be useful for
//! semantic textual similar, semantic search, or paraphrase mining.
//!
//! The implementation is based on [Sentence-Transformers][sbert] and pretrained models
//! available on [Hugging Face Hub][sbert-hub] can be used. It's however necessary to
//! convert them using the script `utils/convert_model.py` beforehand, see
//! `tests/sentence_embeddings.rs` for such examples.
//!
//! [sbert]: https://sbert.net/
//! [sbert-hub]: https://huggingface.co/sentence-transformers/
//!
//! Basic usage is as follows:
//!
//! ```no_run
//! use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
//!
//! # fn main() -> anyhow::Result<()> {
//! let model = SentenceEmbeddingsBuilder::local("local/path/to/distiluse-base-multilingual-cased")
//! .with_device(tch::Device::cuda_if_available())
//! .create_model()?;
//!
//! let sentences = ["This is an example sentence", "Each sentence is converted"];
//! let embeddings = model.encode(&sentences)?;
//! # Ok(())
//! # }
//! ```
pub mod builder;
mod config;
pub mod layers;
mod pipeline;
mod resources;
pub use builder::SentenceEmbeddingsBuilder;
pub use config::{
SentenceEmbeddingsConfig, SentenceEmbeddingsModuleConfig, SentenceEmbeddingsModuleType,
SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
SentenceEmbeddingsTokenizerConfig,
};
pub use pipeline::{
SentenceEmbeddingsModel, SentenceEmbeddingsModelOuput, SentenceEmbeddingsOption,
SentenceEmbeddingsTokenizerOuput,
};
pub use resources::{
SentenceEmbeddingsConfigResources, SentenceEmbeddingsDenseConfigResources,
SentenceEmbeddingsDenseResources, SentenceEmbeddingsModelType,
SentenceEmbeddingsModulesConfigResources, SentenceEmbeddingsPoolingConfigResources,
SentenceEmbeddingsTokenizerConfigResources,
};
/// Length = sequence length
pub type Attention = Vec<f32>;
/// Length = sequence length
pub type AttentionHead = Vec<Attention>;
/// Length = number of heads per attention layer
pub type AttentionLayer = Vec<AttentionHead>;
/// Length = number of attention layers
pub type AttentionOutput = Vec<AttentionLayer>;
pub type Embedding = Vec<f32>;

View File

@ -0,0 +1,461 @@
use std::borrow::Borrow;
use std::convert::TryInto;
use rust_tokenizers::tokenizer::TruncationStrategy;
use tch::{nn, Tensor};
use crate::albert::AlbertForSentenceEmbeddings;
use crate::bert::BertForSentenceEmbeddings;
use crate::distilbert::DistilBertForSentenceEmbeddings;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::sentence_embeddings::layers::{Dense, DenseConfig, Pooling, PoolingConfig};
use crate::pipelines::sentence_embeddings::{
AttentionHead, AttentionLayer, AttentionOutput, Embedding, SentenceEmbeddingsConfig,
SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
SentenceEmbeddingsTokenizerConfig,
};
use crate::roberta::RobertaForSentenceEmbeddings;
use crate::t5::T5ForSentenceEmbeddings;
use crate::{Config, RustBertError};
/// # Abstraction that holds one particular sentence embeddings model, for any of the supported models
pub enum SentenceEmbeddingsOption {
/// Bert for Sentence Embeddings
Bert(BertForSentenceEmbeddings),
/// DistilBert for Sentence Embeddings
DistilBert(DistilBertForSentenceEmbeddings),
/// Roberta for Sentence Embeddings
Roberta(RobertaForSentenceEmbeddings),
/// Albert for Sentence Embeddings
Albert(AlbertForSentenceEmbeddings),
/// T5 for Sentence Embeddings
T5(T5ForSentenceEmbeddings),
}
impl SentenceEmbeddingsOption {
/// Instantiate a new sentence embeddings transformer of the supplied type.
///
/// # Arguments
///
/// * `transformer_type` - `ModelType` indicating the transformer model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. rust_model.ot)
/// * `config` - A configuration (the transformer model type of the configuration must be compatible with the value for `transformer_type`)
pub fn new<'p, P>(
transformer_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
use SentenceEmbeddingsOption::*;
let option = match transformer_type {
ModelType::Bert => Bert(BertForSentenceEmbeddings::new(p, &(config.try_into()?))),
ModelType::DistilBert => DistilBert(DistilBertForSentenceEmbeddings::new(
p,
&(config.try_into()?),
)),
ModelType::Roberta => Roberta(RobertaForSentenceEmbeddings::new_with_optional_pooler(
p,
&(config.try_into()?),
false,
)),
ModelType::Albert => Albert(AlbertForSentenceEmbeddings::new(p, &(config.try_into()?))),
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
_ => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Unsupported transformer model {:?} for Sentence Embeddings",
transformer_type
)));
}
};
Ok(option)
}
/// Interface method to forward() of the particular transformer models.
pub fn forward(
&self,
tokens_ids: &Tensor,
tokens_masks: &Tensor,
) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
match self {
Self::Bert(transformer) => transformer
.forward_t(
Some(tokens_ids),
Some(tokens_masks),
None,
None,
None,
None,
None,
false,
)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions,
)
}),
Self::DistilBert(transformer) => transformer
.forward_t(Some(tokens_ids), Some(tokens_masks), None, false)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions,
)
}),
Self::Roberta(transformer) => transformer
.forward_t(
Some(tokens_ids),
Some(tokens_masks),
None,
None,
None,
None,
None,
false,
)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions,
)
}),
Self::Albert(transformer) => transformer
.forward_t(
Some(tokens_ids),
Some(tokens_masks),
None,
None,
None,
false,
)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions.map(|attentions| {
attentions
.into_iter()
.map(|tensors| {
let num_inner_groups = tensors.len() as f64;
tensors.into_iter().sum::<Tensor>() / num_inner_groups
})
.collect()
}),
)
}),
Self::T5(transformer) => transformer.forward(tokens_ids, tokens_masks),
}
}
}
/// # SentenceEmbeddingsModel to perform sentence embeddings
///
/// It is made of the following blocks:
/// - `transformer`: Base transformer model
/// - `pooling`: Pooling layer
/// - `dense` _(optional)_: Linear (feed forward) layer
/// - `normalization` _(optional)_: Embeddings normalization
pub struct SentenceEmbeddingsModel {
sentence_bert_config: SentenceEmbeddingsSentenceBertConfig,
tokenizer: TokenizerOption,
tokenizer_truncation_strategy: TruncationStrategy,
var_store: nn::VarStore,
transformer: SentenceEmbeddingsOption,
transformer_config: ConfigOption,
pooling_layer: Pooling,
dense_layer: Option<Dense>,
normalize_embeddings: bool,
}
impl SentenceEmbeddingsModel {
/// Build a new `SentenceEmbeddingsModel`
///
/// # Arguments
///
/// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
pub fn new(config: SentenceEmbeddingsConfig) -> Result<Self, RustBertError> {
let SentenceEmbeddingsConfig {
modules_config_resource,
sentence_bert_config_resource,
tokenizer_config_resource,
tokenizer_vocab_resource,
tokenizer_merges_resource,
transformer_type,
transformer_config_resource,
transformer_weights_resource,
pooling_config_resource,
dense_config_resource,
dense_weights_resource,
device,
} = config;
let modules =
SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
.validate()?;
// Setup tokenizer
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
tokenizer_config_resource.get_local_path()?,
);
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
sentence_bert_config_resource.get_local_path()?,
);
let tokenizer = TokenizerOption::from_file(
transformer_type,
tokenizer_vocab_resource
.get_local_path()?
.to_string_lossy()
.as_ref(),
tokenizer_merges_resource
.as_ref()
.map(|resource| resource.get_local_path())
.transpose()?
.map(|path| path.to_string_lossy().into_owned())
.as_deref(),
sentence_bert_config.do_lower_case,
tokenizer_config.strip_accents,
tokenizer_config.add_prefix_space,
)?;
// Setup transformer
let mut var_store = nn::VarStore::new(device);
let transformer_config = ConfigOption::from_file(
transformer_type,
transformer_config_resource.get_local_path()?,
);
let transformer = SentenceEmbeddingsOption::new(
transformer_type,
&var_store.root(),
&transformer_config,
)?;
var_store.load(transformer_weights_resource.get_local_path()?)?;
// Setup pooling layer
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
let pooling_layer = Pooling::new(pooling_config);
// Setup dense layer
let dense_layer = if modules.dense_module().is_some() {
let dense_config =
DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
Some(Dense::new(
dense_config,
dense_weights_resource.unwrap().get_local_path()?,
device,
)?)
} else {
None
};
let normalize_embeddings = modules.has_normalization();
Ok(Self {
tokenizer,
sentence_bert_config,
tokenizer_truncation_strategy: TruncationStrategy::LongestFirst,
var_store,
transformer,
transformer_config,
pooling_layer,
dense_layer,
normalize_embeddings,
})
}
/// Sets the tokenizer's truncation strategy
pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) {
self.tokenizer_truncation_strategy = truncation_strategy;
}
/// Tokenizes the inputs
pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOuput
where
S: AsRef<str> + Sync,
{
let tokenized_input = self.tokenizer.encode_list(
inputs,
self.sentence_bert_config.max_seq_length,
&self.tokenizer_truncation_strategy,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap_or(0);
let pad_token_id = self.tokenizer.get_pad_id().unwrap_or(0);
let tokens_ids = tokenized_input
.into_iter()
.map(|input| {
let mut token_ids = input.token_ids;
token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]);
token_ids
})
.collect::<Vec<_>>();
let tokens_masks = tokens_ids
.iter()
.map(|input| {
Tensor::of_slice(
&input
.iter()
.map(|&e| if e == pad_token_id { 0_i64 } else { 1_i64 })
.collect::<Vec<_>>(),
)
})
.collect::<Vec<_>>();
let tokens_ids = tokens_ids
.into_iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
SentenceEmbeddingsTokenizerOuput {
tokens_ids,
tokens_masks,
}
}
/// Computes sentence embeddings, outputs `Tensor`.
pub fn encode_as_tensor<S>(
&self,
inputs: &[S],
) -> Result<SentenceEmbeddingsModelOuput, RustBertError>
where
S: AsRef<str> + Sync,
{
let SentenceEmbeddingsTokenizerOuput {
tokens_ids,
tokens_masks,
} = self.tokenize(inputs);
let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device());
let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device());
let (tokens_embeddings, all_attentions) =
tch::no_grad(|| self.transformer.forward(&tokens_ids, &tokens_masks))?;
let mean_pool =
tch::no_grad(|| self.pooling_layer.forward(tokens_embeddings, &tokens_masks));
let maybe_linear = if let Some(dense_layer) = &self.dense_layer {
tch::no_grad(|| dense_layer.forward(&mean_pool))
} else {
mean_pool
};
let maybe_normalized = if self.normalize_embeddings {
let norm = &maybe_linear
.norm_scalaropt_dim(2, &[1], true)
.clamp_min(1e-12)
.expand_as(&maybe_linear);
maybe_linear / norm
} else {
maybe_linear
};
Ok(SentenceEmbeddingsModelOuput {
embeddings: maybe_normalized,
all_attentions,
})
}
/// Computes sentence embeddings.
pub fn encode<S>(&self, inputs: &[S]) -> Result<Vec<Embedding>, RustBertError>
where
S: AsRef<str> + Sync,
{
let SentenceEmbeddingsModelOuput { embeddings, .. } = self.encode_as_tensor(inputs)?;
Ok(Vec::from(embeddings))
}
fn nb_layers(&self) -> usize {
use SentenceEmbeddingsOption::*;
match (&self.transformer, &self.transformer_config) {
(Bert(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
(Bert(_), _) => unreachable!(),
(DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_layers as usize,
(DistilBert(_), _) => unreachable!(),
(Roberta(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
(Roberta(_), _) => unreachable!(),
(Albert(_), ConfigOption::Albert(conf)) => conf.num_hidden_layers as usize,
(Albert(_), _) => unreachable!(),
(T5(_), ConfigOption::T5(conf)) => conf.num_layers as usize,
(T5(_), _) => unreachable!(),
}
}
fn nb_heads(&self) -> usize {
use SentenceEmbeddingsOption::*;
match (&self.transformer, &self.transformer_config) {
(Bert(_), ConfigOption::Bert(conf)) => conf.num_attention_heads as usize,
(Bert(_), _) => unreachable!(),
(DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_heads as usize,
(DistilBert(_), _) => unreachable!(),
(Roberta(_), ConfigOption::Roberta(conf)) => conf.num_attention_heads as usize,
(Roberta(_), _) => unreachable!(),
(Albert(_), ConfigOption::Albert(conf)) => conf.num_attention_heads as usize,
(Albert(_), _) => unreachable!(),
(T5(_), ConfigOption::T5(conf)) => conf.num_heads as usize,
(T5(_), _) => unreachable!(),
}
}
/// Computes sentence embeddings, also outputs `AttentionOutput`s.
pub fn encode_with_attention<S>(
&self,
inputs: &[S],
) -> Result<(Vec<Embedding>, Vec<AttentionOutput>), RustBertError>
where
S: AsRef<str> + Sync,
{
let SentenceEmbeddingsModelOuput {
embeddings,
all_attentions,
} = self.encode_as_tensor(inputs)?;
let embeddings = Vec::from(embeddings);
let all_attentions = all_attentions.ok_or_else(|| {
RustBertError::InvalidConfigurationError("No attention outputted".into())
})?;
let attention_outputs = (0..inputs.len() as i64)
.map(|i| {
let mut attention_output = AttentionOutput::with_capacity(self.nb_layers());
for layer in all_attentions.iter() {
let mut attention_layer = AttentionLayer::with_capacity(self.nb_heads());
for head in 0..self.nb_heads() {
let attention_slice = layer
.slice(0, i, i + 1, 1)
.slice(1, head as i64, head as i64 + 1, 1)
.squeeze();
let attention_head = AttentionHead::from(attention_slice);
attention_layer.push(attention_head);
}
attention_output.push(attention_layer);
}
attention_output
})
.collect::<Vec<AttentionOutput>>();
Ok((embeddings, attention_outputs))
}
}
/// Container for the SentenceEmbeddings tokenizer output.
pub struct SentenceEmbeddingsTokenizerOuput {
pub tokens_ids: Vec<Tensor>,
pub tokens_masks: Vec<Tensor>,
}
/// Container for the SentenceEmbeddings model output.
pub struct SentenceEmbeddingsModelOuput {
pub embeddings: Tensor,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -0,0 +1,184 @@
/// # Pretrained config files for sentence embeddings
pub struct SentenceEmbeddingsModulesConfigResources;
/// # Pretrained dense weights files for sentence embeddings
pub struct SentenceEmbeddingsDenseResources;
/// # Pretrained dense config files for sentence embeddings
pub struct SentenceEmbeddingsDenseConfigResources;
/// # Pretrained pooling config files for sentence embeddings
pub struct SentenceEmbeddingsPoolingConfigResources;
/// # Pretrained config files for sentence embeddings
pub struct SentenceEmbeddingsConfigResources;
/// # Pretrained tokenizer config files for sentence embeddings
pub struct SentenceEmbeddingsTokenizerConfigResources;
pub enum SentenceEmbeddingsModelType {
DistiluseBaseMultilingualCased,
BertBaseNliMeanTokens,
AllMiniLmL12V2,
AllDistilrobertaV1,
ParaphraseAlbertSmallV2,
SentenceT5Base,
}
impl SentenceEmbeddingsModulesConfigResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/sbert-config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/modules.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/sbert-config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/modules.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/sbert-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/modules.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/sbert-config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/modules.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/sbert-config",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/modules.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/sbert-config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/modules.json",
);
}
impl SentenceEmbeddingsDenseResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/sbert-dense",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/2_Dense/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/sbert-dense",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/2_Dense/rust_model.ot",
);
}
impl SentenceEmbeddingsDenseConfigResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/sbert-dense-config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/2_Dense/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/sbert-dense-config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/2_Dense/config.json",
);
}
impl SentenceEmbeddingsPoolingConfigResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/1_Pooling/config.json",
);
}
impl SentenceEmbeddingsConfigResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/sbert-config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/sbert-config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/sbert-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/sbert-config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/sbert-config",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/sbert-config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/sentence_bert_config.json",
);
}
impl SentenceEmbeddingsTokenizerConfigResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/tokenizer-config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/tokenizer-config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/tokenizer-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/tokenizer-config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
"paraphrase-albert-small-v2/tokenizer-config",
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/tokenizer-config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/tokenizer_config.json",
);
}

View File

@ -67,6 +67,7 @@ mod roberta_model;
pub use embeddings::RobertaEmbeddings;
pub use roberta_model::{
RobertaConfig, RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
RobertaForQuestionAnswering, RobertaForSentenceEmbeddings, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
};

View File

@ -68,6 +68,11 @@ impl RobertaModelResources {
"xlm-roberta-ner-es/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/model",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot",
);
}
impl RobertaConfigResources {
@ -106,6 +111,11 @@ impl RobertaConfigResources {
"xlm-roberta-ner-es/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json",
);
}
impl RobertaVocabResources {
@ -144,6 +154,11 @@ impl RobertaVocabResources {
"xlm-roberta-ner-es/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/vocab",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json",
);
}
impl RobertaMergesResources {
@ -162,6 +177,11 @@ impl RobertaMergesResources {
"roberta-qa/merges",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/merges",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt",
);
}
pub struct RobertaLMHead {
@ -972,6 +992,10 @@ impl RobertaForQuestionAnswering {
}
}
/// # RoBERTa for sentence embeddings
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
pub type RobertaForSentenceEmbeddings = BertModel<RobertaEmbeddings>;
/// Container for the RoBERTa masked LM model output.
pub struct RobertaMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position

View File

@ -41,7 +41,7 @@
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer = T5Tokenizer::from_file(spiece_path.to_str().unwrap(), true);
//! let config = T5Config::from_file(config_path);
//! let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
//! let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
@ -55,6 +55,7 @@ mod t5_model;
pub use attention::LayerState;
pub use t5_model::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Generator, T5Model, T5ModelOutput,
T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, T5VocabResources,
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5ForSentenceEmbeddings, T5Generator,
T5Model, T5ModelOutput, T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages,
T5VocabResources,
};

View File

@ -59,6 +59,11 @@ impl T5ModelResources {
"t5-base/model",
"https://huggingface.co/t5-base/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/model",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
);
}
impl T5ConfigResources {
@ -72,6 +77,11 @@ impl T5ConfigResources {
"t5-base/config",
"https://huggingface.co/t5-base/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/config",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
);
}
impl T5VocabResources {
@ -85,6 +95,11 @@ impl T5VocabResources {
"t5-base/spiece",
"https://huggingface.co/t5-base/resolve/main/spiece.model",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
"sentence-t5-base/spiece",
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
);
}
const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
@ -133,6 +148,8 @@ pub struct T5Config {
pub feed_forward_proj: Option<FeedForwardProj>,
pub tie_word_embeddings: Option<bool>,
task_specific_params: Option<TaskSpecificParams>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
}
/// # T5 task-specific configurations
@ -209,6 +226,8 @@ impl Default for T5Config {
feed_forward_proj: Some(FeedForwardProj::Relu),
tie_word_embeddings: None,
task_specific_params: None,
output_attentions: None,
output_hidden_states: None,
}
}
}
@ -233,8 +252,6 @@ impl T5Model {
///
/// * `p` - Variable store path for the root of the BART model
/// * `config` - `T5Config` object defining the model architecture
/// * `output_attention` - flag indicating if the model should output the attention weights of intermediate layers
/// * `output_hidden_states` - flag indicating if the model should output the hidden states weights of intermediate layers
///
/// # Example
///
@ -248,21 +265,9 @@ impl T5Model {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = T5Config::from_file(config_path);
/// let output_attentions = true;
/// let output_hidden_states = true;
/// let t5: T5Model = T5Model::new(
/// &p.root() / "t5",
/// &config,
/// output_attentions,
/// output_hidden_states,
/// );
/// let t5: T5Model = T5Model::new(&p.root() / "t5", &config);
/// ```
pub fn new<'p, P>(
p: P,
config: &T5Config,
output_attentions: bool,
output_hidden_states: bool,
) -> T5Model
pub fn new<'p, P>(p: P, config: &T5Config) -> T5Model
where
P: Borrow<nn::Path<'p>>,
{
@ -280,16 +285,16 @@ impl T5Model {
config,
false,
false,
output_attentions,
output_hidden_states,
config.output_attentions.unwrap_or(false),
config.output_hidden_states.unwrap_or(false),
);
let decoder = T5Stack::new(
p / "decoder",
config,
true,
true,
output_attentions,
output_hidden_states,
config.output_attentions.unwrap_or(false),
config.output_hidden_states.unwrap_or(false),
);
T5Model {
@ -338,7 +343,7 @@ impl T5Model {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = T5Config::from_file(config_path);
/// # let t5_model: T5Model = T5Model::new(&vs.root(), &config, false, false);
/// # let t5_model: T5Model = T5Model::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -450,8 +455,6 @@ impl T5ForConditionalGeneration {
///
/// * `p` - Variable store path for the root of the BART model
/// * `config` - `T5Config` object defining the model architecture
/// * `output_attention` - flag indicating if the model should output the attention weights of intermediate layers
/// * `output_hidden_states` - flag indicating if the model should output the hidden states weights of intermediate layers
///
/// # Example
///
@ -465,27 +468,15 @@ impl T5ForConditionalGeneration {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = T5Config::from_file(config_path);
/// let output_attentions = true;
/// let output_hidden_states = true;
/// let t5 = T5ForConditionalGeneration::new(
/// &p.root() / "t5",
/// &config,
/// output_attentions,
/// output_hidden_states,
/// );
/// let t5 = T5ForConditionalGeneration::new(&p.root() / "t5", &config);
/// ```
pub fn new<'p, P>(
p: P,
config: &T5Config,
output_attentions: bool,
output_hidden_states: bool,
) -> T5ForConditionalGeneration
pub fn new<'p, P>(p: P, config: &T5Config) -> T5ForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = T5Model::new(p, config, output_attentions, output_hidden_states);
let base_model = T5Model::new(p, config);
let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true);
let lm_head = if !tie_word_embeddings {
@ -549,7 +540,7 @@ impl T5ForConditionalGeneration {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = T5Config::from_file(config_path);
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -666,7 +657,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = T5Config::from_file(config_path);
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -749,6 +740,84 @@ impl LMHeadModel for T5ForConditionalGeneration {
}
}
/// # T5 for sentence embeddings
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
pub struct T5ForSentenceEmbeddings {
embeddings: nn::Embedding,
encoder: T5Stack,
}
impl T5ForSentenceEmbeddings {
/// Build a new `T5ForSentenceEmbeddings`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BART model
/// * `config` - `T5Config` object defining the model architecture
///
/// It consists of only an encoder (there is no decoder).
pub fn new<'p, P>(p: P, config: &T5Config) -> Self
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
Default::default(),
);
let encoder = T5Stack::new(
p / "encoder",
config,
false,
false,
config.output_attentions.unwrap_or(false),
config.output_hidden_states.unwrap_or(false),
);
Self {
embeddings,
encoder,
}
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Input of shape (*batch size*, *source_sequence_length*).
/// * `mask` - Attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
///
/// # Returns
///
/// * Tuple containing:
/// - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `Option<Vec<Tensor>>` of length *num_encoder_layers* of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing attention weights for all layers of the encoder
pub fn forward(
&self,
input_ids: &Tensor,
mask: &Tensor,
) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
let transformer_output = self.encoder.forward_t(
Some(input_ids),
Some(mask),
None,
None,
None,
&self.embeddings,
None,
false,
)?;
Ok((
transformer_output.hidden_state,
transformer_output.all_attentions,
))
}
}
/// Container holding a T5 model output. The decoder output may hold the hidden state of
/// the last layer of the decoder, or may hold logits for a custom head module after the
/// decoder (e.g. for language modeling tasks)
@ -812,7 +881,7 @@ impl T5Generator {
let mut var_store = nn::VarStore::new(device);
let config = T5Config::from_file(config_path);
let model = T5ForConditionalGeneration::new(&var_store.root(), &config, false, false);
let model = T5ForConditionalGeneration::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));

View File

@ -0,0 +1,157 @@
use rust_bert::pipelines::sentence_embeddings::{
SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType,
};
#[test]
fn sbert_distilbert() -> anyhow::Result<()> {
let model = SentenceEmbeddingsBuilder::remote(
SentenceEmbeddingsModelType::DistiluseBaseMultilingualCased,
)
.create_model()?;
let sentences = ["This is an example sentence", "Each sentence is converted"];
let embeddings = model.encode(&sentences)?;
assert!((embeddings[0][0] as f64 - -0.03479306).abs() < 1e-4);
assert!((embeddings[0][1] as f64 - 0.02635195).abs() < 1e-4);
assert!((embeddings[0][2] as f64 - -0.04427199).abs() < 1e-4);
assert!((embeddings[0][509] as f64 - 0.01743882).abs() < 1e-4);
assert!((embeddings[0][510] as f64 - -0.01952395).abs() < 1e-4);
assert!((embeddings[0][511] as f64 - -0.00118101).abs() < 1e-4);
assert!((embeddings[1][0] as f64 - 0.02096637).abs() < 1e-4);
assert!((embeddings[1][1] as f64 - -0.00401743).abs() < 1e-4);
assert!((embeddings[1][2] as f64 - -0.05093712).abs() < 1e-4);
assert!((embeddings[1][509] as f64 - 0.03618195).abs() < 1e-4);
assert!((embeddings[1][510] as f64 - 0.0294408).abs() < 1e-4);
assert!((embeddings[1][511] as f64 - -0.04497765).abs() < 1e-4);
Ok(())
}
#[test]
fn sbert_bert() -> anyhow::Result<()> {
let model =
SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::BertBaseNliMeanTokens)
.create_model()?;
let sentences = ["this is an example sentence", "each sentence is converted"];
let embeddings = model.encode(&sentences)?;
assert!((embeddings[0][0] as f64 - -0.393099815).abs() < 1e-4);
assert!((embeddings[0][1] as f64 - 0.0388629436).abs() < 1e-4);
assert!((embeddings[0][2] as f64 - 1.98742473).abs() < 1e-4);
assert!((embeddings[0][765] as f64 - -0.609367728).abs() < 1e-4);
assert!((embeddings[0][766] as f64 - -1.09462142).abs() < 1e-4);
assert!((embeddings[0][767] as f64 - 0.326490253).abs() < 1e-4);
assert!((embeddings[1][0] as f64 - 0.0615336187).abs() < 1e-4);
assert!((embeddings[1][1] as f64 - 0.32736221).abs() < 1e-4);
assert!((embeddings[1][2] as f64 - 1.8332324).abs() < 1e-4);
assert!((embeddings[1][765] as f64 - -0.129853949).abs() < 1e-4);
assert!((embeddings[1][766] as f64 - 0.460893631).abs() < 1e-4);
assert!((embeddings[1][767] as f64 - 0.240354523).abs() < 1e-4);
Ok(())
}
#[test]
fn sbert_bert_small() -> anyhow::Result<()> {
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL12V2)
.create_model()?;
let sentences = ["this is an example sentence", "each sentence is converted"];
let embeddings = model.encode(&sentences)?;
assert!((embeddings[0][0] as f64 - -2.02682902e-04).abs() < 1e-4);
assert!((embeddings[0][1] as f64 - 8.14802647e-02).abs() < 1e-4);
assert!((embeddings[0][2] as f64 - 3.13617811e-02).abs() < 1e-4);
assert!((embeddings[0][381] as f64 - 6.20930083e-02).abs() < 1e-4);
assert!((embeddings[0][382] as f64 - 4.91031967e-02).abs() < 1e-4);
assert!((embeddings[0][383] as f64 - -2.90199649e-04).abs() < 1e-4);
assert!((embeddings[1][0] as f64 - 6.47571534e-02).abs() < 1e-4);
assert!((embeddings[1][1] as f64 - 4.85198125e-02).abs() < 1e-4);
assert!((embeddings[1][2] as f64 - -1.78603437e-02).abs() < 1e-4);
assert!((embeddings[1][381] as f64 - 3.37569155e-02).abs() < 1e-4);
assert!((embeddings[1][382] as f64 - 8.43371451e-03).abs() < 1e-4);
assert!((embeddings[1][383] as f64 - -6.00359812e-02).abs() < 1e-4);
Ok(())
}
#[test]
fn sbert_distilroberta() -> anyhow::Result<()> {
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllDistilrobertaV1)
.create_model()?;
let sentences = ["This is an example sentence", "Each sentence is converted"];
let embeddings = model.encode(&sentences)?;
assert!((embeddings[0][0] as f64 - -0.03375624).abs() < 1e-4);
assert!((embeddings[0][1] as f64 - -0.06316338).abs() < 1e-4);
assert!((embeddings[0][2] as f64 - -0.0316612).abs() < 1e-4);
assert!((embeddings[0][765] as f64 - 0.03684864).abs() < 1e-4);
assert!((embeddings[0][766] as f64 - -0.02036646).abs() < 1e-4);
assert!((embeddings[0][767] as f64 - -0.01574).abs() < 1e-4);
assert!((embeddings[1][0] as f64 - -0.01409588).abs() < 1e-4);
assert!((embeddings[1][1] as f64 - 0.00091114).abs() < 1e-4);
assert!((embeddings[1][2] as f64 - -0.00096315).abs() < 1e-4);
assert!((embeddings[1][765] as f64 - -0.02571585).abs() < 1e-4);
assert!((embeddings[1][766] as f64 - -0.00289072).abs() < 1e-4);
assert!((embeddings[1][767] as f64 - -0.00579975).abs() < 1e-4);
Ok(())
}
#[test]
fn sbert_albert() -> anyhow::Result<()> {
let model =
SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2)
.create_model()?;
let sentences = ["this is an example sentence", "each sentence is converted"];
let embeddings = model.encode(&sentences)?;
assert!((embeddings[0][0] as f64 - 0.20412037).abs() < 1e-4);
assert!((embeddings[0][1] as f64 - 0.48823047).abs() < 1e-4);
assert!((embeddings[0][2] as f64 - 0.5664698).abs() < 1e-4);
assert!((embeddings[0][765] as f64 - -0.37474486).abs() < 1e-4);
assert!((embeddings[0][766] as f64 - 0.0254627).abs() < 1e-4);
assert!((embeddings[0][767] as f64 - -0.6846024).abs() < 1e-4);
assert!((embeddings[1][0] as f64 - 0.25720373).abs() < 1e-4);
assert!((embeddings[1][1] as f64 - 0.24648172).abs() < 1e-4);
assert!((embeddings[1][2] as f64 - -0.2521183).abs() < 1e-4);
assert!((embeddings[1][765] as f64 - 0.4667896).abs() < 1e-4);
assert!((embeddings[1][766] as f64 - 0.14219822).abs() < 1e-4);
assert!((embeddings[1][767] as f64 - 0.3986863).abs() < 1e-4);
Ok(())
}
#[test]
fn sbert_t5() -> anyhow::Result<()> {
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::SentenceT5Base)
.create_model()?;
let sentences = ["This is an example sentence", "Each sentence is converted"];
let embeddings = model.encode(&sentences)?;
assert!((embeddings[0][0] as f64 - -0.00904849).abs() < 1e-4);
assert!((embeddings[0][1] as f64 - 0.0191336).abs() < 1e-4);
assert!((embeddings[0][2] as f64 - 0.02657794).abs() < 1e-4);
assert!((embeddings[0][765] as f64 - -0.00876413).abs() < 1e-4);
assert!((embeddings[0][766] as f64 - -0.05602207).abs() < 1e-4);
assert!((embeddings[0][767] as f64 - -0.02163094).abs() < 1e-4);
assert!((embeddings[1][0] as f64 - -0.00785422).abs() < 1e-4);
assert!((embeddings[1][1] as f64 - 0.03018173).abs() < 1e-4);
assert!((embeddings[1][2] as f64 - 0.03129675).abs() < 1e-4);
assert!((embeddings[1][765] as f64 - -0.01246878).abs() < 1e-4);
assert!((embeddings[1][766] as f64 - -0.06240674).abs() < 1e-4);
assert!((embeddings[1][767] as f64 - -0.00590969).abs() < 1e-4);
Ok(())
}

View File

@ -11,6 +11,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head")
parser.add_argument("--prefix", help="Add a prefix on weight names")
parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part")
args = parser.parse_args()
source_file = Path(args.source_file)
@ -24,6 +26,10 @@ if __name__ == "__main__":
if args.skip_embeddings:
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split('.')[-1]
if isinstance(v, Tensor):
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')