Merge branch 'main' into patch-1

This commit is contained in:
guillaume-be 2024-08-18 12:07:33 +01:00 committed by GitHub
commit 8b545d31dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
228 changed files with 9910 additions and 4143 deletions

View File

@ -1,8 +1,8 @@
on:
push:
branches: [ master ]
branches: [ main ]
pull_request:
branches: [ master ]
branches: [ main ]
name: Build
@ -20,6 +20,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --features download-libtorch
build-no-defaults:
name: Build no defaults
@ -34,7 +35,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --no-default-features
args: --no-default-features --features download-libtorch
build-windows:
name: Build Windows
@ -49,6 +50,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --features download-libtorch
build-mac-os:
name: Build macOS
@ -63,6 +65,7 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: build
args: --features download-libtorch
test-batch-0:
name: Integration tests (batch 0)
@ -89,6 +92,7 @@ jobs:
--test fnet
--test deberta
--test deberta_v2
--features download-libtorch
test-batch-1:
name: Integration tests (batch 1)
@ -114,6 +118,7 @@ jobs:
--test longformer
--test pegasus
--test gpt_neo
--features download-libtorch
test-batch-2:
name: Integration tests (batch 2)
@ -132,6 +137,28 @@ jobs:
--test sentence_embeddings
--test longt5
--test gpt_j
--test nllb
--features download-libtorch
test-opt-features:
name: Integration tests (Optional features)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --package rust-bert
--features onnx
--features hf-tokenizers
--test onnx
--test hf_tokenizers
--features download-libtorch
convert-model:
name: Model conversion test
@ -147,7 +174,7 @@ jobs:
with:
python-version: '3.10'
- run: |
pip install -r requirements.txt --progress-bar off
pip install -r ./utils/requirements.txt --progress-bar off
python ./utils/download-dependencies_distilbert.py
fmt:

4
.gitignore vendored
View File

@ -17,4 +17,6 @@ Cargo.lock
/target
#**/*.rs.bk
/resources/
/models/
/.venv/
convert_model.log

View File

