Merge remote-tracking branch 'origin/master' into entity_consolidation

# Conflicts:
#	src/pipelines/ner.rs
This commit is contained in:
Guillaume Becquin 2021-11-20 11:03:05 +01:00
commit d84b2819d9
218 changed files with 39123 additions and 9572 deletions

View File

@ -0,0 +1,148 @@
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
name: Build
jobs:
build:
name: Build Linux
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: build
build-windows:
name: Build Windows
runs-on: windows-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: build
build-mac-os:
name: Build macOS
runs-on: macos-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: build
test-batch-0:
name: Integration tests (batch 0)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --package rust-bert
--test albert
--test bart
--test bert
--test distilbert
--test distilgpt2
--test electra
--test gpt2
--test marian
--test fnet
test-batch-1:
name: Integration tests (batch 1)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --package rust-bert
--test mobilebert
--test openai_gpt
--test prophetnet
--test reformer
--test roberta
--test t5
--test xlnet
--test longformer
--test pegasus
--test gpt_neo
convert-model:
name: Model conversion test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions/setup-python@v2
with:
python-version: '3.7'
- run: |
pip install -r requirements.txt --progress-bar off
python ./utils/download-dependencies_distilbert.py
fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern -A clippy::upper-case-acronyms

358
CHANGELOG.md Normal file
View File

