mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Add sbert implementation for inference (#250)
* Add sbert implementation for inference * Fix clippy warnings * Refactor sentence embeddings into a dedicated pipeline * Add output_attentions and output_hidden_states to T5Config * Add sbert implementation for inference * Fix clippy warnings * Refactor sentence embeddings into a dedicated pipeline * Add output_attentions and output_hidden_states to T5Config * Improve sentence_embeddings implementation * Dedicated tokenizer config for strip_accents and add_prefix_space * Rename forward to encode_as_tensor * Remove _conf from Dense layer * Add sentence embeddings docs * Addition of remote resources and tests update * Merge feature branch and fix doctests * Add SentenceEmbeddingsBuilder<Remote> and improve remote resources * Use tch::no_grad in sentence embeddings * Updated changelog, registration of sentence embeddings integration tests Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
parent
6b20da41de
commit
4d8a298586
16
.github/workflows/continuous-integration.yml
vendored
16
.github/workflows/continuous-integration.yml
vendored
@ -100,6 +100,22 @@ jobs:
|
||||
--test pegasus
|
||||
--test gpt_neo
|
||||
|
||||
test-batch-2:
|
||||
name: Integration tests (batch 2)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --package rust-bert
|
||||
--test sentence_embeddings
|
||||
|
||||
convert-model:
|
||||
name: Model conversion test
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -2,6 +2,8 @@
|
||||
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
## [Unreleased]
|
||||
## Added
|
||||
- Support for sentence embeddings models and pipelines, based on [SentenceTransformers](https://www.sbert.net).
|
||||
|
||||
## [0.18.0] - 2022-05-29
|
||||
## Added
|
||||
|
17
examples/sentence_embeddings.rs
Normal file
17
examples/sentence_embeddings.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use rust_bert::pipelines::sentence_embeddings::{
|
||||
SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType,
|
||||
};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up sentence embeddings model
|
||||
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL12V2)
|
||||
.create_model()?;
|
||||
|
||||
// Define input
|
||||
let sentences = ["this is an example sentence", "each sentence is converted"];
|
||||
|
||||
// Generate Embeddings
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
println!("{:?}", embeddings);
|
||||
Ok(())
|
||||
}
|
25
examples/sentence_embeddings_local.rs
Normal file
25
examples/sentence_embeddings_local.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
|
||||
|
||||
/// Download model:
|
||||
/// ```sh
|
||||
/// git lfs install
|
||||
/// git -C resources clone https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2
|
||||
/// ```
|
||||
/// Prepare model:
|
||||
/// ```sh
|
||||
/// python ./utils/convert_model.py resources/all-MiniLM-L12-v2/pytorch_model.bin
|
||||
/// ```
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up sentence embeddings model
|
||||
let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2")
|
||||
.with_device(tch::Device::cuda_if_available())
|
||||
.create_model()?;
|
||||
|
||||
// Define input
|
||||
let sentences = ["this is an example sentence", "each sentence is converted"];
|
||||
|
||||
// Generate Embeddings
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
println!("{:?}", embeddings);
|
||||
Ok(())
|
||||
}
|
@ -37,6 +37,11 @@ impl AlbertModelResources {
|
||||
"albert-base-v2/model",
|
||||
"https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/model",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl AlbertConfigResources {
|
||||
@ -45,6 +50,11 @@ impl AlbertConfigResources {
|
||||
"albert-base-v2/config",
|
||||
"https://huggingface.co/albert-base-v2/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/config",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl AlbertVocabResources {
|
||||
@ -53,6 +63,11 @@ impl AlbertVocabResources {
|
||||
"albert-base-v2/spiece",
|
||||
"https://huggingface.co/albert-base-v2/resolve/main/spiece.model",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/spiece",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/spiece.model",
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
@ -1048,6 +1063,10 @@ impl AlbertForMultipleChoice {
|
||||
}
|
||||
}
|
||||
|
||||
/// # ALBERT for sentence embeddings
|
||||
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
|
||||
pub type AlbertForSentenceEmbeddings = AlbertModel;
|
||||
|
||||
/// Container for the ALBERT model output.
|
||||
pub struct AlbertOutput {
|
||||
/// Last hidden states from the model
|
||||
|
@ -59,8 +59,8 @@ mod encoder;
|
||||
|
||||
pub use albert_model::{
|
||||
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
|
||||
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||
AlbertMaskedLMOutput, AlbertModel, AlbertModelResources, AlbertOutput,
|
||||
AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput,
|
||||
AlbertForQuestionAnswering, AlbertForSentenceEmbeddings, AlbertForSequenceClassification,
|
||||
AlbertForTokenClassification, AlbertMaskedLMOutput, AlbertModel, AlbertModelResources,
|
||||
AlbertOutput, AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput,
|
||||
AlbertTokenClassificationOutput, AlbertVocabResources,
|
||||
};
|
||||
|
@ -52,6 +52,16 @@ impl BertModelResources {
|
||||
"bert-qa/model",
|
||||
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/model",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/model",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl BertConfigResources {
|
||||
@ -70,6 +80,16 @@ impl BertConfigResources {
|
||||
"bert-qa/config",
|
||||
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/config",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/config",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl BertVocabResources {
|
||||
@ -88,6 +108,16 @@ impl BertVocabResources {
|
||||
"bert-qa/vocab",
|
||||
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/vocab",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/vocab.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/vocab",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
@ -1179,6 +1209,10 @@ impl BertForQuestionAnswering {
|
||||
}
|
||||
}
|
||||
|
||||
/// # BERT for sentence embeddings
|
||||
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
|
||||
pub type BertForSentenceEmbeddings = BertModel<BertEmbeddings>;
|
||||
|
||||
/// Container for the BERT model output.
|
||||
pub struct BertModelOutput {
|
||||
/// Last hidden states from the model
|
||||
|
@ -59,8 +59,8 @@ pub(crate) mod encoder;
|
||||
|
||||
pub use bert_model::{
|
||||
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
|
||||
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
|
||||
BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources,
|
||||
BertForQuestionAnswering, BertForSentenceEmbeddings, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources,
|
||||
BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput,
|
||||
BertVocabResources,
|
||||
};
|
||||
|
@ -26,6 +26,10 @@ pub fn _tanh(x: &Tensor) -> Tensor {
|
||||
x.tanh()
|
||||
}
|
||||
|
||||
pub fn _identity(x: &Tensor) -> Tensor {
|
||||
x.shallow_clone()
|
||||
}
|
||||
|
||||
pub struct TensorFunction(Box<fn(&Tensor) -> Tensor>);
|
||||
|
||||
impl TensorFunction {
|
||||
@ -58,6 +62,8 @@ pub enum Activation {
|
||||
gelu_new,
|
||||
/// Tanh
|
||||
tanh,
|
||||
/// Identity
|
||||
identity,
|
||||
}
|
||||
|
||||
impl Activation {
|
||||
@ -69,6 +75,7 @@ impl Activation {
|
||||
Activation::gelu_new => _gelu_new,
|
||||
Activation::mish => _mish,
|
||||
Activation::tanh => _tanh,
|
||||
Activation::identity => _identity,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
@ -30,3 +30,15 @@ impl ResourceProvider for LocalResource {
|
||||
Ok(self.local_path.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LocalResource {
|
||||
fn from(local_path: PathBuf) -> Self {
|
||||
Self { local_path }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for Box<dyn ResourceProvider + Send> {
|
||||
fn from(local_path: PathBuf) -> Self {
|
||||
Box::new(LocalResource { local_path })
|
||||
}
|
||||
}
|
||||
|
@ -46,6 +46,11 @@ impl DistilBertModelResources {
|
||||
"distilbert-qa/model",
|
||||
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/model",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl DistilBertConfigResources {
|
||||
@ -64,6 +69,11 @@ impl DistilBertConfigResources {
|
||||
"distilbert-qa/config",
|
||||
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/config",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl DistilBertVocabResources {
|
||||
@ -82,6 +92,11 @@ impl DistilBertVocabResources {
|
||||
"distilbert-qa/vocab",
|
||||
"https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/vocab",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/vocab.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
@ -755,6 +770,10 @@ impl DistilBertForTokenClassification {
|
||||
}
|
||||
}
|
||||
|
||||
/// # DistilBERT for sentence embeddings
|
||||
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
|
||||
pub type DistilBertForSentenceEmbeddings = DistilBertModel;
|
||||
|
||||
/// Container for the DistilBERT masked LM model output.
|
||||
pub struct DistilBertMaskedLMOutput {
|
||||
/// Logits for the vocabulary items at each sequence position
|
||||
|
@ -60,8 +60,8 @@ mod transformer;
|
||||
|
||||
pub use distilbert_model::{
|
||||
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification, DistilBertMaskedLMOutput, DistilBertModel,
|
||||
DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||
DistilBertForSentenceEmbeddings, DistilBertForTokenClassification, DistilBertMaskedLMOutput,
|
||||
DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||
DistilBertQuestionAnsweringOutput, DistilBertSequenceClassificationOutput,
|
||||
DistilBertTokenClassificationOutput, DistilBertVocabResources,
|
||||
};
|
||||
|
16
src/lib.rs
16
src/lib.rs
@ -146,7 +146,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>2. Translation </b> </summary>
|
||||
//!
|
||||
@ -198,7 +198,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>3. Summarization </b> </summary>
|
||||
//!
|
||||
@ -249,7 +249,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>4. Dialogue Model </b> </summary>
|
||||
//!
|
||||
@ -281,7 +281,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>5. Natural Language Generation </b> </summary>
|
||||
//!
|
||||
@ -319,7 +319,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>6. Zero-shot classification </b> </summary>
|
||||
//!
|
||||
@ -402,7 +402,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>7. Sentiment analysis </b> </summary>
|
||||
//!
|
||||
@ -445,7 +445,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>8. Named Entity Recognition </b> </summary>
|
||||
//!
|
||||
@ -502,7 +502,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! </details>
|
||||
//!
|
||||
//!
|
||||
//! <details>
|
||||
//! <summary> <b>9. Part of Speech tagging </b> </summary>
|
||||
//!
|
||||
|
@ -54,22 +54,28 @@ use rust_tokenizers::vocab::{
|
||||
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq)]
|
||||
/// # Identifies the type of model
|
||||
pub enum ModelType {
|
||||
Bart,
|
||||
#[serde(alias = "bert")]
|
||||
Bert,
|
||||
#[serde(alias = "distilbert")]
|
||||
DistilBert,
|
||||
Deberta,
|
||||
DebertaV2,
|
||||
#[serde(alias = "roberta")]
|
||||
Roberta,
|
||||
XLMRoberta,
|
||||
Electra,
|
||||
Marian,
|
||||
MobileBert,
|
||||
#[serde(alias = "t5")]
|
||||
T5,
|
||||
#[serde(alias = "albert")]
|
||||
Albert,
|
||||
XLNet,
|
||||
GPT2,
|
||||
@ -309,6 +315,62 @@ impl ConfigOption {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&ConfigOption> for BertConfig {
|
||||
type Error = RustBertError;
|
||||
|
||||
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
|
||||
match config {
|
||||
ConfigOption::Bert(config) | ConfigOption::Roberta(config) => Ok(config.clone()),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a BertConfig for Bert or a RobertaConfig for Roberta!"
|
||||
.to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&ConfigOption> for DistilBertConfig {
|
||||
type Error = RustBertError;
|
||||
|
||||
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
Ok(config.clone())
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a DistilBertConfig for DistilBert!".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&ConfigOption> for AlbertConfig {
|
||||
type Error = RustBertError;
|
||||
|
||||
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
|
||||
if let ConfigOption::Albert(config) = config {
|
||||
Ok(config.clone())
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply an AlbertConfig for Albert!".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&ConfigOption> for T5Config {
|
||||
type Error = RustBertError;
|
||||
|
||||
fn try_from(config: &ConfigOption) -> Result<Self, Self::Error> {
|
||||
if let ConfigOption::T5(config) = config {
|
||||
Ok(config.clone())
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a T5Config for T5!".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerOption {
|
||||
/// Interface method to load a tokenizer from file
|
||||
pub fn from_file(
|
||||
|
@ -416,6 +416,7 @@ pub mod generation_utils;
|
||||
pub mod ner;
|
||||
pub mod pos_tagging;
|
||||
pub mod question_answering;
|
||||
pub mod sentence_embeddings;
|
||||
pub mod sentiment;
|
||||
pub mod sequence_classification;
|
||||
pub mod summarization;
|
||||
|
185
src/pipelines/sentence_embeddings/builder.rs
Normal file
185
src/pipelines/sentence_embeddings/builder.rs
Normal file
@ -0,0 +1,185 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use tch::Device;
|
||||
|
||||
use crate::pipelines::common::ModelType;
|
||||
use crate::pipelines::sentence_embeddings::{
|
||||
SentenceEmbeddingsConfig, SentenceEmbeddingsModel, SentenceEmbeddingsModulesConfig,
|
||||
};
|
||||
use crate::{Config, RustBertError};
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
pipelines::sentence_embeddings::resources::SentenceEmbeddingsModelType,
|
||||
resources::RemoteResource,
|
||||
};
|
||||
|
||||
/// # SentenceEmbeddings Model Builder
|
||||
///
|
||||
/// Allows the user to build a model from standard Sentence-Transformer files
|
||||
/// (configuration and weights).
|
||||
pub struct SentenceEmbeddingsBuilder<T> {
|
||||
device: Device,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> SentenceEmbeddingsBuilder<T> {
|
||||
pub fn with_device(mut self, device: Device) -> Self {
|
||||
self.device = device;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Local {
|
||||
model_dir: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelConfig {
|
||||
model_type: ModelType,
|
||||
}
|
||||
|
||||
impl Config for ModelConfig {}
|
||||
|
||||
impl SentenceEmbeddingsBuilder<Local> {
|
||||
pub fn local<P: Into<PathBuf>>(model_dir: P) -> Self {
|
||||
Self {
|
||||
device: Device::cuda_if_available(),
|
||||
inner: Local {
|
||||
model_dir: model_dir.into(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_model(self) -> Result<SentenceEmbeddingsModel, RustBertError> {
|
||||
let model_dir = self.inner.model_dir;
|
||||
|
||||
let modules_config = model_dir.join("modules.json");
|
||||
let modules = SentenceEmbeddingsModulesConfig::from_file(&modules_config).validate()?;
|
||||
|
||||
let transformer_config = model_dir.join("config.json");
|
||||
let transformer_type = ModelConfig::from_file(&transformer_config).model_type;
|
||||
let transformer_weights = model_dir.join("rust_model.ot");
|
||||
|
||||
let pooling_config = model_dir
|
||||
.join(&modules.pooling_module().path)
|
||||
.join("config.json");
|
||||
|
||||
let (dense_config, dense_weights) = modules
|
||||
.dense_module()
|
||||
.map(|m| {
|
||||
(
|
||||
Some(model_dir.join(&m.path).join("config.json")),
|
||||
Some(model_dir.join(&m.path).join("rust_model.ot")),
|
||||
)
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
|
||||
let tokenizer_config = model_dir.join("tokenizer_config.json");
|
||||
let sentence_bert_config = model_dir.join("sentence_bert_config.json");
|
||||
let (tokenizer_vocab, tokenizer_merges) = match transformer_type {
|
||||
ModelType::Bert | ModelType::DistilBert => (model_dir.join("vocab.txt"), None),
|
||||
ModelType::Roberta => (
|
||||
model_dir.join("vocab.json"),
|
||||
Some(model_dir.join("merges.txt")),
|
||||
),
|
||||
ModelType::Albert => (model_dir.join("spiece.model"), None),
|
||||
ModelType::T5 => (model_dir.join("spiece.model"), None),
|
||||
_ => {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Unsupported transformer model {:?} for Sentence Embeddings",
|
||||
transformer_type
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let config = SentenceEmbeddingsConfig {
|
||||
modules_config_resource: modules_config.into(),
|
||||
transformer_type,
|
||||
transformer_config_resource: transformer_config.into(),
|
||||
transformer_weights_resource: transformer_weights.into(),
|
||||
pooling_config_resource: pooling_config.into(),
|
||||
dense_config_resource: dense_config.map(|r| r.into()),
|
||||
dense_weights_resource: dense_weights.map(|r| r.into()),
|
||||
sentence_bert_config_resource: sentence_bert_config.into(),
|
||||
tokenizer_config_resource: tokenizer_config.into(),
|
||||
tokenizer_vocab_resource: tokenizer_vocab.into(),
|
||||
tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()),
|
||||
device: self.device,
|
||||
};
|
||||
|
||||
SentenceEmbeddingsModel::new(config)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
pub struct Remote {
|
||||
config: SentenceEmbeddingsConfig,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl SentenceEmbeddingsBuilder<Remote> {
|
||||
pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self {
|
||||
Self {
|
||||
device: Device::cuda_if_available(),
|
||||
inner: Remote {
|
||||
config: SentenceEmbeddingsConfig::from(model_type),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn modules_config(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.modules_config_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn transformer_config(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.transformer_config_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn transformer_weights(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.transformer_weights_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn pooling_config(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.pooling_config_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn dense_config(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.dense_config_resource = Some(Box::new(resource));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn dense_weights(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.dense_weights_resource = Some(Box::new(resource));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn sentence_bert_config(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.sentence_bert_config_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tokenizer_config(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.tokenizer_config_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tokenizer_vocab(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.tokenizer_vocab_resource = Box::new(resource);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tokenizer_merges(mut self, resource: RemoteResource) -> Self {
|
||||
self.inner.config.tokenizer_merges_resource = Some(Box::new(resource));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn create_model(self) -> Result<SentenceEmbeddingsModel, RustBertError> {
|
||||
SentenceEmbeddingsModel::new(self.inner.config)
|
||||
}
|
||||
}
|
429
src/pipelines/sentence_embeddings/config.rs
Normal file
429
src/pipelines/sentence_embeddings/config.rs
Normal file
@ -0,0 +1,429 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tch::Device;
|
||||
|
||||
use crate::pipelines::common::ModelType;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::{Config, RustBertError};
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources},
|
||||
bert::{BertConfigResources, BertModelResources, BertVocabResources},
|
||||
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
|
||||
pipelines::sentence_embeddings::resources::{
|
||||
SentenceEmbeddingsConfigResources, SentenceEmbeddingsModelType,
|
||||
SentenceEmbeddingsModulesConfigResources, SentenceEmbeddingsPoolingConfigResources,
|
||||
SentenceEmbeddingsTokenizerConfigResources,
|
||||
},
|
||||
pipelines::sentence_embeddings::{
|
||||
SentenceEmbeddingsDenseConfigResources, SentenceEmbeddingsDenseResources,
|
||||
},
|
||||
resources::RemoteResource,
|
||||
roberta::{
|
||||
RobertaConfigResources, RobertaMergesResources, RobertaModelResources,
|
||||
RobertaVocabResources,
|
||||
},
|
||||
t5::{T5ConfigResources, T5ModelResources, T5VocabResources},
|
||||
};
|
||||
|
||||
/// # Configuration for sentence embeddings
|
||||
///
|
||||
/// Contains information regarding the transformer model to load, the optional extra
|
||||
/// layers, and device to place the model on.
|
||||
pub struct SentenceEmbeddingsConfig {
|
||||
/// Modules configuration resource, contains layers definition
|
||||
pub modules_config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Transformer model type
|
||||
pub transformer_type: ModelType,
|
||||
/// Transformer model configuration resource
|
||||
pub transformer_config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Transformer weights resource
|
||||
pub transformer_weights_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Pooling layer configuration resource
|
||||
pub pooling_config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Optional dense layer configuration resource
|
||||
pub dense_config_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
/// Optional dense layer weights resource
|
||||
pub dense_weights_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
/// Sentence BERT specific configuration resource
|
||||
pub sentence_bert_config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Transformer's tokenizer configuration resource
|
||||
pub tokenizer_config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Transformer's tokenizer vocab resource
|
||||
pub tokenizer_vocab_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Optional transformer's tokenizer merges resource
|
||||
pub tokenizer_merges_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
/// Device to place the transformer model on
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
|
||||
fn from(model_type: SentenceEmbeddingsModelType) -> Self {
|
||||
match model_type {
|
||||
SentenceEmbeddingsModelType::DistiluseBaseMultilingualCased => SentenceEmbeddingsConfig {
|
||||
modules_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsModulesConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
transformer_type: ModelType::DistilBert,
|
||||
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsPoolingConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
dense_config_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsDenseConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
))),
|
||||
dense_weights_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsDenseResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
))),
|
||||
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsTokenizerConfigResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTILUSE_BASE_MULTILINGUAL_CASED,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::BertBaseNliMeanTokens => SentenceEmbeddingsConfig {
|
||||
modules_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsModulesConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
transformer_type: ModelType::Bert,
|
||||
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsPoolingConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
dense_config_resource: None,
|
||||
dense_weights_resource: None,
|
||||
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsTokenizerConfigResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertVocabResources::BERT_BASE_NLI_MEAN_TOKENS,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::AllMiniLmL12V2 => SentenceEmbeddingsConfig {
|
||||
modules_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsModulesConfigResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
transformer_type: ModelType::Bert,
|
||||
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertConfigResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsPoolingConfigResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
dense_config_resource: None,
|
||||
dense_weights_resource: None,
|
||||
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsConfigResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsTokenizerConfigResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertVocabResources::ALL_MINI_LM_L12_V2,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig {
|
||||
modules_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsModulesConfigResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
transformer_type: ModelType::Roberta,
|
||||
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsPoolingConfigResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
dense_config_resource: None,
|
||||
dense_weights_resource: None,
|
||||
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsConfigResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsTokenizerConfigResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ALL_DISTILROBERTA_V1,
|
||||
)),
|
||||
tokenizer_merges_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ALL_DISTILROBERTA_V1,
|
||||
))),
|
||||
device: Device::cuda_if_available(),
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2 => SentenceEmbeddingsConfig {
|
||||
modules_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsModulesConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
transformer_type: ModelType::Albert,
|
||||
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
|
||||
AlbertModelResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsPoolingConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
dense_config_resource: None,
|
||||
dense_weights_resource: None,
|
||||
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsTokenizerConfigResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::PARAPHRASE_ALBERT_SMALL_V2,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
},
|
||||
|
||||
SentenceEmbeddingsModelType::SentenceT5Base => SentenceEmbeddingsConfig {
|
||||
modules_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsModulesConfigResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
transformer_type: ModelType::T5,
|
||||
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
T5ConfigResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
|
||||
T5ModelResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsPoolingConfigResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
dense_config_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsDenseConfigResources::SENTENCE_T5_BASE,
|
||||
))),
|
||||
dense_weights_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsDenseResources::SENTENCE_T5_BASE,
|
||||
))),
|
||||
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsConfigResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
SentenceEmbeddingsTokenizerConfigResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
|
||||
T5VocabResources::SENTENCE_T5_BASE,
|
||||
)),
|
||||
tokenizer_merges_resource: None,
|
||||
device: Device::cuda_if_available(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the modules that define the model's layers
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SentenceEmbeddingsModulesConfig(pub Vec<SentenceEmbeddingsModuleConfig>);
|
||||
|
||||
impl std::ops::Deref for SentenceEmbeddingsModulesConfig {
|
||||
type Target = Vec<SentenceEmbeddingsModuleConfig>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::DerefMut for SentenceEmbeddingsModulesConfig {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<SentenceEmbeddingsModuleConfig>> for SentenceEmbeddingsModulesConfig {
|
||||
fn from(source: Vec<SentenceEmbeddingsModuleConfig>) -> Self {
|
||||
Self(source)
|
||||
}
|
||||
}
|
||||
|
||||
impl Config for SentenceEmbeddingsModulesConfig {}
|
||||
|
||||
impl SentenceEmbeddingsModulesConfig {
|
||||
pub fn validate(self) -> Result<Self, RustBertError> {
|
||||
match self.get(0) {
|
||||
Some(SentenceEmbeddingsModuleConfig {
|
||||
module_type: SentenceEmbeddingsModuleType::Transformer,
|
||||
..
|
||||
}) => (),
|
||||
Some(_) => {
|
||||
return Err(RustBertError::InvalidConfigurationError(
|
||||
"First module defined in modules.json must be a Transformer".to_string(),
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::InvalidConfigurationError(
|
||||
"No modules found in modules.json".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match self.get(1) {
|
||||
Some(SentenceEmbeddingsModuleConfig {
|
||||
module_type: SentenceEmbeddingsModuleType::Pooling,
|
||||
..
|
||||
}) => (),
|
||||
Some(_) => {
|
||||
return Err(RustBertError::InvalidConfigurationError(
|
||||
"Second module defined in modules.json must be a Pooling".to_string(),
|
||||
));
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::InvalidConfigurationError(
|
||||
"Pooling module not found in second position in modules.json".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn transformer_module(&self) -> &SentenceEmbeddingsModuleConfig {
|
||||
self.get(0).as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn pooling_module(&self) -> &SentenceEmbeddingsModuleConfig {
|
||||
self.get(1).as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn dense_module(&self) -> Option<&SentenceEmbeddingsModuleConfig> {
|
||||
for i in 2..=3 {
|
||||
if let Some(SentenceEmbeddingsModuleConfig {
|
||||
module_type: SentenceEmbeddingsModuleType::Dense,
|
||||
..
|
||||
}) = self.get(i)
|
||||
{
|
||||
return self.get(i);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn has_normalization(&self) -> bool {
|
||||
for i in 2..=3 {
|
||||
if let Some(SentenceEmbeddingsModuleConfig {
|
||||
module_type: SentenceEmbeddingsModuleType::Normalize,
|
||||
..
|
||||
}) = self.get(i)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration defining a single module (model's layer)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SentenceEmbeddingsModuleConfig {
|
||||
pub idx: usize,
|
||||
pub name: String,
|
||||
pub path: String,
|
||||
#[serde(rename = "type")]
|
||||
#[serde(with = "serde_sentence_embeddings_module_type")]
|
||||
pub module_type: SentenceEmbeddingsModuleType,
|
||||
}
|
||||
|
||||
/// Available module types, based on Sentence-Transformers
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum SentenceEmbeddingsModuleType {
|
||||
Transformer,
|
||||
Pooling,
|
||||
Dense,
|
||||
Normalize,
|
||||
}
|
||||
|
||||
mod serde_sentence_embeddings_module_type {
|
||||
use super::SentenceEmbeddingsModuleType;
|
||||
use serde::{de, Deserializer, Serializer};
|
||||
|
||||
pub fn serialize<S>(
|
||||
module_type: &SentenceEmbeddingsModuleType,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&format!("sentence_transformers.models.{:?}", module_type))
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<SentenceEmbeddingsModuleType, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct SentenceEmbeddingsModuleTypeVisitor;
|
||||
|
||||
impl de::Visitor<'_> for SentenceEmbeddingsModuleTypeVisitor {
|
||||
type Value = SentenceEmbeddingsModuleType;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a sentence embeddings module type")
|
||||
}
|
||||
|
||||
fn visit_str<E: de::Error>(self, s: &str) -> Result<Self::Value, E> {
|
||||
s.split('.')
|
||||
.last()
|
||||
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_string())))
|
||||
.transpose()
|
||||
.map_err(de::Error::custom)?
|
||||
.ok_or_else(|| format!("Invalid SentenceEmbeddingsModuleType: {}", s))
|
||||
.map_err(de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(SentenceEmbeddingsModuleTypeVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for Sentence-Transformers specific parameters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SentenceEmbeddingsSentenceBertConfig {
|
||||
pub max_seq_length: usize,
|
||||
pub do_lower_case: bool,
|
||||
}
|
||||
|
||||
impl Config for SentenceEmbeddingsSentenceBertConfig {}
|
||||
|
||||
/// Configuration for transformer's tokenizer
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SentenceEmbeddingsTokenizerConfig {
|
||||
pub add_prefix_space: Option<bool>,
|
||||
pub strip_accents: Option<bool>,
|
||||
}
|
||||
|
||||
impl Config for SentenceEmbeddingsTokenizerConfig {}
|
151
src/pipelines/sentence_embeddings/layers.rs
Normal file
151
src/pipelines/sentence_embeddings/layers.rs
Normal file
@ -0,0 +1,151 @@
|
||||
use std::path::Path;
|
||||
|
||||
use serde::{de, Deserialize, Deserializer};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
use crate::common::activations::{Activation, TensorFunction};
|
||||
use crate::{Config, RustBertError};
|
||||
|
||||
/// Configuration for [`Pooling`](Pooling) layer.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PoolingConfig {
|
||||
/// Dimensions for the word embeddings
|
||||
pub word_embedding_dimension: i64,
|
||||
/// Use the first token (CLS token) as text representations
|
||||
pub pooling_mode_cls_token: bool,
|
||||
/// Use max in each dimension over all tokens
|
||||
pub pooling_mode_max_tokens: bool,
|
||||
/// Perform mean-pooling
|
||||
pub pooling_mode_mean_tokens: bool,
|
||||
/// Perform mean-pooling, but devide by sqrt(input_length)
|
||||
pub pooling_mode_mean_sqrt_len_tokens: bool,
|
||||
}
|
||||
|
||||
impl Config for PoolingConfig {}
|
||||
|
||||
/// Performs pooling (max or mean) on the token embeddings.
|
||||
///
|
||||
/// Using pooling, it generates from a variable sized sentence a fixed sized sentence
|
||||
/// embedding. You can concatenate multiple poolings together.
|
||||
pub struct Pooling {
|
||||
conf: PoolingConfig,
|
||||
}
|
||||
|
||||
impl Pooling {
|
||||
pub fn new(conf: PoolingConfig) -> Pooling {
|
||||
Pooling { conf }
|
||||
}
|
||||
|
||||
pub fn forward(&self, mut token_embeddings: Tensor, attention_mask: &Tensor) -> Tensor {
|
||||
let mut output_vectors = Vec::new();
|
||||
|
||||
if self.conf.pooling_mode_cls_token {
|
||||
let cls_token = token_embeddings.select(1, 0); // Take first token by default
|
||||
output_vectors.push(cls_token);
|
||||
}
|
||||
|
||||
if self.conf.pooling_mode_max_tokens {
|
||||
let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings);
|
||||
// Set padding tokens to large negative value
|
||||
token_embeddings = token_embeddings.masked_fill_(&input_mask_expanded.eq(0), -1e9);
|
||||
let max_over_time = token_embeddings.max_dim(1, true).0;
|
||||
output_vectors.push(max_over_time);
|
||||
}
|
||||
|
||||
if self.conf.pooling_mode_mean_tokens || self.conf.pooling_mode_mean_sqrt_len_tokens {
|
||||
let input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(&token_embeddings);
|
||||
let sum_embeddings =
|
||||
(token_embeddings * &input_mask_expanded).sum_dim_intlist(&[1], false, Kind::Float);
|
||||
|
||||
let sum_mask = input_mask_expanded.sum_dim_intlist(&[1], false, Kind::Float);
|
||||
let sum_mask = sum_mask.clamp_min(10e-9);
|
||||
|
||||
if self.conf.pooling_mode_mean_tokens {
|
||||
output_vectors.push(&sum_embeddings / &sum_mask);
|
||||
}
|
||||
if self.conf.pooling_mode_mean_sqrt_len_tokens {
|
||||
output_vectors.push(sum_embeddings / sum_mask.sqrt());
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::cat(&output_vectors, 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for [`Dense`](Dense) layer.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DenseConfig {
|
||||
/// Size of the input dimension
|
||||
pub in_features: i64,
|
||||
/// Output size
|
||||
pub out_features: i64,
|
||||
/// Add a bias vector
|
||||
pub bias: bool,
|
||||
/// Activation function applied on output
|
||||
#[serde(deserialize_with = "last_part")]
|
||||
pub activation_function: Activation,
|
||||
}
|
||||
|
||||
impl Config for DenseConfig {}
|
||||
|
||||
/// Split the given string on `.` and try to construct an `Activation` from the last part
|
||||
fn last_part<'de, D>(deserializer: D) -> Result<Activation, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let activation = String::deserialize(deserializer)?;
|
||||
activation
|
||||
.split('.')
|
||||
.last()
|
||||
.map(|s| serde_json::from_value(serde_json::Value::String(s.to_lowercase())))
|
||||
.transpose()
|
||||
.map_err(de::Error::custom)?
|
||||
.ok_or_else(|| format!("Invalid Activation: {}", activation))
|
||||
.map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
/// Feed-forward function with activiation function.
|
||||
///
|
||||
/// This layer takes a fixed-sized sentence embedding and passes it through a
|
||||
/// feed-forward layer. Can be used to generate deep averaging networs (DAN).
|
||||
pub struct Dense {
|
||||
linear: nn::Linear,
|
||||
activation: TensorFunction,
|
||||
_var_store: nn::VarStore,
|
||||
}
|
||||
|
||||
impl Dense {
|
||||
pub fn new<P: AsRef<Path>>(
|
||||
dense_conf: DenseConfig,
|
||||
dense_weights: P,
|
||||
device: Device,
|
||||
) -> Result<Dense, RustBertError> {
|
||||
let mut vs_dense = nn::VarStore::new(device);
|
||||
|
||||
let linear_conf = nn::LinearConfig {
|
||||
ws_init: nn::Init::Const(0.),
|
||||
bs_init: Some(nn::Init::Const(0.)),
|
||||
bias: dense_conf.bias,
|
||||
};
|
||||
let linear = nn::linear(
|
||||
&vs_dense.root(),
|
||||
dense_conf.in_features,
|
||||
dense_conf.out_features,
|
||||
linear_conf,
|
||||
);
|
||||
|
||||
let activation = dense_conf.activation_function.get_function();
|
||||
|
||||
vs_dense.load(dense_weights)?;
|
||||
|
||||
Ok(Dense {
|
||||
linear,
|
||||
activation,
|
||||
_var_store: vs_dense,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Tensor {
|
||||
self.activation.get_fn()(&x.apply(&self.linear))
|
||||
}
|
||||
}
|
64
src/pipelines/sentence_embeddings/mod.rs
Normal file
64
src/pipelines/sentence_embeddings/mod.rs
Normal file
@ -0,0 +1,64 @@
|
||||
//! # Sentence Embeddings pipeline
|
||||
//!
|
||||
//! Compute sentence/text embeddings that can be compared (e.g. with
|
||||
//! cosine-similarity) to find sentences with a similar meaning. This can be useful for
|
||||
//! semantic textual similar, semantic search, or paraphrase mining.
|
||||
//!
|
||||
//! The implementation is based on [Sentence-Transformers][sbert] and pretrained models
|
||||
//! available on [Hugging Face Hub][sbert-hub] can be used. It's however necessary to
|
||||
//! convert them using the script `utils/convert_model.py` beforehand, see
|
||||
//! `tests/sentence_embeddings.rs` for such examples.
|
||||
//!
|
||||
//! [sbert]: https://sbert.net/
|
||||
//! [sbert-hub]: https://huggingface.co/sentence-transformers/
|
||||
//!
|
||||
//! Basic usage is as follows:
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
|
||||
//!
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//! let model = SentenceEmbeddingsBuilder::local("local/path/to/distiluse-base-multilingual-cased")
|
||||
//! .with_device(tch::Device::cuda_if_available())
|
||||
//! .create_model()?;
|
||||
//!
|
||||
//! let sentences = ["This is an example sentence", "Each sentence is converted"];
|
||||
//! let embeddings = model.encode(&sentences)?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
pub mod builder;
|
||||
mod config;
|
||||
pub mod layers;
|
||||
mod pipeline;
|
||||
mod resources;
|
||||
|
||||
pub use builder::SentenceEmbeddingsBuilder;
|
||||
pub use config::{
|
||||
SentenceEmbeddingsConfig, SentenceEmbeddingsModuleConfig, SentenceEmbeddingsModuleType,
|
||||
SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
|
||||
SentenceEmbeddingsTokenizerConfig,
|
||||
};
|
||||
pub use pipeline::{
|
||||
SentenceEmbeddingsModel, SentenceEmbeddingsModelOuput, SentenceEmbeddingsOption,
|
||||
SentenceEmbeddingsTokenizerOuput,
|
||||
};
|
||||
|
||||
pub use resources::{
|
||||
SentenceEmbeddingsConfigResources, SentenceEmbeddingsDenseConfigResources,
|
||||
SentenceEmbeddingsDenseResources, SentenceEmbeddingsModelType,
|
||||
SentenceEmbeddingsModulesConfigResources, SentenceEmbeddingsPoolingConfigResources,
|
||||
SentenceEmbeddingsTokenizerConfigResources,
|
||||
};
|
||||
|
||||
/// Length = sequence length
|
||||
pub type Attention = Vec<f32>;
|
||||
/// Length = sequence length
|
||||
pub type AttentionHead = Vec<Attention>;
|
||||
/// Length = number of heads per attention layer
|
||||
pub type AttentionLayer = Vec<AttentionHead>;
|
||||
/// Length = number of attention layers
|
||||
pub type AttentionOutput = Vec<AttentionLayer>;
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
461
src/pipelines/sentence_embeddings/pipeline.rs
Normal file
461
src/pipelines/sentence_embeddings/pipeline.rs
Normal file
@ -0,0 +1,461 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::convert::TryInto;
|
||||
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
use crate::albert::AlbertForSentenceEmbeddings;
|
||||
use crate::bert::BertForSentenceEmbeddings;
|
||||
use crate::distilbert::DistilBertForSentenceEmbeddings;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::pipelines::sentence_embeddings::layers::{Dense, DenseConfig, Pooling, PoolingConfig};
|
||||
use crate::pipelines::sentence_embeddings::{
|
||||
AttentionHead, AttentionLayer, AttentionOutput, Embedding, SentenceEmbeddingsConfig,
|
||||
SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
|
||||
SentenceEmbeddingsTokenizerConfig,
|
||||
};
|
||||
use crate::roberta::RobertaForSentenceEmbeddings;
|
||||
use crate::t5::T5ForSentenceEmbeddings;
|
||||
use crate::{Config, RustBertError};
|
||||
|
||||
/// # Abstraction that holds one particular sentence embeddings model, for any of the supported models
|
||||
pub enum SentenceEmbeddingsOption {
|
||||
/// Bert for Sentence Embeddings
|
||||
Bert(BertForSentenceEmbeddings),
|
||||
/// DistilBert for Sentence Embeddings
|
||||
DistilBert(DistilBertForSentenceEmbeddings),
|
||||
/// Roberta for Sentence Embeddings
|
||||
Roberta(RobertaForSentenceEmbeddings),
|
||||
/// Albert for Sentence Embeddings
|
||||
Albert(AlbertForSentenceEmbeddings),
|
||||
/// T5 for Sentence Embeddings
|
||||
T5(T5ForSentenceEmbeddings),
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsOption {
|
||||
/// Instantiate a new sentence embeddings transformer of the supplied type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `transformer_type` - `ModelType` indicating the transformer model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. rust_model.ot)
|
||||
/// * `config` - A configuration (the transformer model type of the configuration must be compatible with the value for `transformer_type`)
|
||||
pub fn new<'p, P>(
|
||||
transformer_type: ModelType,
|
||||
p: P,
|
||||
config: &ConfigOption,
|
||||
) -> Result<Self, RustBertError>
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
use SentenceEmbeddingsOption::*;
|
||||
|
||||
let option = match transformer_type {
|
||||
ModelType::Bert => Bert(BertForSentenceEmbeddings::new(p, &(config.try_into()?))),
|
||||
ModelType::DistilBert => DistilBert(DistilBertForSentenceEmbeddings::new(
|
||||
p,
|
||||
&(config.try_into()?),
|
||||
)),
|
||||
ModelType::Roberta => Roberta(RobertaForSentenceEmbeddings::new_with_optional_pooler(
|
||||
p,
|
||||
&(config.try_into()?),
|
||||
false,
|
||||
)),
|
||||
ModelType::Albert => Albert(AlbertForSentenceEmbeddings::new(p, &(config.try_into()?))),
|
||||
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
|
||||
_ => {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Unsupported transformer model {:?} for Sentence Embeddings",
|
||||
transformer_type
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(option)
|
||||
}
|
||||
|
||||
/// Interface method to forward() of the particular transformer models.
|
||||
pub fn forward(
|
||||
&self,
|
||||
tokens_ids: &Tensor,
|
||||
tokens_masks: &Tensor,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
|
||||
match self {
|
||||
Self::Bert(transformer) => transformer
|
||||
.forward_t(
|
||||
Some(tokens_ids),
|
||||
Some(tokens_masks),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.map(|transformer_output| {
|
||||
(
|
||||
transformer_output.hidden_state,
|
||||
transformer_output.all_attentions,
|
||||
)
|
||||
}),
|
||||
Self::DistilBert(transformer) => transformer
|
||||
.forward_t(Some(tokens_ids), Some(tokens_masks), None, false)
|
||||
.map(|transformer_output| {
|
||||
(
|
||||
transformer_output.hidden_state,
|
||||
transformer_output.all_attentions,
|
||||
)
|
||||
}),
|
||||
Self::Roberta(transformer) => transformer
|
||||
.forward_t(
|
||||
Some(tokens_ids),
|
||||
Some(tokens_masks),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.map(|transformer_output| {
|
||||
(
|
||||
transformer_output.hidden_state,
|
||||
transformer_output.all_attentions,
|
||||
)
|
||||
}),
|
||||
Self::Albert(transformer) => transformer
|
||||
.forward_t(
|
||||
Some(tokens_ids),
|
||||
Some(tokens_masks),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.map(|transformer_output| {
|
||||
(
|
||||
transformer_output.hidden_state,
|
||||
transformer_output.all_attentions.map(|attentions| {
|
||||
attentions
|
||||
.into_iter()
|
||||
.map(|tensors| {
|
||||
let num_inner_groups = tensors.len() as f64;
|
||||
tensors.into_iter().sum::<Tensor>() / num_inner_groups
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
)
|
||||
}),
|
||||
Self::T5(transformer) => transformer.forward(tokens_ids, tokens_masks),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # SentenceEmbeddingsModel to perform sentence embeddings
|
||||
///
|
||||
/// It is made of the following blocks:
|
||||
/// - `transformer`: Base transformer model
|
||||
/// - `pooling`: Pooling layer
|
||||
/// - `dense` _(optional)_: Linear (feed forward) layer
|
||||
/// - `normalization` _(optional)_: Embeddings normalization
|
||||
pub struct SentenceEmbeddingsModel {
|
||||
sentence_bert_config: SentenceEmbeddingsSentenceBertConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
tokenizer_truncation_strategy: TruncationStrategy,
|
||||
var_store: nn::VarStore,
|
||||
transformer: SentenceEmbeddingsOption,
|
||||
transformer_config: ConfigOption,
|
||||
pooling_layer: Pooling,
|
||||
dense_layer: Option<Dense>,
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsModel {
|
||||
/// Build a new `SentenceEmbeddingsModel`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - `SentenceEmbeddingsConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
|
||||
pub fn new(config: SentenceEmbeddingsConfig) -> Result<Self, RustBertError> {
|
||||
let SentenceEmbeddingsConfig {
|
||||
modules_config_resource,
|
||||
sentence_bert_config_resource,
|
||||
tokenizer_config_resource,
|
||||
tokenizer_vocab_resource,
|
||||
tokenizer_merges_resource,
|
||||
transformer_type,
|
||||
transformer_config_resource,
|
||||
transformer_weights_resource,
|
||||
pooling_config_resource,
|
||||
dense_config_resource,
|
||||
dense_weights_resource,
|
||||
device,
|
||||
} = config;
|
||||
|
||||
let modules =
|
||||
SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
|
||||
.validate()?;
|
||||
|
||||
// Setup tokenizer
|
||||
|
||||
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
|
||||
tokenizer_config_resource.get_local_path()?,
|
||||
);
|
||||
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
|
||||
sentence_bert_config_resource.get_local_path()?,
|
||||
);
|
||||
let tokenizer = TokenizerOption::from_file(
|
||||
transformer_type,
|
||||
tokenizer_vocab_resource
|
||||
.get_local_path()?
|
||||
.to_string_lossy()
|
||||
.as_ref(),
|
||||
tokenizer_merges_resource
|
||||
.as_ref()
|
||||
.map(|resource| resource.get_local_path())
|
||||
.transpose()?
|
||||
.map(|path| path.to_string_lossy().into_owned())
|
||||
.as_deref(),
|
||||
sentence_bert_config.do_lower_case,
|
||||
tokenizer_config.strip_accents,
|
||||
tokenizer_config.add_prefix_space,
|
||||
)?;
|
||||
|
||||
// Setup transformer
|
||||
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
let transformer_config = ConfigOption::from_file(
|
||||
transformer_type,
|
||||
transformer_config_resource.get_local_path()?,
|
||||
);
|
||||
let transformer = SentenceEmbeddingsOption::new(
|
||||
transformer_type,
|
||||
&var_store.root(),
|
||||
&transformer_config,
|
||||
)?;
|
||||
var_store.load(transformer_weights_resource.get_local_path()?)?;
|
||||
|
||||
// Setup pooling layer
|
||||
|
||||
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
|
||||
let pooling_layer = Pooling::new(pooling_config);
|
||||
|
||||
// Setup dense layer
|
||||
|
||||
let dense_layer = if modules.dense_module().is_some() {
|
||||
let dense_config =
|
||||
DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
|
||||
Some(Dense::new(
|
||||
dense_config,
|
||||
dense_weights_resource.unwrap().get_local_path()?,
|
||||
device,
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let normalize_embeddings = modules.has_normalization();
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
sentence_bert_config,
|
||||
tokenizer_truncation_strategy: TruncationStrategy::LongestFirst,
|
||||
var_store,
|
||||
transformer,
|
||||
transformer_config,
|
||||
pooling_layer,
|
||||
dense_layer,
|
||||
normalize_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
/// Sets the tokenizer's truncation strategy
|
||||
pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) {
|
||||
self.tokenizer_truncation_strategy = truncation_strategy;
|
||||
}
|
||||
|
||||
/// Tokenizes the inputs
|
||||
pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOuput
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokenized_input = self.tokenizer.encode_list(
|
||||
inputs,
|
||||
self.sentence_bert_config.max_seq_length,
|
||||
&self.tokenizer_truncation_strategy,
|
||||
0,
|
||||
);
|
||||
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let pad_token_id = self.tokenizer.get_pad_id().unwrap_or(0);
|
||||
let tokens_ids = tokenized_input
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
let mut token_ids = input.token_ids;
|
||||
token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]);
|
||||
token_ids
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let tokens_masks = tokens_ids
|
||||
.iter()
|
||||
.map(|input| {
|
||||
Tensor::of_slice(
|
||||
&input
|
||||
.iter()
|
||||
.map(|&e| if e == pad_token_id { 0_i64 } else { 1_i64 })
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let tokens_ids = tokens_ids
|
||||
.into_iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
SentenceEmbeddingsTokenizerOuput {
|
||||
tokens_ids,
|
||||
tokens_masks,
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes sentence embeddings, outputs `Tensor`.
|
||||
pub fn encode_as_tensor<S>(
|
||||
&self,
|
||||
inputs: &[S],
|
||||
) -> Result<SentenceEmbeddingsModelOuput, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let SentenceEmbeddingsTokenizerOuput {
|
||||
tokens_ids,
|
||||
tokens_masks,
|
||||
} = self.tokenize(inputs);
|
||||
let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device());
|
||||
let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device());
|
||||
|
||||
let (tokens_embeddings, all_attentions) =
|
||||
tch::no_grad(|| self.transformer.forward(&tokens_ids, &tokens_masks))?;
|
||||
|
||||
let mean_pool =
|
||||
tch::no_grad(|| self.pooling_layer.forward(tokens_embeddings, &tokens_masks));
|
||||
let maybe_linear = if let Some(dense_layer) = &self.dense_layer {
|
||||
tch::no_grad(|| dense_layer.forward(&mean_pool))
|
||||
} else {
|
||||
mean_pool
|
||||
};
|
||||
let maybe_normalized = if self.normalize_embeddings {
|
||||
let norm = &maybe_linear
|
||||
.norm_scalaropt_dim(2, &[1], true)
|
||||
.clamp_min(1e-12)
|
||||
.expand_as(&maybe_linear);
|
||||
maybe_linear / norm
|
||||
} else {
|
||||
maybe_linear
|
||||
};
|
||||
|
||||
Ok(SentenceEmbeddingsModelOuput {
|
||||
embeddings: maybe_normalized,
|
||||
all_attentions,
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes sentence embeddings.
|
||||
pub fn encode<S>(&self, inputs: &[S]) -> Result<Vec<Embedding>, RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let SentenceEmbeddingsModelOuput { embeddings, .. } = self.encode_as_tensor(inputs)?;
|
||||
Ok(Vec::from(embeddings))
|
||||
}
|
||||
|
||||
fn nb_layers(&self) -> usize {
|
||||
use SentenceEmbeddingsOption::*;
|
||||
match (&self.transformer, &self.transformer_config) {
|
||||
(Bert(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
|
||||
(Bert(_), _) => unreachable!(),
|
||||
(DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_layers as usize,
|
||||
(DistilBert(_), _) => unreachable!(),
|
||||
(Roberta(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
|
||||
(Roberta(_), _) => unreachable!(),
|
||||
(Albert(_), ConfigOption::Albert(conf)) => conf.num_hidden_layers as usize,
|
||||
(Albert(_), _) => unreachable!(),
|
||||
(T5(_), ConfigOption::T5(conf)) => conf.num_layers as usize,
|
||||
(T5(_), _) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn nb_heads(&self) -> usize {
|
||||
use SentenceEmbeddingsOption::*;
|
||||
match (&self.transformer, &self.transformer_config) {
|
||||
(Bert(_), ConfigOption::Bert(conf)) => conf.num_attention_heads as usize,
|
||||
(Bert(_), _) => unreachable!(),
|
||||
(DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_heads as usize,
|
||||
(DistilBert(_), _) => unreachable!(),
|
||||
(Roberta(_), ConfigOption::Roberta(conf)) => conf.num_attention_heads as usize,
|
||||
(Roberta(_), _) => unreachable!(),
|
||||
(Albert(_), ConfigOption::Albert(conf)) => conf.num_attention_heads as usize,
|
||||
(Albert(_), _) => unreachable!(),
|
||||
(T5(_), ConfigOption::T5(conf)) => conf.num_heads as usize,
|
||||
(T5(_), _) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes sentence embeddings, also outputs `AttentionOutput`s.
|
||||
pub fn encode_with_attention<S>(
|
||||
&self,
|
||||
inputs: &[S],
|
||||
) -> Result<(Vec<Embedding>, Vec<AttentionOutput>), RustBertError>
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let SentenceEmbeddingsModelOuput {
|
||||
embeddings,
|
||||
all_attentions,
|
||||
} = self.encode_as_tensor(inputs)?;
|
||||
|
||||
let embeddings = Vec::from(embeddings);
|
||||
let all_attentions = all_attentions.ok_or_else(|| {
|
||||
RustBertError::InvalidConfigurationError("No attention outputted".into())
|
||||
})?;
|
||||
|
||||
let attention_outputs = (0..inputs.len() as i64)
|
||||
.map(|i| {
|
||||
let mut attention_output = AttentionOutput::with_capacity(self.nb_layers());
|
||||
for layer in all_attentions.iter() {
|
||||
let mut attention_layer = AttentionLayer::with_capacity(self.nb_heads());
|
||||
for head in 0..self.nb_heads() {
|
||||
let attention_slice = layer
|
||||
.slice(0, i, i + 1, 1)
|
||||
.slice(1, head as i64, head as i64 + 1, 1)
|
||||
.squeeze();
|
||||
let attention_head = AttentionHead::from(attention_slice);
|
||||
attention_layer.push(attention_head);
|
||||
}
|
||||
attention_output.push(attention_layer);
|
||||
}
|
||||
attention_output
|
||||
})
|
||||
.collect::<Vec<AttentionOutput>>();
|
||||
|
||||
Ok((embeddings, attention_outputs))
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for the SentenceEmbeddings tokenizer output.
|
||||
pub struct SentenceEmbeddingsTokenizerOuput {
|
||||
pub tokens_ids: Vec<Tensor>,
|
||||
pub tokens_masks: Vec<Tensor>,
|
||||
}
|
||||
|
||||
/// Container for the SentenceEmbeddings model output.
|
||||
pub struct SentenceEmbeddingsModelOuput {
|
||||
pub embeddings: Tensor,
|
||||
pub all_attentions: Option<Vec<Tensor>>,
|
||||
}
|
184
src/pipelines/sentence_embeddings/resources.rs
Normal file
184
src/pipelines/sentence_embeddings/resources.rs
Normal file
@ -0,0 +1,184 @@
|
||||
/// # Pretrained config files for sentence embeddings
|
||||
pub struct SentenceEmbeddingsModulesConfigResources;
|
||||
|
||||
/// # Pretrained dense weights files for sentence embeddings
|
||||
pub struct SentenceEmbeddingsDenseResources;
|
||||
|
||||
/// # Pretrained dense config files for sentence embeddings
|
||||
pub struct SentenceEmbeddingsDenseConfigResources;
|
||||
|
||||
/// # Pretrained pooling config files for sentence embeddings
|
||||
pub struct SentenceEmbeddingsPoolingConfigResources;
|
||||
|
||||
/// # Pretrained config files for sentence embeddings
|
||||
pub struct SentenceEmbeddingsConfigResources;
|
||||
|
||||
/// # Pretrained tokenizer config files for sentence embeddings
|
||||
pub struct SentenceEmbeddingsTokenizerConfigResources;
|
||||
|
||||
pub enum SentenceEmbeddingsModelType {
|
||||
DistiluseBaseMultilingualCased,
|
||||
BertBaseNliMeanTokens,
|
||||
AllMiniLmL12V2,
|
||||
AllDistilrobertaV1,
|
||||
ParaphraseAlbertSmallV2,
|
||||
SentenceT5Base,
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsModulesConfigResources {
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/modules.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/modules.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/modules.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/modules.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/modules.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/modules.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsDenseResources {
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/sbert-dense",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/2_Dense/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/sbert-dense",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/2_Dense/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsDenseConfigResources {
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/sbert-dense-config",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/2_Dense/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/sbert-dense-config",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/2_Dense/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsPoolingConfigResources {
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/sbert-pooling-config",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/1_Pooling/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/sbert-pooling-config",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/1_Pooling/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/sbert-pooling-config",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/1_Pooling/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/sbert-pooling-config",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/1_Pooling/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/sbert-pooling-config",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/1_Pooling/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/sbert-pooling-config",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/1_Pooling/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsConfigResources {
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/sentence_bert_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/sentence_bert_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/sentence_bert_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/sentence_bert_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/sentence_bert_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/sbert-config",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/sentence_bert_config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl SentenceEmbeddingsTokenizerConfigResources {
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased>. Modified with conversion to C-array format.
|
||||
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
|
||||
"distiluse-base-multilingual-cased/tokenizer-config",
|
||||
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/tokenizer_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
|
||||
"bert-base-nli-mean-tokens/tokenizer-config",
|
||||
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/tokenizer_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
|
||||
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
|
||||
"all-mini-lm-l12-v2/tokenizer-config",
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/tokenizer-config",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/tokenizer_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2>. Modified with conversion to C-array format.
|
||||
pub const PARAPHRASE_ALBERT_SMALL_V2: (&'static str, &'static str) = (
|
||||
"paraphrase-albert-small-v2/tokenizer-config",
|
||||
"https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2/resolve/main/tokenizer_config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/tokenizer-config",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/tokenizer_config.json",
|
||||
);
|
||||
}
|
@ -67,6 +67,7 @@ mod roberta_model;
|
||||
pub use embeddings::RobertaEmbeddings;
|
||||
pub use roberta_model::{
|
||||
RobertaConfig, RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
|
||||
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||
RobertaForQuestionAnswering, RobertaForSentenceEmbeddings, RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification, RobertaMergesResources, RobertaModelResources,
|
||||
RobertaVocabResources,
|
||||
};
|
||||
|
@ -68,6 +68,11 @@ impl RobertaModelResources {
|
||||
"xlm-roberta-ner-es/model",
|
||||
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/model",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl RobertaConfigResources {
|
||||
@ -106,6 +111,11 @@ impl RobertaConfigResources {
|
||||
"xlm-roberta-ner-es/config",
|
||||
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/config",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl RobertaVocabResources {
|
||||
@ -144,6 +154,11 @@ impl RobertaVocabResources {
|
||||
"xlm-roberta-ner-es/spiece",
|
||||
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/vocab",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl RobertaMergesResources {
|
||||
@ -162,6 +177,11 @@ impl RobertaMergesResources {
|
||||
"roberta-qa/merges",
|
||||
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
|
||||
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
|
||||
"all-distilroberta-v1/merges",
|
||||
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
pub struct RobertaLMHead {
|
||||
@ -972,6 +992,10 @@ impl RobertaForQuestionAnswering {
|
||||
}
|
||||
}
|
||||
|
||||
/// # RoBERTa for sentence embeddings
|
||||
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
|
||||
pub type RobertaForSentenceEmbeddings = BertModel<RobertaEmbeddings>;
|
||||
|
||||
/// Container for the RoBERTa masked LM model output.
|
||||
pub struct RobertaMaskedLMOutput {
|
||||
/// Logits for the vocabulary items at each sequence position
|
||||
|
@ -41,7 +41,7 @@
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer = T5Tokenizer::from_file(spiece_path.to_str().unwrap(), true);
|
||||
//! let config = T5Config::from_file(config_path);
|
||||
//! let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
|
||||
//! let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
//!
|
||||
//! # Ok(())
|
||||
@ -55,6 +55,7 @@ mod t5_model;
|
||||
|
||||
pub use attention::LayerState;
|
||||
pub use t5_model::{
|
||||
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Generator, T5Model, T5ModelOutput,
|
||||
T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, T5VocabResources,
|
||||
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5ForSentenceEmbeddings, T5Generator,
|
||||
T5Model, T5ModelOutput, T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages,
|
||||
T5VocabResources,
|
||||
};
|
||||
|
@ -59,6 +59,11 @@ impl T5ModelResources {
|
||||
"t5-base/model",
|
||||
"https://huggingface.co/t5-base/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/model",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl T5ConfigResources {
|
||||
@ -72,6 +77,11 @@ impl T5ConfigResources {
|
||||
"t5-base/config",
|
||||
"https://huggingface.co/t5-base/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/config",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl T5VocabResources {
|
||||
@ -85,6 +95,11 @@ impl T5VocabResources {
|
||||
"t5-base/spiece",
|
||||
"https://huggingface.co/t5-base/resolve/main/spiece.model",
|
||||
);
|
||||
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/sentence-t5-base>. Modified with conversion to C-array format.
|
||||
pub const SENTENCE_T5_BASE: (&'static str, &'static str) = (
|
||||
"sentence-t5-base/spiece",
|
||||
"https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model",
|
||||
);
|
||||
}
|
||||
|
||||
const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German];
|
||||
@ -133,6 +148,8 @@ pub struct T5Config {
|
||||
pub feed_forward_proj: Option<FeedForwardProj>,
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
task_specific_params: Option<TaskSpecificParams>,
|
||||
pub output_attentions: Option<bool>,
|
||||
pub output_hidden_states: Option<bool>,
|
||||
}
|
||||
|
||||
/// # T5 task-specific configurations
|
||||
@ -209,6 +226,8 @@ impl Default for T5Config {
|
||||
feed_forward_proj: Some(FeedForwardProj::Relu),
|
||||
tie_word_embeddings: None,
|
||||
task_specific_params: None,
|
||||
output_attentions: None,
|
||||
output_hidden_states: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -233,8 +252,6 @@ impl T5Model {
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the BART model
|
||||
/// * `config` - `T5Config` object defining the model architecture
|
||||
/// * `output_attention` - flag indicating if the model should output the attention weights of intermediate layers
|
||||
/// * `output_hidden_states` - flag indicating if the model should output the hidden states weights of intermediate layers
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -248,21 +265,9 @@ impl T5Model {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = T5Config::from_file(config_path);
|
||||
/// let output_attentions = true;
|
||||
/// let output_hidden_states = true;
|
||||
/// let t5: T5Model = T5Model::new(
|
||||
/// &p.root() / "t5",
|
||||
/// &config,
|
||||
/// output_attentions,
|
||||
/// output_hidden_states,
|
||||
/// );
|
||||
/// let t5: T5Model = T5Model::new(&p.root() / "t5", &config);
|
||||
/// ```
|
||||
pub fn new<'p, P>(
|
||||
p: P,
|
||||
config: &T5Config,
|
||||
output_attentions: bool,
|
||||
output_hidden_states: bool,
|
||||
) -> T5Model
|
||||
pub fn new<'p, P>(p: P, config: &T5Config) -> T5Model
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
@ -280,16 +285,16 @@ impl T5Model {
|
||||
config,
|
||||
false,
|
||||
false,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
config.output_attentions.unwrap_or(false),
|
||||
config.output_hidden_states.unwrap_or(false),
|
||||
);
|
||||
let decoder = T5Stack::new(
|
||||
p / "decoder",
|
||||
config,
|
||||
true,
|
||||
true,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
config.output_attentions.unwrap_or(false),
|
||||
config.output_hidden_states.unwrap_or(false),
|
||||
);
|
||||
|
||||
T5Model {
|
||||
@ -338,7 +343,7 @@ impl T5Model {
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = T5Config::from_file(config_path);
|
||||
/// # let t5_model: T5Model = T5Model::new(&vs.root(), &config, false, false);
|
||||
/// # let t5_model: T5Model = T5Model::new(&vs.root(), &config);
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -450,8 +455,6 @@ impl T5ForConditionalGeneration {
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the BART model
|
||||
/// * `config` - `T5Config` object defining the model architecture
|
||||
/// * `output_attention` - flag indicating if the model should output the attention weights of intermediate layers
|
||||
/// * `output_hidden_states` - flag indicating if the model should output the hidden states weights of intermediate layers
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -465,27 +468,15 @@ impl T5ForConditionalGeneration {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = T5Config::from_file(config_path);
|
||||
/// let output_attentions = true;
|
||||
/// let output_hidden_states = true;
|
||||
/// let t5 = T5ForConditionalGeneration::new(
|
||||
/// &p.root() / "t5",
|
||||
/// &config,
|
||||
/// output_attentions,
|
||||
/// output_hidden_states,
|
||||
/// );
|
||||
/// let t5 = T5ForConditionalGeneration::new(&p.root() / "t5", &config);
|
||||
/// ```
|
||||
pub fn new<'p, P>(
|
||||
p: P,
|
||||
config: &T5Config,
|
||||
output_attentions: bool,
|
||||
output_hidden_states: bool,
|
||||
) -> T5ForConditionalGeneration
|
||||
pub fn new<'p, P>(p: P, config: &T5Config) -> T5ForConditionalGeneration
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
let p = p.borrow();
|
||||
|
||||
let base_model = T5Model::new(p, config, output_attentions, output_hidden_states);
|
||||
let base_model = T5Model::new(p, config);
|
||||
let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true);
|
||||
|
||||
let lm_head = if !tie_word_embeddings {
|
||||
@ -549,7 +540,7 @@ impl T5ForConditionalGeneration {
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = T5Config::from_file(config_path);
|
||||
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
|
||||
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -666,7 +657,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = T5Config::from_file(config_path);
|
||||
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
|
||||
/// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -749,6 +740,84 @@ impl LMHeadModel for T5ForConditionalGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
/// # T5 for sentence embeddings
|
||||
/// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel).
|
||||
pub struct T5ForSentenceEmbeddings {
|
||||
embeddings: nn::Embedding,
|
||||
encoder: T5Stack,
|
||||
}
|
||||
|
||||
impl T5ForSentenceEmbeddings {
|
||||
/// Build a new `T5ForSentenceEmbeddings`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the BART model
|
||||
/// * `config` - `T5Config` object defining the model architecture
|
||||
///
|
||||
/// It consists of only an encoder (there is no decoder).
|
||||
pub fn new<'p, P>(p: P, config: &T5Config) -> Self
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
let p = p.borrow();
|
||||
|
||||
let embeddings: nn::Embedding = embedding(
|
||||
p / "shared",
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let encoder = T5Stack::new(
|
||||
p / "encoder",
|
||||
config,
|
||||
false,
|
||||
false,
|
||||
config.output_attentions.unwrap_or(false),
|
||||
config.output_hidden_states.unwrap_or(false),
|
||||
);
|
||||
|
||||
Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Input of shape (*batch size*, *source_sequence_length*).
|
||||
/// * `mask` - Attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * Tuple containing:
|
||||
/// - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
|
||||
/// - `Option<Vec<Tensor>>` of length *num_encoder_layers* of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing attention weights for all layers of the encoder
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
mask: &Tensor,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
|
||||
let transformer_output = self.encoder.forward_t(
|
||||
Some(input_ids),
|
||||
Some(mask),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&self.embeddings,
|
||||
None,
|
||||
false,
|
||||
)?;
|
||||
Ok((
|
||||
transformer_output.hidden_state,
|
||||
transformer_output.all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Container holding a T5 model output. The decoder output may hold the hidden state of
|
||||
/// the last layer of the decoder, or may hold logits for a custom head module after the
|
||||
/// decoder (e.g. for language modeling tasks)
|
||||
@ -812,7 +881,7 @@ impl T5Generator {
|
||||
let mut var_store = nn::VarStore::new(device);
|
||||
|
||||
let config = T5Config::from_file(config_path);
|
||||
let model = T5ForConditionalGeneration::new(&var_store.root(), &config, false, false);
|
||||
let model = T5ForConditionalGeneration::new(&var_store.root(), &config);
|
||||
var_store.load(weights_path)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));
|
||||
|
157
tests/sentence_embeddings.rs
Normal file
157
tests/sentence_embeddings.rs
Normal file
@ -0,0 +1,157 @@
|
||||
use rust_bert::pipelines::sentence_embeddings::{
|
||||
SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn sbert_distilbert() -> anyhow::Result<()> {
|
||||
let model = SentenceEmbeddingsBuilder::remote(
|
||||
SentenceEmbeddingsModelType::DistiluseBaseMultilingualCased,
|
||||
)
|
||||
.create_model()?;
|
||||
|
||||
let sentences = ["This is an example sentence", "Each sentence is converted"];
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
|
||||
assert!((embeddings[0][0] as f64 - -0.03479306).abs() < 1e-4);
|
||||
assert!((embeddings[0][1] as f64 - 0.02635195).abs() < 1e-4);
|
||||
assert!((embeddings[0][2] as f64 - -0.04427199).abs() < 1e-4);
|
||||
assert!((embeddings[0][509] as f64 - 0.01743882).abs() < 1e-4);
|
||||
assert!((embeddings[0][510] as f64 - -0.01952395).abs() < 1e-4);
|
||||
assert!((embeddings[0][511] as f64 - -0.00118101).abs() < 1e-4);
|
||||
|
||||
assert!((embeddings[1][0] as f64 - 0.02096637).abs() < 1e-4);
|
||||
assert!((embeddings[1][1] as f64 - -0.00401743).abs() < 1e-4);
|
||||
assert!((embeddings[1][2] as f64 - -0.05093712).abs() < 1e-4);
|
||||
assert!((embeddings[1][509] as f64 - 0.03618195).abs() < 1e-4);
|
||||
assert!((embeddings[1][510] as f64 - 0.0294408).abs() < 1e-4);
|
||||
assert!((embeddings[1][511] as f64 - -0.04497765).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbert_bert() -> anyhow::Result<()> {
|
||||
let model =
|
||||
SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::BertBaseNliMeanTokens)
|
||||
.create_model()?;
|
||||
|
||||
let sentences = ["this is an example sentence", "each sentence is converted"];
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
|
||||
assert!((embeddings[0][0] as f64 - -0.393099815).abs() < 1e-4);
|
||||
assert!((embeddings[0][1] as f64 - 0.0388629436).abs() < 1e-4);
|
||||
assert!((embeddings[0][2] as f64 - 1.98742473).abs() < 1e-4);
|
||||
assert!((embeddings[0][765] as f64 - -0.609367728).abs() < 1e-4);
|
||||
assert!((embeddings[0][766] as f64 - -1.09462142).abs() < 1e-4);
|
||||
assert!((embeddings[0][767] as f64 - 0.326490253).abs() < 1e-4);
|
||||
|
||||
assert!((embeddings[1][0] as f64 - 0.0615336187).abs() < 1e-4);
|
||||
assert!((embeddings[1][1] as f64 - 0.32736221).abs() < 1e-4);
|
||||
assert!((embeddings[1][2] as f64 - 1.8332324).abs() < 1e-4);
|
||||
assert!((embeddings[1][765] as f64 - -0.129853949).abs() < 1e-4);
|
||||
assert!((embeddings[1][766] as f64 - 0.460893631).abs() < 1e-4);
|
||||
assert!((embeddings[1][767] as f64 - 0.240354523).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbert_bert_small() -> anyhow::Result<()> {
|
||||
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllMiniLmL12V2)
|
||||
.create_model()?;
|
||||
|
||||
let sentences = ["this is an example sentence", "each sentence is converted"];
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
|
||||
assert!((embeddings[0][0] as f64 - -2.02682902e-04).abs() < 1e-4);
|
||||
assert!((embeddings[0][1] as f64 - 8.14802647e-02).abs() < 1e-4);
|
||||
assert!((embeddings[0][2] as f64 - 3.13617811e-02).abs() < 1e-4);
|
||||
assert!((embeddings[0][381] as f64 - 6.20930083e-02).abs() < 1e-4);
|
||||
assert!((embeddings[0][382] as f64 - 4.91031967e-02).abs() < 1e-4);
|
||||
assert!((embeddings[0][383] as f64 - -2.90199649e-04).abs() < 1e-4);
|
||||
|
||||
assert!((embeddings[1][0] as f64 - 6.47571534e-02).abs() < 1e-4);
|
||||
assert!((embeddings[1][1] as f64 - 4.85198125e-02).abs() < 1e-4);
|
||||
assert!((embeddings[1][2] as f64 - -1.78603437e-02).abs() < 1e-4);
|
||||
assert!((embeddings[1][381] as f64 - 3.37569155e-02).abs() < 1e-4);
|
||||
assert!((embeddings[1][382] as f64 - 8.43371451e-03).abs() < 1e-4);
|
||||
assert!((embeddings[1][383] as f64 - -6.00359812e-02).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbert_distilroberta() -> anyhow::Result<()> {
|
||||
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::AllDistilrobertaV1)
|
||||
.create_model()?;
|
||||
|
||||
let sentences = ["This is an example sentence", "Each sentence is converted"];
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
|
||||
assert!((embeddings[0][0] as f64 - -0.03375624).abs() < 1e-4);
|
||||
assert!((embeddings[0][1] as f64 - -0.06316338).abs() < 1e-4);
|
||||
assert!((embeddings[0][2] as f64 - -0.0316612).abs() < 1e-4);
|
||||
assert!((embeddings[0][765] as f64 - 0.03684864).abs() < 1e-4);
|
||||
assert!((embeddings[0][766] as f64 - -0.02036646).abs() < 1e-4);
|
||||
assert!((embeddings[0][767] as f64 - -0.01574).abs() < 1e-4);
|
||||
|
||||
assert!((embeddings[1][0] as f64 - -0.01409588).abs() < 1e-4);
|
||||
assert!((embeddings[1][1] as f64 - 0.00091114).abs() < 1e-4);
|
||||
assert!((embeddings[1][2] as f64 - -0.00096315).abs() < 1e-4);
|
||||
assert!((embeddings[1][765] as f64 - -0.02571585).abs() < 1e-4);
|
||||
assert!((embeddings[1][766] as f64 - -0.00289072).abs() < 1e-4);
|
||||
assert!((embeddings[1][767] as f64 - -0.00579975).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbert_albert() -> anyhow::Result<()> {
|
||||
let model =
|
||||
SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::ParaphraseAlbertSmallV2)
|
||||
.create_model()?;
|
||||
|
||||
let sentences = ["this is an example sentence", "each sentence is converted"];
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
|
||||
assert!((embeddings[0][0] as f64 - 0.20412037).abs() < 1e-4);
|
||||
assert!((embeddings[0][1] as f64 - 0.48823047).abs() < 1e-4);
|
||||
assert!((embeddings[0][2] as f64 - 0.5664698).abs() < 1e-4);
|
||||
assert!((embeddings[0][765] as f64 - -0.37474486).abs() < 1e-4);
|
||||
assert!((embeddings[0][766] as f64 - 0.0254627).abs() < 1e-4);
|
||||
assert!((embeddings[0][767] as f64 - -0.6846024).abs() < 1e-4);
|
||||
|
||||
assert!((embeddings[1][0] as f64 - 0.25720373).abs() < 1e-4);
|
||||
assert!((embeddings[1][1] as f64 - 0.24648172).abs() < 1e-4);
|
||||
assert!((embeddings[1][2] as f64 - -0.2521183).abs() < 1e-4);
|
||||
assert!((embeddings[1][765] as f64 - 0.4667896).abs() < 1e-4);
|
||||
assert!((embeddings[1][766] as f64 - 0.14219822).abs() < 1e-4);
|
||||
assert!((embeddings[1][767] as f64 - 0.3986863).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbert_t5() -> anyhow::Result<()> {
|
||||
let model = SentenceEmbeddingsBuilder::remote(SentenceEmbeddingsModelType::SentenceT5Base)
|
||||
.create_model()?;
|
||||
|
||||
let sentences = ["This is an example sentence", "Each sentence is converted"];
|
||||
let embeddings = model.encode(&sentences)?;
|
||||
|
||||
assert!((embeddings[0][0] as f64 - -0.00904849).abs() < 1e-4);
|
||||
assert!((embeddings[0][1] as f64 - 0.0191336).abs() < 1e-4);
|
||||
assert!((embeddings[0][2] as f64 - 0.02657794).abs() < 1e-4);
|
||||
assert!((embeddings[0][765] as f64 - -0.00876413).abs() < 1e-4);
|
||||
assert!((embeddings[0][766] as f64 - -0.05602207).abs() < 1e-4);
|
||||
assert!((embeddings[0][767] as f64 - -0.02163094).abs() < 1e-4);
|
||||
|
||||
assert!((embeddings[1][0] as f64 - -0.00785422).abs() < 1e-4);
|
||||
assert!((embeddings[1][1] as f64 - 0.03018173).abs() < 1e-4);
|
||||
assert!((embeddings[1][2] as f64 - 0.03129675).abs() < 1e-4);
|
||||
assert!((embeddings[1][765] as f64 - -0.01246878).abs() < 1e-4);
|
||||
assert!((embeddings[1][766] as f64 - -0.06240674).abs() < 1e-4);
|
||||
assert!((embeddings[1][767] as f64 - -0.00590969).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -11,6 +11,8 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
|
||||
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head")
|
||||
parser.add_argument("--prefix", help="Add a prefix on weight names")
|
||||
parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part")
|
||||
args = parser.parse_args()
|
||||
|
||||
source_file = Path(args.source_file)
|
||||
@ -24,6 +26,10 @@ if __name__ == "__main__":
|
||||
if args.skip_embeddings:
|
||||
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
|
||||
continue
|
||||
if args.prefix:
|
||||
k = args.prefix + k
|
||||
if args.suffix:
|
||||
k = k.split('.')[-1]
|
||||
if isinstance(v, Tensor):
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
|
||||
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
|
||||
|
Loading…
Reference in New Issue
Block a user