@ -2,14 +2,54 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## Added
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
- (BREAKING) Upgraded to `torch` 2.2 (via `tch` 0.15.0).
## [0.22.0] - 2024-01-20
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.
## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering n-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
- Updated ONNX runtime backend version to 1.15.x
- Issue with incorrect results for QA models with a tokenizer not using segment ids
- Issue with GPT-J that was incorrectly tracking the gradients for the attention bias
## Changed
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).
- (BREAKING) Text generation traits and pipelines (including conversation, summarization and translation) now return a `Result` for improved error handling
## [0.21.0] - 2023-06-03
## Added
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.
- Addition of `add_tokens` and `add_extra_ids` interface methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
- Addition of `get_tokenizer` and `get_tokenizer_mut` methods to all pipelines allowing to get a (mutable) reference to the pipeline tokenizer.
- Addition of a `get_embedding_dim` method to get the dimension of the embeddings for sentence embeddings pipelines
- `get_vocab_size`, `get_decoder_start_token_id` and `get_prefix_and_forced_bos_id` for the `TokenizerOption` in pipelines
- Addition of the [GPT-J](https://www.eleuther.ai/artifacts/gpt-j) model architecture
- Addition of the [NLLB](https://arxiv.org/abs/2207.04672) model architecture and pretrained weights
- Addition of support for ONNX models (encoder, decoders, encoder-decoders) via the [ort](https://github.com/pykeio/ort) onnxruntime bindings
- Integration of ONNX models to the sequence classification, token classification, question answering, zero-shot classification, text generation, summarization and translation pipelines
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer
- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator)
- (BREAKING) Upgraded to `torch` 2.0 (via `tch` 0.13.0). The process to automatically download the dependencies have changed, it must now be enabled via the `download-libtorch` feature flag.
- Read the `decoder_start_token_id` from the provided configuration rather than using a hard-coded default value
- (BREAKING) Changed the return type of the `LanguageGenerator` and pipelines functions `float`, `half`, `set_device` to `Result<(), RustBertError>` as these become fallible for ONNX models
- (BREAKING) Wrapped the model resources specification for the pipeline `Config` objects into an `Enum` to allow handling both torch-based and ONNX models.
The `model_resources` field now needs to be wrapped in the corresponding enum variant, e.g. `model_resources: ModelResources::TORCH(model_resource)` for Torch-based models
- (BREAKING) Added the `forced_bos_token_id` and `forced_eos_token_id` fields to text generation models.
If these are not None, this will trigger a forced BOS/EOS token generation at the first of `max_length` positions (aligns with the Pytorch Transformers library)
- Project structure refactoring (torch-based models moved under common module). Non-breaking change via re-exports.
## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)
- Remove the (unused) pooler from the set of weights for BERT Masked LM architecture
## [0.20.0] - 2023-01-21
## Added
@ -412,4 +452,4 @@ All notable changes to this project will be documented in this file. The format
- Tensor conversion tools from Pytorch to Libtorch format
- DistilBERT model architecture
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.20.1-alpha"
version = "0.22.0"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and language models"
@ -8,6 +8,7 @@ repository = "https://github.com/guillaume-be/rust-bert"
documentation = "https://docs.rs/rust-bert"
license = "Apache-2.0"
readme = "README.md"
build = "build.rs"
keywords = [
"nlp",
"deep-learning",
@ -60,34 +61,78 @@ harness = false
opt-level = 3
[features]
default = ["remote"]
default = ["remote", "default-tls"]
doc-only = ["tch/doc-only"]
all-tests = []
remote = ["cached-path", "dirs", "lazy_static"]
download-libtorch = ["tch/download-libtorch"]
onnx = ["ort", "ndarray"]
rustls-tls = ["cached-path/rustls-tls"]
default-tls = ["cached-path/default-tls"]
hf-tokenizers = ["tokenizers"]
[package.metadata.docs.rs]
features = ["doc-only"]
[dependencies]
rust_tokenizers = "8.0.0"
tch = "~0.10.1"
rust_tokenizers = "8.1.1"
tch = { version = "0.16.0", features = ["download-libtorch"] }
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
ordered-float = "4.2.0"
uuid = { version = "1", features = ["v4"] }
thiserror = "1"
half = "2"
regex = "1.6"
cached-path = { version = "0.6", optional = true }
dirs = { version = "4", optional = true }
cached-path = { version = "0.6", default-features = false, optional = true }
dirs = { version = "5", optional = true }
lazy_static = { version = "1", optional = true }
ort = { version = "1.16.3", optional = true, default-features = false, features = [
"half",
] }
ndarray = { version = "0.15", optional = true }
tokenizers = { version = "0.19.1", optional = true, default-features = false, features = [
"onig",
] }
[dev-dependencies]
anyhow = "1"
csv = "1"
criterion = "0.4"
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "=0.10.0"
criterion = "0.5"
tokio = { version = "1.35", features = ["sync", "rt-multi-thread", "macros"] }
tempfile = "3"
itertools = "0.10"
itertools = "0.13.0"
tracing-subscriber = { version = "0.3", default-features = false, features = [
"env-filter",
"fmt",
] }
ort = { version = "1.16.3", features = ["load-dynamic"] }
[[example]]
name = "onnx-masked-lm"
required-features = ["onnx"]
[[example]]
name = "onnx-question-answering"
required-features = ["onnx"]
[[example]]
name = "onnx-sequence-classification"
required-features = ["onnx"]
[[example]]
name = "onnx-text-generation"
required-features = ["onnx"]
[[example]]
name = "onnx-token-classification"
required-features = ["onnx"]
[[example]]
name = "onnx-translation"
required-features = ["onnx"]
[[example]]
name = "generation_gpt2_hf_tokenizers"
required-features = ["hf-tokenizers"]

436
README.md
View File

@ -5,10 +5,21 @@
[![Documentation](https://docs.rs/rust-bert/badge.svg)](https://docs.rs/rust-bert)
![License](https://img.shields.io/crates/l/rust_bert.svg)
Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
Rust-native state-of-the-art Natural Language Processing models and pipelines.
Port of Hugging Face's
[Transformers library](https://github.com/huggingface/transformers), using
[tch-rs](https://github.com/LaurentMazare/tch-rs) or
[onnxruntime bindings](https://github.com/pykeio/ort) and pre-processing from
[rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports
multi-threaded tokenization and GPU inference. This repository exposes the model
base architecture, task-specific heads (see below) and
[ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are
available at the end of this document.
Get started with tasks including question answering, named entity recognition,
translation, summarization, text generation, conversational agents and more in
just a few lines of code:
Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
```rust
let qa_model = QuestionAnsweringModel::new(Default::default())?;
@ -19,97 +30,201 @@ Get started with tasks including question answering, named entity recognition, t
```
Output:
```
[Answer { score: 0.9976, start: 13, end: 21, answer: "Amsterdam" }]
```
The tasks currently supported include:
- Translation
- Summarization
- Multi-turn dialogue
- Zero-shot classification
- Sentiment Analysis
- Named Entity Recognition
- Part of Speech tagging
- Question-Answering
- Language Generation
- Masked Language Model
- Sentence Embeddings
- Translation
- Summarization
- Multi-turn dialogue
- Zero-shot classification
- Sentiment Analysis
- Named Entity Recognition
- Part of Speech tagging
- Question-Answering
- Language Generation
- Masked Language Model
- Sentence Embeddings
- Keywords extraction
<details>
<summary> <b>Expand to display the supported models/tasks matrix </b> </summary>
| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|**Sentence Embeddings**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:
DistilBERT|✅|✅|✅| | | |✅| ✅|
MobileBERT|✅|✅|✅| | | |✅| |
DeBERTa|✅|✅|✅| | | |✅| |
DeBERTa (v2)|✅|✅|✅| | | |✅| |
FNet|✅|✅|✅| | | |✅| |
BERT|✅|✅|✅| | | |✅| ✅|
RoBERTa|✅|✅|✅| | | |✅| ✅|
GPT| | | |✅ | | | | |
GPT2| | | |✅ | | | | |
GPT-Neo| | | |✅ | | | | |
BART|✅| | |✅ |✅| | | |
Marian| | | | | |✅| | |
MBart|✅| | |✅ | | | | |
M2M100| | | |✅ | | | | |
Electra | |✅| | | | |✅| |
ALBERT |✅|✅|✅| | | |✅| ✅ |
T5 | | | |✅ |✅|✅| | ✅ |
LongT5 | | | |✅ |✅|| | |
XLNet|✅|✅|✅|✅ | | |✅| |
Reformer|✅| |✅|✅ | | |✅| |
ProphetNet| | | |✅ |✅ | | | |
Longformer|✅|✅|✅| | | |✅| |
Pegasus| | | | |✅| | | |
| | **Sequence classification** | **Token classification** | **Question answering** | **Text Generation** | **Summarization** | **Translation** | **Masked LM** | **Sentence Embeddings** |
| :----------: | :-------------------------: | :----------------------: | :--------------------: | :-----------------: | :---------------: | :-------------: | :-----------: | :---------------------: |
| DistilBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ |
| MobileBERT | ✅ | ✅ | ✅ | | | | ✅ | |
| DeBERTa | ✅ | ✅ | ✅ | | | | ✅ | |
| DeBERTa (v2) | ✅ | ✅ | ✅ | | | | ✅ | |
| FNet | ✅ | ✅ | ✅ | | | | ✅ | |
| BERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ |
| RoBERTa | ✅ | ✅ | ✅ | | | | ✅ | ✅ |
| GPT | | | | ✅ | | | | |
| GPT2 | | | | ✅ | | | | |
| GPT-Neo | | | | ✅ | | | | |
| GPT-J | | | | ✅ | | | | |
| BART | ✅ | | | ✅ | ✅ | | | |
| Marian | | | | | | ✅ | | |
| MBart | ✅ | | | ✅ | | | | |
| M2M100 | | | | ✅ | | | | |
| NLLB | | | | ✅ | | | | |
| Electra | | ✅ | | | | | ✅ | |
| ALBERT | ✅ | ✅ | ✅ | | | | ✅ | ✅ |
| T5 | | | | ✅ | ✅ | ✅ | | ✅ |
| LongT5 | | | | ✅ | ✅ | | | |
| XLNet | ✅ | ✅ | ✅ | ✅ | | | ✅ | |
| Reformer | ✅ | | ✅ | ✅ | | | ✅ | |
| ProphetNet | | | | ✅ | ✅ | | | |
| Longformer | ✅ | ✅ | ✅ | | | | ✅ | |
| Pegasus | | | | | ✅ | | | |
</details>
## Getting started
This library relies on the [tch](https://github.com/LaurentMazare/tch-rs) crate for bindings to the C++ Libtorch API.
The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up your environment
to use these bindings, please refer to the [tch](https://github.com/LaurentMazare/tch-rs) for detailed information or support.
This library relies on the [tch](https://github.com/LaurentMazare/tch-rs) crate
for bindings to the C++ Libtorch API. The libtorch library is required can be
downloaded either automatically or manually. The following provides a reference
on how to set-up your environment to use these bindings, please refer to the
[tch](https://github.com/LaurentMazare/tch-rs) for detailed information or
support.
Furthermore, this library relies on a cache folder for downloading pre-trained models.
This cache location defaults to `~/.cache/.rustbert`, but can be changed by setting the `RUSTBERT_CACHE` environment variable. Note that the language models used by this library are in the order of the 100s of MBs to GBs.
Furthermore, this library relies on a cache folder for downloading pre-trained
models. This cache location defaults to `~/.cache/.rustbert`, but can be changed
by setting the `RUSTBERT_CACHE` environment variable. Note that the language
models used by this library are in the order of the 100s of MBs to GBs.
### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.1%2Bcu117.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This
package requires `v2.2`: if this version is no longer available on the "get
started" page, the file should be accessible by modifying the target link,
for example
`https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip`
for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as
dependency from [crates.io](https://crates.io), please check the required
`LIBTORCH` on the published package
[readme](https://crates.io/crates/rust-bert) as it may differ from the
version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
```bash
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
```
##### Windows
```powershell
$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
```
#### macOS + Homebrew
```bash
brew install pytorch jq
export LIBTORCH=$(brew --cellar pytorch)/$(brew info --json pytorch | jq -r '.[0].installed[0].version')
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
```
### Automatic installation
Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`.
Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
Alternatively, you can let the `build` script automatically download the
`libtorch` library for you. The `download-libtorch` feature flag needs to be
enabled. The CPU version of libtorch will be downloaded by default. To download
a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to
`cu118`. Note that the libtorch library is large (order of several GBs for the
CUDA-enabled version) and the first build may therefore take several minutes to
complete.
### Verifying installation
Verify your installation (and linking with libtorch) by adding the `rust-bert`
dependency to your `Cargo.toml` or by cloning the rust-bert source and running
an example:
```bash
git clone git@github.com:guillaume-be/rust-bert.git
cd rust-bert
cargo run --example sentence_embeddings
```
## ONNX Support (Optional)
ONNX support can be enabled via the optional `onnx` feature. This crate then
leverages the [ort](https://github.com/pykeio/ort) crate with bindings to the
onnxruntime C++ library. We refer the user to this page project for further
installation instructions/support.
1. Enable the optional `onnx` feature. The `rust-bert` crate does not include
any optional dependencies for `ort`, the end user should select the set of
features that would be adequate for pulling the required `onnxruntime` C++
library.
2. The current recommended installation is to use dynamic linking by pointing to
an existing library location. Use the `load-dynamic` cargo feature for `ort`.
3. set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime
library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib`
depending on the operating system). These can be downloaded from the
[release page](https://github.com/microsoft/onnxruntime/releases) of the
onnxruntime project
Most architectures (including encoders, decoders and encoder-decoders) are
supported. the library aims at keeping compatibility with models exported using
the [Optimum](https://github.com/huggingface/optimum) library. A detailed guide
on how to export a Transformer model to ONNX using Optimum is available at
https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model
The resources used to create ONNX models are similar to those based on Pytorch,
replacing the pytorch by the ONNX model. Since ONNX models are less flexible
than their Pytorch counterparts in the handling of optional arguments, exporting
a decoder or encoder-decoder model to ONNX will usually result in multiple
files. These files are expected (but not all are necessary) for use in this
library as per the table below:
| Architecture | Encoder file | Decoder without past file | Decoder with past file |
| --------------------------- | ------------ | ------------------------- | ---------------------- |
| Encoder (e.g. BERT) | required | not used | not used |
| Decoder (e.g. GPT2) | not used | required | optional |
| Encoder-decoder (e.g. BART) | required | required | optional |
Note that the computational efficiency will drop when the `decoder with past`
file is optional but not provided since the model will not used cached past keys
and values for the attention mechanism, leading to a high number of redundant
computations. The Optimum library offers export options to ensure such a
`decoder with past` model file is created. The base encoder and decoder model
architecture are available (and exposed for convenience) in the `encoder` and
`decoder` modules, respectively.
Generation models (pure decoder or encoder/decoder architectures) are available
in the `models` module. ost pipelines are available for ONNX model checkpoints,
including sequence classification, zero-shot classification, token
classification (including named entity recognition and part-of-speech tagging),
question answering, text generation, summarization and translation. These models
use the same configuration and tokenizer files as their Pytorch counterparts
when used in a pipeline. Examples leveraging ONNX models are given in the
`./examples` directory
## Ready-to-use pipelines
Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
**Disclaimer**
The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are
available as part of this crate. The following capabilities are currently
available:
**Disclaimer** The contributors of this repository are not responsible for any
generation from the 3rd party utilization of the pretrained systems proposed
herein.
<details>
<summary> <b>1. Question Answering</b> </summary>
Extractive question answering from a given question and context. DistilBERT model fine-tuned on SQuAD (Stanford Question Answering Dataset)
Extractive question answering from a given question and context. DistilBERT
model fine-tuned on SQuAD (Stanford Question Answering Dataset)
```rust
let qa_model = QuestionAnsweringModel::new(Default::default())?;
@ -121,20 +236,27 @@ Extractive question answering from a given question and context. DistilBERT mode
```
Output:
```
[Answer { score: 0.9976, start: 13, end: 21, answer: "Amsterdam" }]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>2. Translation </b> </summary>
Translation pipeline supporting a broad range of source and target languages. Leverages two main architectures for translation tasks:
- Marian-based models, for specific source/target combinations
- M2M100 models allowing for direct translation between 100 languages (at a higher computational cost and lower performance for some selected languages)
Translation pipeline supporting a broad range of source and target languages.
Leverages two main architectures for translation tasks:
- Marian-based models, for specific source/target combinations
- M2M100 models allowing for direct translation between 100 languages (at a
higher computational cost and lower performance for some selected languages)
Marian-based pretrained models for the following language pairs are readily
available in the library - but the user can import any Pytorch-based model for
predictions
Marian-based pretrained models for the following language pairs are readily available in the library - but the user can import any Pytorch-based
model for predictions
- English <-> French
- English <-> Spanish
- English <-> Portuguese
@ -150,30 +272,36 @@ model for predictions
- English <-> Hindi
- French <-> German
For languages not supported by the proposed pretrained Marian models, the user can leverage a M2M100 model supporting direct translation between 100 languages (without intermediate English translation)
The full list of supported languages is available in the [crate documentation](https://docs.rs/rust-bert/latest/rust_bert/pipelines/translation/enum.Language.html)
For languages not supported by the proposed pretrained Marian models, the user
can leverage a M2M100 model supporting direct translation between 100 languages
(without intermediate English translation) The full list of supported languages
is available in the
[crate documentation](https://docs.rs/rust-bert/latest/rust_bert/pipelines/translation/enum.Language.html)
```rust
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
fn main() -> anyhow::Result<()> {
let model = TranslationModelBuilder::new()
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::Spanish, Language::French, Language::Italian])
.create_model()?;
let input_text = "This is a sentence to be translated";
let output = model.translate(&[input_text], None, Language::French)?;
for sentence in output {
println!("{}", sentence);
}
Ok(())
}
```
```rust
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
fn main() -> anyhow::Result<()> {
let model = TranslationModelBuilder::new()
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::Spanish, Language::French, Language::Italian])
.create_model()?;
let input_text = "This is a sentence to be translated";
let output = model.translate(&[input_text], None, Language::Spanish)?;
for sentence in output {
println!("{}", sentence);
}
Ok(())
}
```
Output:
```
Il s'agit d'une phrase à traduire
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>3. Summarization </b> </summary>
@ -206,26 +334,35 @@ about exoplanets like K2-18b."];
let output = summarization_model.summarize(&input);
```
(example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
(example from:
[WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
Output:
```
"Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
This is the first such discovery in a planet in its star's habitable zone.
The planet is not too hot and not too cold for liquid water to exist."
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>4. Dialogue Model </b> </summary>
Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
This pipeline allows the generation of single or multi-turn conversations between a human and a model.
Conversation model based on Microsoft's
[DialoGPT](https://github.com/microsoft/DialoGPT). This pipeline allows the
generation of single or multi-turn conversations between a human and a model.
The DialoGPT's page states that
> The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
> under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
> The human evaluation results indicate that the response generated from
> DialoGPT is comparable to human response quality under a single-turn
> conversation Turing test.
> ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
The model uses a `ConversationManager` to keep track of active conversations and
generate responses to them.
```rust
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
@ -236,19 +373,24 @@ let mut conversation_manager = ConversationManager::new();
let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?");
let output = conversation_model.generate_responses(&mut conversation_manager);
```
Example output:
```
"The Big Lebowski."
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>5. Natural Language Generation </b> </summary>
Generate language based on a prompt. GPT2 and GPT available as base models.
Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
This may impact the results, it is recommended to submit prompts of similar length for best results
Include techniques such as beam search, top-k and nucleus sampling, temperature
setting and repetition penalty. Supports batch generation of sentences from
several prompts. Sequences will be left-padded with the model's padding token if
present, the unknown token otherwise. This may impact the results, it is
recommended to submit prompts of similar length for best results
```rust
let model = GPT2Generator::new(Default::default())?;
@ -263,7 +405,9 @@ This may impact the results, it is recommended to submit prompts of similar leng
let output = model.generate(Some(&[input_context_1, input_context_2]), generate_options);
```
Example output:
```
[
"The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year"
@ -274,12 +418,15 @@ Example output:
"The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>6. Zero-shot classification </b> </summary>
Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
Performs zero-shot classification on input sentences with provided labels using
a model fine-tuned for Natural Language Inference.
```rust
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
@ -296,18 +443,22 @@ Performs zero-shot classification on input sentences with provided labels using
```
Output:
```
[
[ Label { "politics", score: 0.972 }, Label { "public health", score: 0.032 }, Label {"economics", score: 0.006 }, Label {"sports", score: 0.004 } ],
[ Label { "politics", score: 0.975 }, Label { "public health", score: 0.0818 }, Label {"economics", score: 0.852 }, Label {"sports", score: 0.001 } ],
]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>7. Sentiment analysis </b> </summary>
Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on SST-2.
Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on
SST-2.
```rust
let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -319,9 +470,11 @@ Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on SST
let output = sentiment_classifier.predict(&input);
```
(Example courtesy of [IMDb](http://www.imdb.com))
Output:
```
[
Sentiment { polarity: Positive, score: 0.9981985493795946 },
@ -329,13 +482,17 @@ Output:
Sentiment { polarity: Positive, score: 0.9997248985164333 }
]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>8. Named Entity Recognition </b> </summary>
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model fine-tuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Extracts entities (Person, Location, Organization, Miscellaneous) from text.
BERT cased large model fine-tuned on CoNNL03, contributed by the
[MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Models are currently available for English, German, Spanish and Dutch.
```rust
let ner_model = NERModel::new(default::default())?;
@ -346,7 +503,9 @@ Models are currently available for English, German, Spanish and Dutch.
let output = ner_model.predict(&input);
```
Output:
```
[
[
@ -359,8 +518,9 @@ Output:
]
]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>9. Keywords/keyphrases extraction</b> </summary>
@ -381,7 +541,9 @@ fn main() -> anyhow::Result<()> {
let output = keyword_extraction_model.predict(&[input])?;
}
```
Output:
```
"rust" - 0.50910604
"programming" - 0.35731024
@ -389,12 +551,14 @@ Output:
"concurrent" - 0.31229728
"program" - 0.29115444
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>10. Part of Speech tagging </b> </summary>
Extracts Part of Speech tags (Noun, Verb, Adjective...) from text.
```rust
let pos_model = POSModel::new(default::default())?;
@ -402,7 +566,9 @@ Extracts Part of Speech tags (Noun, Verb, Adjective...) from text.
let output = pos_model.predict(&input);
```
Output:
```
[
Entity { word: "My", score: 0.1560, label: "PRP" }
@ -411,12 +577,15 @@ Output:
Entity { word: "Bob", score: 0.7460, label: "NNP" }
]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>11. Sentence embeddings </b> </summary>
Generate sentence embeddings (vector representation). These can be used for applications including dense information retrieval.
Generate sentence embeddings (vector representation). These can be used for
applications including dense information retrieval.
```rust
let model = SentenceEmbeddingsBuilder::remote(
SentenceEmbeddingsModelType::AllMiniLmL12V2
@ -427,21 +596,25 @@ Generate sentence embeddings (vector representation). These can be used for appl
"each sentence is converted"
];
let output = model.predict(&sentences);
let output = model.encode(&sentences)?;
```
Output:
```
[
[-0.000202666, 0.08148022, 0.03136178, 0.002920636 ...],
[0.064757116, 0.048519745, -0.01786038, -0.0479775 ...]
]
```
</details>
&nbsp;
&nbsp;
<details>
<summary> <b>12. Masked Language Model </b> </summary>
Predict masked words in input sentences.
```rust
let model = MaskedLanguageModel::new(Default::default())?;
@ -452,7 +625,9 @@ Predict masked words in input sentences.
let output = model.predict(&sentences);
```
Output:
```
[
[MaskedToken { text: "college", id: 2267, score: 8.091}],
@ -462,29 +637,61 @@ Output:
]
]
```
</details>
## Benchmarks
For simple pipelines (sequence classification, tokens classification, question answering) the performance between Python and Rust is expected to be comparable. This is because the most expensive part of these pipeline is the language model itself, sharing a common implementation in the Torch backend. The [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/) provides a benchmarks section covering all pipelines.
For simple pipelines (sequence classification, tokens classification, question
answering) the performance between Python and Rust is expected to be comparable.
This is because the most expensive part of these pipeline is the language model
itself, sharing a common implementation in the Torch backend. The
[End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/)
provides a benchmarks section covering all pipelines.
For text generation tasks (summarization, translation, conversation, free text generation), significant benefits can be expected (up to 2 to 4 times faster processing depending on the input and application). The article [Accelerating text generation with Rust](https://guillaume-be.github.io/2020-11-21/generation_benchmarks) focuses on these text generation applications and provides more details on the performance comparison to Python.
For text generation tasks (summarization, translation, conversation, free text
generation), significant benefits can be expected (up to 2 to 4 times faster
processing depending on the input and application). The article
[Accelerating text generation with Rust](https://guillaume-be.github.io/2020-11-21/generation_benchmarks)
focuses on these text generation applications and provides more details on the
performance comparison to Python.
## Loading pretrained and custom model weights
The base model and task-specific heads are also available for users looking to expose their own transformer based models.
Examples on how to prepare the date using a native tokenizers Rust library are available in `./examples` for BERT, DistilBERT, RoBERTa, GPT, GPT2 and BART.
Note that when importing models from Pytorch, the convention for parameters naming needs to be aligned with the Rust schema. Loading of the pre-trained weights will fail if any of the model parameters weights cannot be found in the weight files.
If this quality check is to be skipped, an alternative method `load_partial` can be invoked from the variables store.
The base model and task-specific heads are also available for users looking to
expose their own transformer based models. Examples on how to prepare the date
using a native tokenizers Rust library are available in `./examples` for BERT,
DistilBERT, RoBERTa, GPT, GPT2 and BART. Note that when importing models from
Pytorch, the convention for parameters naming needs to be aligned with the Rust
schema. Loading of the pre-trained weights will fail if any of the model
parameters weights cannot be found in the weight files. If this quality check is
to be skipped, an alternative method `load_partial` can be invoked from the
variables store.
Pretrained models are available on Hugging face's [model hub](https://huggingface.co/models?filter=rust) and can be loaded using `RemoteResources` defined in this library.
A conversion utility script is included in `./utils` to convert Pytorch weights to a set of weights compatible with this library. This script requires Python and `torch` to be set-up, and can be used as follows:
`python ./utils/convert_model.py path/to/pytorch_model.bin` where `path/to/pytorch_model.bin` is the location of the original Pytorch weights.
Pretrained models are available on Hugging face's
[model hub](https://huggingface.co/models?filter=rust) and can be loaded using
`RemoteResources` defined in this library.
A conversion utility script is included in `./utils` to convert Pytorch weights
to a set of weights compatible with this library. This script requires Python
and `torch` to be set-up, and can be used as follows:
`python ./utils/convert_model.py path/to/pytorch_model.bin` where
`path/to/pytorch_model.bin` is the location of the original Pytorch weights.
```bash
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
python utils/convert_model.py path/to/pytorch_model.bin
```
## Citation
If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/):
If you use `rust-bert` for your work, please cite
[End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/):
```bibtex
@inproceedings{becquin-2020-end,
title = "End-to-end {NLP} Pipelines in Rust",
@ -499,6 +706,7 @@ If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in R
## Acknowledgements
Thank you to [Hugging Face](https://huggingface.co) for hosting a set of weights compatible with this Rust library.
The list of ready-to-use pretrained models is listed at [https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust).
Thank you to [Hugging Face](https://huggingface.co) for hosting a set of weights
compatible with this Rust library. The list of ready-to-use pretrained models is
listed at
[https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust).

View File

@ -5,7 +5,7 @@ use criterion::{black_box, Criterion};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use std::time::{Duration, Instant};
@ -14,7 +14,9 @@ use tch::Device;
fn create_text_generation_model() -> TextGenerationModel {
let config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
))),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
@ -35,6 +37,7 @@ fn create_text_generation_model() -> TextGenerationModel {
diversity_penalty: None,
num_return_sequences: 5,
device: Device::cuda_if_available(),
kind: None,
};
TextGenerationModel::new(config).unwrap()
}
@ -50,10 +53,6 @@ fn generation_forward_pass(iters: u64, model: &TextGenerationModel, data: &[&str
}
fn bench_generation(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_text_generation_model();
// Define input

View File

@ -3,7 +3,7 @@ extern crate criterion;
use criterion::{black_box, Criterion};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -17,7 +17,9 @@ static BATCH_SIZE: usize = 64;
fn create_qa_model() -> QuestionAnsweringModel {
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
@ -52,7 +54,9 @@ fn qa_load_model(iters: u64) -> Duration {
let start = Instant::now();
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
@ -69,9 +73,7 @@ fn qa_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up QA model
let model = create_qa_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));

View File

@ -79,9 +79,7 @@ fn sst2_load_model(iters: u64) -> Duration {
fn bench_sst2(c: &mut Criterion) {
// Set-up classifier
let model = create_sentiment_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}
// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH").expect(
"Please set the \"SST2_PATH\" environment variable pointing to the SST2 dataset folder",

View File

@ -40,9 +40,6 @@ fn summarization_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_summarization_model();
// Define input

View File

@ -17,12 +17,8 @@ fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration {
}
fn bench_tensor_ops(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let input = Tensor::rand(&[32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand(&[512, 512], (Kind::Float, Device::cuda_if_available()));
let input = Tensor::rand([32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand([512, 512], (Kind::Float, Device::cuda_if_available()));
let _ = &input.matmul(&weights);
c.bench_function("Matrix multiply ", |b| {

View File

@ -14,9 +14,6 @@ fn create_model() -> TokenClassificationModel {
fn bench_token_classification_predict(c: &mut Criterion) {
// Set-up model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_model();
// Define input

View File

@ -73,9 +73,6 @@ fn translation_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up translation model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_translation_model();
// Define input

29
build.rs Normal file
View File

@ -0,0 +1,29 @@
// Copyright 2023 Laurent Mazare
// https://github.com/LaurentMazare/diffusers-rs/blob/main/build.rs
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
fn main() {
let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
match os.as_str() {
"linux" | "windows" => {
if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") {
println!(
"cargo:rustc-link-arg=-Wl,-rpath={}",
lib_path.to_string_lossy()
);
}
println!("cargo:rustc-link-arg=-Wl,--no-as-needed");
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
println!("cargo:rustc-link-arg=-ltorch");
}
_ => {}
}
}

View File

@ -0,0 +1,98 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use std::sync::{Arc, RwLock};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::common::ModelResource;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider};
use tch::Device;
fn main() -> anyhow::Result<()> {
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
let weights = Arc::new(RwLock::new(get_weights()?));
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input)?;
for sentence in output {
println!("{sentence}");
}
let summarization_model =
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input)?;
for sentence in output {
println!("{sentence}");
}
Ok(())
}
fn get_weights() -> anyhow::Result<Vec<u8>, anyhow::Error> {
let model_resource = RemoteResource::from_pretrained(BartModelResources::DISTILBART_CNN_6_6);
Ok(std::fs::read(model_resource.get_local_path()?)?)
}
fn config(device: Device, model_data: Arc<RwLock<Vec<u8>>>) -> SummarizationConfig {
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = ModelResource::Torch(Box::new(BufferResource { data: model_data }));
SummarizationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: Some(142),
device,
..Default::default()
}
}

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
@ -26,7 +26,9 @@ fn main() -> anyhow::Result<()> {
// Language identification
let sequence_classification_config = SequenceClassificationConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::CODEBERTA_LANGUAGE_ID),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
RobertaModelResources::CODEBERTA_LANGUAGE_ID,
))),
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERTA_LANGUAGE_ID),
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERTA_LANGUAGE_ID),
Some(RemoteResource::from_pretrained(
@ -56,7 +58,9 @@ fn main() -> anyhow::Result<()> {
// Masked language model
let config = MaskedLanguageConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::CODEBERT_MLM),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
RobertaModelResources::CODEBERT_MLM,
))),
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERT_MLM),
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERT_MLM),
Some(RemoteResource::from_pretrained(

View File

@ -30,7 +30,7 @@ fn main() -> anyhow::Result<()> {
let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
for sentence in output {
println!("{sentence:?}");

View File

@ -0,0 +1,61 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::{ModelType, TokenizerOption};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use std::fs::File;
use std::io::Write;
use tempfile::TempDir;
fn main() -> anyhow::Result<()> {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
};
// Create tokenizer
let tmp_dir = TempDir::new()?;
let special_token_map_path = tmp_dir.path().join("special_token_map.json");
let mut tmp_file = File::create(&special_token_map_path)?;
writeln!(
tmp_file,
r#"{{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}}"#
)?;
let tokenizer_path = RemoteResource::from_pretrained((
"gpt2/tokenizer",
"https://huggingface.co/gpt2/resolve/main/tokenizer.json",
))
.get_local_path()?;
let tokenizer =
TokenizerOption::from_hf_tokenizer_file(tokenizer_path, special_token_map_path)?;
let model = TextGenerationModel::new_with_tokenizer(generate_config, tokenizer)?;
let input_context = "The dog";
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None)?;
for sentence in output {
println!("{sentence:?}");
}
Ok(())
}

View File

@ -17,7 +17,7 @@ extern crate anyhow;
use rust_bert::gpt_neo::{
GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -53,11 +53,11 @@ fn main() -> anyhow::Result<()> {
};
let mut model = TextGenerationModel::new(generate_config)?;
model.set_device(Device::cuda_if_available());
model.set_device(Device::cuda_if_available())?;
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
for sentence in output {
println!("{sentence}");

View File

@ -1,7 +1,7 @@
use std::path::PathBuf;
use rust_bert::gpt_j::{GptJConfigResources, GptJMergesResources, GptJVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{LocalResource, RemoteResource};
use tch::Device;
@ -44,7 +44,6 @@ use tch::Device;
/// ```
///
/// [gpt-j-6B-float16]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16
///
fn main() -> anyhow::Result<()> {
// Resources paths
@ -68,7 +67,7 @@ fn main() -> anyhow::Result<()> {
let generation_config = TextGenerationConfig {
model_type: ModelType::GPTJ,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -90,7 +89,7 @@ fn main() -> anyhow::Result<()> {
"It was a very nice and sunny",
"It was a gloom winter night, and",
];
let output = model.generate(&prompts, None);
let output = model.generate(&prompts, None)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0], "It was a very nice and sunny day, and I was sitting in the garden of my house, enjoying the sun and the fresh air. I was thinking");

View File

@ -14,7 +14,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::reformer::{
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {
));
let generate_config = TextGenerationConfig {
model_type: ModelType::Reformer,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,
@ -52,7 +52,7 @@ fn main() -> anyhow::Result<()> {
let input_context_1 = "The really great men must, I think,";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
let output = model.generate(&[input_context_1, input_context_2], None)?;
for sentence in output {
println!("{sentence}");

View File

@ -14,7 +14,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let generate_config = TextGenerationConfig {
model_type: ModelType::XLNet,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,
@ -47,7 +47,7 @@ fn main() -> anyhow::Result<()> {
let model = TextGenerationModel::new(generate_config)?;
let input_context = "Once upon a time,";
let output = model.generate(&[input_context], None);
let output = model.generate(&[input_context], None)?;
for sentence in output {
println!("{sentence}");

View File

@ -12,14 +12,16 @@
extern crate anyhow;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up model
let config = MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,

View File

@ -4,7 +4,7 @@ use rust_bert::deberta::{
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Kind, Tensor};
@ -27,7 +27,6 @@ fn main() -> anyhow::Result<()> {
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
@ -39,7 +38,7 @@ fn main() -> anyhow::Result<()> {
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
vs.load(weights_path)?;
load_weights(&model_resource, &mut vs, None, device)?;
// Define input
let input = [("I love you.", "I like you.")];
@ -63,7 +62,7 @@ fn main() -> anyhow::Result<()> {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

View File

@ -0,0 +1,36 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let masked_lm = MaskedLanguageModel::new(MaskedLanguageConfig::new(
ModelType::Bert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/model.onnx",
"onnx-bert-base-uncased-for-masked-lm",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/config.json",
"onnx-bert-base-uncased-for-masked-lm",
),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/vocab.txt",
"onnx-bert-base-uncased-for-masked-lm",
),
None,
false,
None,
None,
Some(String::from("<mask>")),
))?;
let input = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];
let output = masked_lm.predict(input)?;
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,40 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let qa_model = QuestionAnsweringModel::new(QuestionAnsweringConfig::new(
ModelType::Roberta,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/model.onnx",
"onnx-roberta-base-squad2",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/config.json",
"onnx-roberta-base-squad2",
),
RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/vocab.json",
"onnx-roberta-base-squad2",
),
Some(RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/merges.txt",
"onnx-roberta-base-squad2",
)),
false,
None,
None,
))?;
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let output = qa_model.predict(&[qa_input], 1, 32);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,37 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::sentiment::SentimentModel;
use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let classification_model = SentimentModel::new(SequenceClassificationConfig::new(
ModelType::DistilBert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/model.onnx",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
),
RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
),
None,
true,
None,
None,
))?;
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = classification_model.predict(input);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,42 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let text_generation_model = TextGenerationModel::new(TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::ONNX(ONNXModelResources {
encoder_resource: None,
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_model.onnx",
"onnx-gpt2",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_with_past_model.onnx",
"onnx-gpt2",
))),
}),
config_resource: Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/config.json",
"onnx-gpt2",
)),
vocab_resource: Box::new(RemoteResource::new(
"https://huggingface.co/gpt2/resolve/main/vocab.json",
"onnx-gpt2",
)),
merges_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/gpt2/resolve/main/merges.txt",
"onnx-gpt2",
))),
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
})?;
let prompts = ["It was a very nice and sunny"];
let output = text_generation_model.generate(&prompts, None);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,36 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig,
};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let token_classification_model = NERModel::new(TokenClassificationConfig::new(
ModelType::Bert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/model.onnx",
"onnx-bert-base-NER",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/config.json",
"onnx-bert-base-NER",
),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/vocab.txt",
"onnx-bert-base-NER",
),
None,
false,
None,
None,
LabelAggregationOption::First,
))?;
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
let output = token_classification_model.predict_full_entities(&input);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,63 @@
use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
use tch::Device;
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let translation_model = TranslationModel::new(TranslationConfig::new(
ModelType::M2M100,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
"onnx-m2m100_418M",
))),
}),
RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
"onnx-m2m100_418M",
),
RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
"onnx-m2m100_418M",
),
Some(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
"onnx-m2m100_418M",
)),
M2M100SourceLanguages::M2M100_418M,
M2M100TargetLanguages::M2M100_418M,
Device::cuda_if_available(),
))?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::French,
)?);
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::Spanish,
)?);
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::Hindi,
)?);
println!("{:?}", outputs);
Ok(())
}

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -23,7 +23,9 @@ fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta

