diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f8dfda..92f059b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,16 @@ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). -## [Unreleased] +## [0.18.0] - 2022-07-24 ## Added - Support for sentence embeddings models and pipelines, based on [SentenceTransformers](https://www.sbert.net). ## Changed - Upgraded to `torch` 1.12 (via `tch` 0.8.0) +## Fixed +- Allow empty slices or slices of empty prompts for text generation. + ## [0.18.0] - 2022-05-29 ## Added - Addition of the DeBERTa language model and support for question answering, sequence and token classification diff --git a/Cargo.toml b/Cargo.toml index 1ddc8e2..b94031b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-bert" -version = "0.18.0" +version = "0.19.0" authors = ["Guillaume Becquin "] edition = "2018" description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)" @@ -65,22 +65,22 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "~7.0.2" tch = "~0.8.0" -serde_json = "1.0.81" -serde = { version = "1.0.137", features = ["derive"] } +serde_json = "1.0.82" +serde = { version = "1.0.140", features = ["derive"] } ordered-float = "3.0.0" -uuid = { version = "1.1.0", features = ["v4"] } +uuid = { version = "1.1.2", features = ["v4"] } thiserror = "1.0.31" -half = "1.8.2" +half = "2.1.0" cached-path = { version = "0.5.3", optional = true } dirs = { version = "4.0.0", optional = true } lazy_static = { version = "1.4.0", optional = true } [dev-dependencies] -anyhow = "1.0.57" +anyhow = "1.0.58" csv = "1.1.6" -criterion = "0.3.5" -tokio = { version = "1.18.2", features = ["sync", "rt-multi-thread", "macros"] } +criterion = "0.3.6" +tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] } torch-sys = "~0.8.0" tempfile = "3.3.0" itertools = "0.10.3" diff --git a/README.md b/README.md index edfbfd0..9839cbe 100644 --- a/README.md +++ b/README.md @@ -32,35 +32,36 @@ The tasks currently supported include: - Named Entity Recognition - Part of Speech tagging - Question-Answering - - Language Generation. + - Language Generation + - Sentence Embeddings
Expand to display the supported models/tasks matrix -| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**| -:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----: -DistilBERT|✅|✅|✅| | | |✅| -MobileBERT|✅|✅|✅| | | |✅| -DeBERTa|✅|✅|✅| | | |✅| -DeBERTa (v2)|✅|✅|✅| | | |✅| -FNet|✅|✅|✅| | | |✅| -BERT|✅|✅|✅| | | |✅| -RoBERTa|✅|✅|✅| | | |✅| -GPT| | | |✅ | | | | -GPT2| | | |✅ | | | | -GPT-Neo| | | |✅ | | | | -BART|✅| | |✅ |✅| | | -Marian| | | | | |✅| | -MBart|✅| | |✅ | | | | -M2M100| | | |✅ | | | | -Electra | |✅| | | | |✅| -ALBERT |✅|✅|✅| | | |✅| -T5 | | | |✅ |✅|✅| | -XLNet|✅|✅|✅|✅ | | |✅| -Reformer|✅| |✅|✅ | | |✅| -ProphetNet| | | |✅ |✅ | | | -Longformer|✅|✅|✅| | | |✅| -Pegasus| | | | |✅| | | +| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|**Sentence Embeddings**| +:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----: +DistilBERT|✅|✅|✅| | | |✅| ✅| +MobileBERT|✅|✅|✅| | | |✅| | +DeBERTa|✅|✅|✅| | | |✅| | +DeBERTa (v2)|✅|✅|✅| | | |✅| | +FNet|✅|✅|✅| | | |✅| | +BERT|✅|✅|✅| | | |✅| ✅| +RoBERTa|✅|✅|✅| | | |✅| ✅| +GPT| | | |✅ | | | | | +GPT2| | | |✅ | | | | | +GPT-Neo| | | |✅ | | | | | +BART|✅| | |✅ |✅| | | | +Marian| | | | | |✅| | | +MBart|✅| | |✅ | | | | | +M2M100| | | |✅ | | | | | +Electra | |✅| | | | |✅| | +ALBERT |✅|✅|✅| | | |✅| ✅ | +T5 | | | |✅ |✅|✅| | ✅ | +XLNet|✅|✅|✅|✅ | | |✅| | +Reformer|✅| |✅|✅ | | |✅| | +ProphetNet| | | |✅ |✅ | | | | +Longformer|✅|✅|✅| | | |✅| | +Pegasus| | | | |✅| | | |
## Getting started @@ -379,6 +380,31 @@ Output: ] ``` +  +
+ 10. Sentence embeddings + +Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval. +```rust + let model = SentenceEmbeddingsBuilder::remote( + SentenceEmbeddingsModelType::AllMiniLmL12V2 + ).create_model()?; + + let sentences = [ + "this is an example sentence", + "each sentence is converted" + ]; + + let output = model.predict(&sentences); +``` +Output: +``` +[ + [-0.000202666, 0.08148022, 0.03136178, 0.002920636 ...], + [0.064757116, 0.048519745, -0.01786038, -0.0479775 ...] +] +``` +
## Benchmarks diff --git a/examples/sentence_embeddings_local.rs b/examples/sentence_embeddings_local.rs index f148f94..f4d43d0 100644 --- a/examples/sentence_embeddings_local.rs +++ b/examples/sentence_embeddings_local.rs @@ -21,7 +21,6 @@ use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder; /// ```sh /// python ../utils/convert_model.py resources/path/to/2_Dense/pytorch_model.bin --suffix /// ``` -/// fn main() -> anyhow::Result<()> { // Set-up sentence embeddings model let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2") diff --git a/src/lib.rs b/src/lib.rs index f7b1938..5fee123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,8 @@ //! - Named Entity Recognition //! - Part of Speech tagging //! - Question-Answering -//! - Language Generation. +//! - Language Generation +//! - Sentence Embeddings //! //! More information on these can be found in the [`pipelines` module](./pipelines/index.html) //! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust @@ -47,30 +48,30 @@ //!
//! Click to expand to display the supported models/tasks matrix //! -//! | |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**| -//! :-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----: -//! DistilBERT|✅|✅|✅| | | |✅| -//! MobileBERT|✅|✅|✅| | | |✅| -//! DeBERTa|✅|✅|✅| | | |✅| -//! DeBERTa (v2)|✅|✅|✅| | | |✅| -//! FNet|✅|✅|✅| | | |✅| -//! BERT|✅|✅|✅| | | |✅| -//! RoBERTa|✅|✅|✅| | | |✅| -//! GPT| | | |✅ | | | | -//! GPT2| | | |✅ | | | | -//! GPT-Neo| | | |✅ | | | | -//! BART|✅| | |✅ |✅| | | -//! Marian| | | | | |✅| | -//! MBart|✅| | |✅ | | | | -//! M2M100| | | |✅ | | | | -//! Electra | |✅| | | | |✅| -//! ALBERT |✅|✅|✅| | | |✅| -//! T5 | | | |✅ |✅|✅| | -//! XLNet|✅|✅|✅|✅ | | |✅| -//! Reformer|✅| |✅|✅ | | |✅| -//! ProphetNet| | | |✅ |✅ | | | -//! Longformer|✅|✅|✅| | | |✅| -//! Pegasus| | | | |✅| | | +//!| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|**Sentence Embeddings**| +//!:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----: +//!DistilBERT|✅|✅|✅| | | |✅| ✅| +//!MobileBERT|✅|✅|✅| | | |✅| | +//!DeBERTa|✅|✅|✅| | | |✅| | +//!DeBERTa (v2)|✅|✅|✅| | | |✅| | +//!FNet|✅|✅|✅| | | |✅| | +//!BERT|✅|✅|✅| | | |✅| ✅| +//!RoBERTa|✅|✅|✅| | | |✅| ✅| +//!GPT| | | |✅ | | | | | +//!GPT2| | | |✅ | | | | | +//!GPT-Neo| | | |✅ | | | | | +//!BART|✅| | |✅ |✅| | | | +//!Marian| | | | | |✅| | | +//!MBart|✅| | |✅ | | | | | +//!M2M100| | | |✅ | | | | | +//!Electra | |✅| | | | |✅| | +//!ALBERT |✅|✅|✅| | | |✅| ✅ | +//!T5 | | | |✅ |✅|✅| | ✅ | +//!XLNet|✅|✅|✅|✅ | | |✅| | +//!Reformer|✅| |✅|✅ | | |✅| | +//!ProphetNet| | | |✅ |✅ | | | | +//!Longformer|✅|✅|✅| | | |✅| | +//!Pegasus| | | | |✅| | | | //!
//! //! # Getting started @@ -84,8 +85,8 @@ //! //! ### Manual installation (recommended) //! -//! 1. Download `libtorch` from . This package requires `v1.10.0`: if this version is no longer available on the "get started" page, -//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11. +//! 1. Download `libtorch` from . This package requires `v1.12.0`: if this version is no longer available on the "get started" page, +//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11. //! 2. Extract the library to a location of your choice //! 3. Set the following environment variables //! ##### Linux: @@ -103,7 +104,7 @@ //! ### Automatic installation //! //! Alternatively, you can let the `build` script automatically download the `libtorch` library for you. -//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu111`. +//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu116`. //! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete. //! //! # Ready-to-use pipelines @@ -544,7 +545,36 @@ //! ] //! # ; //! ``` +//! +//!   +//!
+//! 10. Sentence embeddings //! +//! Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval. +//!```no_run +//! # use rust_bert::pipelines::sentence_embeddings::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType}; +//! # fn main() -> anyhow::Result<()> { +//! let model = SentenceEmbeddingsBuilder::remote( +//! SentenceEmbeddingsModelType::AllMiniLmL12V2 +//! ).create_model()?; +//! +//! let sentences = [ +//! "this is an example sentence", +//! "each sentence is converted" +//! ]; +//! +//! let output = model.predict(&sentences); +//! # } +//! ``` +//! Output: +//! ```no_run +//! # let output = +//! [ +//! [-0.000202666, 0.08148022, 0.03136178, 0.002920636], +//! [0.064757116, 0.048519745, -0.01786038, -0.0479775], +//! ] +//! # ; +//! ``` //!
//! //! ## Benchmarks diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 002453f..864a711 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -409,6 +409,34 @@ //! ] //! # ; //! ``` +//! +//! #### 10. Sentence embeddings +//! +//! Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval. +//!```no_run +//! # use rust_bert::pipelines::sentence_embeddings::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType}; +//! # fn main() -> anyhow::Result<()> { +//! let model = SentenceEmbeddingsBuilder::remote( +//! SentenceEmbeddingsModelType::AllMiniLmL12V2 +//! ).create_model()?; +//! +//! let sentences = [ +//! "this is an example sentence", +//! "each sentence is converted" +//! ]; +//! +//! let output = model.predict(&sentences); +//! # } +//! ``` +//! Output: +//! ```no_run +//! # let output = +//! [ +//! [-0.000202666, 0.08148022, 0.03136178, 0.002920636], +//! [0.064757116, 0.048519745, -0.01786038, -0.0479775], +//! ] +//! # ; +//! ``` pub mod common; pub mod conversation; diff --git a/tests/longformer.rs b/tests/longformer.rs index 0a61bfe..dc235a1 100644 --- a/tests/longformer.rs +++ b/tests/longformer.rs @@ -309,7 +309,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> { } #[test] -fn mobilebert_for_token_classification() -> anyhow::Result<()> { +fn longformer_for_token_classification() -> anyhow::Result<()> { // Resources paths let config_resource = RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);