@ -0,0 +1,358 @@
# Changelog
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## Changed
- Updated to `tch` 1.6.0 (libtorch 1.10)
- (BREAKING) Simplified the generics for multiple library traits taking as a rule `&[AsRef<str>]` or `&str` as inputs (no longer accepts owned types `Vec` and `String`)
## Added
- (BREAKING) Support for `bad_word_ids` generation, allowing to ban a set of word ids for all model supporting text generation
- Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices)
- (BREAKING) Extension of the generation options that can be provided at runtime (after a model has been instantiated with a `GenerateConfig`), allowing to update the generation options from one text generation to another with the same model. This feature is implemented at the `LanguageGenerator` trait level, the high-level `TextGeneration` pipeline API remains unchanged.
- Addition of the FNet language model and support for sequence, token and multiple choice classification, question answering
## [0.16.0] - 2021-08-24
## Added
- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
- (BREAKING) Support for `output_scores` boolean argument for generation, allowing users to output the log-probability scores of generated sequences. Updated the return type of low-level generate API to `GeneratedTextOutput` and `GeneratedIndicesOutput` containing optional scores along with the generated output.
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
- Addition of the M2M100 Language model and support for text generation / direct translation between 100 language
## Changed
- Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint)
- Upgraded `tch` version to 0.5.0 (using `libtorch` 1.9.0)
- Changed default value of `no_repeat_ngram_size` for text generation from 3 to 0, aligning with [Python's Transformers](https://huggingface.co/transformers/main_classes/model.html?highlight=no_repeat_ngram_size#transformers.generation_utils.GenerationMixin.generate)
- Added the possibility to handle long inputs for token classification tasks (exceeding the model maximum length) using sliding windows over the input
- (BREAKING) Generalized borrowing of Tensors as input for models
- Aligned the optional `all_hidden_states` output for all models
## Fixed
- Updated T5 Decoder cross-attention to no longer use relative position bias (aligned with [Python reference update](https://github.com/huggingface/transformers/pull/8518))
- Removed hardcoded maximum length for sequence and token classification tasks, now using the model maximum position embeddings instead
## [0.15.1] - 2021-06-01
### Fixed
- Fixed conversation model panic for user inputs exceeding the maximum model length (1000 tokens)
- Fixed translation model panic for user inputs exceeding the maximum number of position embeddings
## [0.15.0] - 2021-05-16
### Added
- Addition of translation language pairs:
- English <-> Chinese (Simplified)
- English <-> Chinese (Traditional)
- English <-> Dutch
- English <-> Swedish
- English <-> Arabic
- English <-> Hebrew
- English <-> Hindi
- Addition of a Part of Speech pipeline. This pipeline allows predicting the POS tag (e.g. Noun, Adjective, Verb) of words in input sentences.
- Addition of a lightweight English Part of Speech tagging pretrained MobileBERT model
- Addition of the Pegasus language model and support for conditional generation
- Addition of a model for Pegasus summarization pretrained on the CNN-DM dataset
- Addition of the GPT-Neo language model and pretrained snapshots (125M, 1.3B and 2.7B parameters). Registration of GPT-Neo as an option for `TextGenerationPipeline`.
### Changed
- (BREAKING) Changed `classif_dropout` in `BartConfig` to be an optional field. This affects dependencies instantiating `BartConfig` from scratch, or using `classif_config` for custom model heads.
- (BREAKING) Changed token classification pipelines to return a Vec<Vec<Token>> instead of a Vec<Token>. The token-level predictions are now returned in separate vectors for each input sequence provided as an input (they were previously returned in a flattened vector)
- Simplification of the BART language model code base (also used for Marian and Pegasus language models)
- (BREAKING) Updated to `tch 0.4.1` (based on `libtorch 1.8.1`)
### Fixed
- Fixed character indexing error for Question Answering pipeline answers
### Removed
- Dependency to `itertools` crate
## [0.14.0] - 2021-02-22
### Added
- Addition of the Longformer language model, task-specific heads and registration in relevant pipelines
### Changed
- (BREAKING) Exposed additional settings for the Question Answering pipeline related to the maximum question, context and answer length. This is not backward compatible if the question answering configuration was created without using the `new` creator.
- Simplified the Question answering pipeline to rely on the offsets calculated by the tokenizers instead of a manual alignment. This results in moderate execution speed improvements for this pipeline.
- Updated the padding strategy for the Question answering pipeline. While before all sequences were padded to a fixed `max_length` (defaulting to 384), the padding is now done dynamically based on the length of the inputs. This results in a significant speed improvement for this pipeline.
### Fixed
- Fixed a bug for Question Answering for models that were not based on Wordpiece tokenization (including BPE and unigram based tokenizers). The issue was caused by the pre-tokenization step that was stripping the leading whitespace for all tokens. The performance of these models for QA should improve significantly.
## [0.13.0] - 2021-02-03
### Added
- Addition of the ProphetNet language model, task-specific heads and registration in relevant pipelines
- (BREAKING) Implementation of [Diverse Beam Search](https://arxiv.org/abs/1610.02424). This allows the generation of more diverse sequences within the number of beams. Addition of 2 new fields to the `GenerateConfig` that are propagated through all text generation configs (e.g. `TranslationConfig`):
- `num_beam_groups` (`Option<i64>`), indicating the number of sub-beam groups. This must be a divisor of the number of beams.
- `diversity_penalty` (`Option<f64>`), indicating by which amount to penalize common words between beam groups. This will default to 5.5 if not provided. The impact of this diverse beam search is illustrated in the GPT2 integration tests.
### Changed
- (BREAKING) Simplified the input and output of encoder/decoder models to avoid needing to take ownership of the possibly cached encoder hidden state, offering a minor performance improvement for text generation tasks. The model output field for encoder hidden states are now optional, and only returned if the encoder hidden states were not provided for the given forward path. This may be a breaking change for low-level dependencies that manipulate directly the encoder/decoder model outputs.
- (BREAKING) Moved the language models implementation of the `PrivateLanguageGenerator` and `LanguageGenerator` traits (needed to generate text) to the model modules, cleaning up the generation_utils module.
- Updated download utilities crate, now leveraging Tokio 1.0 runtimes.
### Fixed
- Updated padding information and addition of position ids for batched GPT2 generation. Prior to this change, inputs that required padding had a lower quality for the text generated.
## [0.12.1] - 2021-01-04
### Added
- Addition of the MobileBERT language model, task-specific heads and registration in relevant pipelines
### Changed
- Made all model configurations `Clone`
- Made several base modules of the BERT language model public, and added model output `Struct` for the new publicly exposed, complex types
## [0.12.0] - 2020-11-29
### Added
- Addition of the Reformer language model, task-specific heads and registration in relevant pipelines
- Pre-trained models for DistilRoBERTa, used as a default for integration tests
### Changed
- Updated endpoint of the model resources reflecting changes to the Hugging Face's model hub
- Early stopping turned by default on for translation and summarization
## [0.11.0] - 2020-11-02
### Added
- Support for additional models for the conversational pipeline
### Changed
- Updated the version of Tokenizer crate with consistent visibility
- (BREAKING) move of teh text generation pipeline to its owned pipeline. Shared generation utilities are moved to `generation_utils`
- All models, tokenizers and pipelines are now `Send`
## [0.10.0] - 2020-10-04
### Added
- Benchmark scripts for all pipelines
- Addition of the XLNet model and task-specific heads
### Changed
- (BREAKING) Changed the download method for resources now a method of the resource itself, and leveraging the cached-path crate.
- (BREAKING) Changed the return type of models to be output `Struct` instead of long tuples.
- (BREAKING) Changed the naming of the model main modules from `modelname` to `model_modelname` to avoid confusion with the top level module name
- Extended the range of allowed types for pipelines input, allowing both owned `Vec` and slices, and both `String` and sting slice.
- Handling of all activations functions is mow made from a common module and `Struct`
## [0.9.0] - 2020-09-06
### Added
- Zero-shot classification pipeline using a natural language inference model
### Changed
- (BREAKING) Updated version of tokenizers crate with added options for lower casing, accent stripping and prefix addition
- Updated BART classification model to allow running their `forward` method without being mutable.
## [0.8.0] - 2020-08-25
### Added
- (BREAKING) Improved error handling via the addition of `RustBertError` and error propagation throughout the crate.
### Changed
- Updated version of tokenizers crate with improved error handling
## [0.7.12] - 2020-08-12
### Added
- Addition of the reformer language model and its integration for language generation
### Changed
- Changed model resources endpoints to leverage updated Hugging Face's model hub
- Updated the beam search processing to use vectorized operations
## [0.7.11] - 2020-07-26
### Changed
- Generalization of the accepted input for several pipelines to accept both `Vec` and slices, and to accept both `String` and `&str`
## [0.7.10] - 2020-07-08
### Added
- Addition of the ALBERT language model and task-specific heads
- Addition of German - English translation models
- Addition of the T5 language model and integration in supported pipelines (translation and summarization)
### Changed
- Updated the modules throughout the crate to accept both owned and references to varstore paths.
## [0.7.9] - 2020-06-28
### Added
- Addition of a multi-turn conversational pipeline based on DialoGPT.
## [0.7.8] - 2020-06-23
### Fixed
- Code formatting using `rustfmt`
## [0.7.7] - 2020-06-06
### Changed
- Removed the requirement for generation models to be mutable. Models are now all stateless, and no longer store an internal cache (now provided as an input).
- Updated BART model to take past layer states as an input instead of storing in internally.
### Fixed
- Fixed sequence classification model logits squeeze causing it to crash for batched inputs.
## [0.7.6] - 2020-05-27
### Added
- Addition of translation between Russian and English
### Fixed
- Fixed a bug causing downloads to be incomplete, and removes the creation of a tokio runtime for the download of resources.
## [0.7.5] - 2020-05-25
### Added
- Addition of the Marian model, leveraging a shared language model implementation with the BART model.
- Addition of translation capabilities. Supports translation between English and French, Spanish, Portuguese, Italian, Catalan and German, and between German and French.
## [0.7.4] - 2020-05-25
### Added
- Addition of multi-label classification capabilities for sequence classification via the `predict_mutilabel` function.
## [0.7.3] - 2020-05-19
### Added
- Generalization of pipelines to allow leveraging multiple model architectures. Leveraging `Enum` unpacking, introduces `ConfigOption`, `TokenizerOption` and pipeline-specific Options.
- Addition of generic `SentenceClassificationModel` pipeline. The `SentimentModel` now leverages shared implementation for sentence classification.
- Addition of `TokenClassificationModel` pipeline. The `NERModel`now leverages shared implementation for token classification.
### Changed
- Major rework of tokenization crate, alignment with updated API
## [0.7.2] - 2020-05-03
### Fixed
- Minor bug fixes for tokenization
## [0.7.1] - 2020-05-03
### Added
- Implementation of the Electra model (generator, discriminator, task-specific heads)
- GPT2-medium and GPT2-large models
## [0.7.0] - 2020-04-26
### Added
- Addition of Resources for handling file dependencies (e.g. vocabularies, model weights, configurations). Resources may be `LocalResources` (pointing to a filesystem location) or `RemoteResources` (pointing to a remote endpoint). These resources can be passed to a `download_resource` method that returns the location in the local filesystem for both types of resources, downloading them if necessary.
- Resources specifications for all existing architectures, pointing to model files hosted on Hugging Face's model hub.
### Changed
- (BREAKING) moved the resources' specification to the `GenerateConfig` for `GPT2Generator`.
- (BREAKING) creation of pipeline configurations to contain the resources required to build the pipeline, used as an input rather than paths to local files.
- Updated the configuration for the number of target labels to use the `id2label` field instead of `num_labels` (aligning with changes in standard configuration in the Transformers library). Removed `num_labels` from configurations.
- Made the `output_attentions`, `output_hidden_states` and `torchscript` fields for DistilBERT configuration optional
- Fixed the device placement for sinusoidal embeddings for DistilBERT model.
## [0.6.2] - 2020-04-07
### Changed
- Optimization of the BART model avoiding unnecessary tensor copies for cache manipulation and residual connections.
- Optimization of DistilBERT model when embeddings are provided as an input
## [0.6.1] - 2020-04-06
### Changed
- Minor optimizations to question answering and sentiment analysis pipelines
- Addition of a cache reset for text generation routines
- Implementation of cache reset for BART language model
## [0.6.0] - 2020-04-05
### Added
- BART language model
- Implementation of `LanguageModel` and `PrivateLanguageModel` for BART
- Summarization capabilities
- Tanh activation
### Changed
- (BREAKING) Moved the `LMHeadModel` Trait from GPT2 module to the pipelines module
- Updated the `LMHeadModel` inputs to include `encoder_outputs` and `decoder_input_ids` to support causal language model (e.g. BART)
- (BREAKING) Added methods to the `PrivateLanguageGenerator` to support encoder-decoder models
- (BREAKING) changed the type of `Generator` language model to require mutability (BART caching mechanism stores the cache in the model requiring the entire model mutability - changed at a later point)
- Optimization of the `get_banned_token` method
### Fixed
- Updated the device location of the token update when EOS is not allowed because the minimum sequence length was not reached
- No longer process a given beam hypothesis if it is marked as done
- No longer add beams to a hypothesis if the rank is lower than the number of beams
- Updated final beam update to skip completed hypotheses
## [0.5.3] - 2020-03-27
### Added
- Documentation throughout the crate
- Creation of a `GenerateConfig` configuration structure to hold generation options
### Changed
- Visibility of low-level utilities in the crate
- Updated the generation options to be passed at the text generation model instantiation, rather than at every call to the `generate` method
- Updated visibility of generation routines into a public API and private lower level methods
## [0.5.2] - 2020-03-17
### Changed
- Text generation now takes a `Option<Vec<&str>>` instead of a `Option<&str>`. Shorter sequences are left-padded with `pad` if available, otherwise with `eos`.
- Turned-off gradient calculations for generation process
## [0.5.1] - 2020-03-16
### Fixed
- Beam search completion validation
- Padding sequence for sentences shorter than the maximum length moved to correct device
## [0.5.0] - 2020-03-16
### Added
- DistilGPT2 pretrained weights for GPT2
- `LMHeadModel` trait for model supporting text generation, offering an interface between the model specific input/output, and the generic set of inputs/outputs expected for model supporting text generation
- Implementation of `LMHeadModel` for GPT2 and GPT
- Text generation pipeline, supporting beam search, top-k/top-p decoding, repeated tokens banning, repetition and length penalties as `LanguageGenerator` Trait
- Implementation of `LanguageGenerator` for GPT and GPT2
- Examples and tests for language generation
### Fixed
- Fixed concatenation dimension for GPT2 past
## [0.4.5] - 2020-03-07
### Changed
- Updated input type for `QuestionAnsweringModel`'s `predict` to be `&[QaInput]` instead of a pair of question and context strings. QuestionAnsweringModel now works with a list of inputs and returns a list of predictions, processing inputs as batches.
## [0.4.4] - 2020-03-01
### Added
- Swish and gelu_new activation functions
- GPT2 language model
- GPT language model
## [0.4.3] - 2020-02-25
### Added
- Addition of a NER pipeline
- Addition of a QuestionAnswering pipeline
### Changed
- Moved `SentimentClassifier` from DistilBERT module to the newly created pipelines
- Changed precision of id to label mapping of BERT config from `i32` to `i64`
- Simplified calculation of sinusoidal embeddings for DistilBERT
## [0.4.1] - 2020-02-21
### Added
- Addition of RoBERTa language model
- Addition of `BertEmbedding` trait for BERT-like models
### Changed
- Updated `BertEmbeddings` to implement the newly created `BertEmbedding` Trait
- Updated `BertModel`'s embeddings to be of type `impl BertEmbedding` rather than specific embeddings, allowing to re-use the BERT structure for other models, only replacing the embeddings layer.
### Fixed
- Fixed the variable path for BERT models with task-specific heads to allow loading a snapshot from models trained on Transformers.
## [0.4.0] - 2020-02-18
### Added
- BERT Model and examples
- Addition of `DistilBertForTokenClassification` and `DistilBertForQuestionAnswering` model heads
- Collection of activation functions (gelu, relu, mish)
- Dropout module
- Custom Linear layer, allowing a creation without bias
- Config trait allowing to deserialize from `json` files
### Changed
- (BREAKING) Updated `DistilBertConfig` to use the newly created `Config` Trait
## [0.3.1] - 2020-02-16
### Added
- Integration tests
### Changed
- Migrated from `rust_transformers` v0.2.0 (deprecated) to `rust_tokenizers v1.0.0
## [0.3.0] - 2020-02-13
### Added
- Example for DistilBERT masked language modeling
- Download utilities script for DistilBERT (base and SST2)
### Changed
- made `label2id`, `id2label`, `is_decoder`, `output_past` and `use_bfloat` configuration fields optional for DistilBertConfig
## [0.2.0] - 2020-02-11
### Initial release
- Tensor conversion tools from Pytorch to Libtorch format
- DistilBERT model architecture
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.10.0"
version = "0.16.0"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
@ -22,6 +22,33 @@ name = "convert-tensor"
path = "src/convert-tensor.rs"
doc = false
[[bench]]
name = "sst2_benchmark"
harness = false
[[bench]]
name = "squad_benchmark"
harness = false
[[bench]]
name = "summarization_benchmark"
harness = false
[[bench]]
name = "translation_benchmark"
harness = false
[[bench]]
name = "generation_benchmark"
harness = false
[[bench]]
name = "tensor_operations_benchmark"
harness = false
[profile.bench]
opt-level = 3
[features]
doc-only = ["tch/doc-only"]
all-tests = []
@ -30,18 +57,21 @@ all-tests = []
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~5.0.0"
tch = "~0.2.0"
serde_json = "1.0.56"
serde = { version = "1.0.114", features = ["derive"] }
dirs = "3.0.1"
itertools = "0.9.0"
ordered-float = "2.0.0"
cached-path = "0.4.5"
rust_tokenizers = "~7.0.0"
tch = "~0.6.1"
serde_json = "1.0.68"
serde = { version = "1.0.130", features = ["derive"] }
dirs = "4.0.0"
ordered-float = "2.8.0"
cached-path = "0.5.1"
lazy_static = "1.4.0"
uuid = { version = "0.8.1", features = ["v4"] }
thiserror = "1.0.20"
uuid = { version = "0.8.2", features = ["v4"] }
thiserror = "1.0.30"
half = "1.7.1"
[dev-dependencies]
anyhow = "1.0.32"
csv = "1.1.3"
anyhow = "1.0.44"
csv = "1.1.6"
criterion = "0.3.5"
torch-sys = "~0.6.1"
tempfile = "3.2.0"

280
README.md
View File

@ -1,37 +1,110 @@
# rust-bert
[![Build Status](https://travis-ci.com/guillaume-be/rust-bert.svg?branch=master)](https://travis-ci.com/guillaume-be/rust-bert)
[![Build Status](https://github.com/guillaume-be/rust-bert/workflows/Build/badge.svg?event=push)](https://github.com/guillaume-be/rust-bert/actions)
[![Latest version](https://img.shields.io/crates/v/rust_bert.svg)](https://crates.io/crates/rust_bert)
[![Documentation](https://docs.rs/rust-bert/badge.svg)](https://docs.rs/rust-bert)
![License](https://img.shields.io/crates/l/rust_bert.svg)
Rust native Transformer-based models implementation. Port of Huggingface's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multithreaded tokenization and GPU inference.
This repository exposes the model base architecture, task-specific heads (see below) and ready-to-use pipelines.
Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
The following models are currently implemented:
Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
```rust
let qa_model = QuestionAnsweringModel::new(Default::default())?;
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
| |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----:
Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |
Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | |
Token classification|✅ |✅ | ✅| | | |✅| |✅ | |
Question answering|✅ |✅ |✅| | | | | |✅ | |
Multiple choices| |✅ |✅| | | | | |✅ | |
Next token prediction| | | |✅|✅|✅| | | | |
Natural Language Generation| | | |✅|✅|✅| | | | |
Summarization | | | | | |✅| | | | |
Translation | | | | | |✅| |✅ | |✅|
let answers = qa_model.predict(&[QaInput { question, context }], 1, 32);
```
Output:
```
[Answer { score: 0.9976, start: 13, end: 21, answer: "Amsterdam" }]
```
The tasks currently supported include:
- Translation
- Summarization
- Multi-turn dialogue
- Zero-shot classification
- Sentiment Analysis
- Named Entity Recognition
- Part of Speech tagging
- Question-Answering
- Language Generation.
<details>
<summary> <b>Expand to display the supported models/tasks matrix </b> </summary>
| |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|
:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
DistilBERT|✅|✅|✅| | | |✅|
MobileBERT|✅|✅|✅| | | |✅|
FNet|✅|✅|✅| | | |✅|
BERT|✅|✅|✅| | | |✅|
RoBERTa|✅|✅|✅| | | |✅|
GPT| | | |✅ | | | |
GPT2| | | |✅ | | | |
GPT-Neo| | | |✅ | | | |
BART|✅| | |✅ |✅| | |
Marian| | | | | |✅| |
MBart|✅| | |✅ | | | |
M2M100| | | |✅ | | | |
Electra | |✅| | | | |✅|
ALBERT |✅|✅|✅| | | |✅|
T5 | | | |✅ |✅|✅| |
XLNet|✅|✅|✅|✅ | | |✅|
Reformer|✅| |✅|✅ | | |✅|
ProphetNet| | | |✅ |✅ | | |
Longformer|✅|✅|✅| | | |✅|
Pegasus| | | | |✅| | |
</details>
## Getting started
This library relies on the [tch](https://github.com/LaurentMazare/tch-rs) crate for bindings to the C++ Libtorch API.
The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up your environment
to use these bindings, please refer to the [tch](https://github.com/LaurentMazare/tch-rs) for detailed information or support.
Furthermore, this library relies on a cache folder for downloading pre-trained models.
This cache location defaults to `~/.cache/.rustbert`, but can be changed by setting the `RUSTBERT_CACHE` environment variable. Note that the language models used by this library are in the order of the 100s of MBs to GBs.
### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
```bash
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
```
##### Windows
```powershell
$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
```
### Automatic installation
Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu111`.
Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
## Ready-to-use pipelines
Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
**Disclaimer**
The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
<details>
<summary> <b>1. Question Answering</b> </summary>
#### 1. Question Answering
Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
Extractive question answering from a given question and context. DistilBERT model fine-tuned on SQuAD (Stanford Question Answering Dataset)
```rust
let qa_model = QuestionAnsweringModel::new(Default::default())?;
@ -39,40 +112,66 @@ Extractive question answering from a given question and context. DistilBERT mode
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
let answers = qa_model.predict(&[QaInput { question, context }], 1, 32);
```
Output:
```
[Answer { score: 0.9976814985275269, start: 13, end: 21, answer: "Amsterdam" }]
[Answer { score: 0.9976, start: 13, end: 21, answer: "Amsterdam" }]
```
</details>
&nbsp;
<details>
<summary> <b>2. Translation </b> </summary>
#### 2. Translation
Translation using the MarianMT architecture and pre-trained models from the Opus-MT team from Language Technology at the University of Helsinki.
Currently supported languages are :
- English <-> French
- English <-> Spanish
- English <-> Portuguese
- English <-> Italian
- English <-> Catalan
- English <-> German
- English <-> Russian
- French <-> German
Translation pipeline supporting a broad range of source and target languages. Leverages two main architectures for translation tasks:
- Marian-based models, for specific source/target combinations
- M2M100 models allowing for direct translation between 100 languages (at a higher computational cost and lower performance for some selected languages)
```rust
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
let mut model = TranslationModel::new(translation_config)?;
let input = ["This is a sentence to be translated"];
let output = model.translate(&input);
```
Marian-based pretrained models for the following language pairs are readily available in the library - but the user can import any Pytorch-based
model for predictions
- English <-> French
- English <-> Spanish
- English <-> Portuguese
- English <-> Italian
- English <-> Catalan
- English <-> German
- English <-> Russian
- English <-> Chinese
- English <-> Dutch
- English <-> Swedish
- English <-> Arabic
- English <-> Hebrew
- English <-> Hindi
- French <-> German
For languages not supported by the proposed pretrained Marian models, the user can leverage a M2M100 model supporting direct translation between 100 languages (without intermediate English translation)
The full list of supported languages is available in the [crate documentation](docs.rs/rust-bert/0.15.1/rust_bert/pipelines/translation/enum.Language.html)
```rust
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
fn main() -> anyhow::Result<()> {
let model = TranslationModelBuilder::new()
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::Spanish, Language::French, Language::Italian])
.create_model()?;
let input_text = "This is a sentence to be translated";
let output = model.translate(&[input_text], None, Language::Spanish)?;
for sentence in output {
println!("{}", sentence);
}
Ok(())
}
```
Output:
```
Il s'agit d'une phrase à traduire
```
</details>
&nbsp;
<details>
<summary> <b>3. Summarization </b> </summary>
#### 3. Summarization
Abstractive summarization using a pretrained BART model.
```rust
@ -110,8 +209,11 @@ Output:
This is the first such discovery in a planet in its star's habitable zone.
The planet is not too hot and not too cold for liquid water to exist."
```
</details>
&nbsp;
<details>
<summary> <b>4. Dialogue Model </b> </summary>
#### 4. Dialogue Model
Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
This pipeline allows the generation of single or multi-turn conversations between a human and a model.
The DialoGPT's page states that
@ -133,12 +235,15 @@ Example output:
```
"The Big Lebowski."
```
</details>
&nbsp;
<details>
<summary> <b>5. Natural Language Generation </b> </summary>
#### 5. Natural Language Generation
Generate language based on a prompt. GPT2 and GPT available as base models.
Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
This may impact the results and it is recommended to submit prompts of similar length for best results
This may impact the results, it is recommended to submit prompts of similar length for best results
```rust
let model = GPT2Generator::new(Default::default())?;
@ -146,8 +251,12 @@ This may impact the results and it is recommended to submit prompts of similar l
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), 0, 30, true, false,
5, 1.2, 0, 0.9, 1.0, 1.0, 3, 3, None);
let generate_options = GenerateOptions {
max_length: 30,
..Default::default()
};
let output = model.generate(Some(&[input_context_1, input_context_2]), generate_options);
```
Example output:
```
@ -160,8 +269,11 @@ Example output:
"The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
]
```
</details>
&nbsp;
<details>
<summary> <b>6. Zero-shot classification </b> </summary>
#### 6. Zero-shot classification
Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
```rust
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
@ -185,9 +297,12 @@ Output:
[ Label { "politics", score: 0.975 }, Label { "public health", score: 0.0818 }, Label {"economics", score: 0.852 }, Label {"sports", score: 0.001 } ],
]
```
</details>
&nbsp;
<details>
<summary> <b>7. Sentiment analysis </b> </summary>
#### 7. Sentiment analysis
Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on SST-2.
```rust
let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -209,9 +324,12 @@ Output:
Sentiment { polarity: Positive, score: 0.9997248985164333 }
]
```
</details>
&nbsp;
<details>
<summary> <b>8. Named Entity Recognition </b> </summary>
#### 8. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model fine-tuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Models are currently available for English, German, Spanish and Dutch.
```rust
let ner_model = NERModel::new(default::default())?;
@ -226,36 +344,72 @@ Models are currently available for English, German, Spanish and Dutch.
Output:
```
[
[
Entity { word: "Amy", score: 0.9986, label: "I-PER" }
Entity { word: "Paris", score: 0.9985, label: "I-LOC" }
],
[
Entity { word: "Paris", score: 0.9988, label: "I-LOC" }
Entity { word: "France", score: 0.9993, label: "I-LOC" }
]
]
```
</details>
&nbsp;
<details>
<summary> <b>9. Part of Speech tagging </b> </summary>
## Base models
Extracts Part of Speech tags (Noun, Verb, Adjective...) from text.
```rust
let pos_model = POSModel::new(default::default())?;
let input = ["My name is Bob"];
let output = pos_model.predict(&input);
```
Output:
```
[
Entity { word: "My", score: 0.1560, label: "PRP" }
Entity { word: "name", score: 0.6565, label: "NN" }
Entity { word: "is", score: 0.3697, label: "VBZ" }
Entity { word: "Bob", score: 0.7460, label: "NNP" }
]
```
</details>
## Benchmarks
For simple pipelines (sequence classification, tokens classification, question answering) the performance between Python and Rust is expected to be comparable. This is because the most expensive part of these pipeline is the language model itself, sharing a common implementation in the Torch backend. The [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/) provides a benchmarks section covering all pipelines.
For text generation tasks (summarization, translation, conversation, free text generation), significant benefits can be expected (up to 2 to 4 times faster processing depending on the input and application). The article [Accelerating text generation with Rust](https://guillaume-be.github.io/2020-11-21/generation_benchmarks) focuses on these text generation applications and provides more details on the performance comparison to Python.
## Loading pretrained and custom model weights
The base model and task-specific heads are also available for users looking to expose their own transformer based models.
Examples on how to prepare the date using a native tokenizers Rust library are available in `./examples` for BERT, DistilBERT, RoBERTa, GPT, GPT2 and BART.
Note that when importing models from Pytorch, the convention for parameters naming needs to be aligned with the Rust schema. Loading of the pre-trained weights will fail if any of the model parameters weights cannot be found in the weight files.
If this quality check is to be skipped, an alternative method `load_partial` can be invoked from the variables store.
## Setup
Pretrained models are available on Hugging face's [model hub](https://huggingface.co/models?filter=rust) and can be loaded using `RemoteResources` defined in this library.
A conversion utility script is included in `./utils` to convert Pytorch weights to a set of weights compatible with this library. This script requires Python and `torch` to be set-up, and can be used as follows:
`python ./utils/convert_model.py path/to/pytorch_model.bin` where `path/to/pytorch_model.bin` is the location of the original Pytorch weights.
A number of pretrained model configuration, weights and vocabulary are downloaded directly from [Huggingface's model repository](https://huggingface.co/models).
The list of models available with Rust-compatible weights is available at [https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust).
The models will be downloaded to the environment variable `RUSTBERT_CACHE` if it exists, otherwise to `~/.cache/.rustbert`.
Additional models can be added if of interest, please raise an issue.
In order to load custom weights to the library, these need to be converter to a binary format that can be read by Libtorch (the original `.bin` files are pickles and cannot be used directly).
Several Python scripts to load Pytorch weights and convert them to the appropriate format are provided and can be adapted based on the model needs.
1. Compile the package: `cargo build`
2. Download the model files & perform necessary conversions
- Set-up a virtual environment and install dependencies
- run the conversion script `python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py`. The dependencies will be downloaded to the user's home directory, under `~/rustbert/{}`.
Alternatively you may load local weight files and run the conversion directly.
## Citation
If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/):
```bibtex
@inproceedings{becquin-2020-end,
title = "End-to-end {NLP} Pipelines in Rust",
author = "Becquin, Guillaume",
booktitle = "Proceedings of Second Workshop for NLP Open Source Software (NLP-OSS)",
year = "2020",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.nlposs-1.4",
pages = "20--25",
}
```
## Acknowledgements

View File

@ -0,0 +1,74 @@
#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use std::time::{Duration, Instant};
use tch::Device;
fn create_text_generation_model() -> TextGenerationModel {
let config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
)),
min_length: 0,
max_length: 30,
do_sample: true,
early_stopping: false,
num_beams: 5,
temperature: 1.0,
top_k: 0,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_beam_groups: None,
diversity_penalty: None,
num_return_sequences: 5,
device: Device::cuda_if_available(),
};
TextGenerationModel::new(config).unwrap()
}
fn generation_forward_pass(iters: u64, model: &TextGenerationModel, data: &[&str]) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let _ = model.generate(data, None);
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_generation(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_text_generation_model();
// Define input
let input = ["Hello, I'm a language model,"];
c.bench_function("Generation", |b| {
b.iter_custom(|iters| black_box(generation_forward_pass(iters, &model, &input)))
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_generation
}
criterion_main!(benches);

101
benches/squad_benchmark.rs Normal file
View File

@ -0,0 +1,101 @@
#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use std::env;
use std::path::PathBuf;
use std::time::{Duration, Instant};
static BATCH_SIZE: usize = 64;
fn create_qa_model() -> QuestionAnsweringModel {
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
false,
None,
);
QuestionAnsweringModel::new(config).unwrap()
}
fn squad_forward_pass(
iters: u64,
model: &QuestionAnsweringModel,
squad_data: &[QaInput],
) -> Duration {
let mut duration = Duration::new(0, 0);
let batch_size = BATCH_SIZE;
let mut output = vec![];
for _i in 0..iters {
let start = Instant::now();
for batch in squad_data.chunks(batch_size) {
output.push(model.predict(batch, 1, 64));
}
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn qa_load_model(iters: u64) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
false,
None,
);
let _ = QuestionAnsweringModel::new(config).unwrap();
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_squad(c: &mut Criterion) {
// Set-up QA model
let model = create_qa_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
squad_path.push("dev-v2.0.json");
let mut qa_inputs = squad_processor(squad_path);
qa_inputs.truncate(1000);
c.bench_function("SQuAD forward pass", |b| {
b.iter_custom(|iters| black_box(squad_forward_pass(iters, &model, &qa_inputs)))
});
c.bench_function("Load model", |b| {
b.iter_custom(|iters| black_box(qa_load_model(iters)))
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_squad
}
criterion_main!(benches);

106
benches/sst2_benchmark.rs Normal file
View File

@ -0,0 +1,106 @@
#[macro_use]
extern crate criterion;
use criterion::Criterion;
use rust_bert::pipelines::sentiment::SentimentModel;
use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
use serde::Deserialize;
use std::error::Error;
use std::path::PathBuf;
use std::time::{Duration, Instant};
use std::{env, fs};
use tch::Device;
static BATCH_SIZE: usize = 64;
fn create_sentiment_model() -> SentimentModel {
let config = SequenceClassificationConfig {
device: Device::cuda_if_available(),
..Default::default()
};
SentimentModel::new(config).unwrap()
}
fn sst2_forward_pass(iters: u64, model: &SentimentModel, sst2_data: &[String]) -> Duration {
let mut duration = Duration::new(0, 0);
let batch_size = BATCH_SIZE;
let mut output = vec![];
for _i in 0..iters {
let start = Instant::now();
for batch in sst2_data.chunks(batch_size) {
output.push(
model.predict(
batch
.iter()
.map(|v| v.as_str())
.collect::<Vec<&str>>()
.as_slice(),
),
);
}
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
#[derive(Debug, Deserialize)]
struct Record {
sentence: String,
label: i8,
}
fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
let file = fs::File::open(file_path).expect("unable to open file");
let mut csv = csv::ReaderBuilder::new()
.has_headers(true)
.delimiter(b'\t')
.from_reader(file);
let mut records = Vec::new();
for result in csv.deserialize() {
let record: Record = result?;
records.push(record.sentence);
}
Ok(records)
}
fn sst2_load_model(iters: u64) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let config = SequenceClassificationConfig {
device: Device::cuda_if_available(),
..Default::default()
};
let _ = SentimentModel::new(config).unwrap();
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_sst2(c: &mut Criterion) {
// Set-up classifier
let model = create_sentiment_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}
// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
sst2_path.push("train.tsv");
let mut inputs = ss2_processor(sst2_path).unwrap();
inputs.truncate(2000);
c.bench_function("SST2 forward pass", |b| {
b.iter_custom(|iters| sst2_forward_pass(iters, &model, &inputs))
});
c.bench_function("Load model", |b| b.iter_custom(sst2_load_model));
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_sst2
}
criterion_main!(benches);

View File

@ -0,0 +1,86 @@
#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use std::time::{Duration, Instant};
use tch::Device;
fn create_summarization_model() -> SummarizationModel {
let config = SummarizationConfig {
device: Device::cuda_if_available(),
..Default::default()
};
SummarizationModel::new(config).unwrap()
}
fn summarization_forward_pass(iters: u64, model: &SummarizationModel, data: &[&str]) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let _ = model.summarize(data);
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn summarization_load_model(iters: u64) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let config = SummarizationConfig {
device: Device::cuda_if_available(),
..Default::default()
};
let _ = SummarizationModel::new(config).unwrap();
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_squad(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_summarization_model();
// Define input
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
c.bench_function("Summarization forward pass", |b| {
b.iter_custom(|iters| black_box(summarization_forward_pass(iters, &model, &input)))
});
c.bench_function("Load model", |b| {
b.iter_custom(|iters| black_box(summarization_load_model(iters)))
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_squad
}
criterion_main!(benches);

View File

@ -0,0 +1,39 @@
#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
use std::time::{Duration, Instant};
use tch::kind::Kind;
use tch::{Device, Tensor};
fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let _ = input.matmul(weights);
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_tensor_ops(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let input = Tensor::rand(&[32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand(&[512, 512], (Kind::Float, Device::cuda_if_available()));
let _ = &input.matmul(&weights);
c.bench_function("Matrix multiply ", |b| {
b.iter_custom(|iters| black_box(matrix_multiply(iters, &input, &weights)))
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(100);
targets = bench_tensor_ops
}
criterion_main!(benches);

View File

@ -0,0 +1,110 @@
#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
// use rust_bert::pipelines::common::ModelType;
// use rust_bert::pipelines::translation::TranslationOption::{Marian, T5};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationModel, TranslationModelBuilder};
// use rust_bert::resources::{LocalResource, Resource};
use std::time::{Duration, Instant};
use tch::Device;
fn create_translation_model() -> TranslationModel {
let model = TranslationModelBuilder::new()
.with_device(Device::cuda_if_available())
.with_model_type(ModelType::Marian)
// .with_model_type(ModelType::T5)
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::French])
.create_model()
.unwrap();
// let model_resource = Resource::Local(LocalResource {
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/model.ot".into(),
// });
// let config_resource = Resource::Local(LocalResource {
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/config.json".into(),
// });
// let vocab_resource = Resource::Local(LocalResource {
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/vocab.json".into(),
// });
// let merges_resource = Resource::Local(LocalResource {
// local_path: "E:/Coding/cache/rustbert/marian-mt-en-es/spiece.model".into(),
// });
//
// let source_languages = [Language::English];
// let target_languages = [Language::Spanish];
//
// let translation_config = TranslationConfig::new(
// ModelType::Marian,
// model_resource,
// config_resource,
// vocab_resource,
// merges_resource,
// source_languages,
// target_languages,
// Device::cuda_if_available(),
// );
// let model = TranslationModel::new(translation_config).unwrap();
model
}
fn translation_forward_pass(iters: u64, model: &TranslationModel, data: &[&str]) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let _ = model.translate(data, None, Language::French).unwrap();
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn translation_load_model(iters: u64) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
let _ = create_translation_model();
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
}
fn bench_squad(c: &mut Criterion) {
// Set-up translation model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_translation_model();
// Define input
let input = [
"In findings published Tuesday in Cornell University's arXiv by a team of scientists from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, a planet circling a star in the constellation Leo.",
"This is the first such discovery in a planet in its star's habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, used data from the NASA\'s Hubble telescope to assess changes in the light coming from K2-18b's star as the planet passed between it and Earth.",
"They found that certain wavelengths of light, which are usually absorbed by water, weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere contains water in vapour form.",
"The team from UCL then analyzed the Montreal team's data using their own software and confirmed their conclusion.",
"This was not the first time scientists have found signs of water on an exoplanet, but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.",
"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" said UCL astronomer Angelos Tsiaras.",
"It's the best candidate for habitability right now.\" \"It's a good sign\", said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.",
"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being a potentially habitable planet, but further observations will be required to say for sure. \"",
"K2-18b was first identified in 2015 by the Kepler space telescope.",
"It is about 110 light-years from Earth and larger but less dense.",
];
// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
c.bench_function("Translation forward pass", |b| {
b.iter_custom(|iters| black_box(translation_forward_pass(iters, &model, &input)))
});
c.bench_function("Load model", |b| {
b.iter_custom(|iters| black_box(translation_load_model(iters)))
});
}
criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = bench_squad
}
criterion_main!(benches);

View File

@ -1 +1 @@
too-many-arguments-threshold = 10
too-many-arguments-threshold = 12

View File

@ -1,97 +0,0 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
println!(
"{:?}",
model_output.prediction_scores.double_value(&[0, 0, 0])
);
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing"
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
Ok(())
}

View File

@ -1,84 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
false,
)?;
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
// Print masked tokens
println!("{:?}", model_output.encoder_hidden_state);
println!("{:?}", model_output.decoder_output);
println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0]));
Ok(())
}

View File

@ -12,10 +12,17 @@
extern crate anyhow;
use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
use rust_bert::pipelines::conversation::{
ConversationConfig, ConversationManager, ConversationModel,
};
fn main() -> anyhow::Result<()> {
let conversation_model = ConversationModel::new(Default::default())?;
let config = ConversationConfig {
do_sample: false,
num_beams: 3,
..Default::default()
};
let conversation_model = ConversationModel::new(config)?;
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =

View File

@ -1,104 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = DistilBertConfig::from_file(config_path);
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(())
}

View File

@ -1,389 +0,0 @@
extern crate anyhow;
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::distilbert::{
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
};
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::openai_gpt::{
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
OpenAiGptVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
/// the same location).
fn download_distil_gpt2() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_distilbert_sst2() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_distilbert_qa() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_distilbert() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_gpt2() -> anyhow::Result<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_gpt() -> anyhow::Result<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_roberta() -> anyhow::Result<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bert() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bert_ner() -> anyhow::Result<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bart() -> anyhow::Result<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bart_cnn() -> anyhow::Result<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_electra_generator() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_electra_discriminator() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_albert_base_v2() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn _download_dialogpt() -> anyhow::Result<()> {
// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_t5_small() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_roberta_qa() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bert_qa() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_xlm_roberta_ner_german() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::XLM_ROBERTA_NER_DE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::XLM_ROBERTA_NER_DE,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn main() -> anyhow::Result<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
let _ = download_distilbert_qa();
let _ = download_distilbert();
let _ = download_gpt2();
let _ = download_gpt();
let _ = download_roberta();
let _ = download_bert();
let _ = download_bert_ner();
let _ = download_bart();
let _ = download_bart_cnn();
let _ = download_electra_generator();
let _ = download_electra_discriminator();
let _ = download_albert_base_v2();
let _ = download_t5_small();
let _ = download_roberta_qa();
let _ = download_bert_qa();
let _ = download_xlm_roberta_ner_german();
Ok(())
}

View File

@ -1,89 +0,0 @@
// Copyright 2020 The Google Research Authors.
// Copyright 2019-present, the HuggingFace Inc. team
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = model_output.probabilities.double_value(&[position as i64]);
let generated = if probability > 0.5 {
"generated"
} else {
"original"
};
println!(
"{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(), false, false),
generated,
100f64 * probability
)
}
Ok(())
}

View File

@ -1,93 +0,0 @@
// Copyright 2020 The Google Research Authors.
// Copyright 2019-present, the HuggingFace Inc. team
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
Ok(())
}

View File

@ -12,23 +12,25 @@
extern crate anyhow;
use rust_bert::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
fn main() -> anyhow::Result<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
max_length: 30,
do_sample: true,
num_beams: 5,
temperature: 1.1,
num_return_sequences: 3,
do_sample: false,
num_beams: 1,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;
let model = TextGenerationModel::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";
let output = model.generate(Some(vec![input_context, second_input_context]), None);
// let second_input_context = "The cat was";
let output = model.generate(&[input_context], None);
for sentence in output {
println!("{:?}", sentence);

View File

@ -0,0 +1,66 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2018 Carnegie Mellon University Authors.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::gpt_neo::{
GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
// Set-up model resources
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_125M,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_125M,
));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource,
config_resource,
vocab_resource,
merges_resource,
min_length: 10,
max_length: 32,
do_sample: false,
early_stopping: true,
num_beams: 4,
num_return_sequences: 1,
device: Device::Cpu,
..Default::default()
};
let mut model = TextGenerationModel::new(generate_config)?;
model.set_device(Device::cuda_if_available());
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
for sentence in output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,64 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2018 Carnegie Mellon University Authors.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::reformer::{
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
// Set-up model
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
let generate_config = TextGenerationConfig {
model_type: ModelType::Reformer,
model_resource,
config_resource,
vocab_resource,
merges_resource,
min_length: 100,
max_length: 100,
do_sample: true,
early_stopping: false,
num_beams: 3,
num_return_sequences: 1,
..Default::default()
};
let model = TextGenerationModel::new(generate_config)?;
let input_context_1 = "The really great men must, I think,";
let input_context_2 = "It was a gloom winter night, and";
let output = model.generate(&[input_context_1, input_context_2], None);
for sentence in output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,59 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2018 Carnegie Mellon University Authors.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
let generate_config = TextGenerationConfig {
model_type: ModelType::XLNet,
model_resource,
config_resource,
vocab_resource,
merges_resource,
max_length: 32,
do_sample: false,
num_beams: 3,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
};
let model = TextGenerationModel::new(generate_config)?;
let input_context = "Once upon a time,";
let output = model.generate(&[input_context], None);
for sentence in output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -1,98 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources set-up
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
)?;
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}

View File

@ -17,7 +17,8 @@ use rust_bert::bert::{
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
@ -46,8 +47,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -67,13 +67,13 @@ fn main() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
&None,
&None,
false,
)
});

View File

@ -1,103 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
)?;
let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}

View File

@ -0,0 +1,31 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::pos_tagging::POSModel;
fn main() -> anyhow::Result<()> {
// Set-up model
let pos_model = POSModel::new(Default::default())?;
// Define input
let input = ["My name is Bob"];
// Run model
let output = pos_model.predict(&input);
for (pos, pos_tag) in output[0].iter().enumerate() {
println!("{} - {:?}", pos, pos_tag);
}
Ok(())
}

View File

@ -28,8 +28,8 @@ fn main() -> anyhow::Result<()> {
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
None, //merges resource only relevant with ModelType::Roberta
false,
false,
None,
);

View File

@ -0,0 +1,66 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::longformer::{
LongformerConfigResources, LongformerMergesResources, LongformerModelResources,
LongformerVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
))),
false,
None,
false,
);
let qa_model = QuestionAnsweringModel::new(config)?;
// Define input
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput {
question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};
// Get answer
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -1,119 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bert::BertConfig;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
false,
)?;
let config = BertConfig::from_file(config_path);
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false,
)
});
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(5)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples"
Ok(())
}

View File

@ -0,0 +1,56 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
// Set-up classifier
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,
model_resource,
config_resource,
vocab_resource,
..Default::default()
};
let sentiment_classifier = SentimentModel::new(sentiment_config)?;
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
let output = sentiment_classifier.predict(&input);
for sentiment in output {
println!("{:?}", sentiment);
}
Ok(())
}

View File

@ -0,0 +1,80 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::DISTILBART_CNN_6_6,
));
let summarization_config = SummarizationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource,
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: 142,
device: Device::Cpu,
..Default::default()
};
let summarization_model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,75 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL,
));
let summarization_config = SummarizationConfig {
model_type: ModelType::Pegasus,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
length_penalty: 1.0,
num_beams: 4,
no_repeat_ngram_size: 3,
device: Device::cuda_if_available(),
..Default::default()
};
let summarization_model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,77 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
));
let summarization_config = SummarizationConfig {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
length_penalty: 1.2,
num_beams: 4,
no_repeat_ngram_size: 3,
device: Device::cuda_if_available(),
..Default::default()
};
let summarization_model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -12,10 +12,26 @@
extern crate anyhow;
use rust_bert::pipelines::summarization::SummarizationModel;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
fn main() -> anyhow::Result<()> {
let summarization_model = SummarizationModel::new(Default::default())?;
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let summarization_config = SummarizationConfig::new(
ModelType::T5,
weights_resource,
config_resource,
vocab_resource.clone(),
vocab_resource,
);
let summarization_model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
@ -42,7 +58,7 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{:?}", sentence);
println!("{}", sentence);
}
Ok(())

View File

@ -1,48 +0,0 @@
// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::generation::{GenerateConfig, LanguageGenerator, T5Generator};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
let generate_config = GenerateConfig {
model_resource: weights_resource,
vocab_resource,
config_resource,
max_length: 40,
do_sample: false,
num_beams: 4,
..Default::default()
};
// Set-up masked LM model
let t5_model = T5Generator::new(generate_config)?;
// Define input
let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None);
println!("{:?}", output);
Ok(())
}

View File

@ -12,8 +12,9 @@
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
LabelAggregationOption, TokenClassificationConfig,
};
use rust_bert::resources::{RemoteResource, Resource};
@ -38,12 +39,12 @@ fn main() -> anyhow::Result<()> {
);
// Create the model
let token_classification_model = TokenClassificationModel::new(config)?;
let token_classification_model = NERModel::new(config)?;
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China.",
];
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
let token_outputs = token_classification_model.predict(&input);
for token in token_outputs {
println!("{:?}", token);

View File

@ -13,18 +13,23 @@
extern crate anyhow;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
use tch::Device;
fn main() -> anyhow::Result<()> {
let translation_config =
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let model = TranslationModel::new(translation_config)?;
let model = TranslationModelBuilder::new()
.with_device(Device::cuda_if_available())
.with_model_type(ModelType::Marian)
// .with_large_model()
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::Spanish])
.create_model()?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let input_context_1 = "This is a sentence to be translated";
let input_context_2 = "The dog did not wake up.";
let output = model.translate(&[input_context_1, input_context_2]);
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;
for sentence in output {
println!("{}", sentence);

View File

@ -0,0 +1,64 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
));
let source_languages = M2M100SourceLanguages::M2M100_418M;
let target_languages = M2M100TargetLanguages::M2M100_418M;
let translation_config = TranslationConfig::new(
ModelType::M2M100,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,63 @@
// Copyright 2018-2020 The HuggingFace Inc. team.
// Copyright 2020 Marian Team Authors
// Copyright 2019-2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianModelResources::ENGLISH2CHINESE,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianConfigResources::ENGLISH2CHINESE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianVocabResources::ENGLISH2CHINESE,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MarianSpmResources::ENGLISH2CHINESE,
));
let source_languages = MarianSourceLanguages::ENGLISH2CHINESE;
let target_languages = MarianTargetLanguages::ENGLISH2CHINESE;
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let output = model.translate(&[input_context_1, input_context_2], None, None)?;
for sentence in output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,64 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
let translation_config = TranslationConfig::new(
ModelType::MBart,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);
}
Ok(())
}

View File

@ -0,0 +1,67 @@
// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use tch::Device;
fn main() -> anyhow::Result<()> {
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_BASE));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_BASE));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_BASE));
let source_languages = [
Language::English,
Language::French,
Language::German,
Language::Romanian,
];
let target_languages = [
Language::English,
Language::French,
Language::German,
Language::Romanian,
];
let translation_config = TranslationConfig::new(
ModelType::T5,
model_resource,
config_resource,
vocab_resource,
merges_resource,
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
for sentence in outputs {
println!("{}", sentence);
}
Ok(())
}

View File

@ -1,2 +1,2 @@
torch == 1.5.0
transformers == 2.10.0
torch == 1.8.1
requests == 2.25.1

View File

@ -11,10 +11,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::albert::embeddings::AlbertEmbeddings;
use crate::albert::encoder::AlbertTransformer;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::{albert::embeddings::AlbertEmbeddings, common::activations::TensorFunction};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
@ -31,44 +32,30 @@ pub struct AlbertConfigResources;
pub struct AlbertVocabResources;
impl AlbertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/model",
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
"https://huggingface.co/albert-base-v2/resolve/main/rust_model.ot",
);
}
impl AlbertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/config",
"https://cdn.huggingface.co/albert-base-v2-config.json",
"https://huggingface.co/albert-base-v2/resolve/main/config.json",
);
}
impl AlbertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/ALBERT>. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/spiece",
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
"https://huggingface.co/albert-base-v2/resolve/main/spiece.model",
);
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu_new,
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
/// Mish ([Misra, 2019](https://arxiv.org/abs/1908.08681))
mish,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # ALBERT model configuration
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct AlbertConfig {
@ -102,20 +89,20 @@ pub struct AlbertConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<AlbertConfig> for AlbertConfig {}
impl Config for AlbertConfig {}
/// # ALBERT Base model
/// Base architecture for ALBERT models. Task-specific models will be built from this common base model
/// It is made of the following blocks:
/// - `embeddings`: `token`, `position` and `segment_id` embeddings
/// - `encoder`: Encoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers. Note that the weights are shared across layers, allowing for a reduction in the model memory footprint.
/// - `pooler`: linear layer applied to the first element of the sequence (*[MASK]* token)
/// - `pooler`: linear layer applied to the first element of the sequence (*MASK* token)
/// - `pooler_activation`: Tanh activation function for the pooling layer
pub struct AlbertModel {
embeddings: AlbertEmbeddings,
encoder: AlbertTransformer,
pooler: nn::Linear,
pooler_activation: Box<dyn Fn(&Tensor) -> Tensor>,
pooler_activation: TensorFunction,
}
impl AlbertModel {
@ -154,7 +141,7 @@ impl AlbertModel {
config.hidden_size,
Default::default(),
);
let pooler_activation = Box::new(_tanh);
let pooler_activation = Activation::tanh.get_function();
AlbertModel {
embeddings,
@ -170,7 +157,7 @@ impl AlbertModel {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -205,10 +192,10 @@ impl AlbertModel {
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -217,53 +204,35 @@ impl AlbertModel {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<AlbertOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
} else {
None
};
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let embedding_output = match self.embeddings.forward_t(
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => {
return Err(e);
}
};
)?;
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
let extended_attention_mask: Tensor =
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let transformer_output =
self.encoder
@ -272,7 +241,7 @@ impl AlbertModel {
let pooled_output = self
.pooler
.forward(&transformer_output.hidden_state.select(1, 0));
let pooled_output = (self.pooler_activation)(&pooled_output);
let pooled_output = (self.pooler_activation.get_fn())(&pooled_output);
Ok(AlbertOutput {
hidden_state: transformer_output.hidden_state,
@ -287,7 +256,7 @@ pub struct AlbertMLMHead {
layer_norm: nn::LayerNorm,
dense: nn::Linear,
decoder: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
}
impl AlbertMLMHead {
@ -297,10 +266,7 @@ impl AlbertMLMHead {
{
let p = p.borrow();
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12,
};
let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
@ -323,12 +289,7 @@ impl AlbertMLMHead {
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish,
});
let activation = config.hidden_act.get_function();
AlbertMLMHead {
layer_norm,
@ -339,7 +300,7 @@ impl AlbertMLMHead {
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let output: Tensor = (self.activation)(&hidden_states.apply(&self.dense));
let output: Tensor = (self.activation.get_fn())(&hidden_states.apply(&self.dense));
output.apply(&self.layer_norm).apply(&self.decoder)
}
}
@ -397,7 +358,7 @@ impl AlbertForMaskedLM {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -431,10 +392,10 @@ impl AlbertForMaskedLM {
///
/// let masked_lm_output = no_grad(|| {
/// albert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -442,11 +403,11 @@ impl AlbertForMaskedLM {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> AlbertMaskedLMOutput {
let base_model_output = self
@ -511,10 +472,7 @@ impl AlbertForSequenceClassification {
let p = p.borrow();
let albert = AlbertModel::new(p / "albert", config);
let classifier_dropout_prob = match config.classifier_dropout_prob {
Some(value) => value,
None => 0.1,
};
let classifier_dropout_prob = config.classifier_dropout_prob.unwrap_or(0.1);
let dropout = Dropout::new(classifier_dropout_prob);
let num_labels = config
.id2label
@ -541,7 +499,7 @@ impl AlbertForSequenceClassification {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -574,21 +532,21 @@ impl AlbertForSequenceClassification {
///
/// let classification_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> AlbertSequenceClassificationOutput {
let base_model_output = self
@ -683,7 +641,7 @@ impl AlbertForTokenClassification {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -716,21 +674,21 @@ impl AlbertForTokenClassification {
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> AlbertTokenClassificationOutput {
let base_model_output = self
@ -814,7 +772,7 @@ impl AlbertForQuestionAnswering {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -848,21 +806,21 @@ impl AlbertForQuestionAnswering {
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> AlbertQuestionAnsweringOutput {
let base_model_output = self
@ -881,8 +839,8 @@ impl AlbertForQuestionAnswering {
.apply(&self.qa_outputs)
.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
AlbertQuestionAnsweringOutput {
start_logits,
@ -958,7 +916,7 @@ impl AlbertForMultipleChoice {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -991,21 +949,21 @@ impl AlbertForMultipleChoice {
///
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false).unwrap()
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<AlbertSequenceClassificationOutput, RustBertError> {
let (input_ids, input_embeds, num_choices) = match &input_ids {
@ -1035,30 +993,20 @@ impl AlbertForMultipleChoice {
},
};
let mask = match mask {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None,
};
let token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None,
};
let position_ids = match position_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None,
};
let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let base_model_output = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let base_model_output = self.albert.forward_t(
input_ids.as_ref(),
mask.as_ref(),
token_type_ids.as_ref(),
position_ids.as_ref(),
input_embeds.as_ref(),
train,
)?;
let logits = base_model_output
.pooled_output
.apply_t(&self.dropout, train)

View File

@ -14,7 +14,6 @@
use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
@ -69,14 +68,8 @@ impl AlbertSelfAttention {
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12,
};
let output_attentions = config.output_attentions.unwrap_or(false);
let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
@ -106,7 +99,7 @@ impl AlbertSelfAttention {
pub fn forward_t(
&self,
input_ids: &Tensor,
mask: &Option<Tensor>,
mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = *input_ids.size().first().unwrap();
@ -125,7 +118,10 @@ impl AlbertSelfAttention {
query_layer.matmul(&key_layer.transpose(-1, -2))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let weights = scores
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
let w = self.dense.ws.transpose(0, 1).view((
@ -134,7 +130,8 @@ impl AlbertSelfAttention {
self.hidden_size,
));
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
if !self.output_attentions {

View File

@ -13,6 +13,7 @@
use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
@ -61,10 +62,7 @@ impl AlbertEmbeddings {
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12,
};
let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
@ -86,50 +84,39 @@ impl AlbertEmbeddings {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
let calc_position_ids = if position_ids.is_none() {
Some(
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
)
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
&input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);

View File

@ -11,10 +11,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::albert::albert_model::Activation;
use crate::albert::attention::AlbertSelfAttention;
use crate::albert::AlbertConfig;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
use crate::{albert::attention::AlbertSelfAttention, common::activations::TensorFunction};
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
@ -23,7 +21,7 @@ pub struct AlbertLayer {
full_layer_layer_norm: nn::LayerNorm,
ffn: nn::Linear,
ffn_output: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
}
impl AlbertLayer {
@ -33,12 +31,9 @@ impl AlbertLayer {
{
let p = p.borrow();
let attention = AlbertSelfAttention::new(p / "attention", &config);
let attention = AlbertSelfAttention::new(p / "attention", config);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12,
};
let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
@ -62,12 +57,7 @@ impl AlbertLayer {
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish,
});
let activation = config.hidden_act.get_function();
AlbertLayer {
attention,
@ -81,13 +71,13 @@ impl AlbertLayer {
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, train);
let ffn_output = attention_output.apply(&self.ffn);
let ffn_output: Tensor = (self.activation)(&ffn_output);
let ffn_output: Tensor = (self.activation.get_fn())(&ffn_output);
let ffn_output = ffn_output.apply(&self.ffn_output);
let ffn_output = (ffn_output + attention_output).apply(&self.full_layer_layer_norm);
@ -108,15 +98,8 @@ impl AlbertLayerGroup {
{
let p = p.borrow() / "albert_layers";
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let mut layers: Vec<AlbertLayer> = vec![];
for layer_index in 0..config.inner_group_num {
@ -133,7 +116,7 @@ impl AlbertLayerGroup {
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
@ -151,16 +134,15 @@ impl AlbertLayerGroup {
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
let temp = layer.forward_t(&hidden_state, mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
}
(hidden_state, all_hidden_states, all_attentions)
@ -184,15 +166,8 @@ impl AlbertTransformer {
let p = p.borrow();
let p_layers = p / "albert_layer_groups";
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let embedding_hidden_mapping_in = nn::linear(
p / "embedding_hidden_mapping_in",
@ -243,7 +218,7 @@ impl AlbertTransformer {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
let temp = layer.forward_t(&hidden_state, mask.as_ref(), train);
hidden_state = temp.0;
let attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {

View File

@ -2,16 +2,15 @@
//!
//! Implementation of the ALBERT language model ([https://arxiv.org/abs/1909.11942](https://arxiv.org/abs/1909.11942) Lan, Chen, Goodman, Gimpel, Sharma, Soricut, 2019).
//! This model offers a greatly reduced memory footprint for similar effective size (number and size of layers). The computational cost remains however similar to the original BERT model.
//! The base model is implemented in the `albert::AlbertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `albert::AlbertForMaskedLM`
//! - Multiple choices: `albert:AlbertForMultipleChoice`
//! - Question answering: `albert::AlbertForQuestionAnswering`
//! - Sequence classification: `albert::AlbertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `albert::AlbertForTokenClassification`
//! The base model is implemented in the `albert_model::AlbertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `albert_model::AlbertForMaskedLM`
//! - Multiple choices: `albert_model:AlbertForMultipleChoice`
//! - Question answering: `albert_model::AlbertForQuestionAnswering`
//! - Sequence classification: `albert_model::AlbertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `albert_model::AlbertForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/albert`, run with `cargo run --example albert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -22,12 +21,12 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use rust_tokenizers::AlbertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::AlbertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),

View File

@ -13,7 +13,6 @@
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
@ -24,20 +23,13 @@ pub struct LayerState {
pub prev_key: Tensor,
/// Cached values
pub prev_value: Tensor,
/// Cached keys padding mask
pub prev_key_padding_mask: Option<Tensor>,
}
impl Clone for LayerState {
fn clone(&self) -> Self {
let prev_key_padding_mask = match &self.prev_key_padding_mask {
Some(key_padding_mask) => Some(key_padding_mask.copy()),
None => None,
};
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.copy(),
prev_key_padding_mask,
}
}
}
@ -46,19 +38,11 @@ impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self.prev_value.index_select(0, new_indices);
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(
self.prev_key_padding_mask
.as_ref()
.unwrap()
.index_select(0, new_indices),
);
}
}
}
#[derive(Debug)]
pub struct SelfAttention {
pub struct BartAttention {
num_heads: i64,
head_dim: i64,
dropout: Dropout,
@ -72,7 +56,7 @@ pub struct SelfAttention {
store_cache: bool,
}
impl SelfAttention {
impl BartAttention {
pub fn new<'p, P>(
p: P,
embed_dim: i64,
@ -81,7 +65,7 @@ impl SelfAttention {
encoder_decoder_attention: bool,
store_cache: bool,
output_attentions: bool,
) -> SelfAttention
) -> BartAttention
where
P: Borrow<nn::Path<'p>>,
{
@ -96,7 +80,7 @@ impl SelfAttention {
let scaling = (head_dim as f64).powf(-0.5);
let dropout = Dropout::new(dropout);
SelfAttention {
BartAttention {
num_heads,
head_dim,
dropout,
@ -111,213 +95,90 @@ impl SelfAttention {
}
}
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
x.contiguous()
.view((dim_0, bs * self.num_heads, self.head_dim))
.transpose(0, 1)
fn _shape(&self, x: Tensor, sequence_length: i64, batch_size: i64) -> Tensor {
x.view((batch_size, sequence_length, self.num_heads, self.head_dim))
.transpose(1, 2)
.contiguous()
}
pub fn forward_t(
&self,
query: &Tensor,
key: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
hidden_states: &Tensor,
key_value_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
mut layer_state: Option<LayerState>,
layer_state: Option<LayerState>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
let query_size = query.size();
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(
query.as_ref().apply(&self.q_proj) * self.scaling,
target_sequence_length,
bs,
);
let key = match &layer_state {
Some(_) => {
if self.encoder_decoder_attention {
None
} else {
key
}
}
None => key,
};
let (bs, target_length, embed_dim) = hidden_states.size3().unwrap();
let (k, v) = if self.encoder_decoder_attention {
match key {
Some(key) => (
Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs)),
),
None => (None, None),
let query_states = hidden_states.apply(&self.q_proj) * self.scaling;
let (key_states, value_states) = if self.encoder_decoder_attention {
if let Some(layer_state_value) = layer_state {
(layer_state_value.prev_key, layer_state_value.prev_value)
} else {
(
self._shape(key_value_states.unwrap().apply(&self.k_proj), -1, bs),
self._shape(key_value_states.unwrap().apply(&self.v_proj), -1, bs),
)
}
} else if let Some(layer_state_value) = layer_state {
let key_states = self._shape(hidden_states.apply(&self.k_proj), -1, bs);
let value_states = self._shape(hidden_states.apply(&self.v_proj), -1, bs);
(
Tensor::cat(&[layer_state_value.prev_key, key_states], 2),
Tensor::cat(&[layer_state_value.prev_value, value_states], 2),
)
} else {
(
Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs)),
self._shape(hidden_states.apply(&self.k_proj), -1, bs),
self._shape(hidden_states.apply(&self.v_proj), -1, bs),
)
};
let (k, v, key_padding_mask) =
self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
let source_sequence_length = k.size()[1];
let attention_weights = q.bmm(&k.transpose(1, 2));
let attention_weights = match attention_mask {
Some(mask) => {
let attention_weights = attention_weights.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)) + mask;
attention_weights.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
))
}
None => attention_weights,
};
let attention_weights = match key_padding_mask.as_ref() {
Some(mask) => attention_weights
.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
))
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
)),
None => attention_weights,
};
let attention_weights = attention_weights.softmax(-1, Float);
let attention_probabilities = attention_weights.apply_t(&self.dropout, train);
let output = attention_probabilities
.bmm(&v)
.transpose(0, 1)
.contiguous()
.view((target_sequence_length, bs, self.num_heads * self.head_dim))
.apply(&self.out_proj);
let attention_weights = if self.output_attentions {
Some(attention_weights.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)))
let new_layer_state = if self.store_cache {
Some(LayerState {
prev_key: key_states.copy(),
prev_value: value_states.copy(),
})
} else {
None
};
if self.store_cache {
if layer_state.is_some() {
layer_state.as_mut().unwrap().prev_key =
k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value =
v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
Some(tensor) => Some(tensor),
None => None,
};
} else {
layer_state = Some(LayerState {
prev_key: k.view((bs, self.num_heads, -1, self.head_dim)),
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
prev_key_padding_mask: match key_padding_mask {
Some(tensor) => Some(tensor),
None => None,
},
})
};
let proj_shape = [bs * self.num_heads, -1, self.head_dim];
let query_states = self
._shape(query_states, target_length, bs)
.view(proj_shape);
let key_states = key_states.view(proj_shape);
let value_states = value_states.view(proj_shape);
let source_length = key_states.size()[1];
let mut attention_weights = query_states.bmm(&key_states.transpose(1, 2));
if let Some(attention_mask_value) = attention_mask {
attention_weights =
attention_weights.view([bs, self.num_heads, target_length, source_length])
+ attention_mask_value;
attention_weights =
attention_weights.view([bs * self.num_heads, target_length, source_length]);
};
(output, attention_weights, layer_state)
}
attention_weights = attention_weights.softmax(-1, attention_weights.kind());
fn use_saved_state(
&self,
layer_state: &Option<LayerState>,
k: Option<Tensor>,
v: Option<Tensor>,
key_padding_mask: Option<&Tensor>,
bs: i64,
) -> (Tensor, Tensor, Option<Tensor>) {
match &layer_state {
Some(prev_state) => {
let prev_key = prev_state
.prev_key
.view((bs * self.num_heads, -1, self.head_dim));
let prev_value =
prev_state
.prev_value
.view((bs * self.num_heads, -1, self.head_dim));
let k = if self.encoder_decoder_attention {
prev_key
} else {
Tensor::cat(&[prev_key, k.unwrap()], 1)
};
let v = if self.encoder_decoder_attention {
prev_value
} else {
Tensor::cat(&[prev_value, v.unwrap()], 1)
};
let key_padding_mask = self.use_saved_key_padding_mask(
key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,
k.size()[1],
);
(k, v, key_padding_mask)
}
None => {
let key_padding_mask = match key_padding_mask {
Some(value) => Some(value.copy()),
None => None,
};
(k.unwrap(), v.unwrap(), key_padding_mask)
}
}
}
fn use_saved_key_padding_mask(
&self,
key_padding_mask: Option<&Tensor>,
prev_key_padding_mask: &Option<Tensor>,
bs: i64,
sequence_length: i64,
) -> Option<Tensor> {
if prev_key_padding_mask.is_some() {
if self.encoder_decoder_attention {
Some(prev_key_padding_mask.as_ref().unwrap().copy())
} else {
Some(Tensor::cat(
&[
prev_key_padding_mask.as_ref().unwrap(),
key_padding_mask.as_ref().unwrap(),
],
1,
))
}
let saved_attention_weights = if self.output_attentions {
Some(attention_weights.view((bs, self.num_heads, target_length, source_length)))
} else {
match key_padding_mask {
Some(key_padding_mask) => {
let filler = Tensor::zeros(
&[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()),
);
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
}
None => None,
}
}
None
};
let attention_probas = attention_weights.apply_t(&self.dropout, train);
let attention_output = attention_probas
.bmm(&value_states)
.view([bs, self.num_heads, target_length, self.head_dim])
.transpose(1, 2)
.reshape(&[bs, target_length, embed_dim])
.apply(&self.out_proj);
(attention_output, saved_attention_weights, new_layer_state)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -11,26 +11,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::attention::{LayerState, SelfAttention};
use crate::bart::bart_model::Activation;
use crate::bart::bart_model::{_expand_mask, _prepare_decoder_attention_mask};
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::{
bart::attention::{BartAttention, LayerState},
common::activations::TensorFunction,
};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Bool;
use tch::{nn, Tensor};
pub struct DecoderLayer {
self_attention: SelfAttention,
encoder_attention: SelfAttention,
self_attention: BartAttention,
encoder_attention: BartAttention,
self_attention_layer_norm: nn::LayerNorm,
encoder_attention_layer_norm: nn::LayerNorm,
dropout: Dropout,
activation_dropout: Dropout,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
fc1: nn::Linear,
fc2: nn::Linear,
final_layer_norm: nn::LayerNorm,
@ -47,11 +49,8 @@ impl DecoderLayer {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions {
Some(value) => value,
None => false,
};
let self_attention = SelfAttention::new(
let output_attention = config.output_attentions.unwrap_or(false);
let self_attention = BartAttention::new(
p / "self_attn",
config.d_model,
config.decoder_attention_heads,
@ -60,7 +59,7 @@ impl DecoderLayer {
true,
output_attention,
);
let encoder_attention = SelfAttention::new(
let encoder_attention = BartAttention::new(
p / "encoder_attn",
config.d_model,
config.decoder_attention_heads,
@ -82,17 +81,8 @@ impl DecoderLayer {
let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function {
Some(act_function) => act_function,
None => &Activation::gelu,
};
let activation = Box::new(match activation_function {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh,
});
let activation_function = config.activation_function.unwrap_or(Activation::gelu);
let activation = activation_function.get_function();
let fc1 = nn::linear(
p / "fc1",
config.d_model,
@ -130,9 +120,8 @@ impl DecoderLayer {
&self,
x: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attn_mask: Option<&Tensor>,
causal_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
layer_states: (Option<LayerState>, Option<LayerState>),
train: bool,
) -> (
@ -140,27 +129,22 @@ impl DecoderLayer {
Option<Tensor>,
(Option<LayerState>, Option<LayerState>),
) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(
x,
Some(x),
decoder_padding_mask,
causal_mask,
layer_states.0,
train,
);
let (output, attention_weights, new_self_layer_states) =
self.self_attention
.forward_t(x, None, decoder_attention_mask, layer_states.0, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(
&output,
Some(encoder_hidden_states),
encoder_attn_mask,
None,
encoder_attention_mask,
layer_states.1,
train,
);
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
let output1 = output1.apply(&self.encoder_attention_layer_norm);
let output2 = (self.activation)(&output1.apply(&self.fc1));
let output2 = (self.activation.get_fn())(&output1.apply(&self.fc1));
let output2 = output2
.apply_t(&self.activation_dropout, train)
.apply(&self.fc2)
@ -182,36 +166,20 @@ pub struct BartDecoder {
output_attentions: bool,
output_hidden_states: bool,
output_past: bool,
generation_mode: bool,
scale_embedding: f64,
}
impl BartDecoder {
pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartDecoder
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartDecoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_past = match config.output_past {
Some(value) => value,
None => true,
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true,
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false,
};
let output_past = config.output_past.unwrap_or(true);
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let normalize_embedding = config.normalize_embedding.unwrap_or(true);
let static_position_embeddings = config.static_position_embeddings.unwrap_or(false);
let scale_embedding = match config.scale_embedding {
Some(value) => {
if value {
@ -239,11 +207,6 @@ impl BartDecoder {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1,
};
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
p / "embed_positions",
@ -255,7 +218,6 @@ impl BartDecoder {
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
))
};
@ -273,7 +235,6 @@ impl BartDecoder {
output_attentions,
output_hidden_states,
output_past,
generation_mode,
scale_embedding,
}
}
@ -282,36 +243,44 @@ impl BartDecoder {
&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> BartDecoderOutput {
let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None,
let past_key_values_length = if let Some(old_layer_states_values) = &old_layer_states {
if let Some(old_value_state) = &old_layer_states_values[0].0 {
old_value_state.prev_key.size()[2]
} else {
0
}
} else {
0
};
let positions = self
.embed_positions
.forward(input_ids, self.generation_mode);
let x: Tensor = if self.generation_mode {
let end_inputs = input_ids.size()[1];
let end_positions = positions.size()[1];
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding
+ positions.narrow(1, end_positions - 1, 1)
} else {
input_ids.apply(embeddings) * self.scale_embedding + positions
};
.forward(input_ids, past_key_values_length);
let x: Tensor = input_ids.apply(embeddings) * self.scale_embedding + positions;
let decoder_attention_mask = _prepare_decoder_attention_mask(
decoder_attention_mask,
input_ids.size().as_slice(),
&x,
past_key_values_length,
);
let encoder_attention_mask = encoder_attention_mask
.map(|mask| _expand_mask(mask, Some(*input_ids.size().last().unwrap()), x.kind()));
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
x.apply(layer_norm_embedding)
} else {
x
};
let mut hidden_state = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut hidden_state = x.apply_t(&self.dropout, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.layers.len()))
} else {
@ -332,7 +301,7 @@ impl BartDecoder {
} else {
None
};
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
let mut attention_weights: Option<Tensor>;
for (layer_idx, layer) in self.layers.iter().enumerate() {
@ -342,17 +311,16 @@ impl BartDecoder {
};
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
encoder_hidden_states,
encoder_attention_mask.as_ref(),
decoder_attention_mask.as_ref(),
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
hidden_states.push(hidden_state.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
@ -363,8 +331,8 @@ impl BartDecoder {
}
BartDecoderOutput {
hidden_state: hidden_state.transpose(0, 1),
encoder_padding_mask,
hidden_state,
encoder_attention_mask,
next_decoder_cache,
all_hidden_states,
all_attentions,
@ -377,7 +345,7 @@ pub struct BartDecoderOutput {
/// last decoder layer hidden state
pub hidden_state: Tensor,
/// Padding mask for the encoder positions to attend to
pub encoder_padding_mask: Option<Tensor>,
pub encoder_attention_mask: Option<Tensor>,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
/// Hidden states for all intermediate layers

View File

@ -12,9 +12,8 @@
// limitations under the License.
use std::borrow::Borrow;
use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
use tch::nn::embedding;
use tch::{nn, Kind, Tensor};
/// # Abstraction that holds a embeddings configuration
pub enum EmbeddingOption {
@ -25,13 +24,13 @@ pub enum EmbeddingOption {
impl EmbeddingOption {
/// Interface method to forward_t() of the particular models.
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
match *self {
Self::LearnedPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
embeddings.forward(input, past_key_values_length)
}
Self::SinusoidalPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
embeddings.forward(input, past_key_values_length)
}
}
}
@ -40,48 +39,37 @@ impl EmbeddingOption {
#[derive(Debug)]
pub struct LearnedPositionalEmbedding {
embedding: nn::Embedding,
padding_index: i64,
offset: i64,
}
impl LearnedPositionalEmbedding {
pub fn new<'p, P>(
p: P,
num_embeddings: i64,
embedding_dim: i64,
padding_index: i64,
) -> LearnedPositionalEmbedding
pub fn new<'p, P>(p: P, num_embeddings: i64, embedding_dim: i64) -> LearnedPositionalEmbedding
where
P: Borrow<nn::Path<'p>>,
{
let embedding_config = EmbeddingConfig {
padding_idx: padding_index,
..Default::default()
};
let num_embeddings = num_embeddings + padding_index + 1;
let offset = 2;
let embedding: nn::Embedding =
embedding(p.borrow(), num_embeddings, embedding_dim, embedding_config);
LearnedPositionalEmbedding {
embedding,
padding_index,
}
let num_embeddings = num_embeddings + offset;
let embedding: nn::Embedding = embedding(
p.borrow(),
num_embeddings,
embedding_dim,
Default::default(),
);
LearnedPositionalEmbedding { embedding, offset }
}
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
let positions = if generation_mode {
let positions = self.padding_index + input.size()[1];
input.new_full(&[1, 1], positions, (Int64, input.device()))
} else {
self.create_position_ids_from_input_ids(input, self.padding_index)
};
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
let input_shape = input.size();
let (_, sequence_length) = (input_shape[0], input_shape[1]);
let positions = Tensor::arange_start(
past_key_values_length,
past_key_values_length + sequence_length,
(Kind::Int64, input.device()),
) + self.offset;
positions.apply(&self.embedding)
}
fn create_position_ids_from_input_ids(&self, input_ids: &Tensor, padding_index: i64) -> Tensor {
let mask = input_ids.ne(padding_index).to_kind(Int64);
let position_ids: Tensor = mask.cumsum(1, Int64) * mask + padding_index;
position_ids
}
}
#[derive(Debug)]
@ -107,12 +95,14 @@ impl SinusoidalPositionalEmbedding {
SinusoidalPositionalEmbedding { embedding }
}
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
let positions = if generation_mode {
Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device()))
} else {
Tensor::arange(input.size()[1], (Int64, input.device()))
};
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
let input_shape = input.size();
let (_, sequence_length) = (input_shape[0], input_shape[1]);
let positions = Tensor::arange_start(
past_key_values_length,
past_key_values_length + sequence_length,
(Kind::Int64, input.device()),
);
positions.apply(&self.embedding)
}
}

View File

@ -11,24 +11,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::attention::SelfAttention;
use crate::bart::bart_model::Activation;
use crate::bart::attention::BartAttention;
use crate::bart::bart_model::_expand_mask;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::activations::{Activation, TensorFunction};
use crate::common::dropout::Dropout;
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Bool;
use tch::{nn, Tensor};
pub struct EncoderLayer {
self_attention: SelfAttention,
self_attention: BartAttention,
self_attention_layer_norm: nn::LayerNorm,
dropout: Dropout,
activation_dropout: Dropout,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
fc1: nn::Linear,
fc2: nn::Linear,
final_layer_norm: nn::LayerNorm,
@ -45,11 +44,8 @@ impl EncoderLayer {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions {
Some(value) => value,
None => false,
};
let self_attention = SelfAttention::new(
let output_attention = config.output_attentions.unwrap_or(false);
let self_attention = BartAttention::new(
p / "self_attn",
config.d_model,
config.encoder_attention_heads,
@ -69,13 +65,7 @@ impl EncoderLayer {
Some(act_function) => act_function,
None => &Activation::gelu,
};
let activation = Box::new(match activation_function {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh,
});
let activation = activation_function.get_function();
let fc1 = nn::linear(
p / "fc1",
config.d_model,
@ -110,17 +100,17 @@ impl EncoderLayer {
pub fn forward_t(
&self,
x: &Tensor,
encoder_padding_mask: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) =
self.self_attention
.forward_t(x, None, encoder_padding_mask, None, None, train);
.forward_t(x, None, encoder_attention_mask, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let residual = output.copy();
let output = (self.activation)(&output.apply(&self.fc1));
let output = (self.activation.get_fn())(&output.apply(&self.fc1));
let output = output
.apply_t(&self.activation_dropout, train)
.apply(&self.fc2)
@ -146,22 +136,10 @@ impl BartEncoder {
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true,
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let normalize_embedding = config.normalize_embedding.unwrap_or(true);
let static_position_embeddings = config.static_position_embeddings.unwrap_or(false);
let scale_embedding = match config.scale_embedding {
Some(value) => {
if value {
@ -189,11 +167,6 @@ impl BartEncoder {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1,
};
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
p / "embed_positions",
@ -205,7 +178,6 @@ impl BartEncoder {
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
))
};
@ -233,19 +205,15 @@ impl BartEncoder {
embeddings: &nn::Embedding,
train: bool,
) -> BartEncoderOutput {
let attention_mask = match attention_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None,
};
let x = input_ids.apply(embeddings) * self.scale_embedding;
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
let x: Tensor = x + &self.embed_positions.forward(input_ids, 0);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
x.apply(layer_norm_embedding)
} else {
x
};
let x = x.apply_t(&self.dropout, train).transpose(0, 1);
let attention_mask = attention_mask.map(|mask| _expand_mask(mask, None, x.kind()));
let mut hidden_state = x.apply_t(&self.dropout, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
@ -258,28 +226,22 @@ impl BartEncoder {
None
};
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
}
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
BartEncoderOutput {
hidden_state: hidden_state.transpose(0, 1),
hidden_state,
all_hidden_states,
all_attentions,
}

View File

@ -1,13 +1,12 @@
//! # BART (Lewis et al.)
//!
//! Implementation of the BART language model ([BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) Lewis, Liu, Goyal, Ghazvininejad, Mohamed, Levy, Stoyanov, Zettlemoyer, 2019).
//! The base model is implemented in the `bart::BartModel` struct. The model also includes a language model head: `bart::BartForConditionalGeneration`
//! implementing the common `generation::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//! The base model is implemented in the `bart_model::BartModel` struct. The model also includes a language model head: `bart_model::BartForConditionalGeneration`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/bart`, run with `cargo run --example bart`.
//! Alternatively, the summarization capabilities are illustrated in `examples/summarization.rs`, run with `cargo run --example summarization`.
//! The summarization capabilities are illustrated in `examples/summarization_bart`, run with `cargo run --example summarization_bart`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
@ -17,12 +16,12 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::RobertaTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
@ -50,7 +49,7 @@
//! false,
//! )?;
//! let config = BartConfig::from_file(config_path);
//! let bart_model = BartModel::new(&vs.root(), &config, false);
//! let bart_model = BartModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
@ -65,9 +64,13 @@ mod encoder;
pub use attention::LayerState;
pub use bart_model::{
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
BartForSequenceClassification, BartMergesResources, BartModel, BartModelOutput,
BartModelResources, BartVocabResources,
BartConfig, BartConfigResources, BartForConditionalGeneration, BartForSequenceClassification,
BartGenerator, BartMergesResources, BartModel, BartModelOutput, BartModelResources,
BartVocabResources,
};
pub(crate) use attention::BartAttention;
pub(crate) use bart_model::{_expand_mask, _make_causal_mask, _prepare_decoder_attention_mask};
pub(crate) use decoder::BartDecoderOutput;
pub(crate) use embeddings::LearnedPositionalEmbedding;
pub(crate) use encoder::BartEncoderOutput;

View File

@ -11,11 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::bert_model::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::bert::bert_model::BertConfig;
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
@ -62,10 +61,7 @@ impl BertSelfAttention {
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
BertSelfAttention {
num_attention_heads: config.num_attention_heads,
@ -92,9 +88,9 @@ impl BertSelfAttention {
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (key_layer, value_layer, mask) = match encoder_hidden_states {
@ -127,7 +123,9 @@ impl BertSelfAttention {
query_layer.matmul(&key_layer.transpose(-1, -2))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let weights = scores
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size);
if !self.output_attentions {
@ -203,9 +201,9 @@ impl BertAttention {
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.forward_t(
@ -223,7 +221,7 @@ impl BertAttention {
pub struct BertIntermediate {
lin: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
}
impl BertIntermediate {
@ -239,16 +237,12 @@ impl BertIntermediate {
config.intermediate_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish,
});
let activation = config.hidden_act.get_function();
BertIntermediate { lin, activation }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(self.activation)(&hidden_states.apply(&self.lin))
(self.activation.get_fn())(&hidden_states.apply(&self.lin))
}
}

View File

@ -11,16 +11,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::embeddings::{BertEmbedding, BertEmbeddings};
use crate::bert::encoder::{BertEncoder, BertPooler};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::{
bert::embeddings::{BertEmbedding, BertEmbeddings},
common::activations::TensorFunction,
};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::Init;
use tch::{nn, Kind, Tensor};
@ -34,72 +37,60 @@ pub struct BertConfigResources;
pub struct BertVocabResources;
impl BertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/model",
"https://cdn.huggingface.co/bert-base-uncased-rust_model.ot",
"https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-rust_model.ot",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
);
}
impl BertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/config",
"https://cdn.huggingface.co/bert-base-uncased-config.json",
"https://huggingface.co/bert-base-uncased/resolve/main/config.json",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/config",
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
);
}
impl BertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/bert>. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = (
"bert/vocab",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at <https://github.com/dbmdz/berts>. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by Hugging Face Inc at <https://github.com/huggingface/transformers/tree/master/examples/question-answering>. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/vocab",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
);
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
/// Mish ([Misra, 2019](https://arxiv.org/abs/1908.08681))
mish,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # BERT model configuration
/// Defines the BERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct BertConfig {
@ -121,19 +112,19 @@ pub struct BertConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<BertConfig> for BertConfig {}
impl Config for BertConfig {}
/// # BERT Base model
/// Base architecture for BERT models. Task-specific models will be built from this common base model
/// It is made of the following blocks:
/// - `embeddings`: `token`, `position` and `segment_id` embeddings
/// - `encoder`: Encoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers
/// - `pooler`: linear layer applied to the first element of the sequence (*[MASK]* token)
/// - `pooler`: linear layer applied to the first element of the sequence (*MASK* token)
/// - `is_decoder`: Flag indicating if the model is used as a decoder. If set to true, a causal mask will be applied to hide future positions that should not be attended to.
pub struct BertModel<T: BertEmbedding> {
embeddings: T,
encoder: BertEncoder,
pooler: BertPooler,
pooler: Option<BertPooler>,
is_decoder: bool,
}
@ -168,13 +159,63 @@ impl<T: BertEmbedding> BertModel<T> {
{
let p = p.borrow();
let is_decoder = match config.is_decoder {
Some(value) => value,
None => false,
};
let is_decoder = config.is_decoder.unwrap_or(false);
let embeddings = T::new(p / "embeddings", config);
let encoder = BertEncoder::new(p / "encoder", config);
let pooler = BertPooler::new(p / "pooler", config);
let pooler = Some(BertPooler::new(p / "pooler", config));
BertModel {
embeddings,
encoder,
pooler,
is_decoder,
}
}
/// Build a new `BertModel` with an optional Pooling layer
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BERT model
/// * `config` - `BertConfig` object defining the model architecture and decoder status
/// * `add_pooling_layer` - Enable/Disable an optional pooling layer at the end of the model
///
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert: BertModel<BertEmbeddings> =
/// BertModel::new_with_optional_pooler(&p.root() / "bert", &config, false);
/// ```
pub fn new_with_optional_pooler<'p, P>(
p: P,
config: &BertConfig,
add_pooling_layer: bool,
) -> BertModel<T>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let is_decoder = config.is_decoder.unwrap_or(false);
let embeddings = T::new(p / "embeddings", config);
let encoder = BertEncoder::new(p / "encoder", config);
let pooler = {
if add_pooling_layer {
Some(BertPooler::new(p / "pooler", config))
} else {
None
}
};
BertModel {
embeddings,
@ -190,7 +231,7 @@ impl<T: BertEmbedding> BertModel<T> {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
@ -209,32 +250,31 @@ impl<T: BertEmbedding> BertModel<T> {
///
/// ```no_run
/// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// bert_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// None,
/// None,
/// &None,
/// &None,
/// false,
/// )
/// .unwrap()
@ -242,50 +282,32 @@ impl<T: BertEmbedding> BertModel<T> {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> Result<BertModelOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let calc_mask = Tensor::ones(&input_shape, (Kind::Int8, device));
let mask = mask.unwrap_or(&calc_mask);
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Float, device));
let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
1,
]);
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1));
let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
causal_mask * mask.unsqueeze(1).unsqueeze(1)
} else {
mask.unsqueeze(1).unsqueeze(1)
@ -298,8 +320,17 @@ impl<T: BertEmbedding> BertModel<T> {
}
};
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {
@ -312,7 +343,7 @@ impl<T: BertEmbedding> BertModel<T> {
encoder_hidden_states_shape[0],
encoder_hidden_states_shape[1],
],
(Kind::Int64, device),
(Kind::Int8, device),
),
};
match encoder_mask.dim() {
@ -328,41 +359,31 @@ impl<T: BertEmbedding> BertModel<T> {
None
};
let embedding_output = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => {
return Err(e);
}
};
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
let encoder_output = self.encoder.forward_t(
&embedding_output,
&Some(extended_attention_mask),
Some(&extended_attention_mask),
encoder_hidden_states,
&encoder_extended_attention_mask,
encoder_extended_attention_mask.as_ref(),
train,
);
let pooled_output = self.pooler.forward(&hidden_state);
let pooled_output = self
.pooler
.as_ref()
.map(|pooler| pooler.forward(&encoder_output.hidden_state));
Ok(BertModelOutput {
hidden_state,
hidden_state: encoder_output.hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
all_hidden_states: encoder_output.all_hidden_states,
all_attentions: encoder_output.all_attentions,
})
}
}
pub struct BertPredictionHeadTransform {
dense: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
layer_norm: nn::LayerNorm,
}
@ -379,11 +400,7 @@ impl BertPredictionHeadTransform {
config.hidden_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish,
});
let activation = config.hidden_act.get_function();
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
@ -399,7 +416,7 @@ impl BertPredictionHeadTransform {
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
((&self.activation)(&hidden_states.apply(&self.dense))).apply(&self.layer_norm)
((&self.activation.get_fn())(&hidden_states.apply(&self.dense))).apply(&self.layer_norm)
}
}
@ -432,7 +449,7 @@ impl BertLMPredictionHead {
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.transform.forward(&hidden_states).apply(&self.decoder) + &self.bias
self.transform.forward(hidden_states).apply(&self.decoder) + &self.bias
}
}
@ -486,7 +503,7 @@ impl BertForMaskedLM {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
/// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
@ -504,44 +521,43 @@ impl BertForMaskedLM {
///
/// ```no_run
/// # use rust_bert::bert::{BertForMaskedLM, BertConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model = BertForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// None,
/// None,
/// &None,
/// &None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> BertMaskedLMOutput {
let base_model_output = self
@ -633,7 +649,7 @@ impl BertForSequenceClassification {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -649,28 +665,27 @@ impl BertForSequenceClassification {
///
/// ```no_run
/// # use rust_bert::bert::{BertForSequenceClassification, BertConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -678,11 +693,11 @@ impl BertForSequenceClassification {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> BertSequenceClassificationOutput {
let base_model_output = self
@ -693,14 +708,15 @@ impl BertForSequenceClassification {
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
None,
None,
train,
)
.unwrap();
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
BertSequenceClassificationOutput {
@ -769,7 +785,7 @@ impl BertForMultipleChoice {
///
/// * `input_ids` - Input tensor of shape (*batch size*, *sequence_length*).
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
@ -802,54 +818,48 @@ impl BertForMultipleChoice {
///
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// &input_tensor,
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Tensor,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_ids: &Tensor,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
train: bool,
) -> BertSequenceClassificationOutput {
let num_choices = input_ids.size()[1];
let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
let mask = match mask {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None,
};
let token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None,
};
let position_ids = match position_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None,
};
let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let base_model_output = self
.bert
.forward_t(
Some(input_ids),
mask,
token_type_ids,
position_ids,
Some(&input_ids),
mask.as_ref(),
token_type_ids.as_ref(),
position_ids.as_ref(),
None,
None,
None,
&None,
&None,
train,
)
.unwrap();
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
@ -928,7 +938,7 @@ impl BertForTokenClassification {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -962,10 +972,10 @@ impl BertForTokenClassification {
///
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -973,11 +983,11 @@ impl BertForTokenClassification {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> BertTokenClassificationOutput {
let base_model_output = self
@ -988,8 +998,8 @@ impl BertForTokenClassification {
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
None,
None,
train,
)
.unwrap();
@ -1064,7 +1074,7 @@ impl BertForQuestionAnswering {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -1099,10 +1109,10 @@ impl BertForQuestionAnswering {
///
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -1110,11 +1120,11 @@ impl BertForQuestionAnswering {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> BertQuestionAnsweringOutput {
let base_model_output = self
@ -1125,8 +1135,8 @@ impl BertForQuestionAnswering {
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
None,
None,
train,
)
.unwrap();
@ -1134,8 +1144,8 @@ impl BertForQuestionAnswering {
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
BertQuestionAnsweringOutput {
start_logits,
@ -1151,7 +1161,7 @@ pub struct BertModelOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Pooled output (hidden state for the first token)
pub pooled_output: Tensor,
pub pooled_output: Option<Tensor>,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
@ -1199,3 +1209,30 @@ pub struct BertQuestionAnsweringOutput {
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
#[cfg(test)]
mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
Config,
};
use super::*;
#[test]
#[ignore] // compilation is enough, no need to run
fn bert_model_send() {
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device);
let config = BertConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(&vs.root(), &config));
}
}

View File

@ -13,6 +13,7 @@
use crate::bert::bert_model::BertConfig;
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
@ -27,10 +28,10 @@ pub trait BertEmbedding {
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError>;
}
@ -121,7 +122,7 @@ impl BertEmbedding for BertEmbeddings {
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -153,9 +154,9 @@ impl BertEmbedding for BertEmbeddings {
/// let embedded_output = no_grad(|| {
/// bert_embeddings
/// .forward_t(
/// Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -164,50 +165,41 @@ impl BertEmbedding for BertEmbeddings {
/// ```
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_embeddings.size()[1];
let calc_position_ids = if position_ids.is_none() {
Some(
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
)
} else {
None
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
&input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);

View File

@ -16,6 +16,14 @@ use crate::bert::bert_model::BertConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
/// # BERT Layer
/// Layer used in BERT encoders.
/// It is made of the following blocks:
/// - `attention`: self-attention `BertAttention` layer
/// - `cross_attention`: (optional) cross-attention `BertAttention` layer (if the model is used as a decoder)
/// - `is_decoder`: flag indicating if the model is used as a decoder
/// - `intermediate`: `BertIntermediate` intermediate layer
/// - `output`: `BertOutput` output layer
pub struct BertLayer {
attention: BertAttention,
is_decoder: bool,
@ -25,19 +33,40 @@ pub struct BertLayer {
}
impl BertLayer {
/// Build a new `BertLayer`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BERT model
/// * `config` - `BertConfig` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertLayer};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let layer: BertLayer = BertLayer::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let attention = BertAttention::new(p / "attention", &config);
let attention = BertAttention::new(p / "attention", config);
let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => {
if value {
(
value,
Some(BertAttention::new(p / "cross_attention", &config)),
Some(BertAttention::new(p / "cross_attention", config)),
)
} else {
(value, None)
@ -46,8 +75,8 @@ impl BertLayer {
None => (false, None),
};
let intermediate = BertIntermediate::new(p / "intermediate", &config);
let output = BertOutput::new(p / "output", &config);
let intermediate = BertIntermediate::new(p / "intermediate", config);
let output = BertOutput::new(p / "output", config);
BertLayer {
attention,
@ -58,19 +87,55 @@ impl BertLayer {
}
}
/// Forward pass through the layer
///
/// # Arguments
///
/// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
/// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `BertLayerOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `attention_scores` - `Option<Tensor>` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `cross_attention_scores` - `Option<Tensor>` of shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use rust_bert::bert::{BertConfig, BertLayer};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// let layer: BertLayer = BertLayer::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
///
/// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false));
/// ```
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_output, attention_weights, cross_attention_weights) =
) -> BertLayerOutput {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, None, None, train);
let (attention_output, attention_scores, cross_attention_scores) =
if self.is_decoder & encoder_hidden_states.is_some() {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, cross_attention_weights) =
self.cross_attention.as_ref().unwrap().forward_t(
&attention_output,
@ -81,19 +146,24 @@ impl BertLayer {
);
(attention_output, attention_weights, cross_attention_weights)
} else {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
(attention_output, attention_weights, None)
};
let output = self.intermediate.forward(&attention_output);
let output = self.output.forward_t(&output, &attention_output, train);
(output, attention_weights, cross_attention_weights)
BertLayerOutput {
hidden_state: output,
attention_weights: attention_scores,
cross_attention_weights: cross_attention_scores,
}
}
}
/// # BERT Encoder
/// Encoder used in BERT models.
/// It is made of a Vector of `BertLayer` through which hidden states will be passed. The encoder can also be
/// used as a decoder (with cross-attention) if `encoder_hidden_states` are provided.
pub struct BertEncoder {
output_attentions: bool,
output_hidden_states: bool,
@ -101,21 +171,34 @@ pub struct BertEncoder {
}
impl BertEncoder {
/// Build a new `BertEncoder`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BERT model
/// * `config` - `BertConfig` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertEncoder};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let encoder: BertEncoder = BertEncoder::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "layer";
let output_attentions = if let Some(value) = config.output_attentions {
value
} else {
false
};
let output_hidden_states = if let Some(value) = config.output_hidden_states {
value
} else {
false
};
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let mut layers: Vec<BertLayer> = vec![];
for layer_index in 0..config.num_hidden_layers {
@ -129,14 +212,50 @@ impl BertEncoder {
}
}
/// Forward pass through the encoder
///
/// # Arguments
///
/// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
/// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `BertEncoderOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use rust_bert::bert::{BertConfig, BertEncoder};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device));
///
/// let encoder_output =
/// no_grad(|| encoder.forward_t(&input_tensor, Some(&mask), None, None, false));
/// ```
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
input: &Tensor,
mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> BertEncoderOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
@ -148,37 +267,69 @@ impl BertEncoder {
None
};
let mut hidden_state = hidden_states.copy();
let mut hidden_state = None::<Tensor>;
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
let layer_output = if let Some(hidden_state) = &hidden_state {
layer.forward_t(
hidden_state,
mask,
encoder_hidden_states,
encoder_mask,
train,
)
} else {
layer.forward_t(input, mask, encoder_hidden_states, encoder_mask, train)
};
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
hidden_state = Some(layer_output.hidden_state);
attention_weights = layer_output.attention_weights;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().unwrap().copy());
};
}
(hidden_state, all_hidden_states, all_attentions)
BertEncoderOutput {
hidden_state: hidden_state.unwrap(),
all_hidden_states,
all_attentions,
}
}
}
/// # BERT Pooler
/// Pooler used in BERT models.
/// It is made of a fully connected layer which is applied to the first sequence element.
pub struct BertPooler {
lin: nn::Linear,
}
impl BertPooler {
/// Build a new `BertPooler`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BERT model
/// * `config` - `BertConfig` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertPooler};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let pooler: BertPooler = BertPooler::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPooler
where
P: Borrow<nn::Path<'p>>,
@ -194,7 +345,54 @@ impl BertPooler {
BertPooler { lin }
}
/// Forward pass through the pooler
///
/// # Arguments
///
/// * `hidden_states` - input tensor of shape (*batch size*, *sequence_length*, *hidden_size*).
///
/// # Returns
///
/// * `Tensor` of shape (*batch size*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use rust_bert::bert::{BertConfig, BertPooler};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// let pooler: BertPooler = BertPooler::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
///
/// let pooler_output = no_grad(|| pooler.forward(&input_tensor));
/// ```
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
hidden_states.select(1, 0).apply(&self.lin).tanh()
}
}
/// Container for the BERT layer output.
pub struct BertLayerOutput {
/// Hidden states
pub hidden_state: Tensor,
/// Self attention scores
pub attention_weights: Option<Tensor>,
/// Cross attention scores
pub cross_attention_weights: Option<Tensor>,
}
/// Container for the BERT encoder output.
pub struct BertEncoderOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -1,16 +1,16 @@
//! # BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al.)
//!
//! Implementation of the BERT language model ([https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805) Devlin, Chang, Lee, Toutanova, 2018).
//! The base model is implemented in the `bert::BertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `bert::BertForMaskedLM`
//! - Multiple choices: `bert:BertForMultipleChoice`
//! - Question answering: `bert::BertForQuestionAnswering`
//! - Sequence classification: `bert::BertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `bert::BertForTokenClassification`
//! The base model is implemented in the `bert_model::BertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `bert_model::BertForMaskedLM`
//! - Multiple choices: `bert_model:BertForMultipleChoice`
//! - Question answering: `bert_model::BertForQuestionAnswering`
//! - Sequence classification: `bert_model::BertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `bert_model::BertForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/bert`, run with `cargo run --example bert`.
//! A full working example is provided in `examples/masked_language_model_bert`, run with `cargo run --example masked_language_model_bert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -21,12 +21,12 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
@ -58,10 +58,11 @@ mod embeddings;
pub(crate) mod encoder;
pub use bert_model::{
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources,
BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput,
BertVocabResources,
};
pub use embeddings::{BertEmbedding, BertEmbeddings};
pub use encoder::{BertEncoder, BertEncoderOutput, BertLayer, BertLayerOutput, BertPooler};

View File

@ -1,8 +1,9 @@
use serde::{Deserialize, Serialize};
use std::f64::consts::PI;
use tch::Tensor;
pub fn _gelu(x: &Tensor) -> Tensor {
x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf())
x * 0.5 * (1.0 + (x / ((2.0_f64).sqrt())).erf())
}
pub fn _relu(x: &Tensor) -> Tensor {
@ -18,9 +19,66 @@ pub fn _mish(x: &Tensor) -> Tensor {
}
pub fn _gelu_new(x: &Tensor) -> Tensor {
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
x * 0.5 * (((x.pow_tensor_scalar(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
}
pub fn _tanh(x: &Tensor) -> Tensor {
x.tanh()
}
pub struct TensorFunction(Box<fn(&Tensor) -> Tensor>);
impl TensorFunction {
pub fn new(fun: Box<fn(&Tensor) -> Tensor>) -> Self {
Self(fun)
}
pub fn get_fn(&self) -> &fn(&Tensor) -> Tensor {
&self.0
}
}
impl std::fmt::Debug for TensorFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(f, "TensorFunction")
}
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
/// # Activation function used in the attention layer and masked language model head
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
/// Swish ([Ramachandran, 2017](https://arxiv.org/abs/1710.05941))
swish,
/// Mish ([Misra, 2019](https://arxiv.org/abs/1908.08681))
mish,
/// Gaussian Error Linear Unit (New) ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu_new,
/// Tanh
tanh,
}
impl Activation {
pub fn get_function(&self) -> TensorFunction {
TensorFunction::new(Box::new(match self {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::mish => _mish,
Activation::tanh => _tanh,
}))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore]
fn tensorfunction_send() {
let _: Box<dyn Send> = Box::new(Activation::gelu.get_function());
}
}

View File

@ -15,9 +15,9 @@ use std::io::BufReader;
use std::path::Path;
/// # Utility to deserialize JSON config files
pub trait Config<T>
pub trait Config
where
for<'de> T: Deserialize<'de>,
for<'de> Self: Deserialize<'de>,
{
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
/// The parsing will fail if non-optional keys expected by the model are missing.
@ -36,10 +36,10 @@ where
/// let config_path = Path::new("path/to/config.json");
/// let config = Gpt2Config::from_file(config_path);
/// ```
fn from_file<P: AsRef<Path>>(path: P) -> T {
fn from_file<P: AsRef<Path>>(path: P) -> Self {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);
let config: T = serde_json::from_reader(br).expect("could not parse configuration");
let config: Self = serde_json::from_reader(br).expect("could not parse configuration");
config
}
}

54
src/common/embeddings.rs Normal file
View File

@ -0,0 +1,54 @@
use crate::RustBertError;
use tch::nn::Embedding;
use tch::{Device, Tensor};
pub fn process_ids_embeddings_pair(
input_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
embeddings_matrix: &Embedding,
) -> Result<(Option<Tensor>, Vec<i64>, Device), RustBertError> {
Ok(match (input_ids, input_embeddings) {
(Some(_), Some(_)) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
(Some(input_value), None) => (
Some(input_value.apply(embeddings_matrix)),
input_value.size(),
input_value.device(),
),
(None, Some(embeds)) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(None, size, embeds.device())
}
(None, None) => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
})
}
pub fn get_shape_and_device_from_ids_embeddings_pair(
input_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
) -> Result<(Vec<i64>, Device), RustBertError> {
Ok(match (input_ids, input_embeddings) {
(Some(_), Some(_)) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
(Some(input_value), None) => (input_value.size(), input_value.device()),
(None, Some(embeds)) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(size, embeds.device())
}
(None, None) => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
})
}

View File

@ -1,4 +1,4 @@
use rust_tokenizers::preprocessing::error::TokenizerError;
use rust_tokenizers::error::TokenizerError;
use tch::TchError;
use thiserror::Error;

42
src/common/kind.rs Normal file
View File

@ -0,0 +1,42 @@
use crate::RustBertError;
use tch::{Kind, Scalar};
pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MAX.into()),
Kind::Int8 => Scalar::int(i8::MAX.into()),
Kind::Int16 => Scalar::int(i16::MAX.into()),
Kind::Int => Scalar::int(i32::MAX.into()),
Kind::Int64 => Scalar::int(i64::MAX),
Kind::Half => Scalar::float(half::f16::INFINITY.into()),
Kind::Float => Scalar::float(f32::INFINITY.into()),
Kind::BFloat16 => Scalar::float(half::bf16::INFINITY.into()),
Kind::Double => Scalar::float(f64::INFINITY),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get positive infinity for {:?}",
kind
)))
}
})
}
pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MIN.into()),
Kind::Int8 => Scalar::int(i8::MIN.into()),
Kind::Int16 => Scalar::int(i16::MIN.into()),
Kind::Int => Scalar::int(i32::MIN.into()),
Kind::Int64 => Scalar::int(i64::MIN),
Kind::Half => Scalar::float(half::f16::NEG_INFINITY.into()),
Kind::Float => Scalar::float(f32::NEG_INFINITY.into()),
Kind::BFloat16 => Scalar::float(half::bf16::NEG_INFINITY.into()),
Kind::Double => Scalar::float(f64::NEG_INFINITY),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get negative infinity for {:?}",
kind
)))
}
})
}

View File

@ -1,8 +1,12 @@
pub(crate) mod activations;
pub mod config;
pub(crate) mod dropout;
pub(crate) mod embeddings;
pub mod error;
pub(crate) mod kind;
pub(crate) mod linear;
pub mod resources;
pub(crate) mod summary;
pub use activations::Activation;
pub use config::Config;

View File

@ -103,8 +103,8 @@ impl RemoteResource {
/// ```no_run
/// use rust_bert::resources::{RemoteResource, Resource};
/// let config_resource = Resource::Remote(RemoteResource::new(
/// "http://config_json_location",
/// "configs",
/// "http://config_json_location",
/// ));
/// ```
pub fn new(url: &str, cache_subdir: &str) -> RemoteResource {
@ -115,7 +115,7 @@ impl RemoteResource {
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rusbert/model_name. Note that this does not download the resource (only declares
/// ~/.cache/.rustbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments
@ -132,16 +132,13 @@ impl RemoteResource {
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// )));
/// ```
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let name = name_url_tuple.0.to_string();
let cache_subdir = name_url_tuple.0.to_string();
let url = name_url_tuple.1.to_string();
RemoteResource {
url,
cache_subdir: name,
}
RemoteResource { url, cache_subdir }
}
}
@ -191,7 +188,7 @@ fn _get_cache_directory() -> PathBuf {
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// "https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
/// )));
/// let local_path = model_resource.get_local_path();
/// ```

172
src/common/summary.rs Normal file
View File

@ -0,0 +1,172 @@
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::{Activation, TensorFunction};
use crate::common::dropout::Dropout;
use crate::xlnet::XLNetConfig;
use crate::RustBertError;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use tch::{nn, Tensor};
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
/// # Summary type for the model when used for summarization
pub enum SummaryType {
/// Hidden state stored in the last token
last,
/// Hidden state stored in the first token
first,
/// Mean of all token hidden states
mean,
/// Hidden state stored in the CLS token
cls_index,
}
pub struct SummaryConfig {
pub summary_type: Option<SummaryType>,
pub summary_use_proj: Option<bool>,
pub summary_activation: Option<Activation>,
pub summary_proj_to_labels: Option<bool>,
pub summary_first_dropout: Option<f64>,
pub summary_last_dropout: Option<f64>,
pub num_labels: Option<i64>,
pub hidden_size: i64,
}
impl From<&XLNetConfig> for SummaryConfig {
fn from(config: &XLNetConfig) -> Self {
let num_labels = config
.id2label
.as_ref()
.map(|id2label| id2label.len() as i64);
SummaryConfig {
summary_type: config.summary_type,
summary_use_proj: config.summary_use_proj,
summary_activation: config.summary_activation,
summary_proj_to_labels: config.summary_proj_to_labels,
summary_first_dropout: config.summary_first_dropout,
summary_last_dropout: config.summary_last_dropout,
num_labels,
hidden_size: config.d_model,
}
}
}
pub struct SequenceSummary {
summary: Option<nn::Linear>,
summary_type: SummaryType,
activation: Option<TensorFunction>,
first_dropout: Option<Dropout>,
last_dropout: Option<Dropout>,
}
impl SequenceSummary {
pub fn new<'p, P>(p: P, config: &SummaryConfig) -> Result<SequenceSummary, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let summary_type = config.summary_type.unwrap_or(SummaryType::last);
let summary = if let Some(summary_use_proj) = config.summary_use_proj {
let num_classes = match (config.summary_proj_to_labels, config.num_labels) {
(Some(summary_proj_to_labels), Some(num_labels))
if (num_labels > 0) & summary_proj_to_labels & summary_use_proj =>
{
num_labels
}
_ => config.hidden_size,
};
Some(nn::linear(
p / "summary",
config.hidden_size,
num_classes,
Default::default(),
))
} else {
None
};
let activation = if config.summary_activation.is_some() {
Some(config.summary_activation.as_ref().unwrap().get_function())
} else {
None
};
let first_dropout = match config.summary_first_dropout {
Some(dropout) if dropout > 0.0 => Some(Dropout::new(dropout)),
_ => None,
};
let last_dropout = match config.summary_last_dropout {
Some(dropout) if dropout > 0.0 => Some(Dropout::new(dropout)),
_ => None,
};
Ok(SequenceSummary {
summary,
summary_type,
activation,
first_dropout,
last_dropout,
})
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
cls_index: Option<&Tensor>,
train: bool,
) -> Tensor {
let mut output = match self.summary_type {
SummaryType::last => hidden_states.select(1, -1),
SummaryType::first => hidden_states.select(1, 0),
SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()),
SummaryType::cls_index => {
let cls_index = if let Some(cls_index_value) = cls_index {
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];
expand_dim.push(*hidden_states.size().last().unwrap());
cls_index_value
.unsqueeze(-1)
.unsqueeze(-1)
.expand(expand_dim.as_slice(), true)
} else {
let mut fill_value = hidden_states.size();
fill_value.reverse();
let fill_value = fill_value[2];
hidden_states.select(-2, 0).full_like(fill_value)
};
hidden_states.gather(-2, &cls_index, false).squeeze_dim(-2)
}
};
if let Some(first_dropout) = &self.first_dropout {
output = output.apply_t(first_dropout, train)
};
if let Some(summary) = &self.summary {
output = output.apply(summary)
};
if let Some(activation_fn) = &self.activation {
output = activation_fn.get_fn()(&output)
};
if let Some(last_dropout) = &self.last_dropout {
output = output.apply_t(last_dropout, train)
};
output
}
}

View File

@ -13,7 +13,6 @@
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert_model::DistilBertConfig;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
@ -40,11 +39,7 @@ impl MultiHeadSelfAttention {
let out_lin = nn::linear(p / "out_lin", config.dim, config.dim, Default::default());
let dropout = Dropout::new(config.attention_dropout);
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
MultiHeadSelfAttention {
n_heads: config.n_heads,
@ -73,7 +68,7 @@ impl MultiHeadSelfAttention {
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: &Option<Tensor>,
mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = query.size()[0];
@ -87,15 +82,17 @@ impl MultiHeadSelfAttention {
let scores = if let Some(mask) = mask {
let unmasked_scores = q.matmul(&k.transpose(2, 3));
let mask = mask
.le1(&(mask.zeros_like() + 0.1))
.le_tensor(&(mask.zeros_like() + 0.1))
.view((bs, 1i64, 1i64, k_length))
.expand_as(&unmasked_scores);
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
unmasked_scores.masked_fill(&mask, f64::NEG_INFINITY)
} else {
q.matmul(&k.transpose(2, 3))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let weights = scores
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = self
.flatten(weights.matmul(&v), bs, self.dim_per_head)
.apply(&self.out_lin);

View File

@ -13,6 +13,7 @@
extern crate tch;
use self::tch::{nn, Tensor};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::{DistilBertTransformerOutput, Transformer};
@ -30,70 +31,60 @@ pub struct DistilBertConfigResources;
pub struct DistilBertVocabResources;
impl DistilBertModelResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model",
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
"https://huggingface.co/distilbert-base-uncased/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
);
}
impl DistilBertConfigResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config",
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
"https://huggingface.co/distilbert-base-uncased/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
);
}
impl DistilBertVocabResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab",
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
"https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
);
}
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
/// # Activation function used in the feed-forward layer in the transformer blocks
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # DistilBERT model configuration
/// Defines the DistilBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct DistilBertConfig {
@ -121,7 +112,7 @@ pub struct DistilBertConfig {
pub vocab_size: i64,
}
impl Config<DistilBertConfig> for DistilBertConfig {}
impl Config for DistilBertConfig {}
/// # DistilBERT Base model
/// Base architecture for DistilBERT models. Task-specific models will be built from this common base model
@ -205,36 +196,18 @@ impl DistilBertModel {
///
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .forward_t(Some(&input_tensor), Some(&mask), None, false)
/// .unwrap()
/// });
/// ```
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertTransformerOutput, RustBertError> {
let input_embeddings = match input {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => input_value.apply_t(&self.embeddings, train),
},
None => match input_embeds {
Some(embeds) => embeds,
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let input_embeddings = self.embeddings.forward_t(input, input_embeds, train)?;
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
Ok(transformer_output)
}
@ -343,17 +316,17 @@ impl DistilBertModelClassifier {
///
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// None,
/// false).unwrap()
/// });
/// ```
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertSequenceClassificationOutput, RustBertError> {
let base_model_output =
@ -482,15 +455,15 @@ impl DistilBertModelMaskedLM {
///
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .forward_t(Some(&input_tensor), Some(&mask), None, false)
/// .unwrap()
/// });
/// ```
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertMaskedLMOutput, RustBertError> {
let base_model_output =
@ -601,15 +574,15 @@ impl DistilBertForQuestionAnswering {
///
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .forward_t(Some(&input_tensor), Some(&mask), None, false)
/// .unwrap()
/// });
/// ```
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertQuestionAnsweringOutput, RustBertError> {
let base_model_output =
@ -623,8 +596,8 @@ impl DistilBertForQuestionAnswering {
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
Ok(DistilBertQuestionAnsweringOutput {
start_logits,
@ -729,15 +702,15 @@ impl DistilBertForTokenClassification {
///
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .forward_t(Some(&input_tensor), Some(&mask), None, false)
/// .unwrap()
/// });
/// ```
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertTokenClassificationOutput, RustBertError> {
let base_model_output =

View File

@ -11,13 +11,22 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::distilbert::distilbert_model::DistilBertConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Device, Kind, Tensor};
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
fn create_sinusoidal_embeddings<'p, P>(
config: &DistilBertConfig,
p: P,
device: Device,
) -> nn::Embedding
where
P: Borrow<nn::Path<'p>>,
{
let mut sinusoidal_embedding: Vec<Tensor> =
Vec::with_capacity(config.max_position_embeddings as usize);
for pos in 0..config.max_position_embeddings {
@ -25,11 +34,11 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
for j in 0..config.dim {
if j % 2 == 0 {
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
(pos as f64 / 10000_f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
);
} else {
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
(pos as f64 / 10000_f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
);
}
}
@ -45,7 +54,7 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
..Default::default()
};
let mut embeddings = embedding(
&nn::VarStore::new(device).root(),
p.borrow(),
config.max_position_embeddings,
config.dim,
embedding_config,
@ -88,8 +97,7 @@ impl DistilBertEmbedding {
config.dim,
embedding_config,
),
true => create_sinusoidal_embeddings(&config, p.device()),
true => create_sinusoidal_embeddings(config, p / "position_embeddings", p.device()),
};
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
@ -113,20 +121,27 @@ impl DistilBertEmbedding {
pub fn _set_word_embeddings(&mut self, new_embeddings: nn::Embedding) {
self.word_embeddings = new_embeddings;
}
}
impl ModuleT for DistilBertEmbedding {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
let seq_length = (&input).size().last().unwrap().to_owned();
let position_ids = Tensor::arange(seq_length, (Kind::Int64, input.device()));
let position_ids = position_ids.unsqueeze(0).expand_as(input);
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (calc_input_embeddings, input_size, device) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let word_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let word_embed = input.apply(&self.word_embeddings);
let seq_length = input_size[1];
let position_ids = Tensor::arange(seq_length, (Kind::Int64, device));
let position_ids = position_ids
.unsqueeze(0)
.expand(input_size.as_slice(), true);
let position_embed = position_ids.apply(&self.position_embeddings);
let embeddings = word_embed + position_embed;
embeddings
let embeddings = word_embeds + position_embed;
Ok(embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train)
.apply_t(&self.dropout, train))
}
}

View File

@ -1,15 +1,14 @@
//! # DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (Sanh et al.)
//!
//! Implementation of the DistilBERT language model ([https://arxiv.org/abs/1910.01108](https://arxiv.org/abs/1910.01108) Sanh, Debut, Chaumond, Wolf, 2019).
//! The base model is implemented in the `distilbert::DistilBertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `distilbert::DistilBertForMaskedLM`
//! - Question answering: `distilbert::DistilBertForQuestionAnswering`
//! - Sequence classification: `distilbert::DistilBertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `distilbert::DistilBertForTokenClassification`
//! The base model is implemented in the `distilbert_model::DistilBertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `distilbert_model::DistilBertForMaskedLM`
//! - Question answering: `distilbert_model::DistilBertForQuestionAnswering`
//! - Sequence classification: `distilbert_model::DistilBertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `distilbert_model::DistilBertForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/distilbert_masked_lm.rs`, run with `cargo run --example distilbert_masked_lm`.
//! The example below illustrate a DistilBERT Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -20,7 +19,6 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::distilbert::{
@ -29,6 +27,7 @@
//! };
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
@ -60,7 +59,7 @@ mod embeddings;
mod transformer;
pub use distilbert_model::{
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertMaskedLMOutput, DistilBertModel,
DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertQuestionAnsweringOutput, DistilBertSequenceClassificationOutput,

View File

@ -10,10 +10,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::{_gelu, _relu};
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use crate::distilbert::attention::MultiHeadSelfAttention;
use crate::distilbert::distilbert_model::{Activation, DistilBertConfig};
use crate::distilbert::distilbert_model::DistilBertConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::nn::LayerNorm;
use tch::{nn, Tensor};
@ -22,7 +22,7 @@ pub struct FeedForwardNetwork {
lin1: nn::Linear,
lin2: nn::Linear,
dropout: Dropout,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
}
impl FeedForwardNetwork {
@ -44,10 +44,7 @@ impl FeedForwardNetwork {
Default::default(),
);
let dropout = Dropout::new(config.dropout);
let activation = Box::new(match &config.activation {
Activation::gelu => _gelu,
Activation::relu => _relu,
});
let activation = config.activation.get_function();
FeedForwardNetwork {
lin1,
lin2,
@ -57,7 +54,7 @@ impl FeedForwardNetwork {
}
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
(self.activation)(&input.apply(&self.lin1))
(self.activation.get_fn())(&input.apply(&self.lin1))
.apply(&self.lin2)
.apply_t(&self.dropout, train)
}
@ -77,14 +74,14 @@ impl TransformerBlock {
{
let p = p.borrow();
let attention = MultiHeadSelfAttention::new(p / "attention", &config);
let attention = MultiHeadSelfAttention::new(p / "attention", config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let sa_layer_norm =
nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
let ffn = FeedForwardNetwork::new(p / "ffn", &config);
let ffn = FeedForwardNetwork::new(p / "ffn", config);
let output_layer_norm =
nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
@ -99,12 +96,10 @@ impl TransformerBlock {
pub fn forward_t(
&self,
input: &Tensor,
mask: &Option<Tensor>,
mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, sa_weights) = self
.attention
.forward_t(&input, &input, &input, mask, train);
let (output, sa_weights) = self.attention.forward_t(input, input, input, mask, train);
let output = (input + &output).apply(&self.sa_layer_norm);
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
(output, sa_weights)
@ -123,14 +118,8 @@ impl Transformer {
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "layer";
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let mut layers: Vec<TransformerBlock> = vec![];
for layer_index in 0..config.n_layers {
@ -147,7 +136,7 @@ impl Transformer {
pub fn forward_t(
&self,
input: &Tensor,
mask: Option<Tensor>,
mask: Option<&Tensor>,
train: bool,
) -> DistilBertTransformerOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
@ -161,24 +150,29 @@ impl Transformer {
None
};
let mut hidden_state = input.copy();
// let mut hidden_state = input.copy();
let mut hidden_state: Option<Tensor> = None;
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
let temp = if let Some(hidden_state) = &hidden_state {
layer.forward_t(hidden_state, mask, train)
} else {
layer.forward_t(input, mask, train)
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
hidden_state = Some(temp.0);
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().unwrap().copy());
};
}
DistilBertTransformerOutput {
hidden_state,
hidden_state: hidden_state.unwrap(),
all_hidden_states,
all_attentions,
}

View File

@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::encoder::BertEncoder;
use crate::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::bert::BertConfig;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
@ -32,45 +33,45 @@ pub struct ElectraConfigResources;
pub struct ElectraVocabResources;
impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model",
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
"https://huggingface.co/google/electra-base-generator/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model",
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/rust_model.ot",
);
}
impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config",
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
"https://huggingface.co/google/electra-base-generator/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config",
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json",
);
}
impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab",
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
"https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the Google team at <https://github.com/google-research/electra>. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab",
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # Electra model configuration
/// Defines the Electra model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct ElectraConfig {
@ -95,7 +96,7 @@ pub struct ElectraConfig {
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<ElectraConfig> for ElectraConfig {}
impl Config for ElectraConfig {}
/// # Electra Base model
/// Base architecture for Electra models.
@ -150,7 +151,7 @@ impl ElectraModel {
None
};
let bert_config = BertConfig {
hidden_act: config.hidden_act.clone(),
hidden_act: config.hidden_act,
attention_probs_dropout_prob: config.attention_probs_dropout_prob,
hidden_dropout_prob: config.hidden_dropout_prob,
hidden_size: config.hidden_size,
@ -181,7 +182,7 @@ impl ElectraModel {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -216,10 +217,10 @@ impl ElectraModel {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -228,36 +229,22 @@ impl ElectraModel {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<ElectraModelOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
} else {
None
};
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
@ -269,36 +256,31 @@ impl ElectraModel {
}
};
let hidden_states = match self.embeddings.forward_t(
let hidden_states = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => {
return Err(e);
}
};
)?;
let hidden_states = match &self.embeddings_project {
Some(layer) => hidden_states.apply(layer),
None => hidden_states,
};
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
let encoder_output = self.encoder.forward_t(
&hidden_states,
&Some(extended_attention_mask),
&None,
&None,
Some(&extended_attention_mask),
None,
None,
train,
);
Ok(ElectraModelOutput {
hidden_state,
all_hidden_states,
all_attentions,
hidden_state: encoder_output.hidden_state,
all_hidden_states: encoder_output.all_hidden_states,
all_attentions: encoder_output.all_attentions,
})
}
}
@ -312,7 +294,7 @@ impl ElectraModel {
pub struct ElectraDiscriminatorHead {
dense: nn::Linear,
dense_prediction: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
}
/// Defines the implementation of the ElectraDiscriminatorHead.
@ -356,11 +338,7 @@ impl ElectraDiscriminatorHead {
1,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish,
});
let activation = config.hidden_act.get_function();
ElectraDiscriminatorHead {
dense,
dense_prediction,
@ -401,7 +379,7 @@ impl ElectraDiscriminatorHead {
/// ```
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
let output = (self.activation.get_fn())(&output);
output.apply(&self.dense_prediction).squeeze()
}
}
@ -415,7 +393,7 @@ impl ElectraDiscriminatorHead {
pub struct ElectraGeneratorHead {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
}
/// Defines the implementation of the ElectraGeneratorHead.
@ -458,11 +436,11 @@ impl ElectraGeneratorHead {
config.embedding_size,
Default::default(),
);
let activation = Box::new(_gelu);
let activation = Activation::gelu.get_function();
ElectraGeneratorHead {
layer_norm,
dense,
layer_norm,
activation,
}
}
@ -500,7 +478,7 @@ impl ElectraGeneratorHead {
/// ```
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
let output = (self.activation.get_fn())(&output);
output.apply(&self.layer_norm)
}
}
@ -568,7 +546,7 @@ impl ElectraForMaskedLM {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -602,10 +580,10 @@ impl ElectraForMaskedLM {
///
/// let model_output = no_grad(|| {
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -613,11 +591,11 @@ impl ElectraForMaskedLM {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraMaskedLMOutput {
let base_model_output = self
@ -696,7 +674,7 @@ impl ElectraDiscriminator {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -729,21 +707,21 @@ impl ElectraDiscriminator {
///
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraDiscriminatorOutput {
let base_model_output = self
@ -837,7 +815,7 @@ impl ElectraForTokenClassification {
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
@ -870,21 +848,21 @@ impl ElectraForTokenClassification {
///
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraTokenClassificationOutput {
let base_model_output = self

View File

@ -13,6 +13,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::electra::electra_model::ElectraConfig;
use crate::RustBertError;
use std::borrow::Borrow;
@ -62,10 +63,7 @@ impl ElectraEmbeddings {
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12,
};
let layer_norm_eps = config.layer_norm_eps.unwrap_or(1e-12);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
@ -87,50 +85,40 @@ impl ElectraEmbeddings {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_embeddings.size()[1].to_owned();
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
let calc_position_ids = if position_ids.is_none() {
Some(
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
)
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
&input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);

View File

@ -1,20 +1,19 @@
//! # Electra: Pre-training Text Encoders as Discriminators Rather Than Generators (Clark et al.)
//!
//! Implementation of the Electra language model ([https://openreview.net/pdf?id=r1xMH1BtvB](https://openreview.net/pdf?id=r1xMH1BtvB) Clark, Luong, Le, Manning, 2020).
//! The base model is implemented in the `electra::ElectraModel` struct. Both generator and discriminator are available via specialized heads:
//! - Generator head: `electra::ElectraGeneratorHead`
//! - Discriminator head: `electra::ElectraDiscriminatorHead`
//! The base model is implemented in the `electra_model::ElectraModel` struct. Both generator and discriminator are available via specialized heads:
//! - Generator head: `electra_model::ElectraGeneratorHead`
//! - Discriminator head: `electra_model::ElectraDiscriminatorHead`
//!
//! The generator and discriminator models are built from these:
//! - Generator (masked language model): `electra::ElectraForMaskedLM`
//! - Discriminator: `electra::ElectraDiscriminator`
//! - Generator (masked language model): `electra_model::ElectraForMaskedLM`
//! - Discriminator: `electra_model::ElectraDiscriminator`
//!
//! An additional sequence token classification model is available for reference
//! - Token classification (e.g. NER, POS tagging): `electra::ElectraForTokenClassification`
//! - Token classification (e.g. NER, POS tagging): `electra_model::ElectraForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/electra_masked_lm.rs`, run with `cargo run --example electra_masked_lm`.
//! The example below illustrate a Masked language model example, the structure is similar for other models (e.g. discriminator).
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -25,12 +24,12 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::BertTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),

156
src/fnet/attention.rs Normal file
View File

@ -0,0 +1,156 @@
// Copyright 2021 Google Research
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use crate::fnet::FNetConfig;
use std::borrow::Borrow;
use tch::nn::LayerNormConfig;
use tch::{nn, Tensor};
pub struct FNetFourierTransform {
layer_norm: nn::LayerNorm,
}
impl FNetFourierTransform {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetFourierTransform
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm = nn::layer_norm(
p.sub("output").sub("LayerNorm"),
vec![config.hidden_size],
layer_norm_config,
);
FNetFourierTransform { layer_norm }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let self_outputs = hidden_states.fft_fft2(None, &[1, 2], "backward").real();
(self_outputs + hidden_states).apply(&self.layer_norm)
}
}
pub struct FNetIntermediate {
dense: nn::Linear,
intermediate_activation_function: TensorFunction,
}
impl FNetIntermediate {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetIntermediate
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let intermediate_activation_function = config.hidden_act.get_function();
FNetIntermediate {
dense,
intermediate_activation_function,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.intermediate_activation_function.get_fn()(&hidden_states.apply(&self.dense))
}
}
pub struct FNetOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl FNetOutput {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetOutput
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
FNetOutput {
dense,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states = hidden_states
.apply(&self.dense)
.apply_t(&self.dropout, train);
(input_tensor + hidden_states).apply(&self.layer_norm)
}
}
pub struct FNetLayer {
fourier: FNetFourierTransform,
intermediate: FNetIntermediate,
output: FNetOutput,
}
impl FNetLayer {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fourier = FNetFourierTransform::new(p / "fourier", config);
let intermediate = FNetIntermediate::new(p / "intermediate", config);
let output = FNetOutput::new(p / "output", config);
FNetLayer {
fourier,
intermediate,
output,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
let fourier_outputs = self.fourier.forward(hidden_states);
let intermediate_output = self.intermediate.forward(&fourier_outputs);
self.output
.forward_t(&intermediate_output, &fourier_outputs, train)
}
}

131
src/fnet/embeddings.rs Normal file
View File

@ -0,0 +1,131 @@
// Copyright 2021 Google Research
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::fnet::FNetConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{EmbeddingConfig, LayerNormConfig};
use tch::{nn, Kind, Tensor};
pub struct FNetEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
projection: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl FNetEmbeddings {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetEmbeddings
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let word_embeddings_config = EmbeddingConfig {
padding_idx: config.pad_token_id.unwrap_or(3),
..Default::default()
};
let word_embeddings = nn::embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
word_embeddings_config,
);
let position_embeddings = nn::embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let token_type_embeddings = nn::embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let projection = nn::linear(
p / "projection",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.hidden_dropout_prob);
FNetEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
projection,
layer_norm,
dropout,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeddings, &self.word_embeddings)?;
let input_embeddings =
input_embeddings.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
input_shape.as_slice(),
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let token_type_embeddings = token_type_ids
.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap())
.apply(&self.token_type_embeddings);
let calc_position_ids = if position_ids.is_none() {
Some(Tensor::arange(
input_shape[1],
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let position_embeddings = position_ids
.unwrap_or_else(|| calc_position_ids.as_ref().unwrap())
.apply(&self.position_embeddings);
let embeddings = input_embeddings + token_type_embeddings + position_embeddings;
Ok(embeddings
.apply(&self.layer_norm)
.apply(&self.projection)
.apply_t(&self.dropout, train))
}
}

81
src/fnet/encoder.rs Normal file
View File

@ -0,0 +1,81 @@
// Copyright 2021 Google Research
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::fnet::attention::FNetLayer;
use crate::fnet::FNetConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub struct FNetEncoder {
layers: Vec<FNetLayer>,
output_hidden_states: bool,
}
impl FNetEncoder {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let p_layers = p / "layer";
let mut layers: Vec<FNetLayer> = Vec::with_capacity(config.num_hidden_layers as usize);
for layer_index in 0..config.num_hidden_layers {
layers.push(FNetLayer::new(&p_layers / layer_index, config));
}
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
FNetEncoder {
layers,
output_hidden_states,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> FNetEncoderOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut x: Option<Tensor> = None;
for layer in &self.layers {
let temp = if let Some(x_value) = &x {
layer.forward_t(x_value, train)
} else {
layer.forward_t(hidden_states, train)
};
x = Some(temp);
if let Some(all_hidden_states) = all_hidden_states.borrow_mut() {
all_hidden_states.push(x.as_ref().unwrap().copy());
};
}
FNetEncoderOutput {
hidden_states: x.unwrap(),
all_hidden_states,
}
}
}
/// Container for the FNet encoder output.
pub struct FNetEncoderOutput {
/// Last hidden states from the model
pub hidden_states: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
}

1052
src/fnet/fnet_model.rs Normal file

File diff suppressed because it is too large Load Diff

64
src/fnet/mod.rs Normal file
View File

@ -0,0 +1,64 @@
//! # FNet, Mixing Tokens with Fourier Transforms (Lee-Thorp et al.)
//!
//! Implementation of the FNet language model ([https://arxiv.org/abs/2105.03824](https://arxiv.org/abs/2105.03824) Lee-Thorp, Ainslie, Eckstein, Ontanon, 2021).
//! The base model is implemented in the `fnet_model::FNetModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `fnet_model::FNetForMaskedLM`
//! - Question answering: `fnet_model::FNetForQuestionAnswering`
//! - Sequence classification: `fnet_model::FNetForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `fnet_model::FNetForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! The example below illustrate a FNet Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `FNetTokenizer` using a `spiece.model` SentencePiece (BPE) model file
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: FNetTokenizer =
//! FNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
//! let config = FNetConfig::from_file(config_path);
//! let bert_model = FNetForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
//! # }
//! ```
mod attention;
mod embeddings;
mod encoder;
mod fnet_model;
pub use fnet_model::{
FNetConfig, FNetConfigResources, FNetForMaskedLM, FNetForMultipleChoice,
FNetForQuestionAnswering, FNetForSequenceClassification, FNetForTokenClassification,
FNetMaskedLMOutput, FNetModel, FNetModelOutput, FNetModelResources,
FNetQuestionAnsweringOutput, FNetSequenceClassificationOutput, FNetTokenClassificationOutput,
FNetVocabResources,
};

View File

@ -74,23 +74,15 @@ impl Attention {
let bias = Tensor::ones(&[config.n_ctx, config.n_ctx], (Float, p.device()))
.tril(0)
.view((1, 1, config.n_ctx, config.n_ctx));
let bias = p.var_copy("bias", &bias);
let c_attn = GPTConv1D::new(p / "c_attn", config.n_embd * 3, config.n_embd);
let c_proj = GPTConv1D::new(p / "c_proj", config.n_embd, config.n_embd);
let attn_pdrop = match config.attn_pdrop {
Some(value) => value,
None => 0.1,
};
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1,
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
let output_attentions = config.output_attentions.unwrap_or(false);
let attn_dropout = Dropout::new(attn_pdrop);
let resid_dropout = Dropout::new(resid_pdrop);
@ -136,23 +128,23 @@ impl Attention {
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: &Option<Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let mut w = query.matmul(&key);
let mut w = query.matmul(key);
if self.scale {
w = w / (*value.size().last().unwrap() as f64).sqrt();
}
let (nd, ns) = (w.size()[2], w.size()[3]);
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask {
w = w + mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&value);
w = w.softmax(-1, w.kind()).apply_t(&self.attn_dropout, train);
let output = w.matmul(value);
if self.output_attentions {
(output, Some(w))
@ -164,8 +156,8 @@ impl Attention {
pub fn forward_t(
&self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
layer_past: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let x = x.apply(&self.c_attn).split(self.n_state, 2);
@ -184,7 +176,7 @@ impl Attention {
None => (key, value),
};
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
let (a, attentions) = self.attention(&query, &key, &value, attention_mask, train);
let a = self
.flatten(a)

View File

@ -12,16 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::gpt2::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::Gpt2Tokenizer;
use rust_tokenizers::vocab::Gpt2Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
use tch::{nn, Kind, Tensor};
/// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources;
@ -36,144 +45,138 @@ pub struct Gpt2VocabResources;
pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model",
"https://cdn.huggingface.co/gpt2-rust_model.ot",
"https://huggingface.co/gpt2/resolve/main/rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model",
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
"https://huggingface.co/gpt2-medium/resolve/main/rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model",
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
"https://huggingface.co/gpt2-large/resolve/main/rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model",
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
"https://huggingface.co/gpt2-xl/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
"https://huggingface.co/distilgpt2/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/rust_model.ot",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot",
);
}
impl Gpt2ConfigResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) =
("gpt2/config", "https://cdn.huggingface.co/gpt2-config.json");
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config",
"https://huggingface.co/gpt2/resolve/main/config.json",
);
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config",
"https://cdn.huggingface.co/gpt2-medium-config.json",
"https://huggingface.co/gpt2-medium/resolve/main/config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config",
"https://cdn.huggingface.co/gpt2-large-config.json",
"https://huggingface.co/gpt2-large/resolve/main/config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config",
"https://cdn.huggingface.co/gpt2-xl-config.json",
"https://huggingface.co/gpt2-xl/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config",
"https://cdn.huggingface.co/distilgpt2-config.json",
"https://huggingface.co/distilgpt2/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/config.json",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json",
);
}
impl Gpt2VocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) =
("gpt2/vocab", "https://cdn.huggingface.co/gpt2-vocab.json");
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab",
"https://huggingface.co/gpt2/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab",
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
"https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab",
"https://cdn.huggingface.co/gpt2-large-vocab.json",
"https://huggingface.co/gpt2-large/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab",
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
"https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
"https://huggingface.co/distilgpt2/resolve/main/vocab.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/vocab.json",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json",
);
}
impl Gpt2MergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) =
("gpt2/merges", "https://cdn.huggingface.co/gpt2-merges.txt");
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges",
"https://huggingface.co/gpt2/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges",
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
"https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges",
"https://cdn.huggingface.co/gpt2-large-merges.txt",
"https://huggingface.co/gpt2-large/resolve/main/merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges",
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
"https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/models>. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
"https://huggingface.co/distilgpt2/resolve/main/merges.txt",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
/// Shared under MIT license by the Microsoft team at <https://huggingface.co/microsoft/DialoGPT-medium>. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/merges.txt",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt",
);
}
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
/// # Activation function used in the fully connected layers of the transformer block
pub enum GptActivation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
/// Swish: a Self-Gated Activation Function ([Ramachandran et al., 2017](https://arxiv.org/pdf/1710.05941v1.pdf))
swish,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # GPT2 model configuration
/// Defines the GPT2 model architecture (e.g. number of layers, hidden layer size, vocab size...).
/// Shared between GPT and GPT2 models
@ -181,7 +184,7 @@ pub struct Gpt2Config {
pub attn_pdrop: Option<f64>,
pub embd_pdrop: Option<f64>,
pub hidden_dropout_prob: Option<f64>,
pub afn: Option<GptActivation>,
pub afn: Option<Activation>,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub n_ctx: i64,
@ -197,7 +200,7 @@ pub struct Gpt2Config {
pub vocab_size: i64,
}
impl Config<Gpt2Config> for Gpt2Config {}
impl Config for Gpt2Config {}
/// # GPT2 Base model
/// Base architecture for GPT2 model. Usually complemented with a task-specific head, such as a language model head.
@ -260,10 +263,7 @@ impl Gpt2Model {
Default::default(),
);
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1,
};
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
@ -275,18 +275,10 @@ impl Gpt2Model {
for layer_index in 0..config.n_layer {
h.push(Block::new(&h_path / layer_index, config, true));
}
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_past = match config.output_past {
Some(value) => value,
None => true,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
let output_attentions = config.output_attentions.unwrap_or(false);
let output_past = config.output_past.unwrap_or(true);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
Gpt2Model {
wte,
wpe,
@ -356,12 +348,12 @@ impl Gpt2Model {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// Some(&input_tensor),
/// Some(&past),
/// Some(&attention_mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
/// .unwrap()
@ -369,35 +361,20 @@ impl Gpt2Model {
/// ```
pub fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
input_ids: Option<&Tensor>,
layer_past: Option<&Vec<Tensor>>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Gpt2ModelOutput, RustBertError> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply(&self.wte),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (calc_input_embeddings, input_size, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_size[1];
let (layer_past, layer_past_length) = match layer_past {
Some(value) => {
@ -423,7 +400,7 @@ impl Gpt2Model {
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange1(
None => Tensor::arange_start(
layer_past_length,
seq_length + layer_past_length,
(Int64, input_embeddings.device()),
@ -431,17 +408,16 @@ impl Gpt2Model {
.unsqueeze(0),
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0)
* 10000.0,
),
None => None,
};
let attention_mask: Option<Tensor> = attention_mask.map(|value| {
let attention_mask = value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
.to_kind(input_embeddings.kind());
let attention_mask: Tensor = (1.0 - attention_mask) * (-10000.0);
attention_mask.to_kind(input_embeddings.kind())
});
let position_embeds = position_ids.apply(&self.wpe);
let token_type_embeds = match token_type_ids {
@ -466,11 +442,8 @@ impl Gpt2Model {
let layer_iter = self.h.iter().zip(layer_past);
for layer_values in layer_iter {
let (layer, past) = layer_values;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
let temp =
layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1.as_ref().copy());
@ -478,6 +451,9 @@ impl Gpt2Model {
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
}
Ok(Gpt2ModelOutput {
@ -493,10 +469,8 @@ impl Gpt2Model {
/// GPT2 model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings
/// It is made of the following blocks:
/// - `transformer`: Base Gpt2Model
/// - `lm_head`: Linear layer without bias tied to the weights of the token id embeddings
pub struct GPT2LMHeadModel {
transformer: Gpt2Model,
lm_head: LinearNoBias,
}
impl GPT2LMHeadModel {
@ -528,16 +502,8 @@ impl GPT2LMHeadModel {
let p = p.borrow();
let transformer = Gpt2Model::new(p, config);
let lm_head = linear_no_bias(
p / "lm_head",
config.n_embd,
config.vocab_size,
Default::default(),
);
GPT2LMHeadModel {
transformer,
lm_head,
}
GPT2LMHeadModel { transformer }
}
}
@ -562,9 +528,6 @@ impl LMHeadModel for GPT2LMHeadModel {
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `Gpt2Cache` made of `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
/// - `encoder_hidden_states` - None
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -574,7 +537,7 @@ impl LMHeadModel for GPT2LMHeadModel {
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::pipelines::generation::{Cache, LMHeadModel};
/// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
@ -604,14 +567,14 @@ impl LMHeadModel for GPT2LMHeadModel {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
/// Some(&input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// Some(&attention_mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// None,
/// None,
/// &None,
/// false,
/// )
/// .unwrap()
@ -619,20 +582,20 @@ impl LMHeadModel for GPT2LMHeadModel {
/// ```
fn forward_t(
&self,
input_ids: &Option<Tensor>,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPT2Cache(layer_past) => self.transformer.forward_t(
input_ids,
&layer_past,
layer_past.as_ref(),
attention_mask,
token_type_ids,
position_ids,
@ -641,7 +604,7 @@ impl LMHeadModel for GPT2LMHeadModel {
),
Cache::None => self.transformer.forward_t(
input_ids,
&None,
None,
attention_mask,
token_type_ids,
position_ids,
@ -655,13 +618,12 @@ impl LMHeadModel for GPT2LMHeadModel {
}
}?;
let lm_logits = base_model_output.output.apply(&self.lm_head);
let lm_logits = base_model_output
.output
.linear::<Tensor>(&self.transformer.wte.ws, None);
Ok(LMModelOutput {
lm_logits,
encoder_hidden_state: None,
cache: Cache::GPT2Cache(base_model_output.cache),
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
@ -678,3 +640,197 @@ pub struct Gpt2ModelOutput {
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// # Language generation model based on the GPT2 architecture
pub struct GPT2Generator {
model: GPT2LMHeadModel,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl GPT2Generator {
/// Build a new `GPT2Generator`
///
/// # Arguments
///
/// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::gpt2::GPT2Generator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
/// # Ok(())
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::GPT2,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
None,
None,
)?;
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(tokenizer.convert_tokens_to_ids(&[Gpt2Vocab::bos_value()])[0]);
let eos_token_ids = Some(tokenizer.convert_tokens_to_ids(&[Gpt2Vocab::eos_value()]));
let pad_token_id = Some(tokenizer.convert_tokens_to_ids(&[Gpt2Vocab::eos_value()])[0]);
let max_position_embeddings = config.n_positions;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
Ok(GPT2Generator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {
fn get_model(&self) -> &GPT2LMHeadModel {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
.masked_fill(&attention_mask.eq(0), 1);
match past {
Cache::GPT2Cache(past) => {
if past.is_some() {
PreparedInput {
prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
prepared_past: Cache::GPT2Cache(past),
}
} else {
PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPT2Cache(None),
}
}
}
Cache::None => PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPT2Cache(None),
},
_ => panic!("Cache type incompatible with GPT2"),
}
}
fn reorder_cache(
&self,
past: &mut Cache,
_encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
match past {
Cache::GPT2Cache(cached_decoder_state) => match cached_decoder_state {
Some(value) => {
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None
}
None => None,
},
Cache::None => None,
_ => {
panic!("Invalid cache for GPT2 model");
}
}
}
}
impl LanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {}

View File

@ -1,12 +1,12 @@
//! # GPT2 (Radford et al.)
//!
//! Implementation of the GPT2 language model ([Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) Radford, Wu, Child, Luan, Amodei, Sutskever 2019).
//! The base model is implemented in the `gpt2::Gpt2Model` struct. The model also includes a language model head: `gpt2::GPT2LMHeadModel`
//! implementing the common `generation::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//! The base model is implemented in the `gpt2_model::Gpt2Model` struct. The model also includes a language model head: `gpt2_model::GPT2LMHeadModel`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/summarization.rs`, run with `cargo run --example gpt2`.
//! A full working example is provided in `examples/generation_gpt2`, run with `cargo run --example generation_gpt2`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
@ -16,12 +16,12 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use rust_tokenizers::Gpt2Tokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::Gpt2Tokenizer;
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
@ -60,6 +60,6 @@ mod gpt2_model;
pub(crate) mod transformer;
pub use gpt2_model::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
Gpt2ModelOutput, Gpt2ModelResources, Gpt2VocabResources, GptActivation,
GPT2Generator, GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources,
Gpt2Model, Gpt2ModelOutput, Gpt2ModelResources, Gpt2VocabResources,
};

View File

@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::{_gelu_new, _relu, _swish};
use crate::common::activations::{Activation, TensorFunction};
use crate::common::dropout::Dropout;
use crate::gpt2::attention::{Attention, GPTConv1D};
use crate::gpt2::gpt2_model::{Gpt2Config, GptActivation};
use crate::gpt2::gpt2_model::Gpt2Config;
use std::borrow::Borrow;
use tch::{nn, Tensor};
pub struct MLP {
c_fc: GPTConv1D,
c_proj: GPTConv1D,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
activation: TensorFunction,
dropout: Dropout,
}
@ -35,18 +35,15 @@ impl MLP {
let c_fc = GPTConv1D::new(p / "c_fc", config.n_embd * 4, config.n_embd);
let c_proj = GPTConv1D::new(p / "c_proj", config.n_embd, config.n_embd * 4);
let activation = Box::new(match &config.afn {
let activation = match &config.afn {
Some(activation_enum) => match activation_enum {
GptActivation::gelu => _gelu_new,
GptActivation::relu => _relu,
GptActivation::swish => _swish,
Activation::gelu => &Activation::gelu_new,
default => default,
},
None => _gelu_new,
});
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1,
};
None => &Activation::gelu_new,
}
.get_function();
let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
let dropout = Dropout::new(resid_pdrop);
MLP {
c_fc,
@ -57,7 +54,7 @@ impl MLP {
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
let h = (self.activation)(&x.apply(&self.c_fc));
let h = (self.activation.get_fn())(&x.apply(&self.c_fc));
h.apply(&self.c_proj).apply_t(&self.dropout, train)
}
}
@ -96,8 +93,8 @@ impl Block {
pub fn forward_t(
&self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
layer_past: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) =

603
src/gpt_neo/attention.rs Normal file
View File

@ -0,0 +1,603 @@
// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::gpt_neo::gpt_neo_model::AttentionLayerType;
use crate::gpt_neo::GptNeoConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::{nn, Device, Kind, Tensor};
#[derive(Debug)]
/// # Cache for GPT-Neo attention layers
/// Stores the cached value of key and value
pub struct LayerState {
/// Cached keys
pub prev_key: Tensor,
/// Cached values
pub prev_value: Option<Tensor>,
}
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.as_ref().map(|value| value.copy()),
}
}
}
impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self
.prev_value
.as_ref()
.map(|value| value.index_select(0, new_indices));
}
}
pub(crate) trait GptNeoAttentionUtils {
fn get_block_length_and_num_blocks(sequence_length: i64, window_size: i64) -> (i64, i64) {
let mut block_length = window_size;
while sequence_length % block_length != 0 {
block_length -= 1;
}
let num_blocks = sequence_length / block_length;
(block_length, num_blocks)
}
fn look_back(
input_tensor: &Tensor,
block_length: i64,
window_size: i64,
pad_value: Option<i64>,
is_key_value: bool,
) -> Result<Tensor, RustBertError> {
let padding_size = match input_tensor.size().len() {
3 => Vec::from([0, 0, window_size, 0]),
2 => Vec::from([window_size, 0]),
_ => {
return Err(RustBertError::ValueError(format!(
"Invalid tensor rank, expected 2 or 3, got {}",
input_tensor.size().len()
)));
}
};
let mut padded_tensor = match pad_value {
None => input_tensor.constant_pad_nd(padding_size.as_slice()),
Some(value) => {
if value == 0 {
input_tensor.constant_pad_nd(padding_size.as_slice())
} else {
(input_tensor - value).constant_pad_nd(padding_size.as_slice()) + value
}
}
};
padded_tensor = padded_tensor.unfold(1, window_size + block_length, block_length);
if is_key_value {
padded_tensor = padded_tensor.transpose(-2, -1);
}
Ok(padded_tensor)
}
fn split_sequence_length_dim_to(
input_tensor: &Tensor,
dim_factor_1: i64,
dim_factor_2: i64,
) -> Result<Tensor, RustBertError> {
let batch_size = input_tensor.size()[0];
let mut split_dim_shape = Vec::from([batch_size, dim_factor_1, dim_factor_2]);
Ok(match input_tensor.size().len() {
3 => {
split_dim_shape.push(-1);
input_tensor.reshape(split_dim_shape.as_slice())
}
2 => input_tensor.reshape(split_dim_shape.as_slice()),
_ => {
return Err(RustBertError::ValueError(format!(
"Invalid tensor rank, expected 2 or 3, got {}",
input_tensor.size().len()
)));
}
})
}
fn create_local_attention_mask(
batch_size: i64,
sequence_length: i64,
window_size: i64,
device: Device,
attention_mask: Option<&Tensor>,
) -> Result<Tensor, RustBertError> {
let (block_length, num_blocks) =
Self::get_block_length_and_num_blocks(sequence_length, window_size);
let indices =
Tensor::arange(sequence_length, (Kind::Int64, device)).repeat(&[batch_size, 1]);
let query_indices = Self::split_sequence_length_dim_to(&indices, num_blocks, block_length)?;
let key_indices = Self::look_back(&indices, block_length, window_size, None, false)?;
let causal_mask = query_indices
.unsqueeze(-1)
.ge_tensor(&key_indices.unsqueeze(-2));
let calc_attention_mask = if attention_mask.is_none() {
Some(Tensor::ones(
&[batch_size, sequence_length],
(Kind::Int64, device),
))
} else {
None
};
let attention_mask =
attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
let attention_mask =
Self::look_back(attention_mask, block_length, window_size, None, false)?.unsqueeze(-2);
let causal_mask = causal_mask * attention_mask;
let relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1);
let visible = relative_position.gt(-window_size);
let causal_mask = causal_mask * visible;
Ok(causal_mask.unsqueeze(-3).to_kind(Kind::Bool))
}
fn split_heads(
input_tensor: &Tensor,
num_heads: i64,
attention_head_size: i64,
) -> Result<Tensor, RustBertError> {
let mut new_shape = input_tensor.size();
let _ = new_shape.pop();
new_shape.extend_from_slice(&[num_heads, attention_head_size]);
let reshaped_tensor = input_tensor.view(new_shape.as_slice());
Ok(match reshaped_tensor.size().len() {
5 => reshaped_tensor.permute(&[0, 1, 3, 2, 4]),
4 => reshaped_tensor.permute(&[0, 2, 1, 3]),
_ => {
return Err(RustBertError::ValueError(format!(
"Invalid tensor rank, expected 4 or 5, got {}",
input_tensor.size().len()
)));
}
})
}
fn merge_heads(
input_tensor: &Tensor,
num_heads: i64,
attention_head_size: i64,
) -> Result<Tensor, RustBertError> {
let output_tensor = match input_tensor.size().len() {
5 => input_tensor.permute(&[0, 1, 3, 2, 4]).contiguous(),
4 => input_tensor.permute(&[0, 2, 1, 3]).contiguous(),
_ => {
return Err(RustBertError::ValueError(format!(
"Invalid tensor rank, expected 4 or 5, got {}",
input_tensor.size().len()
)));
}
};
let mut new_shape = output_tensor.size();
new_shape.truncate(new_shape.len() - 2);
new_shape.push(num_heads * attention_head_size);
Ok(output_tensor.view(new_shape.as_slice()))
}
fn attend(
query: &Tensor,
key: &Tensor,
value: &Tensor,
causal_mask: &Tensor,
attention_dropout: &Dropout,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor) {
let query = query.to_kind(Kind::Float);
let key = key.to_kind(Kind::Float);
let attention_weights = query.matmul(&key.transpose(-1, -2));
let mut attention_weights = attention_weights.where_self(
causal_mask,
&Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()),
);
if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value;
};
let attention_weights = attention_weights.softmax(-1, attention_weights.kind());
let attention_weights = attention_weights
.to_kind(value.kind())
.apply_t(attention_dropout, train);
let attention_output = attention_weights.matmul(value);
(attention_output, attention_weights)
}
}
pub struct GptNeoSelfAttention {
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
attention_dropout: Dropout,
resid_dropout: Dropout,
bias: Tensor,
num_heads: i64,
head_dim: i64,
output_attentions: bool,
}
impl GptNeoAttentionUtils for GptNeoSelfAttention {}
impl GptNeoSelfAttention {
pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> GptNeoSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let max_positions = config.max_position_embeddings;
let bias_value = Tensor::ones(&[max_positions, max_positions], (Kind::Int8, p.device()))
.tril(0)
.view([1, 1, max_positions, max_positions])
.requires_grad_(false);
let bias = p.var_copy("bias", &bias_value);
let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);
let num_heads = config.num_heads;
let head_dim = config.hidden_size / config.num_heads;
let linear_config = nn::LinearConfig {
bias: false,
..Default::default()
};
let k_proj = nn::linear(
p / "k_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let v_proj = nn::linear(
p / "v_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let q_proj = nn::linear(
p / "q_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let output_attentions = config.output_attentions.unwrap_or(false);
GptNeoSelfAttention {
k_proj,
v_proj,
q_proj,
out_proj,
attention_dropout,
resid_dropout,
bias,
num_heads,
head_dim,
output_attentions,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
layer_state: Option<&LayerState>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<(Tensor, Option<Tensor>, Option<LayerState>), RustBertError> {
let query = hidden_states.apply(&self.q_proj);
let key = hidden_states.apply(&self.k_proj);
let value = hidden_states.apply(&self.v_proj);
let query = Self::split_heads(&query, self.num_heads, self.head_dim)?;
let mut key = Self::split_heads(&key, self.num_heads, self.head_dim)?;
let mut value = Self::split_heads(&value, self.num_heads, self.head_dim)?;
if let Some(layer_state_value) = &layer_state {
key = Tensor::cat(&[&layer_state_value.prev_key, &key], -2);
value = Tensor::cat(
&[layer_state_value.prev_value.as_ref().unwrap(), &value],
-2,
);
};
let layer_state = Some(LayerState {
prev_key: key.copy(),
prev_value: Some(value.copy()),
});
let query_dims = query.size();
let key_dims = key.size();
let query_length = query_dims[query_dims.len() - 2];
let key_length = key_dims[key_dims.len() - 2];
let causal_mask = self
.bias
.slice(2, key_length - query_length, key_length, 1)
.slice(3, 0, key_length, 1)
.to_kind(Kind::Bool);
let (attention_output, attention_weights) = Self::attend(
&query,
&key,
&value,
&causal_mask,
&self.attention_dropout,
attention_mask,
train,
);
let attention_output = Self::merge_heads(&attention_output, self.num_heads, self.head_dim)?
.apply(&self.out_proj)
.apply_t(&self.resid_dropout, train);
let attention_weights = if self.output_attentions {
Some(attention_weights)
} else {
None
};
Ok((attention_output, attention_weights, layer_state))
}
}
pub struct GptNeoLocalSelfAttention {
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
attention_dropout: Dropout,
resid_dropout: Dropout,
num_heads: i64,
head_dim: i64,
window_size: i64,
embed_dim: i64,
output_attentions: bool,
}
impl GptNeoAttentionUtils for GptNeoLocalSelfAttention {}
impl GptNeoLocalSelfAttention {
pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> GptNeoLocalSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);
let num_heads = config.num_heads;
let head_dim = config.hidden_size / config.num_heads;
let linear_config = nn::LinearConfig {
bias: false,
..Default::default()
};
let k_proj = nn::linear(
p / "k_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let v_proj = nn::linear(
p / "v_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let q_proj = nn::linear(
p / "q_proj",
config.hidden_size,
config.hidden_size,
linear_config,
);
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let window_size = config.window_size;
let embed_dim = config.hidden_size;
let output_attentions = config.output_attentions.unwrap_or(false);
GptNeoLocalSelfAttention {
k_proj,
v_proj,
q_proj,
out_proj,
attention_dropout,
resid_dropout,
num_heads,
head_dim,
window_size,
embed_dim,
output_attentions,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
layer_state: Option<&LayerState>,
attention_mask: &Tensor,
train: bool,
) -> Result<(Tensor, Option<Tensor>), RustBertError> {
let query = hidden_states.apply(&self.q_proj);
let (calc_key_value_hidden_states, past_length) =
if let Some(layer_state_value) = layer_state {
let key_value_hidden_states =
Tensor::cat(&[&layer_state_value.prev_key, hidden_states], 1);
(
Some(key_value_hidden_states),
layer_state_value.prev_key.size()[1],
)
} else {
(None, 0)
};
let key_value_hidden_states = calc_key_value_hidden_states
.as_ref()
.unwrap_or(hidden_states);
let key = key_value_hidden_states.apply(&self.k_proj);
let value = key_value_hidden_states.apply(&self.v_proj);
let hidden_states_shape = hidden_states.size();
let (batch_size, sequence_length) = (hidden_states_shape[0], hidden_states_shape[1]);
let full_sequence_length = sequence_length + past_length;
let (block_length, num_blocks) =
Self::get_block_length_and_num_blocks(full_sequence_length, self.window_size);
let query = if layer_state.is_some() {
Self::split_sequence_length_dim_to(&query, 1, 1)
} else {
Self::split_sequence_length_dim_to(&query, num_blocks, block_length)
}?;
let mut key = Self::look_back(&key, block_length, self.window_size, None, true)?;
let mut value = Self::look_back(&value, block_length, self.window_size, None, true)?;
if layer_state.is_some() {
key = key.narrow(1, -1, 1);
value = value.narrow(1, -1, 1);
}
let query = Self::split_heads(&query, self.num_heads, self.head_dim)?;
let key = Self::split_heads(&key, self.num_heads, self.head_dim)?;
let value = Self::split_heads(&value, self.num_heads, self.head_dim)?;
let calc_attention_mask = if layer_state.is_some() {
Some(attention_mask.narrow(3, -1, 1).narrow(1, -1, 1))
} else {
None
};
let attention_mask = calc_attention_mask.as_ref().unwrap_or(attention_mask);
let (attention_output, attention_weights) = Self::attend(
&query,
&key,
&value,
attention_mask,
&self.attention_dropout,
None,
train,
);
let attention_output = Self::merge_heads(&attention_output, self.num_heads, self.head_dim)?
.reshape(&[batch_size, sequence_length, self.embed_dim])
.apply(&self.out_proj)
.apply_t(&self.resid_dropout, train);
let attention_weights = if self.output_attentions {
Some(attention_weights)
} else {
None
};
Ok((attention_output, attention_weights))
}
}
pub enum GptNeoAttention {
SelfAttention(GptNeoSelfAttention),
LocalSelfAttention(GptNeoLocalSelfAttention),
}
impl GptNeoAttention {
pub fn new<'p, P>(p: P, config: &GptNeoConfig, layer_id: usize) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let attention_type = &config.attention_layers[layer_id];
Ok(match attention_type {
AttentionLayerType::Global => {
GptNeoAttention::SelfAttention(GptNeoSelfAttention::new(p / "attention", config))
}
AttentionLayerType::Local => GptNeoAttention::LocalSelfAttention(
GptNeoLocalSelfAttention::new(p / "attention", config),
),
})
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
layer_state: Option<&LayerState>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<(Tensor, Option<Tensor>, Option<LayerState>), RustBertError> {
let layer_output = match self {
GptNeoAttention::SelfAttention(ref attention) => {
attention.forward_t(hidden_states, layer_state, attention_mask, train)?
}
GptNeoAttention::LocalSelfAttention(ref attention) => {
let output = attention.forward_t(
hidden_states,
layer_state,
attention_mask.ok_or_else(|| {
RustBertError::ValueError(
"Attention mask must be provided for Local self attention".to_string(),
)
})?,
train,
)?;
let new_layer_state = if let Some(old_layer_state) = layer_state {
LayerState {
prev_key: Tensor::cat(&[&old_layer_state.prev_key, hidden_states], 1),
prev_value: None,
}
} else {
LayerState {
prev_key: hidden_states.copy(),
prev_value: None,
}
};
(output.0, output.1, Some(new_layer_state))
}
};
Ok(layer_output)
}
}

133
src/gpt_neo/decoder.rs Normal file
View File

@ -0,0 +1,133 @@
// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use crate::gpt_neo::attention::{GptNeoAttention, LayerState};
use crate::gpt_neo::GptNeoConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::ModuleT;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct GptNeoMLP {
c_fc: nn::Linear,
c_proj: nn::Linear,
activation_function: TensorFunction,
dropout: Dropout,
}
impl GptNeoMLP {
pub fn new<'p, P>(p: P, intermediate_size: i64, config: &GptNeoConfig) -> GptNeoMLP
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let c_fc = nn::linear(
p / "c_fc",
config.hidden_size,
intermediate_size,
Default::default(),
);
let c_proj = nn::linear(
p / "c_proj",
intermediate_size,
config.hidden_size,
Default::default(),
);
let activation_function = config.activation_function.get_function();
let dropout = Dropout::new(config.resid_dropout);
GptNeoMLP {
c_fc,
c_proj,
activation_function,
dropout,
}
}
}
impl ModuleT for GptNeoMLP {
fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
let hidden_states = hidden_states.apply(&self.c_fc);
let hidden_states = self.activation_function.get_fn()(&hidden_states);
hidden_states
.apply(&self.c_proj)
.apply_t(&self.dropout, train)
}
}
pub struct GptNeoBlock {
ln_1: nn::LayerNorm,
ln_2: nn::LayerNorm,
attention: GptNeoAttention,
mlp: GptNeoMLP,
}
impl GptNeoBlock {
pub fn new<'p, P>(
p: P,
layer_id: usize,
config: &GptNeoConfig,
) -> Result<GptNeoBlock, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.hidden_size], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.hidden_size], layer_norm_config);
let attention = GptNeoAttention::new(p / "attn", config, layer_id)?;
let inner_dim = config.intermediate_size.unwrap_or(4 * config.hidden_size);
let mlp = GptNeoMLP::new(p / "mlp", inner_dim, config);
Ok(GptNeoBlock {
ln_1,
ln_2,
attention,
mlp,
})
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
layer_state: Option<&LayerState>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<(Tensor, Option<Tensor>, Option<LayerState>), RustBertError> {
let intermediate = hidden_states.apply(&self.ln_1);
let (intermediate, attention_weights, layer_state) =
self.attention
.forward_t(&intermediate, layer_state, attention_mask, train)?;
let hidden_states = hidden_states + intermediate;
let intermediate = hidden_states.apply(&self.ln_2).apply_t(&self.mlp, train);
let output = hidden_states + intermediate;
Ok((output, attention_weights, layer_state))
}
pub(crate) fn get_attention_type(&self) -> &GptNeoAttention {
&self.attention
}
}

View File

@ -0,0 +1,812 @@
// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::gpt_neo::attention::{GptNeoAttention, GptNeoAttentionUtils};
use crate::gpt_neo::decoder::GptNeoBlock;
use crate::gpt_neo::LayerState;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::{Activation, Config, RustBertError};
use rust_tokenizers::tokenizer::Gpt2Tokenizer;
use rust_tokenizers::vocab::Gpt2Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Kind, Tensor};
/// # GPT-Neo Pretrained model weight files
pub struct GptNeoModelResources;
/// # GPT-Neo Pretrained model config files
pub struct GptNeoConfigResources;
/// # GPT-Neo Pretrained model vocab files
pub struct GptNeoVocabResources;
/// # GPT-Neo Pretrained model merges files
pub struct GptNeoMergesResources;
impl GptNeoModelResources {
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/model",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/model",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/model",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/rust_model.ot",
);
}
impl GptNeoConfigResources {
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/config",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/config",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/config",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/config.json",
);
}
impl GptNeoVocabResources {
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at <https://github.com/openai/gpt-2/blob/master/LICENSE>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/vocab.json",
);
}
impl GptNeoMergesResources {
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/merges",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/merges",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the EleutherAI contributors at <https://www.eleuther.ai>. Modified with conversion to C-array format.
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/merges",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/merges.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "camelCase")]
/// #GPT-Neo attention layer type
pub enum AttentionLayerType {
Global,
Local,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # GPT-Neo model configuration
/// Defines the GPT-Neo model architecture (e.g. number of layers, hidden layer size, vocab size...).
pub struct GptNeoConfig {
pub activation_function: Activation,
pub attention_dropout: f64,
pub attention_layers: Vec<AttentionLayerType>,
pub attention_types: Vec<(Vec<AttentionLayerType>, i64)>,
pub intermediate_size: Option<i64>,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub vocab_size: i64,
pub num_layers: i64,
pub num_heads: i64,
pub hidden_size: i64,
pub window_size: i64,
pub embed_dropout: f64,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub max_position_embeddings: i64,
pub output_past: Option<bool>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_dropout: f64,
}
impl Config for GptNeoConfig {}
/// # GPT-Neo Base model
/// Base architecture for GPT-Neo models. Task-specific models will be built from this common base model
/// It is made of the following blocks:
/// - `word_embeddings`: Word embeddings
/// - `position_embeddings`: Position embeddings
/// - `layers`: Vector of `GptNeoBlock` (transformer part of the model)
pub struct GptNeoModel {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
layers: Vec<GptNeoBlock>,
dropout: Dropout,
layer_norm: nn::LayerNorm,
window_size: i64,
output_attentions: bool,
output_hidden_states: bool,
}
impl GptNeoAttentionUtils for GptNeoModel {}
impl GptNeoModel {
/// Build a new `GptNeoModel`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the GPT-Neo model
/// * `config` - `GptNeoConfig` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = GptNeoConfig::from_file(config_path);
/// let gpt_neo_model = GptNeoModel::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoModel, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let word_embeddings = nn::embedding(
p / "wte",
config.vocab_size,
config.hidden_size,
Default::default(),
);
let position_embeddings = nn::embedding(
p / "wpe",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.embed_dropout);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let layer_norm = nn::layer_norm(p / "ln_f", vec![config.hidden_size], layer_norm_config);
let mut layers: Vec<GptNeoBlock> = Vec::with_capacity(config.num_layers as usize);
let p_layers = p / "h";
for layer_index in 0..config.num_layers {
layers.push(GptNeoBlock::new(
&p_layers / layer_index,
layer_index as usize,
config,
)?);
}
let window_size = config.window_size;
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
Ok(GptNeoModel {
word_embeddings,
position_embeddings,
layers,
dropout,
layer_norm,
window_size,
output_attentions,
output_hidden_states,
})
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
/// * `layer_states` - Optional Vector `Option<Vec<Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention of each layer.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `Result<GptNeoModelOutput, RustBertError>` containing:
/// - `hidden_states` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containing the attention weights for each layer
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = GptNeoConfig::from_file(config_path);
/// # let gpt_neo_model = GptNeoModel::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// gpt_neo_model.forward_t(
/// Some(&input_tensor),
/// Some(&attention_mask),
/// None,
/// None,
/// None,
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<Vec<Option<LayerState>>>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<GptNeoModelOutput, RustBertError> {
let (calc_input_embeddings, input_shape, device) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let (batch_size, current_sequence_length) = (input_shape[0], input_shape[1]);
let past_length = if let Some(past_state_value) = &layer_states {
if let Some(first_layer_state) = &past_state_value[0] {
let mut size_iter = first_layer_state.prev_key.size().into_iter().rev();
size_iter.next();
size_iter.next().unwrap()
} else {
0
}
} else {
0
};
let full_sequence_length = current_sequence_length + past_length;
let calc_position_ids = if position_ids.is_none() {
let position_ids =
Tensor::arange_start(past_length, full_sequence_length, (Kind::Int64, device));
Some(
position_ids
.unsqueeze(0)
.view([-1, current_sequence_length]),
)
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let local_attention_mask = GptNeoModel::create_local_attention_mask(
batch_size,
full_sequence_length,
self.window_size,
device,
attention_mask,
)?;
let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let position_embeds = position_ids.apply(&self.position_embeddings);
let global_attention_mask = attention_mask.map(|attention_mask_value| {
let global_attention_mask = attention_mask_value
.view([batch_size, -1])
.unsqueeze(1)
.unsqueeze(1);
let global_attention_mask = global_attention_mask.to_kind(position_embeds.kind());
(1 - global_attention_mask) * -1e4
});
let mut hidden_state = input_embeds + position_embeds;
if let Some(token_type_ids) = token_type_ids {
hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings);
};
hidden_state = hidden_state.apply_t(&self.dropout, train);
let mut output_shape = input_shape;
output_shape.push(*hidden_state.size().last().unwrap());
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let old_cache = layer_states.unwrap_or_else(|| vec![None; self.layers.len()]);
let mut next_cache = vec![None; self.layers.len()];
let mut x: Option<Tensor> = None;
let mut attention_weights: Option<Tensor>;
for ((layer_idx, layer), layer_state) in
self.layers.iter().enumerate().zip(old_cache.into_iter())
{
let attention_mask = match layer.get_attention_type() {
GptNeoAttention::SelfAttention(_) => global_attention_mask.as_ref(),
GptNeoAttention::LocalSelfAttention(_) => Some(&local_attention_mask),
};
let temp = if let Some(x_value) = &x {
layer.forward_t(x_value, layer_state.as_ref(), attention_mask, train)?
} else {
layer.forward_t(&hidden_state, layer_state.as_ref(), attention_mask, train)?
};
x = Some(temp.0);
attention_weights = temp.1;
next_cache[layer_idx] = temp.2;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(x.as_ref().unwrap().copy());
};
}
let hidden_states = x
.unwrap()
.apply(&self.layer_norm)
.view(output_shape.as_slice());
Ok(GptNeoModelOutput {
hidden_states,
next_cache: Some(next_cache),
all_hidden_states,
all_attentions,
})
}
}
/// # GPT-Neo Model for causal language modeling
/// Gpt-Neo model with a vocabulary decoding head. The language model decoding head is tied to the word embedding matrix weights
/// It is made of the following blocks:
/// - `transformer`: `GptNeoModel` Base ProphetNet model
pub struct GptNeoForCausalLM {
transformer: GptNeoModel,
}
impl GptNeoForCausalLM {
/// Build a new `GptNeoForCausalLM`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the GPT-Neo model
/// * `config` - `GptNeoConfig` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoForCausalLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = GptNeoConfig::from_file(config_path);
/// let gpt_neo_model = GptNeoForCausalLM::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoForCausalLM, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transformer = GptNeoModel::new(p / "transformer", config)?;
Ok(GptNeoForCausalLM { transformer })
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
/// * `layer_states` - Optional Vector `Option<Vec<Option<&LayerState>>>` of length *n_layer* containing tuples with the past keys and values for both the self attention of each layer.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `Result<GptNeoModelLMOutput, RustBertError>` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past content for the the attention layers
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *n_layer + 1* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *n_layer* containing the attention weights for each layer
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt_neo::{GptNeoConfig, GptNeoForCausalLM};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = GptNeoConfig::from_file(config_path);
/// # let gpt_neo_model = GptNeoForCausalLM::new(&vs.root(), &config).unwrap();
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// gpt_neo_model.forward_t(
/// Some(&input_tensor),
/// Some(&attention_mask),
/// None,
/// None,
/// None,
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<Vec<Option<LayerState>>>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<GptNeoModelLMOutput, RustBertError> {
let base_model_output = self.transformer.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
layer_states,
attention_mask,
train,
)?;
let lm_logits = base_model_output
.hidden_states
.linear::<Tensor>(&self.transformer.word_embeddings.ws, None);
Ok(GptNeoModelLMOutput {
lm_logits,
next_cache: base_model_output.next_cache,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
impl LMHeadModel for GptNeoForCausalLM {
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPTNeoCache(layer_past) => self.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
layer_past,
attention_mask,
train,
),
Cache::None => self.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
None,
attention_mask,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT-Neo Model".into(),
));
}
}?;
Ok(LMModelOutput {
lm_logits: base_model_output.lm_logits,
cache: Cache::GPTNeoCache(base_model_output.next_cache),
})
}
}
/// Container for the GPT-Neo model output.
pub struct GptNeoModelOutput {
/// Last hidden states from the model
pub hidden_states: Tensor,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub next_cache: Option<Vec<Option<LayerState>>>,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
///Container holding a GPT-Neo model with LM head output
pub struct GptNeoModelLMOutput {
/// logits
pub lm_logits: Tensor,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub next_cache: Option<Vec<Option<LayerState>>>,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// # Language generation model based on the GPT-Neo architecture
pub struct GptNeoGenerator {
model: GptNeoForCausalLM,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl GptNeoGenerator {
/// Build a new `GPTNeoGenerator`
///
/// # Arguments
///
/// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::gpt_neo::GptNeoGenerator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let gpt_neo_generator = GptNeoGenerator::new(generate_config)?;
/// # Ok(())
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::GPTNeo,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
None,
None,
)?;
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(&var_store.root(), &config)?;
var_store.load(weights_path)?;
let bos_token_id = Some(tokenizer.convert_tokens_to_ids(&[Gpt2Vocab::bos_value()])[0]);
let eos_token_ids = Some(tokenizer.convert_tokens_to_ids(&[Gpt2Vocab::eos_value()]));
let pad_token_id = Some(tokenizer.convert_tokens_to_ids(&[Gpt2Vocab::eos_value()])[0]);
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let max_position_embeddings = config.max_position_embeddings;
Ok(GptNeoGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for GptNeoGenerator {
fn get_model(&self) -> &GptNeoForCausalLM {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
.masked_fill(&attention_mask.eq(0), 1);
match past {
Cache::GPTNeoCache(past) => {
if past.is_some() {
PreparedInput {
prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
prepared_past: Cache::GPTNeoCache(past),
}
} else {
PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPTNeoCache(None),
}
}
}
Cache::None => PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPTNeoCache(None),
},
_ => panic!("Cache type incompatible with GPT-Neo"),
}
}
fn reorder_cache(
&self,
past: &mut Cache,
_encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
match past {
Cache::GPTNeoCache(cached_decoder_state) => match cached_decoder_state {
Some(old_cache) => {
for layer_state in old_cache.iter_mut() {
if layer_state.is_some() {
layer_state.as_mut().unwrap().reorder_cache(beam_indices)
};
}
None
}
None => None,
},
Cache::None => None,
_ => {
panic!("Invalid cache for GPT-Neo model");
}
}
}
}
impl LanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for GptNeoGenerator {}

76
src/gpt_neo/mod.rs Normal file
View File

@ -0,0 +1,76 @@
//! # GPT-Neo
//!
//! Implementation of the GPT-Neo language model ([The Pile: An 800GB Dataset of Diverse Text for Language Modeling](https://arxiv.org/abs/2101.00027) Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others, 2020).
//! The base model is implemented in the `gpt_neo_model::GptNeoModel` struct. A causal language modeling head is implemented in `gpt_neo_model::GptNeoForCausalLM`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/generation_gpt_neo`, run with `cargo run --example generation_gpt_neo`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `GPT2Tokenizer` using a `vocab.json` vocabulary and a `merges.txt` merges file
//!
//! The following pre-trained checkpoints are readily available:
//! - 125M parameters model (GptNeoModelResources::GPT_NEO_125M)
//! - 1.3B parameters model (GptNeoModelResources::GPT_NEO_1_3B)
//! - 2.7B parameters model (GptNeoModelResources::GPT_NEO_2_7B)
//!
//! ```no_run
//! use rust_bert::gpt_neo::{
//! GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
//! };
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::{RemoteResource, Resource};
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! GptNeoConfigResources::GPT_NEO_1_3B,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! GptNeoVocabResources::GPT_NEO_1_3B,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! GptNeoMergesResources::GPT_NEO_1_3B,
//! ));
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
//! GptNeoModelResources::GPT_NEO_1_3B,
//! ));
//!
//! let text_generation_config = TextGenerationConfig {
//! model_type: ModelType::GPTNeo,
//! model_resource,
//! config_resource,
//! vocab_resource,
//! merges_resource,
//! num_beams: 4,
//! no_repeat_ngram_size: 3,
//! device: Device::cuda_if_available(),
//! ..Default::default()
//! };
//! let model = TextGenerationModel::new(text_generation_config)?;
//!
//! let input_context_1 = "It was a very nice and sunny";
//! let input_context_2 = "It was a gloom winter night, and";
//! let output = model.generate(&[input_context_1, input_context_2], None);
//!
//! for sentence in output {
//! println!("{}", sentence);
//! }
//!
//! Ok(())
//! }
//! ```
mod attention;
mod decoder;
mod gpt_neo_model;
pub use gpt_neo_model::{
GptNeoConfig, GptNeoConfigResources, GptNeoForCausalLM, GptNeoGenerator, GptNeoMergesResources,
GptNeoModel, GptNeoModelResources, GptNeoVocabResources,
};
pub use attention::LayerState;

View File

@ -1,22 +1,9 @@
//! # Ready-to-use NLP pipelines and Transformer-based models
//!
//! Rust native Transformer-based models implementation. Port of the [Transformers](https://github.com/huggingface/transformers) library, using the tch-rs crate and pre-processing from rust-tokenizers.
//! Supports multithreaded tokenization and GPU inference. This repository exposes the model base architecture, task-specific heads (see below) and ready-to-use pipelines.
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
//! This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
//!
//! # Quick Start
//!
//! This crate can be used in two different ways:
//! - Ready-to-use NLP pipelines for:
//! - Translation
//! - Summarization
//! - Multi-turn dialogue
//! - Zero-shot classification
//! - Sentiment Analysis
//! - Named Entity Recognition
//! - Question-Answering
//! - Language Generation.
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//!
@ -25,38 +12,571 @@
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! let answers = qa_model.predict(&[QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
//!
//! Output:
//! ```no_run
//! # use rust_bert::pipelines::question_answering::Answer;
//! # let output =
//! [Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: String::from("Amsterdam"),
//! }]
//! # ;
//! ```
//!
//! The tasks currently supported include:
//! - Translation
//! - Summarization
//! - Multi-turn dialogue
//! - Zero-shot classification
//! - Sentiment Analysis
//! - Named Entity Recognition
//! - Part of Speech tagging
//! - Question-Answering
//! - Language Generation.
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
//!
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**
//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:
//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |
//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | |
//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | |
//! Question answering|✅ |✅ |✅| | | | | |✅ | |
//! Multiple choices| |✅ |✅| | | | | |✅ | |
//! Next token prediction| | | |✅|✅| | | | | |
//! Natural Language Generation| | | |✅|✅| | | | | |
//! Summarization| | | | | |✅| | | | |
//! Translation| | | | | | | |✅| |✅|
//! <details>
//! <summary> <b> Click to expand to display the supported models/tasks matrix </b> </summary>
//!
//! # Loading pre-trained models
//! | |**Sequence classification**|**Token classification**|**Question answering**|**Text Generation**|**Summarization**|**Translation**|**Masked LM**|
//! :-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:
//! DistilBERT|✅|✅|✅| | | |✅|
//! MobileBERT|✅|✅|✅| | | |✅|
//! FNet|✅|✅|✅| | | |✅|
//! BERT|✅|✅|✅| | | |✅|
//! RoBERTa|✅|✅|✅| | | |✅|
//! GPT| | | |✅ | | | |
//! GPT2| | | |✅ | | | |
//! GPT-Neo| | | |✅ | | | |
//! BART|✅| | |✅ |✅| | |
//! Marian| | | | | |✅| |
//! MBart|✅| | |✅ | | | |
//! M2M100| | | |✅ | | | |
//! Electra | |✅| | | | |✅|
//! ALBERT |✅|✅|✅| | | |✅|
//! T5 | | | |✅ |✅|✅| |
//! XLNet|✅|✅|✅|✅ | | |✅|
//! Reformer|✅| |✅|✅ | | |✅|
//! ProphetNet| | | |✅ |✅ | | |
//! Longformer|✅|✅|✅| | | |✅|
//! Pegasus| | | | |✅| | |
//! </details>
//!
//! A number of pretrained model configuration, weights and vocabulary are downloaded directly from [Huggingface's model repository](https://huggingface.co/models).
//! The list of models available with Rust-compatible weights is available in the example ./examples/download_all_dependencies.rs. Additional models can be added if of interest, please raise an issue.
//! # Getting started
//!
//! In order to load custom weights to the library, these need to be converter to a binary format that can be read by Libtorch (the original `.bin` files are pickles and cannot be used directly).
//! Several Python scripts to load Pytorch weights and convert them to the appropriate format are provided and can be adapted based on the model needs.
//! This library relies on the [tch](https://github.com/LaurentMazare/tch-rs) crate for bindings to the C++ Libtorch API.
//! The libtorch library is required can be downloaded either automatically or manually. The following provides a reference on how to set-up yoru environment
//! to use these bindings, please refer to the [tch](https://github.com/LaurentMazare/tch-rs) for detailed information or support.
//!
//! The procedure for building custom weights or re-building pretrained weights is as follows:
//! 1. Compile the package: cargo build --release
//! 2. Download the model files & perform necessary conversions
//! - Set-up a virtual environment and install dependencies
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
//! 3. Run the example cargo run --release
//! Furthermore, this library relies on a cache folder for downloading pre-trained models.
//! This cache location defaults to `~/.cache/.rustbert`, but can be changed by setting the `RUSTBERT_CACHE` environment variable. Note that the language models used by this library are in the order of the 100s of MBs to GBs.
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
//! ```bash
//! export LIBTORCH=/path/to/libtorch
//! export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
//! ```
//!
//! ##### Windows
//! ```powershell
//! $Env:LIBTORCH = "X:\path\to\libtorch"
//! $Env:Path += ";X:\path\to\libtorch\lib"
//! ```
//!
//! ### Automatic installation
//!
//! Alternatively, you can let the `build` script automatically download the `libtorch` library for you.
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu111`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//!
//! # Ready-to-use pipelines
//!
//! Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! The following capabilities are currently available:
//!
//! **Disclaimer**
//! The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
//!
//! <details>
//! <summary> <b>1. Question Answering</b> </summary>
//!
//! Extractive question answering from a given question and context. DistilBERT model fine-tuned on SQuAD (Stanford Question Answering Dataset)
//!
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//! # fn main() -> anyhow::Result<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//!
//! let answers = qa_model.predict(&[QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
//!
//! Output: \
//! ```no_run
//! # use rust_bert::pipelines::question_answering::Answer;
//! # let output =
//! [Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: String::from("Amsterdam"),
//! }]
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>2. Translation </b> </summary>
//!
//! Translation pipeline supporting a broad range of source and target languages. Leverages two main architectures for translation tasks:
//! - Marian-based models, for specific source/target combinations
//! - M2M100 models allowing for direct translation between 100 languages (at a higher computational cost and lower performance for some selected languages)
//!
//! Marian-based pretrained models for the following language pairs are readily available in the library - but the user can import any Pytorch-based
//! model for predictions
//! - English <-> French
//! - English <-> Spanish
//! - English <-> Portuguese
//! - English <-> Italian
//! - English <-> Catalan
//! - English <-> German
//! - English <-> Russian
//! - English <-> Chinese
//! - English <-> Dutch
//! - English <-> Swedish
//! - English <-> Arabic
//! - English <-> Hebrew
//! - English <-> Hindi
//! - French <-> German
//!
//! For languages not supported by the proposed pretrained Marian models, the user can leverage a M2M100 model supporting direct translation between 100 languages (without intermediate English translation)
//! The full list of supported languages is available in the [`pipelines` module](./pipelines/translation/enum.Language.html)
//!
//!
//! ```no_run
//! use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
//! fn main() -> anyhow::Result<()> {
//! let model = TranslationModelBuilder::new()
//! .with_source_languages(vec![Language::English])
//! .with_target_languages(vec![Language::Spanish, Language::French, Language::Italian])
//! .create_model()?;
//! let input_text = "This is a sentence to be translated";
//! let output = model.translate(&[input_text], None, Language::Spanish)?;
//! for sentence in output {
//! println!("{}", sentence);
//! }
//! Ok(())
//! }
//! ```
//! Output: \
//! ```no_run
//! # let output =
//! " Il s'agit d'une phrase à traduire"
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>3. Summarization </b> </summary>
//!
//! Abstractive summarization using a pretrained BART model.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::summarization::SummarizationModel;
//!
//! let mut model = SummarizationModel::new(Default::default())?;
//!
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//! said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//! a potentially habitable planet, but further observations will be required to say for sure. \"
//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//! about exoplanets like K2-18b."];
//!
//! let output = model.summarize(&input);
//! # Ok(())
//! # }
//! ```
//! (example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
//!
//! Example output: \
//! ```no_run
//! # let output =
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
//! This is the first such discovery in a planet in its star's habitable zone.
//! The planet is not too hot and not too cold for liquid water to exist."
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>4. Dialogue Model </b> </summary>
//!
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
//!
//! The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
//! let conversation_model = ConversationModel::new(Default::default())?;
//! let mut conversation_manager = ConversationManager::new();
//!
//! let conversation_id =
//! conversation_manager.create("Going to the movies tonight - any suggestions?");
//! let output = conversation_model.generate_responses(&mut conversation_manager);
//! # Ok(())
//! # }
//! ```
//! Example output: \
//! ```no_run
//! # let output =
//! "The Big Lebowski."
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>5. Natural Language Generation </b> </summary>
//!
//! Generate language based on a prompt. GPT2 and GPT available as base models.
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
//! This may impact the results, it is recommended to submit prompts of similar length for best results
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::text_generation::TextGenerationModel;
//! use rust_bert::pipelines::common::ModelType;
//! let mut model = TextGenerationModel::new(Default::default())?;
//! let input_context_1 = "The dog";
//! let input_context_2 = "The cat was";
//!
//! let prefix = None; // Optional prefix to append prompts with, will be excluded from the generated output
//!
//! let output = model.generate(&[input_context_1, input_context_2], prefix);
//! # Ok(())
//! # }
//! ```
//! Example output: \
//! ```no_run
//! # let output =
//! [
//! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
//! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
//! "The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's",
//! "The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,",
//! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
//! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
//! ]
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>6. Zero-shot classification </b> </summary>
//!
//! Performs zero-shot classification on input sentences with provided labels using a model fine-tuned for Natural Language Inference.
//! ```no_run
//! # use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
//! # fn main() -> anyhow::Result<()> {
//! let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
//! let input_sentence = "Who are you voting for in 2020?";
//! let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
//! let candidate_labels = &["politics", "public health", "economics", "sports"];
//! let output = sequence_classification_model.predict_multilabel(
//! &[input_sentence, input_sequence_2],
//! candidate_labels,
//! None,
//! 128,
//! );
//! # Ok(())
//! # }
//! ```
//!
//! outputs:
//! ```no_run
//! # use rust_bert::pipelines::sequence_classification::Label;
//! let output = [
//! [
//! Label {
//! text: "politics".to_string(),
//! score: 0.972,
//! id: 0,
//! sentence: 0,
//! },
//! Label {
//! text: "public health".to_string(),
//! score: 0.032,
//! id: 1,
//! sentence: 0,
//! },
//! Label {
//! text: "economics".to_string(),
//! score: 0.006,
//! id: 2,
//! sentence: 0,
//! },
//! Label {
//! text: "sports".to_string(),
//! score: 0.004,
//! id: 3,
//! sentence: 0,
//! },
//! ],
//! [
//! Label {
//! text: "politics".to_string(),
//! score: 0.975,
//! id: 0,
//! sentence: 1,
//! },
//! Label {
//! text: "economics".to_string(),
//! score: 0.852,
//! id: 2,
//! sentence: 1,
//! },
//! Label {
//! text: "public health".to_string(),
//! score: 0.0818,
//! id: 1,
//! sentence: 1,
//! },
//! Label {
//! text: "sports".to_string(),
//! score: 0.001,
//! id: 3,
//! sentence: 1,
//! },
//! ],
//! ]
//! .to_vec();
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>7. Sentiment analysis </b> </summary>
//!
//! Predicts the binary sentiment for a sentence. DistilBERT model fine-tuned on SST-2.
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
//! # fn main() -> anyhow::Result<()> {
//! let sentiment_model = SentimentModel::new(Default::default())?;
//! let input = [
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
//! "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ];
//! let output = sentiment_model.predict(&input);
//! # Ok(())
//! # }
//! ```
//! (Example courtesy of [IMDb](http://www.imdb.com))
//!
//! Output: \
//! ```no_run
//! # use rust_bert::pipelines::sentiment::Sentiment;
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//! # let output =
//! [
//! Sentiment {
//! polarity: Positive,
//! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ]
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>8. Named Entity Recognition </b> </summary>
//!
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model fine-tuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
//! Models are currently available for English, German, Spanish and Dutch.
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//! # fn main() -> anyhow::Result<()> {
//! let ner_model = NERModel::new(Default::default())?;
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France.",
//! ];
//! let output = ner_model.predict(&input);
//! # Ok(())
//! # }
//! ```
//! Output: \
//! ```no_run
//! # use rust_bert::pipelines::ner::Entity;
//! # let output =
//! [
//! [
//! Entity {
//! word: String::from("Amy"),
//! score: 0.9986,
//! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! ],
//! [
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ],
//! ]
//! # ;
//! ```
//!
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>9. Part of Speech tagging </b> </summary>
//!
//! Extracts Part of Speech tags (Noun, Verb, Adjective...) from text.
//! ```no_run
//! use rust_bert::pipelines::pos_tagging::POSModel;
//! # fn main() -> anyhow::Result<()> {
//! let pos_model = POSModel::new(Default::default())?;
//! let input = ["My name is Bob"];
//! let output = pos_model.predict(&input);
//! # Ok(())
//! # }
//! ```
//! Output: \
//! ```no_run
//! # use rust_bert::pipelines::pos_tagging::POSTag;
//! # let output =
//! [
//! POSTag {
//! word: String::from("My"),
//! score: 0.1560,
//! label: String::from("PRP"),
//! },
//! POSTag {
//! word: String::from("name"),
//! score: 0.6565,
//! label: String::from("NN"),
//! },
//! POSTag {
//! word: String::from("is"),
//! score: 0.3697,
//! label: String::from("VBZ"),
//! },
//! POSTag {
//! word: String::from("Bob"),
//! score: 0.7460,
//! label: String::from("NNP"),
//! },
//! ]
//! # ;
//! ```
//!
//! </details>
//!
//! ## Benchmarks
//!
//! For simple pipelines (sequence classification, tokens classification, question answering) the performance between Python and Rust is expected to be comparable. This is because the most expensive part of these pipeline is the language model itself, sharing a common implementation in the Torch backend. The [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/) provides a benchmarks section covering all pipelines.
//!
//! For text generation tasks (summarization, translation, conversation, free text generation), significant benefits can be expected (up to 2 to 4 times faster processing depending on the input and application). The article [Accelerating text generation with Rust](https://guillaume-be.github.io/2020-11-21/generation_benchmarks) focuses on these text generation applications and provides more details on the performance comparison to Python.
//!
//! ## Loading pretrained and custom model weights
//!
//! The base model and task-specific heads are also available for users looking to expose their own transformer based models.
//! Examples on how to prepare the date using a native tokenizers Rust library are available in `./examples` for BERT, DistilBERT, RoBERTa, GPT, GPT2 and BART.
//! Note that when importing models from Pytorch, the convention for parameters naming needs to be aligned with the Rust schema. Loading of the pre-trained weights will fail if any of the model parameters weights cannot be found in the weight files.
//! If this quality check is to be skipped, an alternative method `load_partial` can be invoked from the variables store.
//!
//! Pretrained models are available on Hugging face's [model hub](https://huggingface.co/models?filter=rust) and can be loaded using `RemoteResources` defined in this library.
//! A conversion utility script is included in `./utils` to convert Pytorch weights to a set of weights compatible with this library. This script requires Python and `torch` to be set-up, and can be used as follows:
//! `python ./utils/convert_model.py path/to/pytorch_model.bin` where `path/to/pytorch_model.bin` is the location of the original Pytorch weights.
//!
//!
//! ## Citation
//!
//! If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/):
//! ```bibtex
//! @inproceedings{becquin-2020-end,
//! title = "End-to-end {NLP} Pipelines in Rust",
//! author = "Becquin, Guillaume",
//! booktitle = "Proceedings of Second Workshop for NLP Open Source Software (NLP-OSS)",
//! year = "2020",
//! publisher = "Association for Computational Linguistics",
//! url = "https://www.aclweb.org/anthology/2020.nlposs-1.4",
//! pages = "20--25",
//! }
//! ```
//!
//! ## Acknowledgements
//!
//! Thank you to [Hugging Face](https://huggingface.co) for hosting a set of weights compatible with this Rust library.
//! The list of ready-to-use pretrained models is listed at [https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust).
pub mod albert;
pub mod bart;
@ -64,13 +584,23 @@ pub mod bert;
mod common;
pub mod distilbert;
pub mod electra;
pub mod fnet;
pub mod gpt2;
pub mod gpt_neo;
pub mod longformer;
pub mod m2m_100;
pub mod marian;
pub mod mbart;
pub mod mobilebert;
pub mod openai_gpt;
pub mod pegasus;
pub mod pipelines;
pub mod prophetnet;
pub mod reformer;
pub mod roberta;
pub mod t5;
pub mod xlnet;
pub use common::error::RustBertError;
pub use common::resources;
pub use common::Config;
pub use common::{Activation, Config};

806
src/longformer/attention.rs Normal file
View File

@ -0,0 +1,806 @@
// Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity;
use crate::longformer::LongformerConfig;
use std::borrow::Borrow;
use tch::{nn, Kind, Tensor};
pub struct LongformerSelfAttention {
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
query_global: nn::Linear,
key_global: nn::Linear,
value_global: nn::Linear,
dropout: Dropout,
one_sided_attention_window_size: i64,
num_heads: i64,
head_dim: i64,
embed_dim: i64,
output_attentions: bool,
}
impl LongformerSelfAttention {
pub fn new<'p, P>(p: P, config: &LongformerConfig, layer_id: i64) -> LongformerSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let num_heads = config.num_attention_heads;
let head_dim = config.hidden_size / num_heads;
let embed_dim = config.hidden_size;
let query = nn::linear(
p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let query_global = nn::linear(
p / "query_global",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key_global = nn::linear(
p / "key_global",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value_global = nn::linear(
p / "value_global",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let one_sided_attention_window_size = config.attention_window[layer_id as usize] / 2;
let output_attentions = config.output_attentions.unwrap_or(false);
LongformerSelfAttention {
query,
key,
value,
query_global,
key_global,
value_global,
dropout,
one_sided_attention_window_size,
num_heads,
head_dim,
embed_dim,
output_attentions,
}
}
fn pad_and_transpose_last_two_dims(&self, hidden_states: &Tensor, padding: &[i64]) -> Tensor {
let output = hidden_states.constant_pad_nd(padding);
let mut output_shape = output.size();
let last_dim = output_shape.pop().unwrap();
let second_last_dim = output_shape.pop().unwrap();
output_shape.push(last_dim);
output_shape.push(second_last_dim);
output.view(output_shape.as_slice())
}
fn pad_and_diagonalize(&self, chunked_hidden_states: &Tensor) -> Tensor {
let chunked_hidden_states_shape = chunked_hidden_states.size();
let (total_num_heads, num_chunks, window_overlap, hidden_dim) = (
chunked_hidden_states_shape[0],
chunked_hidden_states_shape[1],
chunked_hidden_states_shape[2],
chunked_hidden_states_shape[3],
);
chunked_hidden_states
.constant_pad_nd(&[0, window_overlap + 1])
.view([total_num_heads, num_chunks, -1])
.slice(2, 0, -window_overlap, 1)
.view([
total_num_heads,
num_chunks,
window_overlap,
window_overlap + hidden_dim,
])
.slice(3, 0, -1, 1)
}
fn chunk(&self, hidden_states: &Tensor, window_overlap: i64) -> Tensor {
let hidden_states_shape = hidden_states.size();
let hidden_states = hidden_states.view([
hidden_states_shape[0],
hidden_states_shape[1] / (window_overlap * 2),
window_overlap * 2,
hidden_states_shape[2],
]);
let mut chunk_size = hidden_states.size();
chunk_size[1] = chunk_size[1] * 2 - 1;
let mut chunk_stride = hidden_states.stride();
chunk_stride[1] = chunk_stride[1] / 2;
hidden_states.as_strided(chunk_size.as_slice(), chunk_stride.as_slice(), None)
}
fn mask_invalid_locations(&self, input_tensor: &mut Tensor, affected_sequence_length: i64) {
let input_size = input_tensor.size();
let beginning_input_size = vec![
input_size[0],
affected_sequence_length,
input_size[2],
affected_sequence_length + 1,
];
let ending_input_size = vec![
input_size[0],
affected_sequence_length,
input_size[2],
affected_sequence_length + 1,
];
let beginning_mask = Tensor::ones(
&[affected_sequence_length, affected_sequence_length + 1],
(Kind::Int, input_tensor.device()),
)
.tril(0)
.flip(&[0])
.unsqueeze(0)
.unsqueeze(2);
let ending_mask = beginning_mask.flip(&[1, 3]);
let beginning_mask = beginning_mask
.expand(beginning_input_size.as_slice(), true)
.eq(1);
let ending_mask = ending_mask.expand(ending_input_size.as_slice(), true).eq(1);
let _ = input_tensor
.slice(1, 0, affected_sequence_length, 1)
.slice(3, 0, affected_sequence_length + 1, 1)
.masked_fill_(
&beginning_mask,
get_negative_infinity(input_tensor.kind()).unwrap(),
);
let _ = input_tensor
.narrow(1, -affected_sequence_length, affected_sequence_length)
.narrow(
3,
-(affected_sequence_length + 1),
affected_sequence_length + 1,
)
.masked_fill_(
&ending_mask,
get_negative_infinity(input_tensor.kind()).unwrap(),
);
}
fn sliding_chunks_query_key_matmul(
&self,
query: &Tensor,
key: &Tensor,
window_overlap: i64,
) -> Tensor {
let (batch_size, sequence_length, num_heads, head_dim) = query.size4().unwrap();
let chunks_count = sequence_length / window_overlap - 1;
let query =
query
.transpose(1, 2)
.reshape(&[batch_size * num_heads, sequence_length, head_dim]);
let key = key
.transpose(1, 2)
.reshape(&[batch_size * num_heads, sequence_length, head_dim]);
let query = self.chunk(&query, window_overlap);
let key = self.chunk(&key, window_overlap);
let diagonal_chunked_attention_scores = self.pad_and_transpose_last_two_dims(
&Tensor::einsum("bcxd,bcyd->bcxy", &[query, key]),
&[0, 0, 0, 1],
);
let diagonal_attention_scores = Tensor::empty(
&[
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap * 2 + 1,
],
(
diagonal_chunked_attention_scores.kind(),
diagonal_chunked_attention_scores.device(),
),
);
let diagonal_attention_scores_size = diagonal_attention_scores.size();
let diagonal_chunked_attention_scores_size = diagonal_chunked_attention_scores.size();
diagonal_attention_scores
.slice(1, 0, -1, 1)
.slice(3, window_overlap, diagonal_attention_scores_size[3], 1)
.copy_(
&diagonal_chunked_attention_scores
.slice(2, 0, window_overlap, 1)
.slice(3, 0, window_overlap + 1, 1),
);
diagonal_attention_scores
.select(1, -1)
.slice(2, window_overlap, diagonal_attention_scores_size[3], 1)
.copy_(
&diagonal_chunked_attention_scores
.select(1, -1)
.slice(
1,
window_overlap,
diagonal_chunked_attention_scores_size[2],
1,
)
.slice(2, 0, window_overlap + 1, 1),
);
diagonal_attention_scores
.slice(1, 1, diagonal_attention_scores_size[1], 1)
.slice(3, 0, window_overlap, 1)
.copy_(
&diagonal_chunked_attention_scores
.slice(2, -(window_overlap + 1), -1, 1)
.slice(
3,
window_overlap + 1,
diagonal_chunked_attention_scores_size[3],
1,
),
);
diagonal_attention_scores
.select(1, 0)
.slice(1, 1, window_overlap, 1)
.slice(2, 1, window_overlap, 1)
.copy_(
&diagonal_chunked_attention_scores
.select(1, 0)
.slice(1, 0, window_overlap - 1, 1)
.slice(
2,
1 - window_overlap,
diagonal_chunked_attention_scores_size[3],
1,
),
);
let mut diagonal_attention_scores = diagonal_attention_scores
.view([
batch_size,
num_heads,
sequence_length,
2 * window_overlap + 1,
])
.transpose(2, 1);
self.mask_invalid_locations(&mut diagonal_attention_scores, window_overlap);
diagonal_attention_scores
}
fn sliding_chunks_matmul_attention_probas_value(
&self,
attention_probas: &Tensor,
value: &Tensor,
window_overlap: i64,
) -> Tensor {
let (batch_size, sequence_length, num_heads, head_dim) = value.size4().unwrap();
let chunk_counts = sequence_length / window_overlap - 1;
let chunked_attention_probas = attention_probas.transpose(1, 2).reshape(&[
batch_size * num_heads,
sequence_length / window_overlap,
window_overlap,
2 * window_overlap + 1,
]);
let value =
value
.transpose(1, 2)
.reshape(&[batch_size * num_heads, sequence_length, head_dim]);
let padded_value = (value + 1).constant_pad_nd(&[0, 0, window_overlap, window_overlap]) - 1;
let chunked_value_size = &[
batch_size * num_heads,
chunk_counts + 1,
3 * window_overlap,
head_dim,
];
let chunked_value_stride = padded_value.stride();
let chunked_value_stride = &[
chunked_value_stride[0],
window_overlap * chunked_value_stride[1],
chunked_value_stride[1],
chunked_value_stride[2],
];
let chunked_value = padded_value.as_strided(chunked_value_size, chunked_value_stride, None);
let chunked_attention_probas = self.pad_and_diagonalize(&chunked_attention_probas);
Tensor::einsum(
"bcwd,bcdh->bcwh",
&[chunked_attention_probas, chunked_value],
)
.view([batch_size, num_heads, sequence_length, head_dim])
.transpose(1, 2)
}
fn get_global_attention_indices(
&self,
is_index_global_attn: &Tensor,
) -> GlobalAttentionIndices {
let num_global_attention_indices =
is_index_global_attn.sum_dim_intlist(&[1], false, Kind::Int64);
let max_num_global_attention_indices = i64::from(num_global_attention_indices.max());
let is_index_global_attn_nonzero = is_index_global_attn
.nonzero_numpy()
.into_iter()
.map(Some)
.collect();
let is_local_index_global_attention = Tensor::arange(
max_num_global_attention_indices,
(Kind::Int64, is_index_global_attn.device()),
)
.lt_tensor(&num_global_attention_indices.unsqueeze(-1));
let is_local_index_global_attention_nonzero = is_local_index_global_attention
.nonzero_numpy()
.into_iter()
.map(Some)
.collect();
let is_local_index_no_global_attention_nonzero = is_local_index_global_attention
.eq(0)
.nonzero_numpy()
.into_iter()
.map(Some)
.collect();
GlobalAttentionIndices {
max_num_global_attention_indices,
is_index_global_attn_nonzero,
is_local_index_global_attention_nonzero,
is_local_index_no_global_attention_nonzero,
}
}
fn concat_with_global_key_attention_probas(
&self,
key_vectors: &Tensor,
query_vectors: &Tensor,
max_num_global_attention_indices: i64,
is_index_global_attn_nonzero: &[Option<Tensor>],
is_local_index_global_attention_nonzero: &[Option<Tensor>],
is_local_index_no_global_attention_nonzero: &[Option<Tensor>],
) -> Tensor {
let batch_size = key_vectors.size()[0];
let mut key_vectors_only_global = Tensor::zeros(
&[
batch_size,
max_num_global_attention_indices,
self.num_heads,
self.head_dim,
],
(key_vectors.kind(), key_vectors.device()),
);
let _ = key_vectors_only_global.index_put_(
is_local_index_global_attention_nonzero,
&key_vectors.index(is_index_global_attn_nonzero),
false,
);
let attention_probas_from_global_key = Tensor::einsum(
"blhd,bshd->blhs",
&[query_vectors, &key_vectors_only_global],
);
let _ = attention_probas_from_global_key
.index_select(
0,
is_local_index_no_global_attention_nonzero[0]
.as_ref()
.unwrap(),
)
.index_select(
3,
is_local_index_no_global_attention_nonzero[1]
.as_ref()
.unwrap(),
)
.fill_(-10000f64);
attention_probas_from_global_key
}
fn compute_attention_output_with_global_indices(
&self,
value_vectors: &Tensor,
attention_probas: &Tensor,
max_num_global_attention_indices: i64,
is_index_global_attn_nonzero: &[Option<Tensor>],
is_local_index_global_attention_nonzero: &[Option<Tensor>],
) -> Tensor {
let batch_size = attention_probas.size()[0];
let attention_probas_only_global =
attention_probas.narrow(-1, 0, max_num_global_attention_indices);
let mut value_vectors_only_global = Tensor::zeros(
&[
batch_size,
max_num_global_attention_indices,
self.num_heads,
self.head_dim,
],
(value_vectors.kind(), value_vectors.device()),
);
let _ = value_vectors_only_global.index_put_(
is_local_index_global_attention_nonzero,
&value_vectors.index(is_index_global_attn_nonzero),
false,
);
let attention_output_only_global = attention_probas_only_global
.transpose(1, 2)
.matmul(&value_vectors_only_global.transpose(1, 2))
.transpose(1, 2);
let attention_probas_without_global = attention_probas
.narrow(
-1,
max_num_global_attention_indices,
*attention_probas.size().last().unwrap() - max_num_global_attention_indices,
)
.contiguous();
let attn_output_without_global = self.sliding_chunks_matmul_attention_probas_value(
&attention_probas_without_global,
value_vectors,
self.one_sided_attention_window_size,
);
attention_output_only_global + attn_output_without_global
}
fn compute_global_attention_output_from_hidden(
&self,
hidden_states: &Tensor,
max_num_global_attention_indices: i64,
is_index_global_attn_nonzero: &[Option<Tensor>],
is_local_index_global_attention_nonzero: &[Option<Tensor>],
is_local_index_no_global_attention_nonzero: &[Option<Tensor>],
is_index_masked: &Tensor,
train: bool,
) -> (Tensor, Tensor) {
let hidden_states_shape = hidden_states.size();
let (sequence_length, batch_size) = (hidden_states_shape[0], hidden_states_shape[1]);
let mut global_attention_hidden_states = Tensor::zeros(
&[max_num_global_attention_indices, batch_size, self.embed_dim],
(hidden_states.kind(), hidden_states.device()),
);
let _ = global_attention_hidden_states.index_put_(
is_local_index_global_attention_nonzero
.iter()
.rev()
.map(|o| o.as_ref())
.collect::<Vec<Option<&Tensor>>>()
.as_slice(),
&hidden_states.index(
is_index_global_attn_nonzero
.iter()
.rev()
.map(|o| o.as_ref())
.collect::<Vec<Option<&Tensor>>>()
.as_slice(),
),
false,
);
let global_query_vectors_only_global = (global_attention_hidden_states
.apply(&self.query_global)
/ (self.head_dim as f64).sqrt())
.contiguous()
.view([
max_num_global_attention_indices,
batch_size * self.num_heads,
self.head_dim,
])
.transpose(0, 1);
let global_key_vectors = hidden_states
.apply(&self.key_global)
.contiguous()
.view([-1, batch_size * self.num_heads, self.head_dim])
.transpose(0, 1);
let global_value_vectors = hidden_states
.apply(&self.value_global)
.contiguous()
.view([-1, batch_size * self.num_heads, self.head_dim])
.transpose(0, 1);
let global_attention_scores = global_query_vectors_only_global
.bmm(&global_key_vectors.transpose(1, 2))
.view([
batch_size,
self.num_heads,
max_num_global_attention_indices,
sequence_length,
]);
let _ = global_attention_scores
.index_select(
0,
is_local_index_no_global_attention_nonzero[0]
.as_ref()
.unwrap(),
)
.index_select(
2,
is_local_index_no_global_attention_nonzero[1]
.as_ref()
.unwrap(),
)
.fill_(-10000_f64);
let global_attention_scores = global_attention_scores
.masked_fill(&is_index_masked.unsqueeze(1).unsqueeze(1), -10000_f64)
.view([
batch_size * self.num_heads,
max_num_global_attention_indices,
sequence_length,
]);
let global_attention_probas = global_attention_scores
.softmax(-1, global_attention_scores.kind())
.apply_t(&self.dropout, train);
let global_attention_output = global_attention_probas.bmm(&global_value_vectors);
let global_attention_probas = global_attention_probas.view([
batch_size,
self.num_heads,
max_num_global_attention_indices,
sequence_length,
]);
let global_attention_output = global_attention_output.view([
batch_size,
self.num_heads,
max_num_global_attention_indices,
self.head_dim,
]);
(global_attention_output, global_attention_probas)
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
is_index_masked: &Tensor,
is_index_global_attention: &Tensor,
is_global_attention: bool,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let hidden_states = hidden_states.transpose(0, 1);
let query_vectors = hidden_states.apply(&self.query) / (self.head_dim as f64).sqrt();
let key_vectors = hidden_states.apply(&self.key);
let value_vectors = hidden_states.apply(&self.value);
let (sequence_length, batch_size, embed_dim) = hidden_states.size3().unwrap();
let query_vectors = query_vectors
.view([sequence_length, batch_size, self.num_heads, self.head_dim])
.transpose(0, 1);
let key_vectors = key_vectors
.view([sequence_length, batch_size, self.num_heads, self.head_dim])
.transpose(0, 1);
let mut attention_scores = self.sliding_chunks_query_key_matmul(
&query_vectors,
&key_vectors,
self.one_sided_attention_window_size,
);
let remove_from_windowed_attention_mask = attention_mask.ne(0).unsqueeze(-1).unsqueeze(-1);
let float_mask = remove_from_windowed_attention_mask
.totype(attention_scores.kind())
.masked_fill(&remove_from_windowed_attention_mask, -10000.0);
let diagonal_mask = self.sliding_chunks_query_key_matmul(
&Tensor::ones(
float_mask.size().as_slice(),
(float_mask.kind(), float_mask.device()),
),
&float_mask,
self.one_sided_attention_window_size,
);
attention_scores = attention_scores + &diagonal_mask;
let (
max_num_global_attention_indices,
is_index_global_attn_nonzero,
is_local_index_global_attention_nonzero,
is_local_index_no_global_attention_nonzero,
) = if is_global_attention {
let global_attention_indices =
self.get_global_attention_indices(is_index_global_attention);
let global_key_attention_scores = self.concat_with_global_key_attention_probas(
&key_vectors,
&query_vectors,
global_attention_indices.max_num_global_attention_indices,
global_attention_indices
.is_index_global_attn_nonzero
.as_slice(),
global_attention_indices
.is_local_index_global_attention_nonzero
.as_slice(),
global_attention_indices
.is_local_index_no_global_attention_nonzero
.as_slice(),
);
attention_scores = Tensor::cat(&[&global_key_attention_scores, &attention_scores], -1);
(
Some(global_attention_indices.max_num_global_attention_indices),
Some(global_attention_indices.is_index_global_attn_nonzero),
Some(global_attention_indices.is_local_index_global_attention_nonzero),
Some(global_attention_indices.is_local_index_no_global_attention_nonzero),
)
} else {
(None, None, None, None)
};
let mut attention_probas = attention_scores
.softmax(-1, attention_scores.kind())
.masked_fill(&is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
.apply_t(&self.dropout, train);
let value_vectors = value_vectors
.view([sequence_length, batch_size, self.num_heads, self.head_dim])
.transpose(0, 1);
let attention_output = if is_global_attention {
self.compute_attention_output_with_global_indices(
&value_vectors,
&attention_probas,
max_num_global_attention_indices.unwrap(),
is_index_global_attn_nonzero.as_ref().unwrap(),
is_local_index_global_attention_nonzero.as_ref().unwrap(),
)
} else {
self.sliding_chunks_matmul_attention_probas_value(
&attention_probas,
&value_vectors,
self.one_sided_attention_window_size,
)
};
let mut attention_output =
attention_output
.transpose(0, 1)
.reshape(&[sequence_length, batch_size, embed_dim]);
let global_attention_probas = if is_global_attention {
let (global_attention_output, global_attention_probas) = self
.compute_global_attention_output_from_hidden(
&hidden_states,
max_num_global_attention_indices.unwrap(),
is_index_global_attn_nonzero.as_ref().unwrap(),
is_local_index_global_attention_nonzero.as_ref().unwrap(),
is_local_index_no_global_attention_nonzero.as_ref().unwrap(),
is_index_masked,
train,
);
let nonzero_global_attention_output = global_attention_output.transpose(1, 2).index(&[
Some(
is_local_index_global_attention_nonzero.as_ref().unwrap()[0]
.as_ref()
.unwrap(),
),
Some(
is_local_index_global_attention_nonzero.as_ref().unwrap()[1]
.as_ref()
.unwrap(),
),
]);
let _ = attention_output.index_put_(
is_index_global_attn_nonzero
.as_ref()
.unwrap()
.iter()
.rev()
.map(|o| o.as_ref())
.collect::<Vec<Option<&Tensor>>>()
.as_slice(),
&nonzero_global_attention_output.view([
is_local_index_global_attention_nonzero.as_ref().unwrap()[0]
.as_ref()
.unwrap()
.size()[0],
-1,
]),
false,
);
let _ = attention_probas.index_put_(
is_index_global_attn_nonzero.as_ref().unwrap(),
&Tensor::zeros(
attention_probas
.index(is_index_global_attn_nonzero.as_ref().unwrap())
.size()
.as_slice(),
(attention_output.kind(), attention_output.device()),
),
false,
);
Some(global_attention_probas)
} else {
None
};
let attention_probas = if self.output_attentions {
Some(attention_probas)
} else {
None
};
let global_attention_probas = if self.output_attentions {
global_attention_probas
} else {
None
};
(
attention_output.transpose(0, 1),
attention_probas,
global_attention_probas,
)
}
}
struct GlobalAttentionIndices {
max_num_global_attention_indices: i64,
is_index_global_attn_nonzero: Vec<Option<Tensor>>,
is_local_index_global_attention_nonzero: Vec<Option<Tensor>>,
is_local_index_no_global_attention_nonzero: Vec<Option<Tensor>>,
}

View File

@ -0,0 +1,139 @@
// Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::longformer::LongformerConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::EmbeddingConfig;
use tch::{nn, Kind, Tensor};
pub struct LongformerEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
layer_norm: nn::LayerNorm,
dropout: Dropout,
pad_token_id: i64,
}
impl LongformerEmbeddings {
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerEmbeddings
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let pad_token_id = config.pad_token_id.unwrap_or(1);
let embeddings_config = EmbeddingConfig {
padding_idx: pad_token_id,
..Default::default()
};
let word_embeddings = nn::embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embeddings_config,
);
let position_embeddings = nn::embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
embeddings_config,
);
let token_type_embeddings = nn::embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
LongformerEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
pad_token_id,
}
}
fn create_position_ids_from_input_ids(&self, input_ids: &Tensor) -> Tensor {
let mask = input_ids.ne(self.pad_token_id);
mask.cumsum(1, Kind::Int64) * mask + self.pad_token_id
}
fn create_position_ids_from_input_embeds(&self, inputs_embeds: &Tensor) -> Tensor {
let input_shape = inputs_embeds.size();
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
Tensor::arange_start(
self.pad_token_id + 1,
sequence_length + self.pad_token_id + 1,
(Kind::Int64, inputs_embeds.device()),
)
.unsqueeze(0)
.expand(&[batch_size, sequence_length], true)
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let calc_position_ids = if position_ids.is_none() {
if let Some(input_ids) = input_ids {
Some(self.create_position_ids_from_input_ids(input_ids))
} else {
Some(self.create_position_ids_from_input_embeds(input_embeds))
}
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
input_shape.as_slice(),
(Kind::Int64, input_embeds.device()),
))
} else {
None
};
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
Ok((input_embeds + position_embeddings + token_type_embeddings)
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}

360
src/longformer/encoder.rs Normal file
View File

@ -0,0 +1,360 @@
// Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
// Copyright 2021 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use crate::longformer::attention::LongformerSelfAttention;
use crate::longformer::LongformerConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::nn::Module;
use tch::{nn, Tensor};
pub struct LongformerSelfOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl LongformerSelfOutput {
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerSelfOutput
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
LongformerSelfOutput {
dense,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states = hidden_states
.apply(&self.dense)
.apply_t(&self.dropout, train);
(hidden_states + input_tensor).apply(&self.layer_norm)
}
}
pub struct LongformerAttention {
self_attention: LongformerSelfAttention,
output: LongformerSelfOutput,
}
impl LongformerAttention {
pub fn new<'p, P>(p: P, config: &LongformerConfig, layer_id: i64) -> LongformerAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let self_attention = LongformerSelfAttention::new(p / "self", config, layer_id);
let output = LongformerSelfOutput::new(p / "output", config);
LongformerAttention {
self_attention,
output,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
is_index_masked: &Tensor,
is_index_global_attention: &Tensor,
is_global_attention: bool,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_outputs, attention_scores, global_attention_scores) =
self.self_attention.forward_t(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attention,
is_global_attention,
train,
);
let attention_outputs = self
.output
.forward_t(&attention_outputs, hidden_states, train);
(attention_outputs, attention_scores, global_attention_scores)
}
}
#[derive(Debug)]
pub struct LongformerIntermediate {
dense: nn::Linear,
activation_function: TensorFunction,
}
impl LongformerIntermediate {
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerIntermediate
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let activation_function = config.hidden_act.get_function();
LongformerIntermediate {
dense,
activation_function,
}
}
}
impl Module for LongformerIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.activation_function.get_fn()(&hidden_states.apply(&self.dense))
}
}
pub struct LongformerOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl LongformerOutput {
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerOutput
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
LongformerOutput {
dense,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states = hidden_states
.apply(&self.dense)
.apply_t(&self.dropout, train);
(hidden_states + input_tensor).apply(&self.layer_norm)
}
}
pub struct LongformerLayer {
attention: LongformerAttention,
intermediate: LongformerIntermediate,
output: LongformerOutput,
}
impl LongformerLayer {
pub fn new<'p, P>(p: P, config: &LongformerConfig, layer_id: i64) -> LongformerLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let attention = LongformerAttention::new(p / "attention", config, layer_id);
let intermediate = LongformerIntermediate::new(p / "intermediate", config);
let output = LongformerOutput::new(p / "output", config);
LongformerLayer {
attention,
intermediate,
output,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
is_index_masked: &Tensor,
is_index_global_attention: &Tensor,
is_global_attention: bool,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_outputs, attention_scores, global_attention_scores) =
self.attention.forward_t(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attention,
is_global_attention,
train,
);
let intermediate_output = attention_outputs.apply(&self.intermediate);
let attention_outputs =
self.output
.forward_t(&intermediate_output, &attention_outputs, train);
(attention_outputs, attention_scores, global_attention_scores)
}
}
/// Container for the Longformer encoder output.
pub struct LongformerEncoderOutput {
/// Last hidden states from the model
pub hidden_states: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
/// Global attention weights for all intermediate layers
pub all_global_attentions: Option<Vec<Tensor>>,
}
pub struct LongformerEncoder {
layers: Vec<LongformerLayer>,
output_attentions: bool,
output_hidden_states: bool,
}
impl LongformerEncoder {
pub fn new<'p, P>(p: P, config: &LongformerConfig) -> LongformerEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let p_layers = p / "layer";
let mut layers: Vec<LongformerLayer> =
Vec::with_capacity(config.num_hidden_layers as usize);
for layer_index in 0..config.num_hidden_layers {
layers.push(LongformerLayer::new(
&p_layers / layer_index,
config,
layer_index,
));
}
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
LongformerEncoder {
layers,
output_attentions,
output_hidden_states,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
train: bool,
) -> LongformerEncoderOutput {
let is_index_masked = attention_mask.lt(0);
let is_index_global_attention = attention_mask.gt(0);
let is_global_attention = bool::from(is_index_global_attention.any());
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut all_global_attentions: Option<Vec<Tensor>> =
if self.output_attentions & is_global_attention {
Some(vec![])
} else {
None
};
let mut x: Option<Tensor> = None;
let mut attention_weights: Option<Tensor>;
let mut global_attention_weights: Option<Tensor>;
for layer in &self.layers {
let temp = if let Some(x_value) = &x {
layer.forward_t(
x_value,
attention_mask,
&is_index_masked,
&is_index_global_attention,
is_global_attention,
train,
)
} else {
layer.forward_t(
hidden_states,
attention_mask,
&is_index_masked,
&is_index_global_attention,
is_global_attention,
train,
)
};
x = Some(temp.0);
attention_weights = temp.1;
global_attention_weights = temp.2;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().transpose(1, 2));
};
if let Some(global_attentions) = all_global_attentions.borrow_mut() {
global_attentions.push(global_attention_weights.as_ref().unwrap().transpose(2, 3));
};
if let Some(all_hidden_states) = all_hidden_states.borrow_mut() {
all_hidden_states.push(x.as_ref().unwrap().copy());
};
}
LongformerEncoderOutput {
hidden_states: x.unwrap(),
all_hidden_states,
all_attentions,
all_global_attentions,
}
}
}

File diff suppressed because it is too large Load Diff

87
src/longformer/mod.rs Normal file
View File

@ -0,0 +1,87 @@
//! # Longformer: The Long-Document Transformer (Betalgy et al.)
//!
//! Implementation of the Longformer language model ([Longformer: The Long-Document Transformer](https://arxiv.org/abs/2001.04063) Betalgy, Peters, Cohan, 2020).
//! The base model is implemented in the `longformer_model::LongformerModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `longformer_model::LongformerForMaskedLM`
//! - Multiple choices: `longformer_model:LongformerForMultipleChoice`
//! - Question answering: `longformer_model::LongformerForQuestionAnswering`
//! - Sequence classification: `longformer_model::LongformerForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `longformer_model::LongformerForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (question answering) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `RobertaTokenizer` using a `vocab.json` vocabulary and `merges.txt` byte pair encoding merges
//!
//! # Question answering example below:
//!
//! ```no_run
//! use rust_bert::longformer::{
//! LongformerConfigResources, LongformerMergesResources, LongformerModelResources,
//! LongformerVocabResources,
//! };
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::question_answering::{
//! QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
//! };
//! use rust_bert::resources::{RemoteResource, Resource};
//!
//! fn main() -> anyhow::Result<()> {
//! // Set-up Question Answering model
//! let config = QuestionAnsweringConfig::new(
//! ModelType::Longformer,
//! Resource::Remote(RemoteResource::from_pretrained(
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
//! )),
//! Resource::Remote(RemoteResource::from_pretrained(
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
//! )),
//! Resource::Remote(RemoteResource::from_pretrained(
//! LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
//! )),
//! Some(Resource::Remote(RemoteResource::from_pretrained(
//! LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
//! ))),
//! false,
//! None,
//! false,
//! );
//!
//! let qa_model = QuestionAnsweringModel::new(config)?;
//!
//! // Define input
//! let question_1 = String::from("Where does Amy live ?");
//! let context_1 = String::from("Amy lives in Amsterdam");
//! let question_2 = String::from("Where does Eric live");
//! let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
//! let qa_input_1 = QaInput {
//! question: question_1,
//! context: context_1,
//! };
//! let qa_input_2 = QaInput {
//! question: question_2,
//! context: context_2,
//! };
//!
//! // Get answer
//! let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
//! println!("{:?}", answers);
//! Ok(())
//! }
//! ```
mod attention;
mod embeddings;
mod encoder;
mod longformer_model;
pub use longformer_model::{
LongformerConfig, LongformerConfigResources, LongformerForMaskedLM,
LongformerForMultipleChoice, LongformerForQuestionAnswering,
LongformerForSequenceClassification, LongformerForTokenClassification,
LongformerMergesResources, LongformerModel, LongformerModelResources,
LongformerTokenClassificationOutput, LongformerVocabResources,
};

15
src/m2m_100/attention.rs Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::LayerState as BartLayerState;
pub type LayerState = BartLayerState;

195
src/m2m_100/decoder.rs Normal file
View File

@ -0,0 +1,195 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartDecoderOutput, _expand_mask, _make_causal_mask};
use crate::common::dropout::Dropout;
use crate::m2m_100::embeddings::SinusoidalPositionalEmbedding;
use crate::m2m_100::{LayerState, M2M100Config};
use crate::mbart::MBartDecoderLayer;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub type M2M100DecoderLayer = MBartDecoderLayer;
pub struct M2M100Decoder {
dropout: Dropout,
layer_norm: nn::LayerNorm,
layers: Vec<M2M100DecoderLayer>,
embed_positions: SinusoidalPositionalEmbedding,
output_attentions: bool,
output_hidden_states: bool,
output_past: bool,
scale_embedding: f64,
}
impl M2M100Decoder {
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Decoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_past = config.output_past.unwrap_or(true);
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
let scale_embedding = if let Some(scale_embeddings) = config.scale_embedding {
if scale_embeddings {
(config.d_model as f64).sqrt()
} else {
1.0
}
} else {
1.0
};
let dropout = Dropout::new(config.dropout);
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.d_model], Default::default());
let embed_positions = SinusoidalPositionalEmbedding::new(
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
config.pad_token_id.unwrap_or(1),
);
let mut layers: Vec<M2M100DecoderLayer> = vec![];
let p_layers = p / "layers";
for layer_index in 0..config.decoder_layers {
layers.push(M2M100DecoderLayer::new(&p_layers / layer_index, config));
}
M2M100Decoder {
dropout,
layer_norm,
layers,
embed_positions,
output_attentions,
output_hidden_states,
output_past,
scale_embedding,
}
}
pub fn forward_t(
&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attention_mask: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> M2M100DecoderOutput {
let past_key_values_length = if let Some(old_layer_states_values) = &old_layer_states {
if let Some(old_value_state) = &old_layer_states_values[0].0 {
old_value_state.prev_key.size()[2]
} else {
0
}
} else {
0
};
let input_shape = input_ids.size();
let sequence_length = input_shape[1];
let x = input_ids.apply(embeddings) * self.scale_embedding;
let positions = self
.embed_positions
.forward(input_ids, past_key_values_length, x.kind());
let x = x + positions;
let causal_mask = if sequence_length > 1 {
Some(_make_causal_mask(
input_ids.size().as_slice(),
x.kind(),
x.device(),
past_key_values_length,
))
} else {
None
};
let decoder_attention_mask = decoder_attention_mask.map(|attention_mask| {
if let Some(causal_mask) = causal_mask {
causal_mask + _expand_mask(attention_mask, Some(sequence_length), x.kind())
} else {
_expand_mask(attention_mask, Some(sequence_length), x.kind())
}
});
let encoder_attention_mask = encoder_attention_mask
.map(|mask| _expand_mask(mask, Some(*input_ids.size().last().unwrap()), x.kind()));
let mut hidden_state = x.apply_t(&self.dropout, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.output_past {
if old_layer_states.is_some() {
old_layer_states
} else {
Some(vec![(None, None); self.layers.len()])
}
} else {
None
};
let mut attention_weights: Option<Tensor>;
for (layer_idx, layer) in self.layers.iter().enumerate() {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
encoder_hidden_states,
encoder_attention_mask.as_ref(),
decoder_attention_mask.as_ref(),
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
M2M100DecoderOutput {
hidden_state: hidden_state.apply(&self.layer_norm),
encoder_attention_mask,
next_decoder_cache,
all_hidden_states,
all_attentions,
}
}
}
/// Container holding a M2M100 decoder output
pub type M2M100DecoderOutput = BartDecoderOutput;

131
src/m2m_100/embeddings.rs Normal file
View File

@ -0,0 +1,131 @@
// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Borrow;
use std::ops::Deref;
use std::sync::RwLock;
use tch::nn::embedding;
use tch::{nn, Device, Kind, Tensor};
#[derive(Debug)]
pub struct SinusoidalPositionalEmbedding {
embedding: RwLock<nn::Embedding>,
embedding_dim: i64,
padding_idx: i64,
offset: i64,
}
impl SinusoidalPositionalEmbedding {
pub fn new<'p, P>(
p: P,
num_embeddings: i64,
embedding_dim: i64,
padding_idx: i64,
) -> SinusoidalPositionalEmbedding
where
P: Borrow<nn::Path<'p>>,
{
let device = p.borrow().device();
let mut local_varstore = nn::VarStore::new(device);
let offset = 2;
let mut embedding = embedding(
local_varstore.root(),
num_embeddings + offset,
embedding_dim,
Default::default(),
);
embedding
.ws
.set_data(&SinusoidalPositionalEmbedding::build_positional_embeddings(
num_embeddings + offset,
embedding_dim,
padding_idx,
device,
));
local_varstore.freeze();
SinusoidalPositionalEmbedding {
embedding: RwLock::new(embedding),
embedding_dim,
padding_idx,
offset,
}
}
fn build_positional_embeddings(
num_embeddings: i64,
embedding_dim: i64,
padding_idx: i64,
device: Device,
) -> Tensor {
let half_dim = embedding_dim / 2;
let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64);
let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp();
let emb =
Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0);
let mut sinusoidal_embedding =
Tensor::cat(&[&emb.sin(), &emb.cos()], 1).view([num_embeddings, -1]);
if embedding_dim % 2 == 1 {
sinusoidal_embedding = Tensor::cat(
&[
sinusoidal_embedding,
Tensor::zeros(&[num_embeddings, 1], (Kind::Float, device)),
],
1,
);
}
let _ = sinusoidal_embedding.select(0, padding_idx).fill_(0);
let _ = sinusoidal_embedding.requires_grad_(false);
sinusoidal_embedding
}
fn create_position_ids_from_input_ids(
&self,
input_ids: &Tensor,
past_key_values_length: i64,
) -> Tensor {
let mask = input_ids.ne(self.padding_idx).to_kind(Kind::Int64);
let incremental_indices = (mask.cumsum(1, Kind::Int64) + past_key_values_length) * mask;
incremental_indices + self.padding_idx
}
pub fn forward(&self, input_ids: &Tensor, past_key_values_length: i64, kind: Kind) -> Tensor {
let position_ids =
self.create_position_ids_from_input_ids(input_ids, past_key_values_length);
let input_size = input_ids.size();
let seq_length = input_size[1];
let max_pos = self.padding_idx + 1 + seq_length;
let current_size = self.embedding.read().unwrap().ws.size()[0];
if max_pos > current_size {
self.embedding.write().unwrap().ws.set_data(
&SinusoidalPositionalEmbedding::build_positional_embeddings(
max_pos + self.offset,
self.embedding_dim,
self.padding_idx,
input_ids.device(),
),
);
}
let current_kind = self.embedding.read().unwrap().ws.kind();
if current_kind != kind {
let new_embeddings = &self.embedding.read().unwrap().ws.to_kind(kind);
self.embedding.write().unwrap().ws.set_data(new_embeddings);
}
position_ids.apply(self.embedding.read().unwrap().deref())
}
}

Some files were not shown because too many files have changed in this diff Show More