View File

@ -16,7 +16,7 @@ use rust_bert::longformer::{
LongformerConfigResources, LongformerMergesResources, LongformerModelResources,
LongformerVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -26,7 +26,9 @@ fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
))),
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Some(RemoteResource::from_pretrained(

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::RemoteResource;
@ -25,9 +25,9 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
)));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,

View File

@ -15,6 +15,7 @@ extern crate anyhow;
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::common::ModelResource;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -34,7 +35,7 @@ fn main() -> anyhow::Result<()> {
));
let summarization_config = SummarizationConfig {
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -71,7 +72,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig {
model_type: ModelType::Pegasus,
model_resource: weights_resource,
model_resource: ModelResource::Torch(weights_resource),
config_resource,
vocab_resource,
merges_resource: None,
@ -66,7 +66,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
model_resource: ModelResource::Torch(weights_resource),
config_resource,
vocab_resource,
merges_resource: None,
@ -68,7 +68,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig::new(
ModelType::T5,
weights_resource,
ModelResource::Torch(Box::new(weights_resource)),
config_resource,
vocab_resource,
None,
@ -54,7 +54,7 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
let _output = summarization_model.summarize(&input)?;
for sentence in _output {
println!("{sentence}");
}

View File

@ -11,7 +11,7 @@
// limitations under the License.
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig,
@ -22,7 +22,9 @@ fn main() -> anyhow::Result<()> {
// Load a configuration
let config = TokenClassificationConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
None, //merges resource only relevant with ModelType::Roberta

View File

@ -16,7 +16,7 @@ use rust_bert::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::M2M100,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),

View File

@ -17,7 +17,7 @@ use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),

View File

@ -16,7 +16,7 @@ use rust_bert::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::MBart,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
None,

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::T5,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
None,

View File

@ -1,3 +0,0 @@
torch == 1.13.1
requests == 2.25.1
numpy == 1.23.4

View File

@ -43,7 +43,7 @@ impl XDropout {
impl ModuleT for XDropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
if train {
let mask = (Tensor::ones(&[1], (input.kind(), input.device()))
let mask = (Tensor::ones([1], (input.kind(), input.device()))
- input
.empty_like()
.bernoulli_float_(1_f64 - self.dropout_prob))

View File

@ -1,3 +1,7 @@
#[cfg(feature = "onnx")]
use ndarray::ShapeError;
#[cfg(feature = "onnx")]
use ort::OrtError;
use rust_tokenizers::error::TokenizerError;
use tch::TchError;
use thiserror::Error;
@ -22,6 +26,17 @@ pub enum RustBertError {
#[error("Value error: {0}")]
ValueError(String),
#[error("Value error: {0}")]
#[cfg(feature = "onnx")]
OrtError(String),
#[error("Value error: {0}")]
#[cfg(feature = "onnx")]
NdArrayError(String),
#[error("Unsupported operation")]
UnsupportedError,
}
impl From<std::io::Error> for RustBertError {
@ -41,3 +56,16 @@ impl From<TchError> for RustBertError {
RustBertError::TchError(error.to_string())
}
}
#[cfg(feature = "onnx")]
impl From<OrtError> for RustBertError {
fn from(error: OrtError) -> Self {
RustBertError::OrtError(error.to_string())
}
}
#[cfg(feature = "onnx")]
impl From<ShapeError> for RustBertError {
fn from(error: ShapeError) -> Self {
RustBertError::NdArrayError(error.to_string())
}
}

View File

@ -0,0 +1,72 @@
use crate::common::error::RustBertError;
use crate::resources::{Resource, ResourceProvider};
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
/// # In-memory raw buffer resource
#[derive(Debug)]
pub struct BufferResource {
/// The data representing the underlying resource
pub data: Arc<RwLock<Vec<u8>>>,
}
impl ResourceProvider for BufferResource {
/// Not implemented for this resource type
///
/// # Returns
///
/// * `RustBertError::UnsupportedError`
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
Err(RustBertError::UnsupportedError)
}
/// Gets a wrapper referring to the in-memory resource.
///
/// # Returns
///
/// * `Resource` referring to the resource data
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{BufferResource, ResourceProvider};
/// let data = std::fs::read("path/to/rust_model.ot").unwrap();
/// let weights_resource = BufferResource::from(data);
/// let weights = weights_resource.get_resource();
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::Buffer(self.data.write().unwrap()))
}
}
impl From<Vec<u8>> for BufferResource {
fn from(data: Vec<u8>) -> Self {
Self {
data: Arc::new(RwLock::new(data)),
}
}
}
impl From<Vec<u8>> for Box<dyn ResourceProvider> {
fn from(data: Vec<u8>) -> Self {
Box::new(BufferResource {
data: Arc::new(RwLock::new(data)),
})
}
}
impl From<RwLock<Vec<u8>>> for BufferResource {
fn from(lock: RwLock<Vec<u8>>) -> Self {
Self {
data: Arc::new(lock),
}
}
}
impl From<RwLock<Vec<u8>>> for Box<dyn ResourceProvider> {
fn from(lock: RwLock<Vec<u8>>) -> Self {
Box::new(BufferResource {
data: Arc::new(lock),
})
}
}

View File

@ -1,9 +1,9 @@
use crate::common::error::RustBertError;
use crate::resources::ResourceProvider;
use crate::resources::{Resource, ResourceProvider};
use std::path::PathBuf;
/// # Local resource
#[derive(PartialEq, Eq, Clone)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf,
@ -29,6 +29,26 @@ impl ResourceProvider for LocalResource {
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
Ok(self.local_path.clone())
}
/// Gets a wrapper around the path for a local resource.
///
/// # Returns
///
/// * `Resource` wrapping a `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{LocalResource, ResourceProvider};
/// use std::path::PathBuf;
/// let config_resource = LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// };
/// let config_path = config_resource.get_resource();
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::PathBuf(self.local_path.clone()))
}
}
impl From<PathBuf> for LocalResource {

View File

@ -1,6 +1,6 @@
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This crate relies on the concept of Resources to access the data used by the models.
//! This includes:
//! - model weights
//! - configuration files
@ -11,20 +11,35 @@
//! resource location. Two types of resources are pre-defined:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL
//! - BufferResource: refers to a buffer that contains file contents for a resource (currently only
//! usable for weights)
//!
//! For both types of resources, the local location of the file can be retrieved using
//! For `LocalResource` and `RemoteResource`, the local location of the file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
mod buffer;
mod local;
use crate::common::error::RustBertError;
pub use buffer::BufferResource;
pub use local::LocalResource;
use std::fmt::Debug;
use std::ops::DerefMut;
use std::path::PathBuf;
use std::sync::RwLockWriteGuard;
use tch::nn::VarStore;
use tch::{Device, Kind};
/// # Resource Trait that can provide the location of the model, configuration or vocabulary resources
pub trait ResourceProvider {
pub enum Resource<'a> {
PathBuf(PathBuf),
Buffer(RwLockWriteGuard<'a, Vec<u8>>),
}
/// # Resource Trait that can provide the location or data for the model, and location of
/// configuration or vocabulary resources
pub trait ResourceProvider: Debug + Send + Sync {
/// Provides the local path for a resource.
///
/// # Returns
@ -42,9 +57,47 @@ pub trait ResourceProvider {
/// let config_path = config_resource.get_local_path();
/// ```
fn get_local_path(&self) -> Result<PathBuf, RustBertError>;
/// Provides access to an underlying resource.
///
/// # Returns
///
/// * `Resource` wrapping a representation of a resource.
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{BufferResource, LocalResource, ResourceProvider};
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError>;
}
impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
T::get_local_path(self)
}
fn get_resource(&self) -> Result<Resource, RustBertError> {
T::get_resource(self)
}
}
/// Load the provided `VarStore` with model weights from the provided `ResourceProvider`
pub fn load_weights(
rp: &(impl ResourceProvider + ?Sized),
vs: &mut VarStore,
kind: Option<Kind>,
device: Device,
) -> Result<(), RustBertError> {
match rp.get_resource()? {
Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())),
Resource::PathBuf(path) => vs.load(path),
}?;
cast_var_store(vs, kind, device);
Ok(())
}
#[cfg(feature = "remote")]
mod remote;
use crate::pipelines::common::cast_var_store;
#[cfg(feature = "remote")]
pub use remote::RemoteResource;

