Prepare for 0.19 release (#272)

This commit is contained in:
guillaume-be 2022-07-25 06:36:02 +01:00 committed by GitHub
parent 66d596a2bf
commit cce1e2707d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 150 additions and 64 deletions

View File

@ -1,13 +1,16 @@
# Changelog
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## [0.18.0] - 2022-07-24
## Added
- Support for sentence embeddings models and pipelines, based on [SentenceTransformers](https://www.sbert.net).
## Changed
- Upgraded to `torch` 1.12 (via `tch` 0.8.0)
## Fixed
- Allow empty slices or slices of empty prompts for text generation.
## [0.18.0] - 2022-05-29
## Added
- Addition of the DeBERTa language model and support for question answering, sequence and token classification

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.18.0"
version = "0.19.0"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
@ -65,22 +65,22 @@ features = ["doc-only"]
[dependencies]
rust_tokenizers = "~7.0.2"
tch = "~0.8.0"
serde_json = "1.0.81"
serde = { version = "1.0.137", features = ["derive"] }
serde_json = "1.0.82"
serde = { version = "1.0.140", features = ["derive"] }
ordered-float = "3.0.0"
uuid = { version = "1.1.0", features = ["v4"] }
uuid = { version = "1.1.2", features = ["v4"] }
thiserror = "1.0.31"
half = "1.8.2"
half = "2.1.0"
cached-path = { version = "0.5.3", optional = true }
dirs = { version = "4.0.0", optional = true }
lazy_static = { version = "1.4.0", optional = true }
[dev-dependencies]
anyhow = "1.0.57"
anyhow = "1.0.58"
csv = "1.1.6"
criterion = "0.3.5"
tokio = { version = "1.18.2", features = ["sync", "rt-multi-thread", "macros"] }
criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "~0.8.0"
tempfile = "3.3.0"
itertools = "0.10.3"

View File

@ -32,35 +32,36 @@ The tasks currently supported include:
- Named Entity Recognition
- Part of Speech tagging
- Question-Answering
- Language Generation.
- Language Generation
- Sentence Embeddings
<details>
<summary> <b>Expand to display the supported models/tasks matrix </b> </summary>
| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
DistilBERT|✅|✅|✅| | | |✅|
MobileBERT|✅|✅|✅| | | |✅|
DeBERTa|✅|✅|✅| | | |✅|
DeBERTa (v2)|✅|✅|✅| | | |✅|
FNet|✅|✅|✅| | | |✅|
BERT|✅|✅|✅| | | |✅|
RoBERTa|✅|✅|✅| | | |✅|
GPT| | | |✅ | | | |
GPT2| | | |✅ | | | |
GPT-Neo| | | |✅ | | | |
BART|✅| | |✅ |✅| | |
Marian| | | | | |✅| |
MBart|✅| | |✅ | | | |
M2M100| | | |✅ | | | |
Electra | |✅| | | | |✅|
ALBERT |✅|✅|✅| | | |✅|
T5 | | | |✅ |✅|✅| |
XLNet|✅|✅|✅|✅ | | |✅|
Reformer|✅| |✅|✅ | | |✅|
ProphetNet| | | |✅ |✅ | | |
Longformer|✅|✅|✅| | | |✅|
Pegasus| | | | |✅| | |
| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|**Sentence Embeddings**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:
DistilBERT|✅|✅|✅| | | |✅| ✅|
MobileBERT|✅|✅|✅| | | |✅| |
DeBERTa|✅|✅|✅| | | |✅| |
DeBERTa (v2)|✅|✅|✅| | | |✅| |
FNet|✅|✅|✅| | | |✅| |
BERT|✅|✅|✅| | | |✅| ✅|
RoBERTa|✅|✅|✅| | | |✅| ✅|
GPT| | | |✅ | | | | |
GPT2| | | |✅ | | | | |
GPT-Neo| | | |✅ | | | | |
BART|✅| | |✅ |✅| | | |
Marian| | | | | |✅| | |
MBart|✅| | |✅ | | | | |
M2M100| | | |✅ | | | | |
Electra | |✅| | | | |✅| |
ALBERT |✅|✅|✅| | | |✅| ✅ |
T5 | | | |✅ |✅|✅| | ✅ |
XLNet|✅|✅|✅|✅ | | |✅| |
Reformer|✅| |✅|✅ | | |✅| |
ProphetNet| | | |✅ |✅ | | | |
Longformer|✅|✅|✅| | | |✅| |
Pegasus| | | | |✅| | | |
</details>
## Getting started
@ -379,6 +380,31 @@ Output:
]
```
</details>
&nbsp;
<details>
<summary> <b>10. Sentence embeddings </b> </summary>
Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval.
```rust
let model = SentenceEmbeddingsBuilder::remote(
SentenceEmbeddingsModelType::AllMiniLmL12V2
).create_model()?;
let sentences = [
"this is an example sentence",
"each sentence is converted"
];
let output = model.predict(&sentences);
```
Output:
```
[
[-0.000202666, 0.08148022, 0.03136178, 0.002920636 ...],
[0.064757116, 0.048519745, -0.01786038, -0.0479775 ...]
]
```
</details>
## Benchmarks

View File

@ -21,7 +21,6 @@ use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;
/// ```sh
/// python ../utils/convert_model.py resources/path/to/2_Dense/pytorch_model.bin --suffix
/// ```
///
fn main() -> anyhow::Result<()> {
// Set-up sentence embeddings model
let model = SentenceEmbeddingsBuilder::local("resources/all-MiniLM-L12-v2")

View File

@ -39,7 +39,8 @@
//! - Named Entity Recognition
//! - Part of Speech tagging
//! - Question-Answering
//! - Language Generation.
//! - Language Generation
//! - Sentence Embeddings
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
@ -47,30 +48,30 @@
//! <details>
//! <summary> <b> Click to expand to display the supported models/tasks matrix </b> </summary>
//!
//! | |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|
//! :-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
//! DistilBERT|✅|✅|✅| | | |✅|
//! MobileBERT|✅|✅|✅| | | |✅|
//! DeBERTa|✅|✅|✅| | | |✅|
//! DeBERTa (v2)|✅|✅|✅| | | |✅|
//! FNet|✅|✅|✅| | | |✅|
//! BERT|✅|✅|✅| | | |✅|
//! RoBERTa|✅|✅|✅| | | |✅|
//! GPT| | | |✅ | | | |
//! GPT2| | | |✅ | | | |
//! GPT-Neo| | | |✅ | | | |
//! BART|✅| | |✅ |✅| | |
//! Marian| | | | | |✅| |
//! MBart|✅| | |✅ | | | |
//! M2M100| | | |✅ | | | |
//! Electra | |✅| | | | |✅|
//! ALBERT |✅|✅|✅| | | |✅|
//! T5 | | | |✅ |✅|✅| |
//! XLNet|✅|✅|✅|✅ | | |✅|
//! Reformer|✅| |✅|✅ | | |✅|
//! ProphetNet| | | |✅ |✅ | | |
//! Longformer|✅|✅|✅| | | |✅|
//! Pegasus| | | | |✅| | |
//!| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|**Sentence Embeddings**|
//!:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:
//!DistilBERT|✅|✅|✅| | | |✅| ✅|
//!MobileBERT|✅|✅|✅| | | |✅| |
//!DeBERTa|✅|✅|✅| | | |✅| |
//!DeBERTa (v2)|✅|✅|✅| | | |✅| |
//!FNet|✅|✅|✅| | | |✅| |
//!BERT|✅|✅|✅| | | |✅| ✅|
//!RoBERTa|✅|✅|✅| | | |✅| ✅|
//!GPT| | | |✅ | | | | |
//!GPT2| | | |✅ | | | | |
//!GPT-Neo| | | |✅ | | | | |
//!BART|✅| | |✅ |✅| | | |
//!Marian| | | | | |✅| | |
//!MBart|✅| | |✅ | | | | |
//!M2M100| | | |✅ | | | | |
//!Electra | |✅| | | | |✅| |
//!ALBERT |✅|✅|✅| | | |✅| ✅ |
//!T5 | | | |✅ |✅|✅| | ✅ |
//!XLNet|✅|✅|✅|✅ | | |✅| |
//!Reformer|✅| |✅|✅ | | |✅| |
//!ProphetNet| | | |✅ |✅ | | | |
//!Longformer|✅|✅|✅| | | |✅| |
//!Pegasus| | | | |✅| | | |
//! </details>
//!
//! # Getting started
@ -84,8 +85,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.12.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu116.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
@ -103,7 +104,7 @@
//! ### Automatic installation
//!
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu111`.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu116`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//!
//! # Ready-to-use pipelines
@ -544,7 +545,36 @@
//! ]
//! # ;
//! ```
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>10. Sentence embeddings </b> </summary>
//!
//! Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval.
//!```no_run
//! # use rust_bert::pipelines::sentence_embeddings::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType};
//! # fn main() -> anyhow::Result<()> {
//! let model = SentenceEmbeddingsBuilder::remote(
//! SentenceEmbeddingsModelType::AllMiniLmL12V2
//! ).create_model()?;
//!
//! let sentences = [
//! "this is an example sentence",
//! "each sentence is converted"
//! ];
//!
//! let output = model.predict(&sentences);
//! # }
//! ```
//! Output:
//! ```no_run
//! # let output =
//! [
//! [-0.000202666, 0.08148022, 0.03136178, 0.002920636],
//! [0.064757116, 0.048519745, -0.01786038, -0.0479775],
//! ]
//! # ;
//! ```
//! </details>
//!
//! ## Benchmarks

View File

@ -409,6 +409,34 @@
//! ]
//! # ;
//! ```
//!
//! #### 10. Sentence embeddings
//!
//! Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval.
//!```no_run
//! # use rust_bert::pipelines::sentence_embeddings::{SentenceEmbeddingsBuilder, SentenceEmbeddingsModelType};
//! # fn main() -> anyhow::Result<()> {
//! let model = SentenceEmbeddingsBuilder::remote(
//! SentenceEmbeddingsModelType::AllMiniLmL12V2
//! ).create_model()?;
//!
//! let sentences = [
//! "this is an example sentence",
//! "each sentence is converted"
//! ];
//!
//! let output = model.predict(&sentences);
//! # }
//! ```
//! Output:
//! ```no_run
//! # let output =
//! [
//! [-0.000202666, 0.08148022, 0.03136178, 0.002920636],
//! [0.064757116, 0.048519745, -0.01786038, -0.0479775],
//! ]
//! # ;
//! ```
pub mod common;
pub mod conversation;

View File

@ -309,7 +309,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
}
#[test]
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
fn longformer_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource =
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_4096);