View File

@ -6,7 +6,7 @@ use lazy_static::lazy_static;
use std::path::PathBuf;
/// # Remote resource that will be downloaded and cached locally on demand
#[derive(PartialEq, Eq, Clone)]
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
@ -31,7 +31,7 @@ impl RemoteResource {
///
/// ```no_run
/// use rust_bert::resources::RemoteResource;
/// let config_resource = RemoteResource::new("configs", "http://config_json_location");
/// let config_resource = RemoteResource::new("http://config_json_location", "configs");
/// ```
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
RemoteResource {
@ -93,6 +93,23 @@ impl ResourceProvider for RemoteResource {
.cached_path_with_options(&self.url, &Options::default().subdir(&self.cache_subdir))?;
Ok(cached_path)
}
/// Gets a wrapper around the local path for a remote resource.
///
/// # Returns
///
/// * `Resource` wrapping a `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{RemoteResource, ResourceProvider};
/// let config_resource = RemoteResource::new("http://config_json_location", "configs");
/// let config_path = config_resource.get_resource();
/// ```
fn get_resource(&self) -> Result<Resource, RustBertError> {
Ok(Resource::PathBuf(self.get_local_path()?))
}
}
lazy_static! {

View File

@ -1,6 +1,6 @@
//! # Ready-to-use NLP pipelines and Transformer-based models
//!
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using [tch-rs](https://github.com/LaurentMazare/tch-rs) or [onnxruntime bindings](https://github.com/pykeio/ort) and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
//! This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
//!
//! Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
@ -42,6 +42,7 @@
//! - Language Generation
//! - Sentence Embeddings
//! - Masked Language Model
//! - Keywords extraction
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
@ -61,10 +62,12 @@
//!GPT| | | |✅ | | | | |
//!GPT2| | | |✅ | | | | |
//!GPT-Neo| | | |✅ | | | | |
//!GPT-J| | | |✅ | | | | |
//!BART|✅| | |✅ |✅| | | |
//!Marian| | | | | |✅| | |
//!MBart|✅| | |✅ | | | | |
//!M2M100| | | |✅ | | | | |
//!NLLB| | | |✅ | | | | |
//!Electra | |✅| | | | |✅| |
//!ALBERT |✅|✅|✅| | | |✅| ✅ |
//!T5 | | | |✅ |✅|✅| | ✅ |
@ -87,8 +90,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.13.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.2`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
@ -105,10 +108,36 @@
//!
//! ### Automatic installation
//!
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu117`.
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you. The `download-libtorch` feature flag needs to be enabled.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu118`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//!
//! ## ONNX Support (Optional)
//!
//! ONNX support can be enabled via the optional `onnx` feature. This crate then leverages the [ort](https://github.com/pykeio/ort) crate with bindings to the onnxruntime C++ library. We refer the user to this page project for further installation instructions/support.
//! 1. Enable the optional `onnx` feature. The `rust-bert` crate does not include any optional dependencies for `ort`, the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library.
//! 2. The current recommended installation is to use dynamic linking by pointing to an existing library location. Use the `load-dynamic` cargo feature for `ort`.
//! 3. set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib` depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
//!
//! Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library. A detailed guide on how to export a Transformer model to ONNX using optimum is available at <https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model>
//! The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
//!
//! | Architecture | Encoder file | Decoder without past file | Decoder with past file |
//! -----------------------------|---------------|---------------------------|-------------------------
//! | Encoder (e.g. BERT) | required | not used | not used |
//! | Decoder (e.g. GPT2) | not used | required | optional |
//! | Encoder-decoder (e.g. BART) | required | required | optional |
//!
//! Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
//! since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
//! redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
//! he base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
//!
//! Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
//! ost pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
//! token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
//! These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. Examples leveraging ONNX models are given in the `./examples` directory. More information on these can be found in the [`onnx` module](./pipelines/onnx/index.html)
//!
//! # Ready-to-use pipelines
//!
//! Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. More information on these can be found in the [`pipelines` module](./pipelines/index.html)
@ -331,15 +360,15 @@
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
@ -698,33 +727,15 @@
extern crate core;
pub mod albert;
pub mod bart;
pub mod bert;
mod common;
pub mod deberta;
pub mod deberta_v2;
pub mod distilbert;
pub mod electra;
pub mod fnet;
pub mod gpt2;
pub mod gpt_j;
pub mod gpt_neo;
pub mod longformer;
pub mod longt5;
pub mod m2m_100;
pub mod marian;
pub mod mbart;
pub mod mobilebert;
pub mod openai_gpt;
pub mod pegasus;
pub mod models;
pub mod pipelines;
pub mod prophetnet;
pub mod reformer;
pub mod roberta;
pub mod t5;
pub mod xlnet;
pub use common::error::RustBertError;
pub use common::resources;
pub use common::{Activation, Config};
pub use models::{
albert, bart, bert, deberta, deberta_v2, distilbert, electra, fnet, gpt2, gpt_j, gpt_neo,
longformer, longt5, m2m_100, marian, mbart, mobilebert, nllb, openai_gpt, pegasus, prophetnet,
reformer, roberta, t5, xlnet,
};

View File

@ -257,7 +257,7 @@ impl AlbertModel {
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
Some(Tensor::ones(input_shape, (Kind::Int64, device)))
} else {
None
};

View File

@ -130,8 +130,8 @@ impl AlbertSelfAttention {
self.hidden_size,
));
let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w], None) + self.dense.bs.as_ref().unwrap();
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w], None::<i64>)
+ self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
if !self.output_attentions {

View File

@ -16,6 +16,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -176,7 +176,7 @@ impl BartAttention {
.bmm(&value_states)
.view([bs, self.num_heads, target_length, self.head_dim])
.transpose(1, 2)
.reshape(&[bs, target_length, embed_dim])
.reshape([bs, target_length, embed_dim])
.apply(&self.out_proj);
(attention_output, saved_attention_weights, new_layer_state)

View File

@ -21,12 +21,9 @@ use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::RobertaVocab;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
@ -102,12 +99,12 @@ impl BartConfigResources {
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/config",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/config.json",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/config",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/config.json",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/config.json",
);
}
@ -135,12 +132,12 @@ impl BartVocabResources {
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/vocab",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/vocab.json",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/vocab",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/vocab.json",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/vocab.json",
);
}
@ -168,12 +165,12 @@ impl BartMergesResources {
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/merges",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/merges.txt",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/merges",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/merges.txt",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/merges.txt",
);
}
@ -199,6 +196,8 @@ pub struct BartConfig {
pub encoder_layers: i64,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub pad_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
@ -242,6 +241,8 @@ impl Default for BartConfig {
bos_token_id: Some(0),
eos_token_id: Some(2),
pad_token_id: Some(1),
forced_bos_token_id: Some(0),
forced_eos_token_id: Some(2),
id2label: None,
label2id: None,
init_std: 0.02,
@ -272,7 +273,7 @@ pub(crate) fn _make_causal_mask(
let target_length = input_ids_shape[1];
let mut mask = Tensor::full(
&[target_length, target_length],
[target_length, target_length],
get_min(dtype).unwrap(),
(dtype, device),
);
@ -285,14 +286,14 @@ pub(crate) fn _make_causal_mask(
if past_key_values_length > 0 {
mask = Tensor::cat(
&[
Tensor::zeros(&[target_length, past_key_values_length], (dtype, device)),
Tensor::zeros([target_length, past_key_values_length], (dtype, device)),
mask,
],
-1,
);
}
mask.unsqueeze(0).unsqueeze(0).expand(
&[
[
batch_size,
1,
target_length,
@ -308,7 +309,7 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>, dtype: Kin
let expanded_mask = mask
.unsqueeze(1)
.unsqueeze(1)
.expand(&[batch_size, 1, target_length, source_length], true)
.expand([batch_size, 1, target_length, source_length], true)
.totype(dtype);
let inverted_mask: Tensor = 1 - expanded_mask;
inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap())
@ -356,7 +357,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let output = input_ids.empty_like().to_kind(Kind::Int64);
output
.select(1, 0)
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
.copy_(&input_ids.gather(1, &index_eos, false).squeeze());
output
.slice(1, 1, *output.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));
@ -368,7 +369,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
/// It is made of the following blocks:
/// - `encoder`: `BartEncoder` (transformer) made of a vector of encoding layers
/// - `decoder`: `BartDecoder` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
/// - `pad_token_id`: padding token id
pub struct BartModel {
pub(crate) encoder: BartEncoder,
@ -436,7 +437,7 @@ impl BartModel {
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -596,7 +597,7 @@ impl BartForConditionalGeneration {
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -797,7 +798,7 @@ impl BartForSequenceClassification {
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -826,7 +827,7 @@ impl BartForSequenceClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();;
/// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -865,7 +866,7 @@ impl BartForSequenceClassification {
let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind());
let sentence_representation = base_model_output
.decoder_output
.permute(&[2, 0, 1])
.permute([2, 0, 1])
.masked_select(&eos_mask)
.view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
.transpose(0, 1)
@ -891,110 +892,6 @@ impl BartForSequenceClassification {
}
}
impl LMHeadModel for BartForConditionalGeneration {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for BART
/// * `token_type_ids` - Unused for BART
/// * `position_ids` - Unused for BART
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation_utils::LMHeadModel;
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```
fn forward_t(
&self,
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
None,
cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with BART Model".into(),
));
}
};
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
cache: Cache::BARTCache(base_model_output.cache),
})
}
}
/// Container holding a BART model output. The decoder output may hold the hidden state of
/// the last layer of the decoder, or may hold logits for a custom head module after the
/// decoder (e.g. for classification or language modeling tasks)
@ -1024,6 +921,8 @@ pub struct BartGenerator {
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
forced_bos_token_id: Option<i64>,
forced_eos_token_id: Option<i64>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
@ -1099,24 +998,30 @@ impl BartGenerator {
tokenizer: TokenizerOption,
) -> Result<BartGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![2],
});
let forced_bos_token_id = config.forced_bos_token_id;
let forced_eos_token_id = config.forced_eos_token_id;
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(BartGenerator {
@ -1126,6 +1031,8 @@ impl BartGenerator {
generate_config,
bos_token_id,
eos_token_ids,
forced_bos_token_id,
forced_eos_token_id,
pad_token_id,
is_encoder_decoder,
vocab_size,
@ -1133,30 +1040,20 @@ impl BartGenerator {
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer>
for BartGenerator
{
fn get_model(&self) -> &BartForConditionalGeneration {
&self.model
}
impl PrivateLanguageGenerator for BartGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -1167,6 +1064,12 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_forced_bos_token_id(&self) -> Option<i64> {
self.forced_bos_token_id
}
fn get_forced_eos_token_id(&self) -> Option<i64> {
self.forced_eos_token_id
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
@ -1179,31 +1082,57 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn prepare_scores_for_generation(
fn forward_t(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(
scores,
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
);
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
None,
cached_layer_states,
train,
),
Cache::None => self.model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with BART Model".into(),
));
}
}
};
Ok(LMModelOutput {
lm_logits: base_model_output.decoder_output,
cache: Cache::BARTCache(base_model_output.cache),
})
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
Some(self.model.encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
@ -1234,48 +1163,6 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
@ -1312,10 +1199,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}
impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer>
for BartGenerator
{
}
impl LanguageGenerator for BartGenerator {}
#[cfg(test)]
mod test {

View File

@ -340,6 +340,7 @@ impl BartDecoder {
}
}
#[allow(dead_code)]
///Container holding a BART decoder output
pub struct BartDecoderOutput {
/// last decoder layer hidden state

View File

@ -2,7 +2,7 @@
//!
//! Implementation of the BART language model ([BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) Lewis, Liu, Goyal, Ghazvininejad, Mohamed, Levy, Stoyanov, Zettlemoyer, 2019).
//! The base model is implemented in the `bart_model::BartModel` struct. The model also includes a language model head: `bart_model::BartForConditionalGeneration`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
@ -11,6 +11,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `RobertaTokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -42,6 +42,11 @@ impl BertModelResources {
"bert/model",
"https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/model",
"https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model",
@ -75,6 +80,11 @@ impl BertConfigResources {
"bert/config",
"https://huggingface.co/bert-base-uncased/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/config",
"https://huggingface.co/bert-large-uncased/resolve/main/config.json",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config",
@ -108,6 +118,11 @@ impl BertVocabResources {
"bert/vocab",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/vocab",
"https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab",
@ -370,7 +385,7 @@ impl<T: BertEmbedding> BertModel<T> {
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat([
input_shape[0],
input_shape[1],
1,
@ -396,9 +411,11 @@ impl<T: BertEmbedding> BertModel<T> {
train,
)?;
let extended_attention_mask: Tensor =
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let extended_attention_mask: Tensor = ((extended_attention_mask
.ones_like()
.bitwise_xor_tensor(&extended_attention_mask))
* -10000.0)
.to_kind(embedding_output.kind());
let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {
@ -407,7 +424,7 @@ impl<T: BertEmbedding> BertModel<T> {
let encoder_mask = match encoder_mask {
Some(value) => value.copy(),
None => Tensor::ones(
&[
[
encoder_hidden_states_shape[0],
encoder_hidden_states_shape[1],
],
@ -559,7 +576,7 @@ impl BertForMaskedLM {
{
let p = p.borrow();
let bert = BertModel::new(p / "bert", config);
let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
let cls = BertLMPredictionHead::new(p / "cls", config);
BertForMaskedLM { bert, cls }
@ -989,7 +1006,7 @@ impl BertForTokenClassification {
{
let p = p.borrow();
let bert = BertModel::new(p / "bert", config);
let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label

View File

@ -16,6 +16,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -37,7 +37,7 @@ pub trait DisentangledSelfAttention {
pub fn build_relative_position(query_size: i64, key_size: i64, device: Device) -> Tensor {
let q_ids = Tensor::arange(query_size, (Kind::Int64, device));
let k_ids = Tensor::arange(key_size, (Kind::Int64, device));
let rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.view([1, -1]).repeat(&[query_size, 1]);
let rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.view([1, -1]).repeat([query_size, 1]);
rel_pos_ids.slice(0, 0, query_size, 1).unsqueeze(0)
}
@ -62,7 +62,7 @@ impl DebertaDisentangledSelfAttention {
let mut new_shape = x.size();
let _ = new_shape.pop();
new_shape.extend_from_slice(&[self.num_attention_heads, -1]);
x.view(new_shape.as_slice()).permute(&[0, 2, 1, 3])
x.view(new_shape.as_slice()).permute([0, 2, 1, 3])
}
fn linear(&self, weights: &Tensor, bias: Option<&Tensor>, x: &Tensor) -> Tensor {
@ -81,7 +81,7 @@ impl DebertaDisentangledSelfAttention {
) -> Tensor {
let query_layer_size = query_layer.size();
c2p_pos.expand(
&[
[
query_layer_size[0],
query_layer_size[1],
query_layer_size[2],
@ -101,7 +101,7 @@ impl DebertaDisentangledSelfAttention {
let mut key_layer_size = key_layer.size();
key_layer_size.reverse();
c2p_pos.expand(
&[
[
query_layer_size[0],
query_layer_size[1],
key_layer_size[1],
@ -182,7 +182,7 @@ impl DebertaDisentangledSelfAttention {
)
.unsqueeze(0);
let mut score = Tensor::zeros(&[1], (query_layer.kind(), key_layer.device()));
let mut score = Tensor::zeros([1], (query_layer.kind(), key_layer.device()));
// content -> position
if let Some(pos_proj) = &self.pos_proj {
@ -192,7 +192,7 @@ impl DebertaDisentangledSelfAttention {
let c2p_att = c2p_att.gather(
-1,
&self.c2p_dynamic_expand(&c2p_pos, query_layer, &relative_pos),
true,
false,
);
score = score + c2p_att;
}
@ -213,7 +213,7 @@ impl DebertaDisentangledSelfAttention {
.gather(
-1,
&self.p2c_dynamic_expand(&p2c_pos, query_layer, key_layer),
true,
false,
)
.transpose(-1, -2);
if query_layer_size[1] != key_layer_size[1] {
@ -221,7 +221,7 @@ impl DebertaDisentangledSelfAttention {
p2c_att = p2c_att.gather(
-2,
&self.pos_dynamic_expand(&pos_index, &p2c_att, key_layer),
true,
false,
);
}
score = score + p2c_att;
@ -410,9 +410,9 @@ impl DisentangledSelfAttention for DebertaDisentangledSelfAttention {
if let Some(head_logits_proj) = &self.head_logits_proj {
attention_scores = attention_scores
.permute(&[0, 2, 3, 1])
.permute([0, 2, 3, 1])
.apply(head_logits_proj)
.permute(&[0, 3, 1, 2]);
.permute([0, 3, 1, 2]);
}
let mut attention_probs =
@ -420,14 +420,14 @@ impl DisentangledSelfAttention for DebertaDisentangledSelfAttention {
if let Some(head_weights_proj) = &self.head_weights_proj {
attention_probs = attention_probs
.permute(&[0, 2, 3, 1])
.permute([0, 2, 3, 1])
.apply(head_weights_proj)
.permute(&[0, 3, 1, 2]);
.permute([0, 3, 1, 2]);
}
let context_layer = attention_probs
.matmul(&value_layer)
.permute(&[0, 2, 1, 3])
.permute([0, 2, 1, 3])
.contiguous();
let mut new_context_layer_shape = context_layer.size();

View File

@ -127,7 +127,7 @@ where
let calc_position_ids = if position_ids.is_none() {
Some(
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.expand(&[1, -1], true),
.expand([1, -1], true),
)
} else {
None
@ -135,7 +135,7 @@ where
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
&input_shape,
input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {

View File

@ -12,6 +12,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `DebertaTokenizer` using a `vocab.json` vocabulary and `merges.txt` merges file
//!
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -51,7 +51,7 @@ pub fn build_relative_position(
) -> Tensor {
let q_ids = Tensor::arange(query_size, (Kind::Int64, device));
let k_ids = Tensor::arange(key_size, (Kind::Int64, device));
let mut rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.tile(&[q_ids.size()[0], 1]);
let mut rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.tile([q_ids.size()[0], 1]);
if (bucket_size > 0) & (max_position > 0) {
rel_pos_ids = make_log_bucket_position(&rel_pos_ids, bucket_size, max_position);
}
@ -80,7 +80,7 @@ impl DebertaV2DisentangledSelfAttention {
let _ = new_shape.pop();
new_shape.extend_from_slice(&[self.num_attention_heads, -1]);
let x = x.view(new_shape.as_slice());
x.permute(&[0, 2, 1, 3])
x.permute([0, 2, 1, 3])
.contiguous()
.view([-1, x.size()[1], *x.size().last().unwrap()])
}
@ -133,12 +133,12 @@ impl DebertaV2DisentangledSelfAttention {
let pos_query_layer = self
.transpose_for_scores(&relative_embeddings.apply(query_proj))
.repeat(&[query_layer.size()[0] / self.num_attention_heads, 1, 1]);
.repeat([query_layer.size()[0] / self.num_attention_heads, 1, 1]);
let pos_key_layer = self
.transpose_for_scores(&relative_embeddings.apply(key_proj))
.repeat(&[query_layer.size()[0] / self.num_attention_heads, 1, 1]);
.repeat([query_layer.size()[0] / self.num_attention_heads, 1, 1]);
let mut score = Tensor::zeros(&[1], (query_layer.kind(), query_layer.device()));
let mut score = Tensor::zeros([1], (query_layer.kind(), query_layer.device()));
let c2p_pos = if self.pos_att_type.has_type(PositionAttentionType::c2p)
| self.pos_att_type.has_type(PositionAttentionType::p2p)
@ -149,14 +149,14 @@ impl DebertaV2DisentangledSelfAttention {
let c2p_att = c2p_att.gather(
-1,
&c2p_pos.squeeze_dim(0).expand(
&[
[
query_layer.size()[0],
query_layer.size()[1],
*relative_pos.size().last().unwrap(),
],
true,
),
true,
false,
);
score = score + c2p_att / scale;
Some(c2p_pos)
@ -186,10 +186,10 @@ impl DebertaV2DisentangledSelfAttention {
.gather(
-1,
&p2c_pos.squeeze_dim(0).expand(
&[query_layer.size()[0], key_layer_size[1], key_layer_size[1]],
[query_layer.size()[0], key_layer_size[1], key_layer_size[1]],
true,
),
true,
false,
)
.transpose(-1, -2);
score = score + p2c_att / scale;
@ -203,7 +203,7 @@ impl DebertaV2DisentangledSelfAttention {
let p2p_att = p2p_att.gather(
-1,
&c2p_pos.unwrap().expand(
&[
[
query_layer.size()[0],
query_layer.size()[1],
query_layer.size()[2],
@ -211,7 +211,7 @@ impl DebertaV2DisentangledSelfAttention {
],
true,
),
true,
false,
);
score = score + p2p_att;
}
@ -402,7 +402,7 @@ impl DisentangledSelfAttention for DebertaV2DisentangledSelfAttention {
reverse_context_layer_size[1],
reverse_context_layer_size[0],
])
.permute(&[0, 2, 1, 3])
.permute([0, 2, 1, 3])
.contiguous();
let mut new_context_layer_shape = context_layer.size();

View File

@ -78,10 +78,10 @@ impl ConvLayer {
train: bool,
) -> Tensor {
let out = hidden_states
.permute(&[0, 2, 1])
.permute([0, 2, 1])
.contiguous()
.apply(&self.conv)
.permute(&[0, 2, 1])
.permute([0, 2, 1])
.contiguous();
let reverse_mask: Tensor = 1 - input_mask;
let out = out.masked_fill(
@ -235,7 +235,7 @@ impl DebertaV2Encoder {
.unsqueeze(-1)
.to_kind(Kind::Uint8)
}
value if value == 3 => attention_mask.unsqueeze(1),
3 => attention_mask.unsqueeze(1),
_ => attention_mask.shallow_clone(),
}
}

View File

@ -12,6 +12,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `DebertaV2Tokenizer` using a `spiece.model` SentencePiece model file
//!
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -192,8 +192,8 @@ impl DistilBertModel {
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "distilbert";
let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config);
let transformer = Transformer::new(p.borrow() / "transformer", config);
let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
let transformer = Transformer::new(p / "transformer", config);
DistilBertModel {
embeddings,
transformer,

View File

@ -42,7 +42,7 @@ where
);
}
}
let temp_vec = Tensor::of_slice(&temp_vec);
let temp_vec = Tensor::from_slice(&temp_vec);
sinusoidal_embedding.push(temp_vec);
}
let sinusoidal_embedding = Tensor::stack(&sinusoidal_embedding, 0)

View File

@ -14,6 +14,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -266,7 +266,7 @@ impl ElectraModel {
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
Some(Tensor::ones(input_shape, (Kind::Int64, device)))
} else {
None
};

View File

@ -19,6 +19,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -42,7 +42,9 @@ impl FNetFourierTransform {
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let self_outputs = hidden_states.fft_fft2(None, &[1, 2], "backward").real();
let self_outputs = hidden_states
.fft_fft2(None::<i64>, [1, 2], "backward")
.real();
(self_outputs + hidden_states).apply(&self.layer_norm)
}
}

View File

@ -88,6 +88,7 @@ pub struct FNetConfig {
pub pad_token_id: Option<i64>,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub decoder_start_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub output_attentions: Option<bool>,
@ -112,6 +113,7 @@ impl Default for FNetConfig {
pad_token_id: Some(3),
bos_token_id: Some(1),
eos_token_id: Some(2),
decoder_start_token_id: None,
id2label: None,
label2id: None,
output_attentions: None,

View File

@ -14,6 +14,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `FNetTokenizer` using a `spiece.model` SentencePiece (BPE) model file
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -71,7 +71,7 @@ impl Attention {
{
let p = p.borrow();
let bias = Tensor::ones(&[config.n_ctx, config.n_ctx], (Float, p.device()))
let bias = Tensor::ones([config.n_ctx, config.n_ctx], (Float, p.device()))
.tril(0)
.view((1, 1, config.n_ctx, config.n_ctx));
@ -111,9 +111,9 @@ impl Attention {
fn split_heads(&self, x: &Tensor, k: bool) -> Tensor {
let x = x.view((x.size()[0], -1, self.n_head, self.dim_per_head));
if k {
x.permute(&[0, 2, 3, 1])
x.permute([0, 2, 3, 1])
} else {
x.permute(&[0, 2, 1, 3])
x.permute([0, 2, 1, 3])
}
}

View File

@ -20,17 +20,13 @@ use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::Gpt2Tokenizer;
use rust_tokenizers::vocab::Gpt2Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources;
@ -198,6 +194,9 @@ pub struct Gpt2Config {
pub output_hidden_states: Option<bool>,
pub resid_pdrop: Option<f64>,
pub vocab_size: i64,
pub decoder_start_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
}
impl Config for Gpt2Config {}
@ -222,6 +221,9 @@ impl Default for Gpt2Config {
output_hidden_states: None,
resid_pdrop: Some(0.1),
vocab_size: 50257,
decoder_start_token_id: None,
forced_bos_token_id: None,
forced_eos_token_id: None,
}
}
}
@ -529,118 +531,26 @@ impl GPT2LMHeadModel {
GPT2LMHeadModel { transformer }
}
}
impl LMHeadModel for GPT2LMHeadModel {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
/// * `_encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*). Unused for GPT2
/// * `_decoder_input_ids` - Optional tensor of shape (*batch size*, *target_sequence_length*). Unused for GPT2
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `Gpt2Cache` made of `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = Gpt2Config::from_file(config_path);
/// # let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// Some(&input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// Some(&attention_mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// None,
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
fn forward_t(
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
layer_past: Option<&Vec<Tensor>>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPT2Cache(layer_past) => self.transformer.forward_t(
input_ids,
layer_past.as_ref(),
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Cache::None => self.transformer.forward_t(
input_ids,
None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT2 Model".into(),
));
}
}?;
let base_model_output = self.transformer.forward_t(
input_ids,
layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let lm_logits = base_model_output
.output
@ -735,7 +645,6 @@ impl GPT2Generator {
tokenizer: TokenizerOption,
) -> Result<GPT2Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -743,7 +652,12 @@ impl GPT2Generator {
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
@ -751,7 +665,7 @@ impl GPT2Generator {
let max_position_embeddings = config.n_positions;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
Ok(GPT2Generator {
model,
@ -769,18 +683,18 @@ impl GPT2Generator {
}
}
impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {
fn get_model(&self) -> &GPT2LMHeadModel {
&self.model
}
impl PrivateLanguageGenerator for GPT2Generator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -803,8 +717,45 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
match layer_past {
Cache::GPT2Cache(layer_past) => self.model.forward_t(
input_ids,
layer_past.as_ref(),
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Cache::None => self.model.forward_t(
input_ids,
None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
_ => Err(RustBertError::ValueError(
"Cache not compatible with GPT2 Model".into(),
)),
}
}
fn prepare_inputs_for_generation<'a>(
@ -875,4 +826,4 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
}
}
impl LanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {}
impl LanguageGenerator for GPT2Generator {}

View File

@ -2,7 +2,7 @@
//!
//! Implementation of the GPT2 language model ([Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) Radford, Wu, Child, Luan, Amodei, Sutskever 2019).
//! The base model is implemented in the `gpt2_model::Gpt2Model` struct. The model also includes a language model head: `gpt2_model::GPT2LMHeadModel`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
@ -11,6 +11,7 @@
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `Gpt2Tokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
//!
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run

View File

@ -68,11 +68,16 @@ impl GptJAttention {
let p = p.borrow();
let max_positions = config.n_positions;
let bias = Tensor::ones(&[max_positions, max_positions], (Kind::Uint8, p.device()))
let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
.tril(0)
.view([1, 1, max_positions, max_positions])
.requires_grad_(false);
let bias = p.var_copy("bias", &bias);
let mut bias = p
.f_ones_no_train("bias", &[1, 1, max_positions, max_positions])
.unwrap()
.to_kind(Kind::Uint8)
.to_device(p.device());
bias.copy_(&bias_value);
let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
@ -95,21 +100,9 @@ impl GptJAttention {
..Default::default()
};
let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "k_proj").half();
}
let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "v_proj").half();
}
let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "q_proj").half();
}
let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "out_proj").half();
}
GptJAttention {
bias,
@ -142,9 +135,9 @@ impl GptJAttention {
if rotary {
tensor
} else if tensor.size().len() == 5 {
tensor.permute(&[0, 1, 3, 2, 4]) // (batch, blocks, head, block_length, head_features)
tensor.permute([0, 1, 3, 2, 4]) // (batch, blocks, head, block_length, head_features)
} else if tensor.size().len() == 4 {
tensor.permute(&[0, 2, 1, 3]) // (batch, head, seq_length, head_features)
tensor.permute([0, 2, 1, 3]) // (batch, head, seq_length, head_features)
} else {
panic!(
"Input tensor should either be a rotary head, or its rank be one of [4, 5] but is: {}",
@ -155,9 +148,9 @@ impl GptJAttention {
fn merge_heads(tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor {
let tensor = if tensor.size().len() == 5 {
tensor.permute(&[0, 1, 3, 2, 4]).contiguous()
tensor.permute([0, 1, 3, 2, 4]).contiguous()
} else if tensor.size().len() == 4 {
tensor.permute(&[0, 2, 1, 3]).contiguous()
tensor.permute([0, 2, 1, 3]).contiguous()
} else {
panic!(
"Input tensor rank should be one of [4, 5], but is: {}",
@ -197,7 +190,7 @@ impl GptJAttention {
let mask_value = get_min(attention_weights.kind()).unwrap();
let mask_value = Tensor::full(
&attention_weights.size(),
attention_weights.size(),
mask_value,
(attention_weights.kind(), attention_weights.device()),
);
@ -261,8 +254,8 @@ impl GptJAttention {
query = apply_rotary_pos_emb(&query, &sincos, offset);
}
key = key.permute(&[0, 2, 1, 3]);
query = query.permute(&[0, 2, 1, 3]);
key = key.permute([0, 2, 1, 3]);
query = query.permute([0, 2, 1, 3]);
if let Some(layer_past) = layer_past {
key = Tensor::cat(&[&layer_past.prev_key, &key], -2);
@ -297,7 +290,7 @@ fn fixed_pos_embedding(x: &Tensor, seq_len: i64) -> (Tensor, Tensor) {
let sinusoid_inp = Tensor::einsum(
"i , j -> i j",
&[Tensor::arange(seq_len, (x.kind(), x.device())), inv_freq],
None,
None::<i64>,
);
(sinusoid_inp.sin(), sinusoid_inp.cos())
}
@ -312,7 +305,7 @@ fn apply_rotary_pos_emb(x: &Tensor, (sin, cos): &(Tensor, Tensor), offset: i64)
fn duplicate_interleave(m: &Tensor) -> Tensor {
let dim0 = m.size()[0];
m.view([-1, 1]) // flatten the matrix
.repeat(&[1, 2]) // repeat all elements into the 2nd dimension
.repeat([1, 2]) // repeat all elements into the 2nd dimension
.view([dim0, -1]) // reshape into a matrix, interleaving the copy
}

View File

@ -20,12 +20,8 @@ use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::Gpt2Tokenizer;
use rust_tokenizers::vocab::Gpt2Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::nn::{embedding, Linear};
@ -46,7 +42,7 @@ pub struct GptJMergesResources;
/// Model weights for Rust are not available out of the box for GPT-J but can be created
/// simply with the following command:
///
/// ```
/// ```ignore
/// python utils/convert_model.py path/to/gpt_j/pytorch_model.bin
/// ```
///
@ -57,7 +53,6 @@ pub struct GptJMergesResources;
///
/// [gpt-j-6B]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/main
/// [gpt-j-6B-float16]:https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16
///
impl GptJModelResources {
pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = (
"gpt-j-tiny-random/model",
@ -136,10 +131,11 @@ pub struct GptJConfig {
pub rotary_dim: Option<i64>,
pub vocab_size: i64,
pub scale_attn_weights: Option<bool>,
#[serde(default = "default_use_float16")]
pub use_float16: bool,
#[serde(default = "default_preload_on_cpu")]
pub preload_on_cpu: bool,
pub decoder_start_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
}
impl Config for GptJConfig {}
@ -166,16 +162,14 @@ impl Default for GptJConfig {
rotary_dim: Some(64),
vocab_size: 50400,
scale_attn_weights: Some(true),
use_float16: default_use_float16(),
preload_on_cpu: default_preload_on_cpu(),
decoder_start_token_id: None,
forced_bos_token_id: None,
forced_eos_token_id: None,
}
}
}
fn default_use_float16() -> bool {
true
}
fn default_preload_on_cpu() -> bool {
true
}
@ -232,9 +226,6 @@ impl GptJModel {
config.n_embd,
Default::default(),
);
if config.use_float16 {
(&(&p / "wte") / "weight").half()
};
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
@ -244,9 +235,6 @@ impl GptJModel {
..Default::default()
};
let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
if config.use_float16 {
(&p / "ln_f").half()
};
let mut h: Vec<GptJBlock> = vec![];
let h_path = &p / "h";
@ -335,7 +323,7 @@ impl GptJModel {
/// gpt_j_model
/// .forward_t(
/// Some(&input_tensor),
/// Some(&past),
/// Some(past),
/// Some(&attention_mask),
/// Some(&token_type_ids),
/// None,
@ -450,7 +438,7 @@ impl GptJLMHeadModel {
/// # Example
///
/// ```no_run
/// use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig};
/// use rust_bert::gpt_j::{GptJConfig, GptJLMHeadModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
@ -474,91 +462,14 @@ impl GptJLMHeadModel {
config.vocab_size,
Default::default(),
);
if config.use_float16 {
(p / "lm_head").half();
}
GptJLMHeadModel {
transformer,
lm_head,
}
}
}
impl LMHeadModel for GptJLMHeadModel {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
/// * `_position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
/// * `_encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*). Unused for GPT-J
/// * `_decoder_input_ids` - Optional tensor of shape (*batch size*, *target_sequence_length*). Unused for GPT_J
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `GptJCache` made of `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig};
/// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = GptJConfig::from_file(config_path);
/// # let mut gpt_j_model: GptJLMHeadModel = GptJLMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// gpt_j_model
/// .forward_t(
/// Some(&input_tensor),
/// Cache::GPTJCache(Some(past)),
/// Some(&attention_mask),
/// Some(&token_type_ids),
/// None,
/// None,
/// None,
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
fn forward_t(
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
@ -648,7 +559,7 @@ impl GptJGenerator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -688,7 +599,6 @@ impl GptJGenerator {
tokenizer: TokenizerOption,
) -> Result<GptJGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
@ -699,7 +609,12 @@ impl GptJGenerator {
if config.preload_on_cpu && device != Device::Cpu {
var_store.set_device(Device::Cpu);
}
var_store.load(weights_path)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
if device != Device::Cpu {
var_store.set_device(device);
}
@ -710,7 +625,7 @@ impl GptJGenerator {
let max_position_embeddings = config.n_positions;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
Ok(GptJGenerator {
model,
@ -728,18 +643,18 @@ impl GptJGenerator {
}
}
impl PrivateLanguageGenerator<GptJLMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GptJGenerator {
fn get_model(&self) -> &GptJLMHeadModel {
&self.model
}
impl PrivateLanguageGenerator for GptJGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -762,8 +677,54 @@ impl PrivateLanguageGenerator<GptJLMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for Gpt
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPTJCache(layer_past) => self.model.transformer.forward_t(
input_ids,
layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Cache::None => self.model.transformer.forward_t(
input_ids,
None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT-J Model".into(),
));
}
}?;
let lm_logits = base_model_output.output.apply(&self.model.lm_head);
Ok(LMModelOutput {
lm_logits,
cache: Cache::GPTJCache(base_model_output.cache),
})
}
fn prepare_inputs_for_generation<'a>(
@ -833,4 +794,4 @@ impl PrivateLanguageGenerator<GptJLMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for Gpt
}
}
impl LanguageGenerator<GptJLMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GptJGenerator {}
impl LanguageGenerator for GptJGenerator {}

View File

@ -9,7 +9,7 @@
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig};
//! use rust_bert::gpt_j::{GptJConfig, GptJLMHeadModel};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::Gpt2Tokenizer;

View File

@ -43,18 +43,12 @@ impl GptJMLP {
intermediate_size,
Default::default(),
);
if config.use_float16 {
(p / "fc_in").half()
};
let fc_out = nn::linear(
p / "fc_out",
intermediate_size,
config.n_embd,
Default::default(),
);
if config.use_float16 {
(p / "fc_out").half()
};
let activation = match &config.afn {
Some(activation_enum) => match activation_enum {
@ -100,9 +94,6 @@ impl GptJBlock {
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
if config.use_float16 {
(p / "ln_1").half()
};
let attn = GptJAttention::new(p / "attn", config);
let mlp = GptJMLP::new(p / "mlp", config);

View File

@ -70,7 +70,7 @@ impl GptNeoSelfAttention {
let p = p.borrow();
let max_positions = config.max_position_embeddings;
let mut bias = Tensor::ones(&[max_positions, max_positions], (Kind::Uint8, p.device()))
let mut bias = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
.tril(0)
.view([1, 1, max_positions, max_positions])
.requires_grad_(false);
@ -135,11 +135,11 @@ impl GptNeoSelfAttention {
let _ = new_shape.pop();
new_shape.extend_from_slice(&[num_heads, attention_head_size]);
let reshaped_tensor = input_tensor.view(new_shape.as_slice());
reshaped_tensor.permute(&[0, 2, 1, 3])
reshaped_tensor.permute([0, 2, 1, 3])
}
fn merge_heads(input_tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor {
let output_tensor = input_tensor.permute(&[0, 2, 1, 3]).contiguous();
let output_tensor = input_tensor.permute([0, 2, 1, 3]).contiguous();
let mut new_shape = output_tensor.size();
new_shape.truncate(new_shape.len() - 2);
new_shape.push(num_heads * attention_head_size);
@ -173,7 +173,7 @@ impl GptNeoSelfAttention {
let mut attention_weights = attention_weights.where_self(
causal_mask,
&Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()),
&Tensor::from_slice(&[-1e9f32]).to_device(attention_weights.device()),
);
if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value;

View File

@ -18,15 +18,11 @@ use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Activation, Config, RustBertError};
use rust_tokenizers::tokenizer::Gpt2Tokenizer;
use rust_tokenizers::vocab::Gpt2Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// # GPT-Neo Pretrained model weight files
pub struct GptNeoModelResources;
@ -131,6 +127,8 @@ pub struct GptNeoConfig {
pub intermediate_size: Option<i64>,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub vocab_size: i64,
pub num_layers: i64,
pub num_heads: i64,
@ -144,6 +142,7 @@ pub struct GptNeoConfig {
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_dropout: f64,
pub decoder_start_token_id: Option<i64>,
}
impl Config for GptNeoConfig {}
@ -153,7 +152,7 @@ impl Default for GptNeoConfig {
GptNeoConfig {
activation_function: Activation::gelu_new,
attention_dropout: 0.0,
attention_layers: vec![AttentionLayerType::Global, AttentionLayerType::Local]
attention_layers: [AttentionLayerType::Global, AttentionLayerType::Local]
.iter()
.cycle()
.take(24)
@ -166,6 +165,8 @@ impl Default for GptNeoConfig {
intermediate_size: None,
bos_token_id: 50256,
eos_token_id: 50256,
forced_bos_token_id: None,
forced_eos_token_id: None,
vocab_size: 50257,
num_layers: 24,
num_heads: 16,
@ -179,6 +180,7 @@ impl Default for GptNeoConfig {
output_attentions: None,
output_hidden_states: None,
resid_dropout: 0.0,
decoder_start_token_id: None,
}
}
}
@ -570,52 +572,6 @@ impl GptNeoForCausalLM {
}
}
impl LMHeadModel for GptNeoForCausalLM {
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPTNeoCache(layer_past) => self.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
layer_past,
attention_mask,
train,
),
Cache::None => self.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
None,
attention_mask,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT-Neo Model".into(),
));
}
}?;
Ok(LMModelOutput {
lm_logits: base_model_output.lm_logits,
cache: Cache::GPTNeoCache(base_model_output.next_cache),
})
}
}
/// Container for the GPT-Neo model output.
pub struct GptNeoModelOutput {
/// Last hidden states from the model
@ -710,21 +666,25 @@ impl GptNeoGenerator {
tokenizer: TokenizerOption,
) -> Result<GptNeoGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
let pad_token_id = tokenizer.get_pad_id();
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(GptNeoGenerator {
@ -743,18 +703,18 @@ impl GptNeoGenerator {
}
}
impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for GptNeoGenerator {
fn get_model(&self) -> &GptNeoForCausalLM {
&self.model
}
impl PrivateLanguageGenerator for GptNeoGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -777,10 +737,54 @@ impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for G
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPTNeoCache(layer_past) => self.model.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
layer_past,
attention_mask,
train,
),
Cache::None => self.model.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
None,
attention_mask,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT-Neo Model".into(),
));
}
}?;
Ok(LMModelOutput {
lm_logits: base_model_output.lm_logits,
cache: Cache::GPTNeoCache(base_model_output.next_cache),
})
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
@ -851,4 +855,4 @@ impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for G
}
}
impl LanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for GptNeoGenerator {}
impl LanguageGenerator for GptNeoGenerator {}

View File

@ -26,6 +26,7 @@
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelResource;
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoConfigResources::GPT_NEO_1_3B,
//! ));
@ -41,7 +42,7 @@
//!
//! let text_generation_config = TextGenerationConfig {
//! model_type: ModelType::GPTNeo,
//! model_resource,
//! model_resource: ModelResource::Torch(model_resource),
//! config_resource,
//! vocab_resource,
//! merges_resource: Some(merges_resource),
@ -54,7 +55,7 @@
//!
//! let input_context_1 = "It was a very nice and sunny";
//! let input_context_2 = "It was a gloom winter night, and";
//! let output = model.generate(&[input_context_1, input_context_2], None);
//! let output = model.generate(&[input_context_1, input_context_2], None)?;
//!
//! for sentence in output {
//! println!("{}", sentence);

Some files were not shown because too many files have changed in this diff Show More