ONNX Support (#346)

* Fixed Clippy warnings

* Revert "Shallow clone optimization (#243)"

This reverts commit ba584653bc.

* updated dependencies

* tryouts

* GPT2 tryouts

* WIP GPT2

* input mapping

* Cache storage

* Initial GPT2 prototype

* Initial ONNX Config and decoder implementation

* ONNXDecoder first draft

* Use Decoders in example

* Automated tch-ort conversion, decoder implementation

* ONNXCausalDecoder implementation

* Refactored _get_var_store to be optional, added get_device to gen trait

* updated example

* Added decoder_start_token_id to ConfigOption

* Addition of ONNXModelConfig, make max_position_embeddigs optional

* Addition of forward pass function for ONNXModel

* Working ONNX causal decoder

* Simplify tensor conversion

* refactor translation to facilitate ONNX integration

* Implementation of ONNXEncoder

* Implementation of ONNXConditionalGenerator

* working ONNXCausalGenerator

* - Reworked model resources type for pipelines and generators

* Aligned ONNXConditionalGenerator with other generators to use GenerateConfig for creation

* Moved force_token_id_generation to common utils function, fixed tests, Translation implementation

* generalized forced_bos and forced_eos tokens generation

* Aligned the `encode_prompt_text` method across language models

* Fix prompt encoding for causal generation

* Fix prompt encoding for causal generation

* Support for ONNX models for SequenceClassification

* Support for ONNX models for TokenClassification

* Support for ONNX models for POS and NER pipelines

* Support for ONNX models for ZeroShotClassification pipeline

* Support for ONNX models for QuestionAnswering pipeline

* Support for ONNX models for MaskedLM pipeline

* Added token_type_ids , updated layer cache i/o parsing for ONNX pipelines

* Support for ONNX models for TextGenerationPipeline, updated examples for remote resources

* Remove ONNX zero-shot classification example (lack of correct pretrained model)

* Addition of tests for ONNX pipelines support

* Made onnx feature optional

* Fix device lookup with onnx feature enabled

* Updates from main branch

* Flexible tokenizer creation for M2M100 (NLLB support), make NLLB test optional du to their size

* Fixed Clippy warnings

* Addition of documentation for ONNX

* Added documentation for ONNX support

* upcoming tch 1.12 fixes

* Fix merge conflicts

* Fix merge conflicts (2)

* Add download libtorch feature to ONNX tests

* Add download-onnx feature

* attempt to enable onnx download

* add remote resources feature

* onnx download

* pin ort version

* Update ort version
This commit is contained in:
guillaume-be 2023-05-30 07:20:25 +01:00 committed by GitHub
parent 81cde55b25
commit 540c9268e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 4342 additions and 1537 deletions

View File

@ -140,6 +140,24 @@ jobs:
--test nllb
--features download-libtorch
test-onnx:
name: Integration tests (ONNX models)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --package rust-bert
--features onnx
--test onnx
--features download-libtorch
convert-model:
name: Model conversion test
runs-on: ubuntu-latest

View File

@ -7,11 +7,22 @@ All notable changes to this project will be documented in this file. The format
- Addition of `add_tokens` and `add_extra_ids` interface methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
- Addition of `get_tokenizer` and `get_tokenizer_mut` methods to all pipelines allowing to get a (mutable) reference to the pipeline tokenizer.
- Addition of a `get_embedding_dim` method to get the dimension of the embeddings for sentence embeddings pipelines
- `get_vocab_size`, `get_decoder_start_token_id` and `get_prefix_and_forced_bos_id` for the `TokenizerOption` in pipelines
- Addition of the [GPT-J](https://www.eleuther.ai/artifacts/gpt-j) model architecture
- Addition of the [NLLB](https://arxiv.org/abs/2207.04672) model architecture and pretrained weights
- Addition of support for ONNX models (encoder, decoders, encoder-decoders) via the [ort](https://github.com/pykeio/ort) onnxruntime bindings
- Integration of ONNX models to the sequence classification, token classification, question answering, zero-shot classification, text generation, summarization and translation pipelines
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer
- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator)
- (BREAKING) Upgraded to `torch` 2.0 (via `tch` 0.13.0). The process to automatically download the dependencies have changed, it must now be enabled via the `download-libtorch` feature flag.
- Read the `decoder_start_token_id` from the provided configuration rather than using a hard-coded default value
- (BREAKING) Changed the return type of the `LanguageGenerator` and pipelines functions `float`, `half`, `set_device` to `Result<(), RustBertError>` as these become fallible for ONNX models
- (BREAKING) Wrapped the model resources specification for the pipeline `Config` objects into an `Enum` to allow handling both torch-based and ONNX models.
The `model_resources` field now needs to be wrapped in the corresponding enum variant, e.g. `model_resources: ModelResources::TORCH(model_resource)` for Torch-based models
- (BREAKING) Added the `forced_bos_token_id` and `forced_eos_token_id` fields to text generation models.
If these are not None, this will trigger a forced BOS/EOS token generation at the first of `max_length` positions (aligns with the Pytorch Transformers library)
## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.20.1-alpha"
version = "0.21.0-alpha"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and language models"
@ -66,6 +66,7 @@ doc-only = ["tch/doc-only"]
all-tests = []
remote = ["cached-path", "dirs", "lazy_static"]
download-libtorch = ["tch/download-libtorch"]
onnx = ["ort", "ndarray"]
[package.metadata.docs.rs]
features = ["doc-only"]
@ -84,6 +85,8 @@ regex = "1.6"
cached-path = { version = "0.6", optional = true }
dirs = { version = "4", optional = true }
lazy_static = { version = "1", optional = true }
ort = {version="1.14.8", optional = true, default-features = false, features = ["half"]}
ndarray = {version="0.15", optional = true}
[dev-dependencies]
anyhow = "1"
@ -93,3 +96,5 @@ tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.13.0"
tempfile = "3"
itertools = "0.10"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
ort = {version="1.14.8", features = ["load-dynamic"]}

View File

@ -5,7 +5,7 @@
[![Documentation](https://docs.rs/rust-bert/badge.svg)](https://docs.rs/rust-bert)
![License](https://img.shields.io/crates/l/rust_bert.svg)
Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using [tch-rs](https://github.com/LaurentMazare/tch-rs) or [onnxruntime bindings](https://github.com/pykeio/ort) and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
@ -35,6 +35,7 @@ The tasks currently supported include:
- Language Generation
- Masked Language Model
- Sentence Embeddings
- Keywords extraction
<details>
<summary> <b>Expand to display the supported models/tasks matrix </b> </summary>
@ -51,10 +52,12 @@ RoBERTa|✅|✅|✅| | | |✅| ✅|
GPT| | | |✅ | | | | |
GPT2| | | |✅ | | | | |
GPT-Neo| | | |✅ | | | | |
GPT-J| | | |✅ | | | | |
BART|✅| | |✅ |✅| | | |
Marian| | | | | |✅| | |
MBart|✅| | |✅ | | | | |
M2M100| | | |✅ | | | | |
NLLB| | | |✅ | | | | |
Electra | |✅| | | | |✅| |
ALBERT |✅|✅|✅| | | |✅| ✅ |
T5 | | | |✅ |✅|✅| | ✅ |
@ -116,6 +119,32 @@ cd rust-bert
cargo run --example sentence_embeddings
```
## ONNX Support (Optional)
ONNX support can be enabled via the optional `onnx` feature. This crate then leverages the [ort](https://github.com/pykeio/ort) crate with bindings to the onnxruntime C++ library. We refer the user to this page project for further installation instructions/support.
1. Enable the optional `onnx` feature. The `rust-bert` crate does not include any optional dependencies for `ort`, the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library.
2. The current recommended installation is to use dynamic linking by pointing to an existing library location. Use the `load-dynamic` cargo feature for `ort`.
3. set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib` depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library. A detailed guide on how to export a Transformer model to ONNX using optimum is available at https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model
The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
| Architecture | Encoder file | Decoder without past file | Decoder with past file |
|-----------------------------|---------------|---------------------------|-------------------------|
| Encoder (e.g. BERT) | required | not used | not used |
| Decoder (e.g. GPT2) | not used | required | optional |
| Encoder-decoder (e.g. BART) | required | required | optional |
Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
he base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
ost pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. Examples leveraging ONNX models are given in the `./examples` directory
## Ready-to-use pipelines
Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:

View File

@ -5,7 +5,7 @@ use criterion::{black_box, Criterion};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use std::time::{Duration, Instant};
@ -14,7 +14,9 @@ use tch::Device;
fn create_text_generation_model() -> TextGenerationModel {
let config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
))),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(

View File

@ -3,7 +3,7 @@ extern crate criterion;
use criterion::{black_box, Criterion};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -17,7 +17,9 @@ static BATCH_SIZE: usize = 64;
fn create_qa_model() -> QuestionAnsweringModel {
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta
@ -52,7 +54,9 @@ fn qa_load_model(iters: u64) -> Duration {
let start = Instant::now();
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta

View File

@ -17,6 +17,7 @@ use std::sync::{Arc, RwLock};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::common::ModelResource;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider};
use tch::Device;
@ -80,7 +81,7 @@ fn config(device: Device, model_data: Arc<RwLock<Vec<u8>>>) -> SummarizationConf
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Box::new(BufferResource { data: model_data });
let model_resource = ModelResource::Torch(Box::new(BufferResource { data: model_data }));
SummarizationConfig {
model_resource,

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
@ -26,7 +26,9 @@ fn main() -> anyhow::Result<()> {
// Language identification
let sequence_classification_config = SequenceClassificationConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::CODEBERTA_LANGUAGE_ID),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
RobertaModelResources::CODEBERTA_LANGUAGE_ID,
))),
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERTA_LANGUAGE_ID),
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERTA_LANGUAGE_ID),
Some(RemoteResource::from_pretrained(
@ -56,7 +58,9 @@ fn main() -> anyhow::Result<()> {
// Masked language model
let config = MaskedLanguageConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::CODEBERT_MLM),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
RobertaModelResources::CODEBERT_MLM,
))),
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERT_MLM),
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERT_MLM),
Some(RemoteResource::from_pretrained(

View File

@ -17,7 +17,7 @@ extern crate anyhow;
use rust_bert::gpt_neo::{
GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -53,7 +53,7 @@ fn main() -> anyhow::Result<()> {
};
let mut model = TextGenerationModel::new(generate_config)?;
model.set_device(Device::cuda_if_available());
model.set_device(Device::cuda_if_available())?;
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";

View File

@ -1,7 +1,7 @@
use std::path::PathBuf;
use rust_bert::gpt_j::{GptJConfigResources, GptJMergesResources, GptJVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{LocalResource, RemoteResource};
use tch::Device;
@ -67,7 +67,7 @@ fn main() -> anyhow::Result<()> {
let generation_config = TextGenerationConfig {
model_type: ModelType::GPTJ,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),

View File

@ -14,7 +14,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::reformer::{
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {
));
let generate_config = TextGenerationConfig {
model_type: ModelType::Reformer,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -14,7 +14,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let generate_config = TextGenerationConfig {
model_type: ModelType::XLNet,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -12,14 +12,16 @@
extern crate anyhow;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up model
let config = MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,

View File

@ -0,0 +1,36 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let masked_lm = MaskedLanguageModel::new(MaskedLanguageConfig::new(
ModelType::Bert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/model.onnx",
"onnx-bert-base-uncased-for-masked-lm",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/config.json",
"onnx-bert-base-uncased-for-masked-lm",
),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/vocab.txt",
"onnx-bert-base-uncased-for-masked-lm",
),
None,
false,
None,
None,
Some(String::from("<mask>")),
))?;
let input = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];
let output = masked_lm.predict(input)?;
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,40 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let qa_model = QuestionAnsweringModel::new(QuestionAnsweringConfig::new(
ModelType::Roberta,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/model.onnx",
"onnx-roberta-base-squad2",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/config.json",
"onnx-roberta-base-squad2",
),
RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/vocab.json",
"onnx-roberta-base-squad2",
),
Some(RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/merges.txt",
"onnx-roberta-base-squad2",
)),
false,
None,
None,
))?;
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let output = qa_model.predict(&[qa_input], 1, 32);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,37 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::sentiment::SentimentModel;
use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let classification_model = SentimentModel::new(SequenceClassificationConfig::new(
ModelType::DistilBert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/model.onnx",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
),
RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
),
None,
true,
None,
None,
))?;
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = classification_model.predict(input);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,42 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let text_generation_model = TextGenerationModel::new(TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::ONNX(ONNXModelResources {
encoder_resource: None,
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_model.onnx",
"onnx-gpt2",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_with_past_model.onnx",
"onnx-gpt2",
))),
}),
config_resource: Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/config.json",
"onnx-gpt2",
)),
vocab_resource: Box::new(RemoteResource::new(
"https://huggingface.co/gpt2/resolve/main/vocab.json",
"onnx-gpt2",
)),
merges_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/gpt2/resolve/main/merges.txt",
"onnx-gpt2",
))),
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
})?;
let prompts = ["It was a very nice and sunny"];
let output = text_generation_model.generate(&prompts, None);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,36 @@
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig,
};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let token_classification_model = NERModel::new(TokenClassificationConfig::new(
ModelType::Bert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/model.onnx",
"onnx-bert-base-NER",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/config.json",
"onnx-bert-base-NER",
),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/vocab.txt",
"onnx-bert-base-NER",
),
None,
false,
None,
None,
LabelAggregationOption::First,
))?;
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
let output = token_classification_model.predict_full_entities(&input);
println!("{:?}", output);
Ok(())
}

View File

@ -0,0 +1,63 @@
use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
use tch::Device;
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
let translation_model = TranslationModel::new(TranslationConfig::new(
ModelType::M2M100,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
"onnx-m2m100_418M",
))),
}),
RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
"onnx-m2m100_418M",
),
RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
"onnx-m2m100_418M",
),
Some(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
"onnx-m2m100_418M",
)),
M2M100SourceLanguages::M2M100_418M,
M2M100TargetLanguages::M2M100_418M,
Device::cuda_if_available(),
))?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::French,
)?);
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::Spanish,
)?);
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::Hindi,
)?);
println!("{:?}", outputs);
Ok(())
}

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -23,7 +23,9 @@ fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
None, //merges resource only relevant with ModelType::Roberta

View File

@ -16,7 +16,7 @@ use rust_bert::longformer::{
LongformerConfigResources, LongformerMergesResources, LongformerModelResources,
LongformerVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -26,7 +26,9 @@ fn main() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
))),
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Some(RemoteResource::from_pretrained(

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::RemoteResource;
@ -25,9 +25,9 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
)));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,

View File

@ -15,6 +15,7 @@ extern crate anyhow;
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::common::ModelResource;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -34,7 +35,7 @@ fn main() -> anyhow::Result<()> {
));
let summarization_config = SummarizationConfig {
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),

View File

@ -13,7 +13,7 @@
extern crate anyhow;
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig {
model_type: ModelType::Pegasus,
model_resource: weights_resource,
model_resource: ModelResource::Torch(weights_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
model_resource: ModelResource::Torch(weights_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig::new(
ModelType::T5,
weights_resource,
ModelResource::Torch(Box::new(weights_resource)),
config_resource,
vocab_resource,
None,

View File

@ -11,7 +11,7 @@
// limitations under the License.
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig,
@ -22,7 +22,9 @@ fn main() -> anyhow::Result<()> {
// Load a configuration
let config = TokenClassificationConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
None, //merges resource only relevant with ModelType::Roberta

View File

@ -16,7 +16,7 @@ use rust_bert::m2m_100::{
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
M2M100TargetLanguages, M2M100VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::M2M100,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),

View File

@ -17,7 +17,7 @@ use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),

View File

@ -16,7 +16,7 @@ use rust_bert::mbart::{
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
MBartVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::MBart,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
None,

View File

@ -12,7 +12,7 @@
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::T5,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
None,

View File

@ -23,7 +23,6 @@ use crate::pipelines::generation_utils::private_generation_utils::{
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::TruncationStrategy;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
@ -100,12 +99,12 @@ impl BartConfigResources {
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/config",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/config.json",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/config",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/config.json",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/config.json",
);
}
@ -133,12 +132,12 @@ impl BartVocabResources {
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/vocab",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/vocab.json",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/vocab",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/vocab.json",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/vocab.json",
);
}
@ -166,12 +165,12 @@ impl BartMergesResources {
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
"distilbart-cnn-6-6/merges",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/merges.txt",
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
"distilbart-cnn-12-6/merges",
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/merges.txt",
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/merges.txt",
);
}
@ -197,6 +196,8 @@ pub struct BartConfig {
pub encoder_layers: i64,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub pad_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
@ -240,6 +241,8 @@ impl Default for BartConfig {
bos_token_id: Some(0),
eos_token_id: Some(2),
pad_token_id: Some(1),
forced_bos_token_id: Some(0),
forced_eos_token_id: Some(2),
id2label: None,
label2id: None,
init_std: 0.02,
@ -918,6 +921,8 @@ pub struct BartGenerator {
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
forced_bos_token_id: Option<i64>,
forced_eos_token_id: Option<i64>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
@ -1006,10 +1011,12 @@ impl BartGenerator {
Some(value) => vec![value],
None => vec![2],
});
let forced_bos_token_id = config.forced_bos_token_id;
let forced_eos_token_id = config.forced_eos_token_id;
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(BartGenerator {
@ -1019,6 +1026,8 @@ impl BartGenerator {
generate_config,
bos_token_id,
eos_token_ids,
forced_bos_token_id,
forced_eos_token_id,
pad_token_id,
is_encoder_decoder,
vocab_size,
@ -1026,14 +1035,6 @@ impl BartGenerator {
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator for BartGenerator {
@ -1043,11 +1044,11 @@ impl PrivateLanguageGenerator for BartGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -1058,6 +1059,12 @@ impl PrivateLanguageGenerator for BartGenerator {
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_forced_bos_token_id(&self) -> Option<i64> {
self.forced_bos_token_id
}
fn get_forced_eos_token_id(&self) -> Option<i64> {
self.forced_eos_token_id
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
@ -1070,8 +1077,8 @@ impl PrivateLanguageGenerator for BartGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
@ -1119,25 +1126,6 @@ impl PrivateLanguageGenerator for BartGenerator {
})
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(
scores,
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
);
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.model.encode(input_ids, attention_mask))
}
@ -1170,48 +1158,6 @@ impl PrivateLanguageGenerator for BartGenerator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -1,3 +1,7 @@
#[cfg(feature = "onnx")]
use ndarray::ShapeError;
#[cfg(feature = "onnx")]
use ort::OrtError;
use rust_tokenizers::error::TokenizerError;
use tch::TchError;
use thiserror::Error;
@ -23,6 +27,14 @@ pub enum RustBertError {
#[error("Value error: {0}")]
ValueError(String),
#[error("Value error: {0}")]
#[cfg(feature = "onnx")]
OrtError(String),
#[error("Value error: {0}")]
#[cfg(feature = "onnx")]
NdArrayError(String),
#[error("Unsupported operation")]
UnsupportedError,
}
@ -44,3 +56,16 @@ impl From<TchError> for RustBertError {
RustBertError::TchError(error.to_string())
}
}
#[cfg(feature = "onnx")]
impl From<OrtError> for RustBertError {
fn from(error: OrtError) -> Self {
RustBertError::OrtError(error.to_string())
}
}
#[cfg(feature = "onnx")]
impl From<ShapeError> for RustBertError {
fn from(error: ShapeError) -> Self {
RustBertError::NdArrayError(error.to_string())
}
}

View File

@ -4,6 +4,7 @@ use std::path::PathBuf;
use std::sync::{Arc, RwLock};
/// # In-memory raw buffer resource
#[derive(Debug)]
pub struct BufferResource {
/// The data representing the underlying resource
pub data: Arc<RwLock<Vec<u8>>>,

View File

@ -3,7 +3,7 @@ use crate::resources::{Resource, ResourceProvider};
use std::path::PathBuf;
/// # Local resource
#[derive(PartialEq, Eq, Clone)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf,

View File

@ -25,6 +25,7 @@ mod local;
use crate::common::error::RustBertError;
pub use buffer::BufferResource;
pub use local::LocalResource;
use std::fmt::Debug;
use std::ops::DerefMut;
use std::path::PathBuf;
use std::sync::RwLockWriteGuard;
@ -37,7 +38,7 @@ pub enum Resource<'a> {
/// # Resource Trait that can provide the location or data for the model, and location of
/// configuration or vocabulary resources
pub trait ResourceProvider: Send + Sync {
pub trait ResourceProvider: Debug + Send + Sync {
/// Provides the local path for a resource.
///
/// # Returns

View File

@ -6,7 +6,7 @@ use lazy_static::lazy_static;
use std::path::PathBuf;
/// # Remote resource that will be downloaded and cached locally on demand
#[derive(PartialEq, Eq, Clone)]
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,

View File

@ -88,6 +88,7 @@ pub struct FNetConfig {
pub pad_token_id: Option<i64>,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub decoder_start_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub output_attentions: Option<bool>,
@ -112,6 +113,7 @@ impl Default for FNetConfig {
pad_token_id: Some(3),
bos_token_id: Some(1),
eos_token_id: Some(2),
decoder_start_token_id: None,
id2label: None,
label2id: None,
output_attentions: None,

View File

@ -26,7 +26,7 @@ use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources;
@ -194,6 +194,9 @@ pub struct Gpt2Config {
pub output_hidden_states: Option<bool>,
pub resid_pdrop: Option<f64>,
pub vocab_size: i64,
pub decoder_start_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
}
impl Config for Gpt2Config {}
@ -218,6 +221,9 @@ impl Default for Gpt2Config {
output_hidden_states: None,
resid_pdrop: Some(0.1),
vocab_size: 50257,
decoder_start_token_id: None,
forced_bos_token_id: None,
forced_eos_token_id: None,
}
}
}
@ -654,7 +660,7 @@ impl GPT2Generator {
let max_position_embeddings = config.n_positions;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
Ok(GPT2Generator {
model,
@ -679,11 +685,11 @@ impl PrivateLanguageGenerator for GPT2Generator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -706,8 +712,8 @@ impl PrivateLanguageGenerator for GPT2Generator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(

View File

@ -135,6 +135,9 @@ pub struct GptJConfig {
pub use_float16: bool,
#[serde(default = "default_preload_on_cpu")]
pub preload_on_cpu: bool,
pub decoder_start_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
}
impl Config for GptJConfig {}
@ -163,6 +166,9 @@ impl Default for GptJConfig {
scale_attn_weights: Some(true),
use_float16: default_use_float16(),
preload_on_cpu: default_preload_on_cpu(),
decoder_start_token_id: None,
forced_bos_token_id: None,
forced_eos_token_id: None,
}
}
}
@ -630,7 +636,7 @@ impl GptJGenerator {
let max_position_embeddings = config.n_positions;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
Ok(GptJGenerator {
model,
@ -655,11 +661,11 @@ impl PrivateLanguageGenerator for GptJGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -682,8 +688,8 @@ impl PrivateLanguageGenerator for GptJGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(

View File

@ -22,7 +22,7 @@ use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, L
use crate::{Activation, Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// # GPT-Neo Pretrained model weight files
pub struct GptNeoModelResources;
@ -127,6 +127,8 @@ pub struct GptNeoConfig {
pub intermediate_size: Option<i64>,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub vocab_size: i64,
pub num_layers: i64,
pub num_heads: i64,
@ -140,6 +142,7 @@ pub struct GptNeoConfig {
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_dropout: f64,
pub decoder_start_token_id: Option<i64>,
}
impl Config for GptNeoConfig {}
@ -162,6 +165,8 @@ impl Default for GptNeoConfig {
intermediate_size: None,
bos_token_id: 50256,
eos_token_id: 50256,
forced_bos_token_id: None,
forced_eos_token_id: None,
vocab_size: 50257,
num_layers: 24,
num_heads: 16,
@ -175,6 +180,7 @@ impl Default for GptNeoConfig {
output_attentions: None,
output_hidden_states: None,
resid_dropout: 0.0,
decoder_start_token_id: None,
}
}
}
@ -673,7 +679,7 @@ impl GptNeoGenerator {
let pad_token_id = tokenizer.get_pad_id();
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(GptNeoGenerator {
@ -699,11 +705,11 @@ impl PrivateLanguageGenerator for GptNeoGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -727,8 +733,8 @@ impl PrivateLanguageGenerator for GptNeoGenerator {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(

View File

@ -26,6 +26,7 @@
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelResource;
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//! GptNeoConfigResources::GPT_NEO_1_3B,
//! ));
@ -41,7 +42,7 @@
//!
//! let text_generation_config = TextGenerationConfig {
//! model_type: ModelType::GPTNeo,
//! model_resource,
//! model_resource: ModelResource::Torch(model_resource),
//! config_resource,
//! vocab_resource,
//! merges_resource: Some(merges_resource),

View File

@ -1,6 +1,6 @@
//! # Ready-to-use NLP pipelines and Transformer-based models
//!
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using [tch-rs](https://github.com/LaurentMazare/tch-rs) or [onnxruntime bindings](https://github.com/pykeio/ort) and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
//! This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
//!
//! Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
@ -42,6 +42,7 @@
//! - Language Generation
//! - Sentence Embeddings
//! - Masked Language Model
//! - Keywords extraction
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
@ -61,10 +62,12 @@
//!GPT| | | |✅ | | | | |
//!GPT2| | | |✅ | | | | |
//!GPT-Neo| | | |✅ | | | | |
//!GPT-J| | | |✅ | | | | |
//!BART|✅| | |✅ |✅| | | |
//!Marian| | | | | |✅| | |
//!MBart|✅| | |✅ | | | | |
//!M2M100| | | |✅ | | | | |
//!NLLB| | | |✅ | | | | |
//!Electra | |✅| | | | |✅| |
//!ALBERT |✅|✅|✅| | | |✅| ✅ |
//!T5 | | | |✅ |✅|✅| | ✅ |
@ -109,6 +112,32 @@
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu118`.
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
//!
//! ## ONNX Support (Optional)
//!
//! ONNX support can be enabled via the optional `onnx` feature. This crate then leverages the [ort](https://github.com/pykeio/ort) crate with bindings to the onnxruntime C++ library. We refer the user to this page project for further installation instructions/support.
//! 1. Enable the optional `onnx` feature. The `rust-bert` crate does not include any optional dependencies for `ort`, the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library.
//! 2. The current recommended installation is to use dynamic linking by pointing to an existing library location. Use the `load-dynamic` cargo feature for `ort`.
//! 3. set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib` depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
//!
//! Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library. A detailed guide on how to export a Transformer model to ONNX using optimum is available at <https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model>
//! The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
//!
//! | Architecture | Encoder file | Decoder without past file | Decoder with past file |
//! -----------------------------|---------------|---------------------------|-------------------------
//! | Encoder (e.g. BERT) | required | not used | not used |
//! | Decoder (e.g. GPT2) | not used | required | optional |
//! | Encoder-decoder (e.g. BART) | required | required | optional |
//!
//! Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
//! since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
//! redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
//! he base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
//!
//! Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
//! ost pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
//! token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
//! These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. Examples leveraging ONNX models are given in the `./examples` directory. More information on these can be found in the [`onnx` module](./pipelines/onnx/index.html)
//!
//! # Ready-to-use pipelines
//!
//! Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. More information on these can be found in the [`pipelines` module](./pipelines/index.html)

View File

@ -31,11 +31,12 @@
//!
//! fn main() -> anyhow::Result<()> {
//! // Set-up Question Answering model
//! let config = QuestionAnsweringConfig::new(
//! use rust_bert::pipelines::common::ModelResource;
//! let config = QuestionAnsweringConfig::new(
//! ModelType::Longformer,
//! RemoteResource::from_pretrained(
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
//! ),
//! ))),
//! RemoteResource::from_pretrained(
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
//! ),

View File

@ -19,11 +19,10 @@ use crate::pipelines::generation_utils::private_generation_utils::{
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::t5::{FeedForwardProj, T5Config, T5ModelOutput, TaskSpecificParams};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::TruncationStrategy;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use tch::nn::{embedding, LinearConfig};
use tch::{nn, Tensor};
use tch::{nn, Device, Tensor};
/// # LongT5 Pretrained model weight files
pub struct LongT5ModelResources;
@ -79,6 +78,8 @@ pub struct LongT5Config {
pub decoder_start_token_id: Option<i64>,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub initializer_factor: f64,
pub is_encoder_decoder: Option<bool>,
pub layer_norm_epsilon: f64,
@ -112,6 +113,8 @@ impl Default for LongT5Config {
decoder_start_token_id: None,
bos_token_id: None,
eos_token_id: Some(1),
forced_bos_token_id: None,
forced_eos_token_id: None,
initializer_factor: 1.0,
is_encoder_decoder: None,
layer_norm_epsilon: 1e-6,
@ -145,6 +148,8 @@ impl From<&LongT5Config> for T5Config {
decoder_start_token_id: val.decoder_start_token_id,
bos_token_id: None,
eos_token_id: val.eos_token_id,
forced_bos_token_id: val.forced_bos_token_id,
forced_eos_token_id: val.forced_eos_token_id,
initializer_factor: val.initializer_factor,
is_encoder_decoder: val.is_encoder_decoder,
layer_norm_epsilon: val.layer_norm_epsilon,
@ -600,7 +605,7 @@ impl LongT5Generator {
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = pad_token_id;
let decoder_start_id = config.decoder_start_token_id;
// longT5 do not have an embedding matrix for position IDs and relies on relative positions instead
let max_position_embeddings = i64::MAX;
@ -627,11 +632,11 @@ impl PrivateLanguageGenerator for LongT5Generator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -654,8 +659,8 @@ impl PrivateLanguageGenerator for LongT5Generator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
@ -738,48 +743,6 @@ impl PrivateLanguageGenerator for LongT5Generator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -14,17 +14,16 @@ use crate::m2m_100::decoder::M2M100Decoder;
use crate::m2m_100::encoder::M2M100Encoder;
use crate::m2m_100::LayerState;
use crate::mbart::{MBartConfig, MBartModelOutput};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::common::TokenizerOption;
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::pipelines::translation::Language;
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::TruncationStrategy;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// # M2M100 Pretrained model weight files
pub struct M2M100ModelResources;
@ -522,7 +521,7 @@ impl M2M100Generator {
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::M2M100,
generate_config.model_type,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
@ -555,7 +554,7 @@ impl M2M100Generator {
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(M2M100Generator {
@ -572,14 +571,6 @@ impl M2M100Generator {
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator for M2M100Generator {
@ -589,11 +580,11 @@ impl PrivateLanguageGenerator for M2M100Generator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -616,8 +607,8 @@ impl PrivateLanguageGenerator for M2M100Generator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
@ -665,22 +656,6 @@ impl PrivateLanguageGenerator for M2M100Generator {
})
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.model.encode(input_ids, attention_mask))
}
@ -713,48 +688,6 @@ impl PrivateLanguageGenerator for M2M100Generator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -14,15 +14,14 @@
use crate::bart::{BartConfig, BartModel, BartModelOutput, LayerState};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
force_token_id_generation, PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::pipelines::translation::Language;
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::TruncationStrategy;
use std::borrow::Borrow;
use tch::nn::Init;
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
/// # Marian Pretrained model weight files
pub struct MarianModelResources;
@ -773,10 +772,10 @@ impl MarianGenerator {
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id =
Some(tokenizer.get_pad_id().ok_or(RustBertError::TokenizerError(
"The tokenizer must contain a pad token ID to be used as BOS".to_string(),
))?);
let decoder_start_id = match config.decoder_start_token_id {
Some(start_token_id) => Some(start_token_id),
None => pad_token_id,
};
let max_position_embeddings = config.max_position_embeddings;
Ok(MarianGenerator {
@ -793,14 +792,6 @@ impl MarianGenerator {
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator for MarianGenerator {
@ -810,11 +801,11 @@ impl PrivateLanguageGenerator for MarianGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -837,8 +828,8 @@ impl PrivateLanguageGenerator for MarianGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
@ -901,7 +892,11 @@ impl PrivateLanguageGenerator for MarianGenerator {
);
if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
force_token_id_generation(
scores,
self.get_eos_ids().as_ref().unwrap(),
self.get_vocab_size(),
);
}
}
}
@ -938,48 +933,6 @@ impl PrivateLanguageGenerator for MarianGenerator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -22,13 +22,12 @@ use crate::pipelines::generation_utils::private_generation_utils::{
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::pipelines::translation::Language;
use crate::{Activation, Config, RustBertError};
use rust_tokenizers::tokenizer::TruncationStrategy;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig, Init};
use tch::{nn, Tensor};
use tch::{nn, Device, Tensor};
/// # MBART Pretrained model weight files
pub struct MBartModelResources;
@ -99,6 +98,7 @@ pub struct MBartConfig {
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub pad_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub decoder_start_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
@ -138,6 +138,7 @@ impl Default for MBartConfig {
bos_token_id: Some(0),
eos_token_id: Some(2),
pad_token_id: Some(1),
forced_bos_token_id: None,
forced_eos_token_id: Some(2),
decoder_start_token_id: None,
id2label: None,
@ -725,6 +726,7 @@ pub struct MBartGenerator {
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
forced_eos_token_id: Option<i64>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
@ -805,10 +807,11 @@ impl MBartGenerator {
Some(value) => vec![value],
None => vec![2],
});
let forced_eos_token_id = config.forced_eos_token_id;
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(MBartGenerator {
@ -818,6 +821,7 @@ impl MBartGenerator {
generate_config,
bos_token_id,
eos_token_ids,
forced_eos_token_id,
pad_token_id,
is_encoder_decoder,
vocab_size,
@ -825,14 +829,6 @@ impl MBartGenerator {
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator for MBartGenerator {
@ -842,11 +838,11 @@ impl PrivateLanguageGenerator for MBartGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -857,6 +853,9 @@ impl PrivateLanguageGenerator for MBartGenerator {
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_forced_eos_token_id(&self) -> Option<i64> {
self.forced_eos_token_id
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
@ -915,24 +914,8 @@ impl PrivateLanguageGenerator for MBartGenerator {
})
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
@ -967,48 +950,6 @@ impl PrivateLanguageGenerator for MBartGenerator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -24,7 +24,7 @@ use crate::{Config, RustBertError};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
use tch::{nn, Device, Tensor};
/// # GPT Pretrained model weight files
pub struct OpenAiGptModelResources;
@ -505,7 +505,7 @@ impl OpenAIGenerator {
let pad_token_id = tokenizer.get_pad_id();
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.n_positions;
Ok(OpenAIGenerator {
@ -531,11 +531,11 @@ impl PrivateLanguageGenerator for OpenAIGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -558,8 +558,8 @@ impl PrivateLanguageGenerator for OpenAIGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(

View File

@ -11,7 +11,6 @@
// limitations under the License.
use crate::bart::BartModelOutput;
use crate::common::kind::get_negative_infinity;
use crate::mbart::MBartConfig;
use crate::pegasus::decoder::PegasusDecoder;
use crate::pegasus::encoder::PegasusEncoder;
@ -22,10 +21,9 @@ use crate::pipelines::generation_utils::private_generation_utils::{
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::TruncationStrategy;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig, Init};
use tch::{nn, Tensor};
use tch::{nn, Device, Tensor};
/// # Pegasus Pretrained model weight files
pub struct PegasusModelResources;
@ -510,14 +508,13 @@ impl PegasusConditionalGenerator {
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![1],
});
let eos_token_ids = config
.eos_token_id
.map_or(Some(vec![1]), |value| Some(vec![value]));
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(0);
let decoder_start_id = config.decoder_start_token_id.or(Some(0));
let max_position_embeddings = config.max_position_embeddings;
Ok(PegasusConditionalGenerator {
@ -534,18 +531,6 @@ impl PegasusConditionalGenerator {
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(
1,
&impossible_tokens,
get_negative_infinity(scores.kind()).unwrap(),
);
}
}
impl PrivateLanguageGenerator for PegasusConditionalGenerator {
@ -555,11 +540,11 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -582,8 +567,8 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
@ -630,20 +615,6 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
})
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
) {
if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.model.encode(input_ids, attention_mask))
}
@ -676,51 +647,6 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.get_pad_id()
.expect("A padding token must be provided to encode prompt texts."),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -36,8 +36,10 @@ use crate::mbart::MBartConfig;
use crate::mobilebert::MobileBertConfig;
use crate::openai_gpt::OpenAiGptConfig;
use crate::pegasus::PegasusConfig;
use crate::pipelines::translation::Language;
use crate::prophetnet::ProphetNetConfig;
use crate::reformer::ReformerConfig;
use crate::resources::{Resource, ResourceProvider};
use crate::roberta::RobertaConfig;
use crate::t5::T5Config;
use crate::xlnet::XLNetConfig;
@ -52,9 +54,104 @@ use rust_tokenizers::tokenizer::{
use rust_tokenizers::vocab::Vocab;
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::path::Path;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
use tch::{Device, Kind, Tensor};
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXModelConfig;
#[derive(Debug, Default)]
/// Container for ONNX model resources, containing 3 optional resources (Encoder, Decoder and Decoder with past)
pub struct ONNXModelResources {
/// Model encoder resource
pub encoder_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Model encoder resource
pub decoder_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Model encoder resource
pub decoder_with_past_resource: Option<Box<dyn ResourceProvider + Send>>,
}
#[derive(Debug)]
/// Variants to store either a Torch model resource or ONNX resources
pub enum ModelResource {
Torch(Box<dyn ResourceProvider + Send>),
#[cfg(feature = "onnx")]
ONNX(ONNXModelResources),
}
impl ResourceProvider for ModelResource {
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
match self {
ModelResource::Torch(ref resource) => resource.get_local_path(),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
}
}
fn get_resource(&self) -> Result<Resource, RustBertError> {
match self {
ModelResource::Torch(ref resource) => resource.get_resource(),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
}
}
}
pub struct ONNXLocalPaths {
pub encoder_path: Option<PathBuf>,
pub decoder_path: Option<PathBuf>,
pub decoder_with_past_path: Option<PathBuf>,
}
impl ModelResource {
/// Provides the torch resource local path.
/// Returns an error if the variant is not a `ModelResources::TORCH`
pub fn get_torch_local_path(&self) -> Result<PathBuf, RustBertError> {
match self {
ModelResource::Torch(torch_resource) => torch_resource.get_local_path(),
#[cfg(feature = "onnx")]
_ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the Torch local path but other weights variants were given: {:?}", self)))
}
}
#[cfg(feature = "onnx")]
pub fn get_onnx_local_paths(&self) -> Result<ONNXLocalPaths, RustBertError> {
let (encoder_path, decoder_path, decoder_with_past_path) = match self {
ModelResource::ONNX(onnx_model_resources) => Ok((
onnx_model_resources
.encoder_resource.as_ref()
.map(|r| r.get_local_path()),
onnx_model_resources
.decoder_resource.as_ref()
.map(|r| r.get_local_path()),
onnx_model_resources
.decoder_with_past_resource.as_ref()
.map(|r| r.get_local_path()),
)),
_ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the ONNX local paths but other weights variants were given: {:?}", self)))
}?;
Ok(ONNXLocalPaths {
encoder_path: encoder_path.transpose()?,
decoder_path: decoder_path.transpose()?,
decoder_with_past_path: decoder_with_past_path.transpose()?,
})
}
}
pub(crate) fn get_device(_model_resource: ModelResource, device: Device) -> Device {
#[cfg(feature = "onnx")]
let device = if let ModelResource::ONNX(_) = _model_resource {
Device::Cpu
} else {
device
};
#[cfg(not(feature = "onnx"))]
let device = device;
device
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq, Eq)]
/// # Identifies the type of model
@ -92,6 +189,8 @@ pub enum ModelType {
#[serde(alias = "m2m100")]
NLLB,
FNet,
#[cfg(feature = "onnx")]
ONNX,
}
/// # Abstraction that holds a model configuration, can be of any of the supported models
@ -144,6 +243,9 @@ pub enum ConfigOption {
M2M100(M2M100Config),
/// FNet configuration
FNet(FNetConfig),
/// ONNX Model configuration
#[cfg(feature = "onnx")]
ONNX(ONNXModelConfig),
}
/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
@ -220,6 +322,8 @@ impl ConfigOption {
ConfigOption::M2M100(M2M100Config::from_file(path))
}
ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)),
#[cfg(feature = "onnx")]
ModelType::ONNX => ConfigOption::ONNX(ONNXModelConfig::from_file(path)),
}
}
@ -293,6 +397,11 @@ impl ConfigOption {
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
#[cfg(feature = "onnx")]
Self::ONNX(config) => config
.id2label
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
Self::LongT5(_) => panic!("LongT5 does not use a label mapping"),
Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"),
@ -329,6 +438,132 @@ impl ConfigOption {
Self::M2M100(config) => Some(config.max_position_embeddings),
Self::FNet(config) => Some(config.max_position_embeddings),
Self::Roberta(config) => Some(config.max_position_embeddings),
#[cfg(feature = "onnx")]
Self::ONNX(config) => config.max_position_embeddings,
}
}
pub fn get_vocab_size(&self) -> i64 {
match self {
Self::Bart(config) => config.vocab_size,
Self::Bert(config) => config.vocab_size,
Self::Deberta(config) => config.vocab_size,
Self::DebertaV2(config) => config.vocab_size,
Self::DistilBert(config) => config.vocab_size,
Self::Electra(config) => config.vocab_size,
Self::Marian(config) => config.vocab_size,
Self::MobileBert(config) => config.vocab_size,
Self::T5(config) => config.vocab_size,
Self::LongT5(config) => config.vocab_size,
Self::Albert(config) => config.vocab_size,
Self::XLNet(config) => config.vocab_size,
Self::GPT2(config) => config.vocab_size,
Self::GPTJ(config) => config.vocab_size,
Self::Reformer(config) => config.vocab_size,
Self::ProphetNet(config) => config.vocab_size,
Self::Longformer(config) => config.vocab_size,
Self::Pegasus(config) => config.vocab_size,
Self::OpenAiGpt(config) => config.vocab_size,
Self::GPTNeo(config) => config.vocab_size,
Self::MBart(config) => config.vocab_size,
Self::M2M100(config) => config.vocab_size,
Self::FNet(config) => config.vocab_size,
Self::Roberta(config) => config.vocab_size,
#[cfg(feature = "onnx")]
Self::ONNX(config) => config.vocab_size,
}
}
pub fn get_decoder_start_token_id(&self) -> Option<i64> {
match self {
Self::Bart(config) => config.decoder_start_token_id,
Self::Bert(_) => None,
Self::Deberta(_) => None,
Self::DebertaV2(_) => None,
Self::DistilBert(_) => None,
Self::Electra(_) => None,
Self::Marian(config) => config.decoder_start_token_id,
Self::MobileBert(_) => None,
Self::T5(config) => config.decoder_start_token_id,
Self::LongT5(config) => config.decoder_start_token_id,
Self::Albert(_) => None,
Self::XLNet(_) => None,
Self::GPT2(config) => config.decoder_start_token_id,
Self::GPTJ(config) => config.decoder_start_token_id,
Self::Reformer(config) => config.decoder_start_token_id,
Self::ProphetNet(config) => config.decoder_start_token_id,
Self::Longformer(_) => None,
Self::Pegasus(config) => config.decoder_start_token_id,
Self::OpenAiGpt(config) => config.decoder_start_token_id,
Self::GPTNeo(config) => config.decoder_start_token_id,
Self::MBart(config) => config.decoder_start_token_id,
Self::M2M100(config) => config.decoder_start_token_id,
Self::FNet(config) => config.decoder_start_token_id,
Self::Roberta(_) => None,
#[cfg(feature = "onnx")]
Self::ONNX(config) => config.decoder_start_token_id,
}
}
pub fn get_forced_bos_token_id(&self) -> Option<i64> {
match self {
Self::Bart(config) => config.forced_bos_token_id,
Self::Bert(_) => None,
Self::Deberta(_) => None,
Self::DebertaV2(_) => None,
Self::DistilBert(_) => None,
Self::Electra(_) => None,
Self::Marian(config) => config.forced_bos_token_id,
Self::MobileBert(_) => None,
Self::T5(config) => config.forced_bos_token_id,
Self::LongT5(config) => config.forced_bos_token_id,
Self::Albert(_) => None,
Self::XLNet(_) => None,
Self::GPT2(config) => config.forced_bos_token_id,
Self::GPTJ(config) => config.forced_bos_token_id,
Self::Reformer(config) => config.forced_bos_token_id,
Self::ProphetNet(config) => config.forced_bos_token_id,
Self::Longformer(_) => None,
Self::Pegasus(config) => config.forced_bos_token_id,
Self::OpenAiGpt(config) => config.forced_bos_token_id,
Self::GPTNeo(config) => config.forced_bos_token_id,
Self::MBart(config) => config.forced_bos_token_id,
Self::M2M100(config) => config.forced_bos_token_id,
Self::FNet(_) => None,
Self::Roberta(_) => None,
#[cfg(feature = "onnx")]
Self::ONNX(config) => config.forced_bos_token_id,
}
}
pub fn get_forced_eos_token_id(&self) -> Option<i64> {
match self {
Self::Bart(config) => config.forced_eos_token_id,
Self::Bert(_) => None,
Self::Deberta(_) => None,
Self::DebertaV2(_) => None,
Self::DistilBert(_) => None,
Self::Electra(_) => None,
Self::Marian(config) => config.forced_eos_token_id,
Self::MobileBert(_) => None,
Self::T5(config) => config.forced_eos_token_id,
Self::LongT5(config) => config.forced_eos_token_id,
Self::Albert(_) => None,
Self::XLNet(_) => None,
Self::GPT2(config) => config.forced_eos_token_id,
Self::GPTJ(config) => config.forced_eos_token_id,
Self::Reformer(config) => config.forced_eos_token_id,
Self::ProphetNet(config) => config.forced_eos_token_id,
Self::Longformer(_) => None,
Self::Pegasus(config) => config.forced_eos_token_id,
Self::OpenAiGpt(config) => config.forced_eos_token_id,
Self::GPTNeo(config) => config.forced_eos_token_id,
Self::MBart(config) => config.forced_eos_token_id,
Self::M2M100(config) => config.forced_eos_token_id,
Self::FNet(_) => None,
Self::Roberta(_) => None,
#[cfg(feature = "onnx")]
Self::ONNX(config) => config.forced_eos_token_id,
}
}
}
@ -670,6 +905,10 @@ impl TokenizerOption {
lower_case,
strip_accents.unwrap_or(false),
)?),
#[cfg(feature = "onnx")]
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
"Default Tokenizer not defined for generic ONNX models.".to_string(),
))?,
};
Ok(tokenizer)
}
@ -1311,6 +1550,167 @@ impl TokenizerOption {
}
}
/// Helper function to prepare the input for translation models
pub fn get_prefix_and_forced_bos_id(
&self,
source_language: Option<&Language>,
target_language: Option<&Language>,
supported_source_languages: &HashSet<Language>,
supported_target_languages: &HashSet<Language>,
) -> Result<(Option<String>, Option<i64>), RustBertError> {
if let Some(source_language) = source_language {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
"{source_language} not in list of supported languages: {supported_source_languages:?}",
)));
}
}
if let Some(target_language) = target_language {
if !supported_target_languages.contains(target_language) {
return Err(RustBertError::ValueError(format!(
"{target_language} not in list of supported languages: {supported_target_languages:?}"
)));
}
}
Ok(match *self {
Self::Marian(_) => {
if supported_target_languages.len() > 1 {
(
Some(format!(
">>{}<< ",
target_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
"Missing target language for Marian \
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)",
)))?
)),
None,
)
} else {
(None, None)
}
}
Self::T5(_) => (
Some(format!(
"translate {} to {}:",
source_language.ok_or_else(|| RustBertError::ValueError(
"Missing source language for T5".to_string(),
))?,
target_language.ok_or_else(|| RustBertError::ValueError(
"Missing target language for T5".to_string(),
))?,
)),
None,
),
Self::MBart50(_) => {
(
Some(format!(
">>{}<< ",
source_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
"Missing source language for MBart\
(multiple languages supported by model: {supported_source_languages:?}, \
need to specify target language)"
)))?
)),
if let Some(target_language) = target_language {
Some(
self.convert_tokens_to_ids(&[format!(
">>{}<<",
target_language.get_iso_639_1_code().ok_or_else(|| {
RustBertError::ValueError(format!(
"This language has no ISO639-I code. Languages supported by model: {supported_source_languages:?}."
))
})?
)])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for MBart\
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)"
)));
},
)
}
Self::M2M100(_) => (
Some(match source_language {
Some(value) => {
let language_code = value.get_iso_639_1_code().ok_or_else(|| {
RustBertError::ValueError(format!(
"This language has no ISO639-I language code representation. \
languages supported by the model: {supported_target_languages:?}"
))
})?;
match language_code.len() {
2 => format!(">>{language_code}.<< "),
3 => format!(">>{language_code}<< "),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-I code".to_string(),
));
}
}
}
None => {
return Err(RustBertError::ValueError(format!(
"Missing source language for M2M100 \
(multiple languages supported by model: {supported_source_languages:?}, \
need to specify target language)"
)));
}
}),
if let Some(target_language) = target_language {
let language_code = target_language.get_iso_639_1_code().ok_or_else(|| {
RustBertError::ValueError(format!(
"This language has no ISO639-I language code representation. \
languages supported by the model: {supported_target_languages:?}"
))
})?;
Some(
self.convert_tokens_to_ids(&[
match language_code.len() {
2 => format!(">>{language_code}.<<"),
3 => format!(">>{language_code}<<"),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
));
}
},
])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for M2M100 \
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)",
)));
},
),
Self::NLLB(_) => {
let source_language = source_language
.and_then(Language::get_nllb_code)
.map(str::to_string)
.ok_or_else(|| RustBertError::ValueError(
format!("Missing source language for NLLB. Need to specify one from: {supported_source_languages:?}")
))?;
let target_language = target_language
.and_then(Language::get_nllb_code)
.map(str::to_string)
.map(|code| self.convert_tokens_to_ids(&[code])[0])
.ok_or_else(|| RustBertError::ValueError(
format!("Missing target language for NLLB. Need to specify one from: {supported_target_languages:?}")
))?;
(Some(source_language), Some(target_language))
}
_ => (None, None),
})
}
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
where
@ -1793,6 +2193,55 @@ impl TokenizerOption {
}
}
pub fn tokenize_and_pad<'a, S>(
&self,
input: S,
max_length: usize,
device: Device,
) -> (Tensor, Tensor)
where
S: AsRef<[&'a str]>,
{
let mut tokenized_input: Vec<TokenizedInput> = self.encode_list(
input.as_ref(),
max_length,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let pad_id = self
.get_pad_id()
.expect("The Tokenizer used for sequence classification should contain a PAD id");
let tokenized_input_tensors: Vec<Tensor> = tokenized_input
.iter_mut()
.map(|input| {
input.token_ids.resize(max_len, pad_id);
Tensor::from_slice(&(input.token_ids))
})
.collect::<Vec<_>>();
let token_type_ids: Vec<Tensor> = tokenized_input
.iter_mut()
.map(|input| {
input
.segment_ids
.resize(max_len, *input.segment_ids.last().unwrap_or(&0));
Tensor::from_slice(&(input.segment_ids))
})
.collect::<Vec<_>>();
(
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(device),
Tensor::stack(token_type_ids.as_slice(), 0)
.to(device)
.to_kind(Kind::Int64),
)
}
/// Interface method
pub fn add_extra_ids(&mut self, num_extra_ids: i64) {
match *self {

View File

@ -55,7 +55,7 @@
//! from the 3rd party utilization of the pretrained system.
use crate::common::error::RustBertError;
use crate::gpt2::GPT2Generator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::ResourceProvider;
@ -76,7 +76,7 @@ pub struct ConversationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: DialoGPT-medium)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: DialoGPT-medium)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: DialoGPT-medium)
@ -122,9 +122,9 @@ impl Default for ConversationConfig {
fn default() -> ConversationConfig {
ConversationConfig {
model_type: ModelType::GPT2,
model_resource: Box::new(RemoteResource::from_pretrained(
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
)),
))),
config_resource: Box::new(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
)),
@ -157,6 +157,7 @@ impl Default for ConversationConfig {
impl From<ConversationConfig> for GenerateConfig {
fn from(config: ConversationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,

View File

@ -82,20 +82,24 @@ use crate::t5::LayerState as T5LayerState;
use crate::xlnet::LayerState as XLNetLayerState;
use self::ordered_float::OrderedFloat;
use crate::pipelines::common::TokenizerOption;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
extern crate ordered_float;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXLayerCache;
use crate::RustBertError;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
extern crate ordered_float;
/// # Configuration for text generation
pub struct GenerateConfig {
/// Model type used for generation
pub model_type: ModelType,
/// Model weights resource (default: pretrained GPT2 model)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained GPT2 model)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained GPT2 model)
@ -138,7 +142,10 @@ pub struct GenerateConfig {
impl Default for GenerateConfig {
fn default() -> GenerateConfig {
GenerateConfig {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
model_type: ModelType::GPT2,
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
))),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
@ -223,17 +230,19 @@ pub enum Cache {
ProphetNetCache(Option<Vec<(Option<ProphetNetLayerState>, Option<ProphetNetLayerState>)>>),
GPTNeoCache(Option<Vec<Option<GPTNeoLayerState>>>),
GPTJCache(Option<Vec<Option<GPTJLayerState>>>),
#[cfg(feature = "onnx")]
ONNXCache(ONNXLayerCache),
None,
}
pub(crate) mod private_generation_utils {
use rust_tokenizers::TokenIdsWithOffsets;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::mem;
use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy};
use rust_tokenizers::TokenIdsWithOffsets;
use tch::{nn, Device, Kind, Tensor};
use crate::pipelines::common::TokenizerOption;
@ -242,7 +251,7 @@ pub(crate) mod private_generation_utils {
};
use super::ordered_float::OrderedFloat;
use crate::common::kind::get_positive_infinity;
use crate::common::kind::{get_negative_infinity, get_positive_infinity};
use crate::RustBertError;
pub struct InternalGenerateOptions<'a> {
@ -283,17 +292,23 @@ pub(crate) mod private_generation_utils {
pub trait PrivateLanguageGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption;
fn get_device(&self) -> Device;
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError>;
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption;
fn get_var_store(&self) -> &nn::VarStore;
fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
fn get_config(&self) -> &GenerateConfig;
fn get_bos_id(&self) -> Option<i64>;
fn get_eos_ids(&self) -> Option<&Vec<i64>>;
fn get_forced_bos_token_id(&self) -> Option<i64> {
None
}
fn get_forced_eos_token_id(&self) -> Option<i64> {
None
}
fn get_pad_id(&self) -> Option<i64>;
fn is_encoder_decoder(&self) -> bool;
fn get_vocab_size(&self) -> i64;
fn get_decoder_start_id(&self) -> Option<i64>;
fn get_max_positions_embeddings(&self) -> i64;
fn get_max_positions_embeddings(&self) -> Option<i64>;
fn forward_t(
&self,
@ -310,11 +325,32 @@ pub(crate) mod private_generation_utils {
fn prepare_scores_for_generation(
&self,
_scores: &mut Tensor,
_current_length: i64,
_max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
scores: &mut Tensor,
current_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
if let Some(forced_bos_token_id) =
forced_bos_token_id.or(self.get_forced_bos_token_id())
{
force_token_id_generation(
scores,
&[forced_bos_token_id],
self.get_vocab_size(),
);
}
} else if let Some(max_length) = max_length {
if let Some(forced_eos_token_id) = self.get_forced_eos_token_id() {
if current_length == max_length - 1 {
force_token_id_generation(
scores,
&[forced_eos_token_id],
self.get_vocab_size(),
);
}
}
}
}
fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> {
@ -347,48 +383,66 @@ pub(crate) mod private_generation_utils {
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
.collect::<Vec<Vec<i64>>>();
let num_truncated_tokens = token_ids
.iter()
.map(|token_ids| {
let token_ids = if self.is_encoder_decoder() {
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.unwrap_or(0)
})
.collect::<Vec<usize>>();
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>()
} else {
// Special tokens (e.g. BOS) are not added at the end of the prompt for causal generation
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| {
self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)
})
.collect::<Vec<Vec<i64>>>();
let token_ids = token_ids
.into_iter()
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
TokenIdsWithOffsets {
ids: tokens,
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
0,
)
.unwrap()
.0
.ids
})
.collect::<Vec<Vec<i64>>>();
let num_truncated_tokens = token_ids
.iter()
.map(|token_ids| {
max_len
.map(|max_len| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.unwrap_or(0)
})
.collect::<Vec<usize>>();
token_ids
.into_iter()
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
TokenIdsWithOffsets {
ids: tokens,
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
0,
)
.unwrap()
.0
.ids
})
.collect::<Vec<Vec<i64>>>()
};
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
@ -399,13 +453,20 @@ pub(crate) mod private_generation_utils {
let token_ids = token_ids
.into_iter()
.map(|input| {
.map(|mut input| {
let mut temp = vec![pad_token; max_len - input.len()];
temp.extend(input);
temp
if self.is_encoder_decoder() {
input.extend(temp);
input
} else {
// Pad left for causal generation
temp.extend(input);
temp
}
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
@ -767,9 +828,9 @@ pub(crate) mod private_generation_utils {
output_scores: bool,
) -> GeneratedOutputWithScores {
let mut unfinished_sentences =
Tensor::ones([batch_size], (Kind::Int64, self.get_var_store().device()));
Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
let mut sentence_lengths: Tensor =
Tensor::ones([batch_size], (Kind::Int64, self.get_var_store().device()));
Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
self.split_bad_word_ids(gen_opt.bad_word_ids);
let mut static_bad_words_mask: Option<Tensor> = None;
@ -1024,7 +1085,7 @@ pub(crate) mod private_generation_utils {
let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::ones(
[batch_size, gen_opt.num_beams],
(Kind::Float, self.get_var_store().device()),
(Kind::Float, self.get_device()),
) * -1e9;
let _ = beam_scores
.slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams)
@ -1033,11 +1094,11 @@ pub(crate) mod private_generation_utils {
let mut beam_scores = beam_scores.view_([-1]);
let mut beam_tokens = Tensor::zeros(
[batch_size * gen_opt.num_beams],
(Kind::Int64, self.get_var_store().device()),
(Kind::Int64, self.get_device()),
);
let mut beam_indices = Tensor::zeros(
[batch_size * gen_opt.num_beams],
(Kind::Int64, self.get_var_store().device()),
(Kind::Int64, self.get_device()),
);
let mut saved_beam_scores: Option<Vec<Tensor>> =
if output_scores { Some(vec![]) } else { None };
@ -1524,6 +1585,18 @@ pub(crate) mod private_generation_utils {
}
}
}
pub fn force_token_id_generation(scores: &mut Tensor, token_ids: &[i64], vocab_size: i64) {
let impossible_tokens: Vec<i64> = (0..vocab_size)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(
1,
&impossible_tokens,
get_negative_infinity(scores.kind()).unwrap(),
);
}
}
#[derive(Debug, Clone)]
@ -1805,7 +1878,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
generate_options.max_length
});
let encoding_max_len = if self.is_encoder_decoder() {
Some(self.get_max_positions_embeddings())
self.get_max_positions_embeddings()
} else {
max_length
};
@ -1819,9 +1892,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
self.encode_prompt_text(prompts, encoding_max_len, pad_token_id)
}
None => match self.get_bos_id() {
Some(bos_id) => {
Tensor::ones([1, 1], (Int64, self.get_var_store().device())) * bos_id
}
Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
None => panic!(
"A model with a BOS token must be used to start generation with an empty input"
),
@ -2155,16 +2226,19 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
self._get_tokenizer_mut()
}
fn half(&mut self) {
self.get_var_store_mut().half();
fn half(&mut self) -> Result<(), RustBertError> {
self.get_var_store_mut()?.half();
Ok(())
}
fn float(&mut self) {
self.get_var_store_mut().float();
fn float(&mut self) -> Result<(), RustBertError> {
self.get_var_store_mut()?.float();
Ok(())
}
fn set_device(&mut self, device: Device) {
self.get_var_store_mut().set_device(device);
fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
self.get_var_store_mut()?.set_device(device);
Ok(())
}
}

View File

@ -22,9 +22,10 @@
//!use rust_bert::resources::RemoteResource;
//! fn main() -> anyhow::Result<()> {
//!
//! let config = MaskedLanguageConfig::new(
//! use rust_bert::pipelines::common::ModelResource;
//! let config = MaskedLanguageConfig::new(
//! ModelType::Bert,
//! RemoteResource::from_pretrained(BertModelResources::BERT),
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT))),
//! RemoteResource::from_pretrained(BertConfigResources::BERT),
//! RemoteResource::from_pretrained(BertVocabResources::BERT),
//! None,
@ -50,20 +51,23 @@ use crate::common::error::RustBertError;
use crate::deberta::DebertaForMaskedLM;
use crate::deberta_v2::DebertaV2ForMaskedLM;
use crate::fnet::FNetForMaskedLM;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForMaskedLM;
use std::convert::TryFrom;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
bert::{BertConfigResources, BertModelResources, BertVocabResources},
resources::RemoteResource,
};
use rust_tokenizers::tokenizer::TruncationStrategy;
use rust_tokenizers::TokenizedInput;
use std::borrow::Borrow;
use std::convert::TryFrom;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
use tch::{no_grad, Device, Tensor};
#[derive(Debug, Clone)]
/// Output container for masked language model pipeline.
@ -82,7 +86,7 @@ pub struct MaskedLanguageConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
@ -113,9 +117,9 @@ impl MaskedLanguageConfig {
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// * mask_token - A token used for model to predict masking words..
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -125,13 +129,12 @@ impl MaskedLanguageConfig {
mask_token: impl Into<Option<String>>,
) -> MaskedLanguageConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
MaskedLanguageConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -149,7 +152,9 @@ impl Default for MaskedLanguageConfig {
fn default() -> MaskedLanguageConfig {
MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,
@ -176,28 +181,39 @@ pub enum MaskedLanguageOption {
XLMRoberta(RobertaForMaskedLM),
/// FNet for Masked Language
FNet(FNetForMaskedLM),
/// ONNX model for Masked Language
#[cfg(feature = "onnx")]
ONNX(ONNXEncoder),
}
impl MaskedLanguageOption {
/// Instantiate a new masked language model of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
/// * `MaskedLanguageConfig` - Masked language model pipeline configuration. The type of model created will be inferred from the
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
pub fn new(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
match config.model_resource {
ModelResource::Torch(_) => Self::new_torch(config),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Self::new_onnx(config),
}
}
fn new_torch(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
let device = config.device;
let weights_path = config.model_resource.get_torch_local_path()?;
let mut var_store = VarStore::new(device);
let model_config =
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
let model_type = config.model_type;
let model = match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(MaskedLanguageOption::Bert(BertForMaskedLM::new(p, config)))
if let ConfigOption::Bert(config) = model_config {
Ok(MaskedLanguageOption::Bert(BertForMaskedLM::new(
var_store.root(),
config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Bert!".to_string(),
@ -205,9 +221,10 @@ impl MaskedLanguageOption {
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
if let ConfigOption::Deberta(config) = model_config {
Ok(MaskedLanguageOption::Deberta(DebertaForMaskedLM::new(
p, config,
var_store.root(),
config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -216,9 +233,10 @@ impl MaskedLanguageOption {
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
if let ConfigOption::DebertaV2(config) = model_config {
Ok(MaskedLanguageOption::DebertaV2(DebertaV2ForMaskedLM::new(
p, config,
var_store.root(),
config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -227,9 +245,10 @@ impl MaskedLanguageOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
if let ConfigOption::Roberta(config) = model_config {
Ok(MaskedLanguageOption::Roberta(RobertaForMaskedLM::new(
p, config,
var_store.root(),
config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -238,9 +257,10 @@ impl MaskedLanguageOption {
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
if let ConfigOption::Bert(config) = model_config {
Ok(MaskedLanguageOption::XLMRoberta(RobertaForMaskedLM::new(
p, config,
var_store.root(),
config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -249,8 +269,11 @@ impl MaskedLanguageOption {
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(MaskedLanguageOption::FNet(FNetForMaskedLM::new(p, config)))
if let ConfigOption::FNet(config) = model_config {
Ok(MaskedLanguageOption::FNet(FNetForMaskedLM::new(
var_store.root(),
config,
)))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a FNetConfig for FNet!".to_string(),
@ -260,9 +283,29 @@ impl MaskedLanguageOption {
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Masked Language is not implemented for {model_type:?}!",
))),
}
}?;
var_store.load(weights_path)?;
Ok(model)
}
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
.encoder_path
.ok_or(RustBertError::InvalidConfigurationError(
"An encoder file must be provided for masked language ONNX models.".to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
}
/// Returns the `ModelType` for this MaskedLanguageOption
pub fn model_type(&self) -> ModelType {
match *self {
@ -272,6 +315,8 @@ impl MaskedLanguageOption {
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
Self::FNet(_) => ModelType::FNet,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
@ -350,6 +395,21 @@ impl MaskedLanguageOption {
.expect("Error in FNet forward pass.")
.prediction_scores
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => {
let attention_mask = input_ids.unwrap().ones_like();
model
.forward(
input_ids,
Some(&attention_mask),
token_type_ids,
position_ids,
input_embeds,
)
.expect("Error in ONNX forward pass.")
.logits
.unwrap()
}
}
}
}
@ -359,7 +419,7 @@ pub struct MaskedLanguageModel {
tokenizer: TokenizerOption,
language_encode: MaskedLanguageOption,
mask_token: Option<String>,
var_store: VarStore,
device: Device,
max_length: usize,
}
@ -428,25 +488,21 @@ impl MaskedLanguageModel {
config: MaskedLanguageConfig,
tokenizer: TokenizerOption,
) -> Result<MaskedLanguageModel, RustBertError> {
let language_encode = MaskedLanguageOption::new(&config)?;
let config_path = config.config_resource.get_local_path()?;
let device = config.device;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let max_length = model_config
.get_max_len()
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let language_encode =
MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
let mask_token = config.mask_token;
let device = get_device(config.model_resource, config.device);
Ok(MaskedLanguageModel {
tokenizer,
language_encode,
mask_token,
var_store,
device,
max_length,
})
}
@ -481,33 +537,6 @@ impl MaskedLanguageModel {
Ok(output)
}
fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(
input.as_ref(),
self.max_length,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
/// Mask texts
///
/// # Arguments
@ -544,30 +573,22 @@ impl MaskedLanguageModel {
where
S: AsRef<[&'a str]>,
{
let input_tensor = if let Some(mask_token) = &self.mask_token {
let (input_ids, token_type_ids) = if let Some(mask_token) = &self.mask_token {
let input_with_replaced_mask = self.replace_mask_token(input.as_ref(), mask_token)?;
self.prepare_for_model(
self.tokenizer.tokenize_and_pad(
input_with_replaced_mask
.iter()
.map(|w| w.as_str())
.collect::<Vec<&str>>(),
.collect::<Vec<&str>>()
.as_slice(),
self.max_length,
self.device,
)
} else {
self.prepare_for_model(input.as_ref())
self.tokenizer
.tokenize_and_pad(input.as_ref(), self.max_length, self.device)
};
let output = no_grad(|| {
self.language_encode.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
false,
)
});
// get the position of mask_token in input texts
let mask_token_id =
self.tokenizer
@ -575,7 +596,21 @@ impl MaskedLanguageModel {
.ok_or_else(|| RustBertError::InvalidConfigurationError(
"Tokenizer does not have a mask token id, Please use a tokenizer/model with a mask token.".into(),
))?;
let mask_token_mask = input_tensor.eq(mask_token_id);
let mask_token_mask = input_ids.eq(mask_token_id);
let output = no_grad(|| {
self.language_encode.forward_t(
Some(&input_ids),
None,
Some(&token_type_ids),
None,
None,
None,
None,
false,
)
});
let mut output_tokens = Vec::with_capacity(input.as_ref().len());
for input_id in 0..input.as_ref().len() as i64 {
let mut sequence_tokens = vec![];

View File

@ -490,3 +490,6 @@ pub mod text_generation;
pub mod token_classification;
pub mod translation;
pub mod zero_shot_classification;
#[cfg(feature = "onnx")]
pub mod onnx;

View File

@ -90,11 +90,12 @@
//! use tch::Device;
//!
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelResource;
//! let ner_config = TokenClassificationConfig {
//! model_type: ModelType::XLMRoberta,
//! model_resource: Box::new(RemoteResource::from_pretrained(
//! model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
//! )),
//! ))),
//! config_resource: Box::new(RemoteResource::from_pretrained(
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
//! )),

View File

@ -0,0 +1,44 @@
use ort::Session;
use std::collections::HashMap;
#[derive(Debug)]
pub(crate) struct InputOutputNameMapping {
pub(crate) input_names: Vec<String>,
pub(crate) output_names: HashMap<String, usize>,
pub(crate) key_value_output_names: HashMap<String, usize>,
}
pub(crate) fn get_input_output_mapping(session: &Session) -> InputOutputNameMapping {
let input_names = session
.inputs
.iter()
.map(|input| input.name.clone())
.collect::<Vec<String>>();
let output_names = session
.outputs
.iter()
.enumerate()
.map(|(pos, output)| (output.name.clone(), pos))
.collect::<HashMap<String, usize>>();
let mut key_value_output_names = output_names
.iter()
.filter(|(name, _)| name.contains(".key") | name.contains(".value"))
.map(|(name, pos)| (name.clone(), *pos))
.collect::<HashMap<String, usize>>();
if key_value_output_names.is_empty() {
key_value_output_names = output_names
.iter()
.filter(|(name, _)| name.contains("key_value"))
.map(|(name, pos)| (name.clone(), *pos))
.collect::<HashMap<String, usize>>();
}
InputOutputNameMapping {
input_names,
output_names,
key_value_output_names,
}
}

View File

@ -0,0 +1,105 @@
/// # Configuration for ONNX environment and sessions
use crate::RustBertError;
use ort::{
AllocatorType, Environment, ExecutionProvider, GraphOptimizationLevel, MemType, SessionBuilder,
};
use std::sync::Arc;
use tch::Device;
pub(crate) static INPUT_IDS_NAME: &str = "input_ids";
pub(crate) static ATTENTION_MASK_NAME: &str = "attention_mask";
pub(crate) static ENCODER_HIDDEN_STATES_NAME: &str = "encoder_hidden_states";
pub(crate) static ENCODER_ATTENTION_MASK_NAME: &str = "encoder_attention_mask";
pub(crate) static TOKEN_TYPE_IDS: &str = "token_type_ids";
pub(crate) static POSITION_IDS: &str = "position_ids";
pub(crate) static INPUT_EMBEDS: &str = "input_embeds";
pub(crate) static LAST_HIDDEN_STATE: &str = "last_hidden_state";
pub(crate) static LOGITS: &str = "logits";
pub(crate) static START_LOGITS: &str = "start_logits";
pub(crate) static END_LOGITS: &str = "end_logits";
#[derive(Default)]
/// # ONNX Environment configuration
/// See <https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions>
pub struct ONNXEnvironmentConfig {
pub optimization_level: Option<GraphOptimizationLevel>,
pub execution_providers: Option<Vec<ExecutionProvider>>,
pub num_intra_threads: Option<i16>,
pub num_inter_threads: Option<i16>,
pub parallel_execution: Option<bool>,
pub enable_memory_pattern: Option<bool>,
pub allocator: Option<AllocatorType>,
pub memory_type: Option<MemType>,
}
impl ONNXEnvironmentConfig {
/// Create a new `ONNXEnvironmentConfig` from a `tch::Device`.
/// This helper function maps torch device to ONNXRuntime execution providers
pub fn from_device(device: Device) -> Self {
let mut execution_providers = Vec::new();
if let Device::Cuda(_) = device {
execution_providers.push(ExecutionProvider::cuda());
};
execution_providers.push(ExecutionProvider::cpu());
ONNXEnvironmentConfig {
execution_providers: Some(execution_providers),
..Default::default()
}
}
///Build a session builder from an `ONNXEnvironmentConfig`.
pub fn get_session_builder(
&self,
environment: &Arc<Environment>,
) -> Result<SessionBuilder, RustBertError> {
let mut session_builder = SessionBuilder::new(environment)?;
match &self.optimization_level {
Some(GraphOptimizationLevel::Level3) | None => {}
Some(GraphOptimizationLevel::Level2) => {
session_builder =
session_builder.with_optimization_level(GraphOptimizationLevel::Level2)?
}
Some(GraphOptimizationLevel::Level1) => {
session_builder =
session_builder.with_optimization_level(GraphOptimizationLevel::Level1)?
}
Some(GraphOptimizationLevel::Disable) => {
session_builder =
session_builder.with_optimization_level(GraphOptimizationLevel::Disable)?
}
}
if let Some(num_intra_threads) = self.num_intra_threads {
session_builder = session_builder.with_intra_threads(num_intra_threads)?;
}
if let Some(num_inter_threads) = self.num_inter_threads {
session_builder = session_builder.with_inter_threads(num_inter_threads)?;
}
if let Some(parallel_execution) = self.parallel_execution {
session_builder = session_builder.with_parallel_execution(parallel_execution)?;
}
if let Some(enable_memory_pattern) = self.enable_memory_pattern {
session_builder = session_builder.with_memory_pattern(enable_memory_pattern)?;
}
if let Some(allocator) = &self.allocator {
session_builder = session_builder.with_allocator(allocator.clone())?;
}
if let Some(memory_type) = &self.memory_type {
session_builder = session_builder.with_memory_type(memory_type.clone())?;
}
Ok(session_builder)
}
///Build an ONNXEnvironment from an `ONNXEnvironmentConfig`.
pub fn get_environment(&self) -> Result<Arc<Environment>, RustBertError> {
Ok(Arc::new(
Environment::builder()
.with_name("Default environment")
.with_execution_providers(
self.execution_providers
.clone()
.unwrap_or(vec![ExecutionProvider::cpu()]),
)
.build()?,
))
}
}

View File

@ -0,0 +1,57 @@
use crate::RustBertError;
use ndarray::IxDyn;
use ort::tensor::{DynOrtTensor, FromArray, InputTensor};
use std::convert::{TryFrom, TryInto};
use tch::{Kind, Tensor};
pub(crate) fn ort_tensor_to_tch(ort_tensor: &DynOrtTensor<IxDyn>) -> Result<Tensor, RustBertError> {
let ort_tensor = ort_tensor.try_extract::<f32>()?.view().to_owned();
Ok(Tensor::try_from(ort_tensor)?)
}
pub(crate) fn tch_tensor_to_ort(tch_tensor: &Tensor) -> Result<InputTensor, RustBertError> {
let kind = tch_tensor.kind();
Ok(match kind{
Kind::Int64 => {
let array: ndarray::ArrayD<i64> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Float => {
let array: ndarray::ArrayD<f32> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Int => {
let array: ndarray::ArrayD<i32> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Double => {
let array: ndarray::ArrayD<f64> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Half => {
let array: ndarray::ArrayD<half::f16> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Int16 => {
let array: ndarray::ArrayD<i16> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Int8 => {
let array: ndarray::ArrayD<i8> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::Uint8 => {
let array: ndarray::ArrayD<u8> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
Kind::BFloat16 => {
let array: ndarray::ArrayD<half::bf16> = tch_tensor.try_into()?;
InputTensor::from_array(array)
}
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get convert torch tensor to ndarray infinity for {kind:?}",
)))
}
})
}

View File

@ -0,0 +1,113 @@
use crate::pipelines::generation_utils::{Cache, LMModelOutput};
use crate::pipelines::onnx::common::{get_input_output_mapping, InputOutputNameMapping};
use crate::pipelines::onnx::config::{
ONNXEnvironmentConfig, ATTENTION_MASK_NAME, ENCODER_ATTENTION_MASK_NAME,
ENCODER_HIDDEN_STATES_NAME, INPUT_IDS_NAME, POSITION_IDS,
};
use crate::pipelines::onnx::conversion::{ort_tensor_to_tch, tch_tensor_to_ort};
use crate::pipelines::onnx::models::ONNXLayerCache;
use crate::RustBertError;
use ort::{Environment, Session};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tch::Tensor;
pub struct ONNXDecoder {
session: Session,
name_mapping: InputOutputNameMapping,
use_cache: bool,
}
impl ONNXDecoder {
pub fn new(
model_file: PathBuf,
use_cache: bool,
environment: &Arc<Environment>,
onnx_config: &ONNXEnvironmentConfig,
) -> Result<Self, RustBertError> {
let session = onnx_config
.get_session_builder(environment)?
.with_model_from_file(model_file)?;
let name_mapping = get_input_output_mapping(&session);
Ok(Self {
session,
name_mapping,
use_cache,
})
}
pub fn forward(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<&ONNXLayerCache>,
) -> Result<LMModelOutput, RustBertError> {
let mut input_dict = HashMap::new();
if let Some(input_ids) = input_ids {
input_dict.insert(INPUT_IDS_NAME, input_ids);
}
if let Some(attention_mask) = attention_mask {
input_dict.insert(ATTENTION_MASK_NAME, attention_mask);
}
if let Some(encoder_hidden_states) = encoder_hidden_states {
input_dict.insert(ENCODER_HIDDEN_STATES_NAME, encoder_hidden_states);
}
if let Some(encoder_attention_mask) = encoder_attention_mask {
input_dict.insert(ENCODER_ATTENTION_MASK_NAME, encoder_attention_mask);
}
if let Some(position_ids) = position_ids {
input_dict.insert(POSITION_IDS, position_ids);
}
let inputs = self
.name_mapping
.input_names
.iter()
.map(|input_name| {
if let Some(tensor) = input_dict.remove(input_name.as_str()) {
Ok(tch_tensor_to_ort(tensor)?)
} else {
let layer_states = layer_states.ok_or_else(|| {
RustBertError::OrtError(format!(
"{input_name} not found and cache was not provided."
))
})?;
let input_pos = layer_states
.values
.get(&input_name.replace("past", "present"))
.or_else(|| {
layer_states
.values
.get(&input_name.replace("past_key_values", "present"))
})
.ok_or_else(|| {
let found_keys = layer_states.values.keys().collect::<Vec<&String>>();
RustBertError::OrtError(format!(
"{input_name} not found in cache ({found_keys:?})."
))
})?;
tch_tensor_to_ort(input_pos)
}
})
.collect::<Result<Vec<_>, RustBertError>>()?;
let outputs = self.session.run(inputs)?;
let lm_logits =
ort_tensor_to_tch(&outputs[*self.name_mapping.output_names.get("logits").unwrap()])?;
let cache = if self.use_cache {
Cache::ONNXCache(ONNXLayerCache::from_ort_output(
&outputs,
&self.name_mapping.key_value_output_names,
)?)
} else {
Cache::None
};
Ok(LMModelOutput { lm_logits, cache })
}
}

View File

@ -0,0 +1,228 @@
use crate::pipelines::onnx::common::{get_input_output_mapping, InputOutputNameMapping};
use crate::pipelines::onnx::config::{
ONNXEnvironmentConfig, ATTENTION_MASK_NAME, END_LOGITS, INPUT_EMBEDS, INPUT_IDS_NAME,
LAST_HIDDEN_STATE, LOGITS, POSITION_IDS, START_LOGITS, TOKEN_TYPE_IDS,
};
use crate::pipelines::onnx::conversion::{ort_tensor_to_tch, tch_tensor_to_ort};
use crate::RustBertError;
use ort::{Environment, Session};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tch::Tensor;
/// # ONNX Encoder model
/// Container for an ONNX encoder model and the corresponding session. Can be used individually for
/// pure-encoder models (e.g. BERT) or as part of encoder/decoder architectures.
pub struct ONNXEncoder {
session: Session,
name_mapping: InputOutputNameMapping,
}
impl ONNXEncoder {
/// Create a new `ONNXEncoder`. Requires a pointer to the model file for
/// the encoder, a reference to an environment and an ONNX environment configuration.
///
/// # Example
///
/// ```no_run
/// use ort::Environment;
/// use rust_bert::pipelines::onnx::config::ONNXEnvironmentConfig;
/// use rust_bert::pipelines::onnx::ONNXEncoder;
/// use std::path::PathBuf;
/// use std::sync::Arc;
/// let environment = Arc::new(Environment::default());
/// let onnx_config = ONNXEnvironmentConfig::default();
/// let model_file = PathBuf::from("path/to/model.onnx");
///
/// let encoder = ONNXEncoder::new(model_file, &environment, &onnx_config).unwrap();
/// ```
pub fn new(
model_file: PathBuf,
environment: &Arc<Environment>,
onnx_config: &ONNXEnvironmentConfig,
) -> Result<Self, RustBertError> {
let session = onnx_config
.get_session_builder(environment)?
.with_model_from_file(model_file)?;
let name_mapping = get_input_output_mapping(&session);
Ok(Self {
session,
name_mapping,
})
}
/// Forward pass through the model.
///
/// The outputs provided by the model depend on the underlying ONNX model and are all marked as optional to support a broad range of
/// encoder stacks for multiple stacks. The end-user should extract the required output that is provided by the model exported.
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
///
/// # Returns
///
/// * `ONNXEncoderModelOutput` containing:
/// - `last_hidden_state` - Optional `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `logits` - Optional `Tensor` of shape (*batch size*, *num_labels*)
/// - `start_logits` - Optional `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// - `end_logits` - Optional `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// - `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
/// bert_model
/// .forward_t(
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// None,
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
pub fn forward(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
) -> Result<ONNXEncoderModelOutput, RustBertError> {
let mut input_dict = HashMap::new();
if let Some(input_ids) = input_ids {
input_dict.insert(INPUT_IDS_NAME, input_ids);
}
if let Some(attention_mask) = attention_mask {
input_dict.insert(ATTENTION_MASK_NAME, attention_mask);
}
if let Some(token_type_ids) = token_type_ids {
input_dict.insert(TOKEN_TYPE_IDS, token_type_ids);
}
if let Some(position_ids) = position_ids {
input_dict.insert(POSITION_IDS, position_ids);
}
if let Some(input_embeds) = input_embeds {
input_dict.insert(INPUT_EMBEDS, input_embeds);
}
let inputs = self
.name_mapping
.input_names
.iter()
.map(|input_name| {
if let Some(tensor) = input_dict.remove(input_name.as_str()) {
tch_tensor_to_ort(tensor)
} else {
Err(RustBertError::OrtError(format!(
"{input_name} not found but expected by model."
)))
}
})
.collect::<Result<Vec<_>, RustBertError>>()?;
let outputs = self.session.run(inputs)?;
let last_hidden_state = self
.name_mapping
.output_names
.get(LAST_HIDDEN_STATE)
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
.transpose()?;
let logits = self
.name_mapping
.output_names
.get(LOGITS)
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
.transpose()?;
let start_logits = self
.name_mapping
.output_names
.get(START_LOGITS)
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
.transpose()?;
let end_logits = self
.name_mapping
.output_names
.get(END_LOGITS)
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
.transpose()?;
let (hidden_states, attentions) = if self.name_mapping.output_names.len() > 1 {
let hidden_states = self
.name_mapping
.output_names
.iter()
.filter(|(name, _)| name.contains("hidden_states"))
.map(|(_, position)| outputs.get(*position))
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
.collect::<Option<Vec<_>>>();
let attentions = self
.name_mapping
.output_names
.iter()
.filter(|(name, _)| name.contains("attentions"))
.map(|(_, position)| outputs.get(*position))
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
.collect::<Option<Vec<_>>>();
(hidden_states, attentions)
} else {
(None, None)
};
Ok(ONNXEncoderModelOutput {
last_hidden_state,
logits,
start_logits,
end_logits,
hidden_states,
attentions,
})
}
}
/// # ONNX encoder model output.
/// The outputs provided by the model depend on the underlying ONNX model and are all marked as optional to support a broad range of
/// encoder stacks for multiple stacks. The end-user should extract the required output that is provided by the model exported.
pub struct ONNXEncoderModelOutput {
/// Last hidden states, typically used by masked language model encoder models
pub last_hidden_state: Option<Tensor>,
/// logits, typically used by models with a sequence of classification head
pub logits: Option<Tensor>,
/// logits marking the start location of a span (e.g. for extractive question answering tasks)
pub start_logits: Option<Tensor>,
/// logits marking the end location of a span (e.g. for extractive question answering tasks)
pub end_logits: Option<Tensor>,
/// Hidden states for intermediate layers of the model
pub hidden_states: Option<Vec<Tensor>>,
/// Attention weights for intermediate layers of the model
pub attentions: Option<Vec<Tensor>>,
}

115
src/pipelines/onnx/mod.rs Normal file
View File

@ -0,0 +1,115 @@
//! # ONNX model support
//!
//! This crate allows running inference on models that were exported to ONNX via [onnxruntime](https://onnxruntime.ai/about.html)
//! [bindings](https://github.com/pykeio/ort). In order to use ONNX model the corresponding optional feature (`onnx`) should be turned on.
//! This will include the optional `ort` and `ndarray` dependencies. The `rust-bert` crate does not include any optional dependencies for `ort`,
//! the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library. The current recommended
//! installation is to use dynamic linking by pointing to an existing library location:
//! - Use the `load-dynamic` cargo feature for `ort`
//! - set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib`
//! depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
//!
//! For troubleshooting issues when using an ONNX model, it is recommended to add the `tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }`
//! dependency, and use the `tracing_subscriber::fmt::init();` instruction in the `main` binary.
//!
//! Most architectures (including encoders, decoders and encoder-decoders) are supported.
//! the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library.
//! A detailed guide on how to export a Transformer model to ONNX using optimum is available at https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model
//!
//! The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models
//! are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually
//! result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
//!
//! | Architecture | Encoder file | Decoder without past file | Decoder with past file |
//! |----------------------|---------------|----------------------------|-------------------------|
//! | Encoder (e.g. BERT) | required | not used | not used |
//! | Decoder (e.g. GPT2) | not used | required | optional |
//! | Encoder-decoder (e.g. BART) | required | required | optional |
//!
//! Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
//! since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
//! redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
//!
//! The base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
//! Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
//!
//! Most pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
//! token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
//!
//! These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. The following is
//! an example of a translation model based on a ONNX export of M2M100:
//! ```no_run
//! use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
//! use tch::Device;
//!
//! use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use rust_bert::resources::RemoteResource;
//!
//! fn main() -> anyhow::Result<()> {
//! let translation_model = TranslationModel::new(TranslationConfig::new(
//! ModelType::M2M100,
//! ModelResource::ONNX(ONNXModelResources {
//! encoder_resource: Some(Box::new(RemoteResource::new(
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
//! "onnx-m2m100_418M",
//! ))),
//! decoder_resource: Some(Box::new(RemoteResource::new(
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
//! "onnx-m2m100_418M",
//! ))),
//! decoder_with_past_resource: Some(Box::new(RemoteResource::new(
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
//! "onnx-m2m100_418M",
//! ))),
//! }),
//! RemoteResource::new(
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
//! "onnx-m2m100_418M",
//! ),
//! RemoteResource::new(
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
//! "onnx-m2m100_418M",
//! ),
//! Some(RemoteResource::new(
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
//! "onnx-m2m100_418M",
//! )),
//! M2M100SourceLanguages::M2M100_418M,
//! M2M100TargetLanguages::M2M100_418M,
//! Device::cuda_if_available(),
//! ))?;
//!
//! let source_sentence = "This sentence will be translated in multiple languages.";
//!
//! let mut outputs = Vec::new();
//! outputs.extend(translation_model.translate(
//! &[source_sentence],
//! Language::English,
//! Language::French,
//! )?);
//! outputs.extend(translation_model.translate(
//! &[source_sentence],
//! Language::English,
//! Language::Spanish,
//! )?);
//! outputs.extend(translation_model.translate(
//! &[source_sentence],
//! Language::English,
//! Language::Hindi,
//! )?);
//!
//! println!("{:?}", outputs);
//! Ok(())
//! }
//! ```
mod common;
pub mod config;
mod conversion;
mod decoder;
mod encoder;
mod models;
pub use encoder::{ONNXEncoder, ONNXEncoderModelOutput};
pub use models::{ONNXCausalGenerator, ONNXConditionalGenerator, ONNXLayerCache, ONNXModelConfig};

1130
src/pipelines/onnx/models.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@ -92,7 +92,10 @@ use {
mobilebert::{
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
},
pipelines::{common::ModelType, token_classification::LabelAggregationOption},
pipelines::{
common::{ModelResource, ModelType},
token_classification::LabelAggregationOption,
},
resources::RemoteResource,
},
tch::Device,
@ -121,9 +124,9 @@ impl Default for POSConfig {
POSConfig {
token_classification_config: TokenClassificationConfig {
model_type: ModelType::MobileBert,
model_resource: Box::new(RemoteResource::from_pretrained(
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
)),
))),
config_resource: Box::new(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
)),
@ -142,6 +145,14 @@ impl Default for POSConfig {
}
}
impl From<TokenClassificationConfig> for POSConfig {
fn from(token_classification_config: TokenClassificationConfig) -> Self {
POSConfig {
token_classification_config,
}
}
}
impl From<POSConfig> for TokenClassificationConfig {
fn from(pos_config: POSConfig) -> Self {
pos_config.token_classification_config

View File

@ -51,23 +51,27 @@ use crate::distilbert::DistilBertForQuestionAnswering;
use crate::fnet::FNetForQuestionAnswering;
use crate::longformer::LongformerForQuestionAnswering;
use crate::mobilebert::MobileBertForQuestionAnswering;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::reformer::ReformerForQuestionAnswering;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForQuestionAnswering;
use crate::xlnet::XLNetForQuestionAnswering;
use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::cmp::min;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
use tch::{no_grad, Device, Kind, Tensor};
use crate::deberta_v2::DebertaV2ForQuestionAnswering;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
@ -88,6 +92,7 @@ pub struct QaInput {
struct QaFeature {
pub input_ids: Vec<i64>,
pub offsets: Vec<Option<Offset>>,
pub token_type_ids: Vec<i8>,
pub p_mask: Vec<i8>,
pub example_index: i64,
}
@ -128,7 +133,7 @@ fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
/// Contains information regarding the model to load and device to place the model on.
pub struct QuestionAnsweringConfig {
/// Model weights resource (default: pretrained DistilBERT model on SQuAD)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained DistilBERT model on SQuAD)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained DistilBERT model on SQuAD)
@ -166,9 +171,9 @@ impl QuestionAnsweringConfig {
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -177,13 +182,12 @@ impl QuestionAnsweringConfig {
add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -212,9 +216,9 @@ impl QuestionAnsweringConfig {
/// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new<RM, RC, RV>(
pub fn custom_new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -227,13 +231,12 @@ impl QuestionAnsweringConfig {
max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -253,9 +256,9 @@ impl QuestionAnsweringConfig {
impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_resource: Box::new(RemoteResource::from_pretrained(
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
)),
))),
config_resource: Box::new(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
)),
@ -303,6 +306,9 @@ pub enum QuestionAnsweringOption {
Longformer(LongformerForQuestionAnswering),
/// FNet for Question Answering
FNet(FNetForQuestionAnswering),
/// ONNX model for Question Answering
#[cfg(feature = "onnx")]
ONNX(ONNXEncoder),
}
impl QuestionAnsweringOption {
@ -310,23 +316,30 @@ impl QuestionAnsweringOption {
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
/// * `QuestionAnsweringConfig` - Question answering pipeline configuration. The type of model created will be inferred from the
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
pub fn new(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
match config.model_resource {
ModelResource::Torch(_) => Self::new_torch(config),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Self::new_onnx(config),
}
}
fn new_torch(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
let device = config.device;
let weights_path = config.model_resource.get_torch_local_path()?;
let mut var_store = VarStore::new(device);
let model_config = &mut ConfigOption::from_file(
config.model_type,
config.config_resource.get_local_path()?,
);
let model_type = config.model_type;
let model = match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
if let ConfigOption::Bert(config) = model_config {
Ok(QuestionAnsweringOption::Bert(
BertForQuestionAnswering::new(p, config),
BertForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -335,9 +348,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
if let ConfigOption::Deberta(config) = model_config {
Ok(QuestionAnsweringOption::Deberta(
DebertaForQuestionAnswering::new(p, config),
DebertaForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -346,9 +359,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
if let ConfigOption::DebertaV2(config) = model_config {
Ok(QuestionAnsweringOption::DebertaV2(
DebertaV2ForQuestionAnswering::new(p, config),
DebertaV2ForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -357,9 +370,10 @@ impl QuestionAnsweringOption {
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
if let ConfigOption::DistilBert(ref mut config) = model_config {
config.sinusoidal_pos_embds = false;
Ok(QuestionAnsweringOption::DistilBert(
DistilBertForQuestionAnswering::new(p, config),
DistilBertForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -368,9 +382,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
if let ConfigOption::MobileBert(config) = model_config {
Ok(QuestionAnsweringOption::MobileBert(
MobileBertForQuestionAnswering::new(p, config),
MobileBertForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -379,9 +393,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
if let ConfigOption::Roberta(config) = model_config {
Ok(QuestionAnsweringOption::Roberta(
RobertaForQuestionAnswering::new(p, config),
RobertaForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -390,9 +404,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
if let ConfigOption::Bert(config) = model_config {
Ok(QuestionAnsweringOption::XLMRoberta(
RobertaForQuestionAnswering::new(p, config),
RobertaForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -401,9 +415,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
if let ConfigOption::Albert(config) = model_config {
Ok(QuestionAnsweringOption::Albert(
AlbertForQuestionAnswering::new(p, config),
AlbertForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -412,9 +426,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
if let ConfigOption::XLNet(config) = model_config {
Ok(QuestionAnsweringOption::XLNet(
XLNetForQuestionAnswering::new(p, config)?,
XLNetForQuestionAnswering::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -423,9 +437,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::Reformer => {
if let ConfigOption::Reformer(config) = config {
if let ConfigOption::Reformer(config) = model_config {
Ok(QuestionAnsweringOption::Reformer(
ReformerForQuestionAnswering::new(p, config)?,
ReformerForQuestionAnswering::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -434,9 +448,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
if let ConfigOption::Longformer(config) = model_config {
Ok(QuestionAnsweringOption::Longformer(
LongformerForQuestionAnswering::new(p, config),
LongformerForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -445,9 +459,9 @@ impl QuestionAnsweringOption {
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
if let ConfigOption::FNet(config) = model_config {
Ok(QuestionAnsweringOption::FNet(
FNetForQuestionAnswering::new(p, config),
FNetForQuestionAnswering::new(var_store.root(), config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -458,7 +472,28 @@ impl QuestionAnsweringOption {
_ => Err(RustBertError::InvalidConfigurationError(format!(
"QuestionAnswering not implemented for {model_type:?}!",
))),
}
}?;
var_store.load(weights_path)?;
Ok(model)
}
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
.encoder_path
.ok_or(RustBertError::InvalidConfigurationError(
"An encoder file must be provided for question answering ONNX models.".to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
}
/// Returns the `ModelType` for this SequenceClassificationOption
@ -476,6 +511,8 @@ impl QuestionAnsweringOption {
Self::Reformer(_) => ModelType::Reformer,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
@ -485,6 +522,7 @@ impl QuestionAnsweringOption {
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor) {
match *self {
@ -547,6 +585,19 @@ impl QuestionAnsweringOption {
.expect("Error in fnet forward pass");
(outputs.start_logits, outputs.end_logits)
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => {
let outputs = model
.forward(
input_ids,
mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
_token_type_ids,
None,
input_embeds,
)
.expect("Error in ONNX forward pass.");
(outputs.start_logits.unwrap(), outputs.end_logits.unwrap())
}
}
}
}
@ -561,7 +612,7 @@ pub struct QuestionAnsweringModel {
max_query_length: usize,
max_answer_len: usize,
qa_model: QuestionAnsweringOption,
var_store: VarStore,
device: Device,
}
impl QuestionAnsweringModel {
@ -631,8 +682,7 @@ impl QuestionAnsweringModel {
question_answering_config: QuestionAnsweringConfig,
tokenizer: TokenizerOption,
) -> Result<QuestionAnsweringModel, RustBertError> {
let config_path = question_answering_config.config_resource.get_local_path()?;
let device = question_answering_config.device;
let qa_model = QuestionAnsweringOption::new(&question_answering_config)?;
let pad_idx = tokenizer
.get_pad_id()
@ -640,19 +690,6 @@ impl QuestionAnsweringModel {
let sep_idx = tokenizer
.get_sep_id()
.expect("The Tokenizer used for Question Answering should contain a SEP id");
let mut var_store = VarStore::new(device);
let mut model_config =
ConfigOption::from_file(question_answering_config.model_type, config_path);
if let ConfigOption::DistilBert(ref mut config) = model_config {
config.sinusoidal_pos_embds = false;
};
let qa_model = QuestionAnsweringOption::new(
question_answering_config.model_type,
var_store.root(),
&model_config,
)?;
if question_answering_config.max_seq_length
< (question_answering_config.max_query_length
@ -668,8 +705,10 @@ impl QuestionAnsweringModel {
question_answering_config.doc_stride
)));
}
crate::resources::load_weights(&question_answering_config.model_resource, &mut var_store)?;
let device = get_device(
question_answering_config.model_resource,
question_answering_config.device,
);
Ok(QuestionAnsweringModel {
tokenizer,
pad_idx,
@ -679,7 +718,7 @@ impl QuestionAnsweringModel {
max_query_length: question_answering_config.max_query_length,
max_answer_len: question_answering_config.max_answer_length,
qa_model,
var_store,
device,
})
}
@ -758,11 +797,16 @@ impl QuestionAnsweringModel {
let end = start + min(len_features - start, batch_size);
let batch_features = &mut features[start..end];
no_grad(|| {
let (input_ids, attention_masks) = self.pad_features(batch_features);
let (input_ids, attention_masks, token_type_ids) =
self.pad_features(batch_features);
let (start_logits, end_logits) =
self.qa_model
.forward_t(Some(&input_ids), Some(&attention_masks), None, false);
let (start_logits, end_logits) = self.qa_model.forward_t(
Some(&input_ids),
Some(&attention_masks),
None,
Some(&token_type_ids),
false,
);
let start_logits = start_logits.detach();
let end_logits = end_logits.detach();
@ -935,6 +979,7 @@ impl QuestionAnsweringModel {
let qa_feature = QaFeature {
input_ids: encoded_span.token_ids,
offsets: encoded_span.token_offsets,
token_type_ids: encoded_span.segment_ids,
p_mask,
example_index,
};
@ -947,7 +992,7 @@ impl QuestionAnsweringModel {
spans
}
fn pad_features(&self, features: &mut [QaFeature]) -> (Tensor, Tensor) {
fn pad_features(&self, features: &mut [QaFeature]) -> (Tensor, Tensor, Tensor) {
let max_len = features
.iter()
.map(|feature| feature.input_ids.len())
@ -970,6 +1015,9 @@ impl QuestionAnsweringModel {
feature.offsets.resize(max_len, None);
feature.p_mask.resize(max_len, 1);
feature.input_ids.resize(max_len, self.pad_idx);
feature
.token_type_ids
.resize(max_len, *feature.token_type_ids.last().unwrap_or(&0));
}
let padded_input_ids = features
@ -977,9 +1025,17 @@ impl QuestionAnsweringModel {
.map(|input| Tensor::from_slice(input.input_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
(input_ids, attention_masks)
let padded_token_type_ids = features
.iter_mut()
.map(|input| Tensor::from_slice(input.token_type_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.device);
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.device);
let token_type_ids = Tensor::stack(&padded_token_type_ids, 0)
.to(self.device)
.to_kind(Kind::Int64);
(input_ids, attention_masks, token_type_ids)
}
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {

View File

@ -22,8 +22,9 @@
//! # fn main() -> anyhow::Result<()> {
//!
//! //Load a configuration
//! use rust_bert::pipelines::common::ModelResource;
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
//! RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2))),
//! RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
//! RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
//! None, // Merge resources
@ -66,20 +67,21 @@ use crate::distilbert::DistilBertModelClassifier;
use crate::fnet::FNetForSequenceClassification;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::reformer::ReformerForSequenceClassification;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification;
use rust_tokenizers::tokenizer::TruncationStrategy;
use rust_tokenizers::TokenizedInput;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor};
use tch::{no_grad, Device, Kind, Tensor};
use crate::deberta_v2::DebertaV2ForSequenceClassification;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
@ -106,7 +108,7 @@ pub struct SequenceClassificationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
@ -134,9 +136,9 @@ impl SequenceClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -145,13 +147,12 @@ impl SequenceClassificationConfig {
add_prefix_space: impl Into<Option<bool>>,
) -> SequenceClassificationConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
SequenceClassificationConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -169,7 +170,9 @@ impl Default for SequenceClassificationConfig {
fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig::new(
ModelType::DistilBert,
RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
))),
RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
None,
@ -209,6 +212,9 @@ pub enum SequenceClassificationOption {
Longformer(LongformerForSequenceClassification),
/// FNet for Sequence Classification
FNet(FNetForSequenceClassification),
/// ONNX Model for Sequence Classification
#[cfg(feature = "onnx")]
ONNX(ONNXEncoder),
}
impl SequenceClassificationOption {
@ -216,23 +222,28 @@ impl SequenceClassificationOption {
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
/// * `SequenceClassificationConfig` - Sequence classification pipeline configuration. The type of model created will be inferred from the
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
pub fn new(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
match config.model_resource {
ModelResource::Torch(_) => Self::new_torch(config),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Self::new_onnx(config),
}
}
fn new_torch(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
let device = config.device;
let weights_path = config.model_resource.get_torch_local_path()?;
let mut var_store = VarStore::new(device);
let model_config =
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
let model_type = config.model_type;
let model = match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(SequenceClassificationOption::Bert(
BertForSequenceClassification::new(p, config)?,
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Bert(
BertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -241,9 +252,9 @@ impl SequenceClassificationOption {
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(SequenceClassificationOption::Deberta(
DebertaForSequenceClassification::new(p, config)?,
if let ConfigOption::Deberta(config) = model_config {
Ok(Self::Deberta(
DebertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -252,9 +263,9 @@ impl SequenceClassificationOption {
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
Ok(SequenceClassificationOption::DebertaV2(
DebertaV2ForSequenceClassification::new(p, config)?,
if let ConfigOption::DebertaV2(config) = model_config {
Ok(Self::DebertaV2(
DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -263,9 +274,9 @@ impl SequenceClassificationOption {
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(SequenceClassificationOption::DistilBert(
DistilBertModelClassifier::new(p, config)?,
if let ConfigOption::DistilBert(config) = model_config {
Ok(Self::DistilBert(
DistilBertModelClassifier::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -274,9 +285,9 @@ impl SequenceClassificationOption {
}
}
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(SequenceClassificationOption::MobileBert(
MobileBertForSequenceClassification::new(p, config)?,
if let ConfigOption::MobileBert(config) = model_config {
Ok(Self::MobileBert(
MobileBertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -285,9 +296,9 @@ impl SequenceClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
Ok(SequenceClassificationOption::Roberta(
RobertaForSequenceClassification::new(p, config)?,
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::Roberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -296,9 +307,9 @@ impl SequenceClassificationOption {
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Roberta(config) = config {
Ok(SequenceClassificationOption::XLMRoberta(
RobertaForSequenceClassification::new(p, config)?,
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::XLMRoberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -307,9 +318,9 @@ impl SequenceClassificationOption {
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(SequenceClassificationOption::Albert(
AlbertForSequenceClassification::new(p, config)?,
if let ConfigOption::Albert(config) = model_config {
Ok(Self::Albert(
AlbertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -318,9 +329,9 @@ impl SequenceClassificationOption {
}
}
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(SequenceClassificationOption::XLNet(
XLNetForSequenceClassification::new(p, config)?,
if let ConfigOption::XLNet(config) = model_config {
Ok(Self::XLNet(
XLNetForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -329,9 +340,9 @@ impl SequenceClassificationOption {
}
}
ModelType::Bart => {
if let ConfigOption::Bart(config) = config {
Ok(SequenceClassificationOption::Bart(
BartForSequenceClassification::new(p, config)?,
if let ConfigOption::Bart(config) = model_config {
Ok(Self::Bart(
BartForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -340,9 +351,9 @@ impl SequenceClassificationOption {
}
}
ModelType::Reformer => {
if let ConfigOption::Reformer(config) = config {
Ok(SequenceClassificationOption::Reformer(
ReformerForSequenceClassification::new(p, config)?,
if let ConfigOption::Reformer(config) = model_config {
Ok(Self::Reformer(
ReformerForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -351,9 +362,9 @@ impl SequenceClassificationOption {
}
}
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(SequenceClassificationOption::Longformer(
LongformerForSequenceClassification::new(p, config)?,
if let ConfigOption::Longformer(config) = model_config {
Ok(Self::Longformer(
LongformerForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -362,9 +373,9 @@ impl SequenceClassificationOption {
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(SequenceClassificationOption::FNet(
FNetForSequenceClassification::new(p, config)?,
if let ConfigOption::FNet(config) = model_config {
Ok(Self::FNet(
FNetForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -372,10 +383,36 @@ impl SequenceClassificationOption {
))
}
}
#[cfg(feature = "onnx")]
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Sequence Classification not implemented for {model_type:?}!",
))),
}
}?;
var_store.load(weights_path)?;
Ok(model)
}
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
.encoder_path
.ok_or(RustBertError::InvalidConfigurationError(
"An encoder file must be provided for sequence classification ONNX models."
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
}
/// Returns the `ModelType` for this SequenceClassificationOption
@ -394,6 +431,8 @@ impl SequenceClassificationOption {
Self::Reformer(_) => ModelType::Reformer,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
@ -534,6 +573,21 @@ impl SequenceClassificationOption {
.expect("Error in FNet forward pass.")
.logits
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => {
let attention_mask = input_ids.unwrap().ones_like();
model
.forward(
input_ids,
Some(&attention_mask),
token_type_ids,
position_ids,
input_embeds,
)
.expect("Error in ONNX forward pass.")
.logits
.unwrap()
}
}
}
}
@ -543,7 +597,7 @@ pub struct SequenceClassificationModel {
tokenizer: TokenizerOption,
sequence_classifier: SequenceClassificationOption,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
device: Device,
max_length: usize,
}
@ -615,23 +669,20 @@ impl SequenceClassificationModel {
tokenizer: TokenizerOption,
) -> Result<SequenceClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let device = config.device;
let sequence_classifier = SequenceClassificationOption::new(&config)?;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let max_length = model_config
.get_max_len()
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let sequence_classifier =
SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
let device = get_device(config.model_resource, config.device);
Ok(SequenceClassificationModel {
tokenizer,
sequence_classifier,
label_mapping,
var_store,
device,
max_length,
})
}
@ -645,36 +696,6 @@ impl SequenceClassificationModel {
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(
input.as_ref(),
self.max_length,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let pad_id = self
.tokenizer
.get_pad_id()
.expect("The Tokenizer used for sequence classification should contain a PAD id");
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.into_iter()
.map(|mut input| {
input.token_ids.resize(max_len, pad_id);
Tensor::from_slice(&(input.token_ids))
})
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
/// Classify texts
///
/// # Arguments
@ -705,12 +726,14 @@ impl SequenceClassificationModel {
where
S: AsRef<[&'a str]>,
{
let input_tensor = self.prepare_for_model(input.as_ref());
let (input_ids, token_type_ids) =
self.tokenizer
.tokenize_and_pad(input.as_ref(), self.max_length, self.device);
let output = no_grad(|| {
let output = self.sequence_classifier.forward_t(
Some(&input_tensor),
None,
Some(&input_ids),
None,
Some(&token_type_ids),
None,
None,
false,
@ -774,12 +797,14 @@ impl SequenceClassificationModel {
input: &[&str],
threshold: f64,
) -> Result<Vec<Vec<Label>>, RustBertError> {
let input_tensor = self.prepare_for_model(input);
let (input_ids, token_type_ids) =
self.tokenizer
.tokenize_and_pad(input.as_ref(), self.max_length, self.device);
let output = no_grad(|| {
let output = self.sequence_classifier.forward_t(
Some(&input_tensor),
None,
Some(&input_ids),
None,
Some(&token_type_ids),
None,
None,
false,

View File

@ -67,14 +67,15 @@ use tch::Device;
use crate::bart::BartGenerator;
use crate::common::error::RustBertError;
use crate::pegasus::PegasusConditionalGenerator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::prophetnet::ProphetNetConditionalGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use crate::longt5::LongT5Generator;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXConditionalGenerator;
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
@ -88,7 +89,7 @@ pub struct SummarizationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BART model on CNN-DM)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained BART model on CNN-DM)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BART model on CNN-DM)
@ -133,25 +134,24 @@ impl SummarizationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * model_resource - The `ModelResources` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
) -> SummarizationConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
SummarizationConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -179,7 +179,9 @@ impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig::new(
ModelType::Bart,
RemoteResource::from_pretrained(BartModelResources::BART_CNN),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
))),
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
Some(RemoteResource::from_pretrained(
@ -192,6 +194,7 @@ impl Default for SummarizationConfig {
impl From<SummarizationConfig> for GenerateConfig {
fn from(config: SummarizationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
@ -227,22 +230,29 @@ pub enum SummarizationOption {
ProphetNet(ProphetNetConditionalGenerator),
/// Summarizer based on Pegasus model
Pegasus(PegasusConditionalGenerator),
/// Summarizer based on ONNX model
#[cfg(feature = "onnx")]
ONNX(ONNXConditionalGenerator),
}
impl SummarizationOption {
pub fn new(config: SummarizationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Bart => Ok(SummarizationOption::Bart(BartGenerator::new(
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
ONNXConditionalGenerator::new(config.into(), None, None)?,
)),
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
ModelType::LongT5 => Ok(SummarizationOption::LongT5(LongT5Generator::new(
(ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
(ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(LongT5Generator::new(
config.into(),
)?)),
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
(ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new(config.into())?,
)),
ModelType::Pegasus => Ok(SummarizationOption::Pegasus(
(ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
PegasusConditionalGenerator::new(config.into())?,
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
@ -256,21 +266,25 @@ impl SummarizationOption {
config: SummarizationConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Bart => Ok(SummarizationOption::Bart(
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
)),
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(
BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
(ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
ModelType::LongT5 => Ok(SummarizationOption::LongT5(
(ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(
LongT5Generator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
(ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::Pegasus => Ok(SummarizationOption::Pegasus(
(ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
PegasusConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
@ -288,28 +302,34 @@ impl SummarizationOption {
Self::LongT5(_) => ModelType::LongT5,
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
/// Interface method to access tokenizer
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::Bart(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
Self::LongT5(model_ref) => model_ref._get_tokenizer(),
Self::ProphetNet(model_ref) => model_ref._get_tokenizer(),
Self::Pegasus(model_ref) => model_ref._get_tokenizer(),
Self::Bart(model_ref) => model_ref.get_tokenizer(),
Self::T5(model_ref) => model_ref.get_tokenizer(),
Self::LongT5(model_ref) => model_ref.get_tokenizer(),
Self::ProphetNet(model_ref) => model_ref.get_tokenizer(),
Self::Pegasus(model_ref) => model_ref.get_tokenizer(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer(),
}
}
/// Interface method to access tokenizer
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::Bart(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
Self::LongT5(model_ref) => model_ref._get_tokenizer_mut(),
Self::ProphetNet(model_ref) => model_ref._get_tokenizer_mut(),
Self::Pegasus(model_ref) => model_ref._get_tokenizer_mut(),
Self::Bart(model_ref) => model_ref.get_tokenizer_mut(),
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
Self::LongT5(model_ref) => model_ref.get_tokenizer_mut(),
Self::ProphetNet(model_ref) => model_ref.get_tokenizer_mut(),
Self::Pegasus(model_ref) => model_ref.get_tokenizer_mut(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
}
}
@ -344,6 +364,12 @@ impl SummarizationOption {
.into_iter()
.map(|output| output.text)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate(prompt_texts, None)
.into_iter()
.map(|output| output.text)
.collect(),
}
}
}

View File

@ -38,14 +38,15 @@ use crate::gpt2::GPT2Generator;
use crate::gpt_j::GptJGenerator;
use crate::gpt_neo::GptNeoGenerator;
use crate::openai_gpt::OpenAIGenerator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::reformer::ReformerGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use crate::xlnet::XLNetGenerator;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXCausalGenerator;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
@ -59,7 +60,7 @@ pub struct TextGenerationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BART model on CNN-DM)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained BART model on CNN-DM)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BART model on CNN-DM)
@ -104,25 +105,24 @@ impl TextGenerationConfig {
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
/// * model_resource - The `ModelResources` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
) -> TextGenerationConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
TextGenerationConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -150,7 +150,9 @@ impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig::new(
ModelType::GPT2,
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2_MEDIUM,
))),
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
Some(RemoteResource::from_pretrained(
@ -163,6 +165,7 @@ impl Default for TextGenerationConfig {
impl From<TextGenerationConfig> for GenerateConfig {
fn from(config: TextGenerationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
@ -202,30 +205,37 @@ pub enum TextGenerationOption {
Reformer(ReformerGenerator),
/// Text Generator based on T5 model
T5(T5Generator),
/// ONNX model for text generation
#[cfg(feature = "onnx")]
ONNX(ONNXCausalGenerator),
}
impl TextGenerationOption {
pub fn new(config: TextGenerationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
ONNXCausalGenerator::new(config.into(), None, None)?,
)),
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
config.into(),
)?)),
ModelType::OpenAiGpt => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
(ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
config.into(),
)?)),
ModelType::XLNet => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
(ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
config.into(),
)?)),
ModelType::Reformer => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
(ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
config.into(),
)?)),
ModelType::GPTNeo => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
(ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
config.into(),
)?)),
ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
(ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
(ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Text generation not implemented for {:?}!",
config.model_type
@ -237,26 +247,30 @@ impl TextGenerationOption {
config: TextGenerationConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(TextGenerationOption::GPT2(
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
)),
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::OpenAiGpt => Ok(TextGenerationOption::GPT(
(ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(
OpenAIGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::XLNet => Ok(TextGenerationOption::XLNet(
(ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(
XLNetGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::Reformer => Ok(TextGenerationOption::Reformer(
(ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(
ReformerGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::GPTNeo => Ok(TextGenerationOption::GPTNeo(
(ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(
GptNeoGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(
(ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(
GptJGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::T5 => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
(ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
@ -277,32 +291,37 @@ impl TextGenerationOption {
Self::XLNet(_) => ModelType::XLNet,
Self::Reformer(_) => ModelType::Reformer,
Self::T5(_) => ModelType::T5,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
/// Interface method to access tokenizer
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT(model_ref) => model_ref._get_tokenizer(),
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
Self::GPTNeo(model_ref) => model_ref._get_tokenizer(),
Self::GPTJ(model_ref) => model_ref._get_tokenizer(),
Self::XLNet(model_ref) => model_ref._get_tokenizer(),
Self::Reformer(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
Self::GPT(model_ref) => model_ref.get_tokenizer(),
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
Self::GPTJ(model_ref) => model_ref.get_tokenizer(),
Self::XLNet(model_ref) => model_ref.get_tokenizer(),
Self::Reformer(model_ref) => model_ref.get_tokenizer(),
Self::T5(model_ref) => model_ref.get_tokenizer(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer(),
}
}
/// Interface method to access tokenizer
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::GPT(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPTNeo(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPTJ(model_ref) => model_ref._get_tokenizer_mut(),
Self::XLNet(model_ref) => model_ref._get_tokenizer_mut(),
Self::Reformer(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPT(model_ref) => model_ref.get_tokenizer_mut(),
Self::GPT2(model_ref) => model_ref.get_tokenizer_mut(),
Self::GPTNeo(model_ref) => model_ref.get_tokenizer_mut(),
Self::GPTJ(model_ref) => model_ref.get_tokenizer_mut(),
Self::XLNet(model_ref) => model_ref.get_tokenizer_mut(),
Self::Reformer(model_ref) => model_ref.get_tokenizer_mut(),
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
}
}
@ -357,10 +376,16 @@ impl TextGenerationOption {
.into_iter()
.map(|output| output.indices)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate_indices(prompt_texts, generate_options)
.into_iter()
.map(|output| output.indices)
.collect(),
}
}
pub fn half(&mut self) {
pub fn half(&mut self) -> Result<(), RustBertError> {
match self {
Self::GPT(model_ref) => model_ref.half(),
Self::GPT2(model_ref) => model_ref.half(),
@ -369,10 +394,14 @@ impl TextGenerationOption {
Self::XLNet(model_ref) => model_ref.half(),
Self::Reformer(model_ref) => model_ref.half(),
Self::T5(model_ref) => model_ref.half(),
#[cfg(feature = "onnx")]
Self::ONNX(_) => Err(RustBertError::OrtError(
"Type casting not supported for ONNX models.".to_string(),
)),
}
}
pub fn float(&mut self) {
pub fn float(&mut self) -> Result<(), RustBertError> {
match self {
Self::GPT(model_ref) => model_ref.float(),
Self::GPT2(model_ref) => model_ref.float(),
@ -381,10 +410,14 @@ impl TextGenerationOption {
Self::XLNet(model_ref) => model_ref.float(),
Self::Reformer(model_ref) => model_ref.float(),
Self::T5(model_ref) => model_ref.float(),
#[cfg(feature = "onnx")]
Self::ONNX(_) => Err(RustBertError::OrtError(
"Type casting not supported for ONNX models.".to_string(),
)),
}
}
pub fn set_device(&mut self, device: Device) {
pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
match self {
Self::GPT(model_ref) => model_ref.set_device(device),
Self::GPT2(model_ref) => model_ref.set_device(device),
@ -393,6 +426,10 @@ impl TextGenerationOption {
Self::XLNet(model_ref) => model_ref.set_device(device),
Self::Reformer(model_ref) => model_ref.set_device(device),
Self::T5(model_ref) => model_ref.set_device(device),
#[cfg(feature = "onnx")]
Self::ONNX(_) => Err(RustBertError::OrtError(
"Device assignment not supported for ONNX models.".to_string(),
)),
}
}
}
@ -520,16 +557,16 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
self.model.get_tokenizer_mut()
}
pub fn half(&mut self) {
self.model.half();
pub fn half(&mut self) -> Result<(), RustBertError> {
self.model.half()
}
pub fn float(&mut self) {
self.model.float();
pub fn float(&mut self) -> Result<(), RustBertError> {
self.model.float()
}
pub fn set_device(&mut self, device: Device) {
self.model.set_device(device);
pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
self.model.set_device(device)
}
/// Generate texts from provided prompts

View File

@ -21,11 +21,12 @@
//! use rust_bert::pipelines::common::ModelType;
//! # fn main() -> anyhow::Result<()> {
//!
//! use rust_bert::pipelines::common::ModelResource;
//! //Load a configuration
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
//! let config = TokenClassificationConfig::new(
//! ModelType::Bert,
//! RemoteResource::from_pretrained(BertModelResources::BERT_NER),
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_NER))),
//! RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
//! RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
//! None, //merges resource only relevant with ModelType::Roberta
@ -120,7 +121,9 @@ use crate::electra::ElectraForTokenClassification;
use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::common::{
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForTokenClassification;
use crate::xlnet::XLNetForTokenClassification;
@ -131,13 +134,14 @@ use rust_tokenizers::{
TokenizedInput,
};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::cmp::min;
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor};
use tch::{no_grad, Device, Kind, Tensor};
use crate::deberta_v2::DebertaV2ForTokenClassification;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
bert::{BertConfigResources, BertModelResources, BertVocabResources},
@ -195,6 +199,8 @@ struct InputFeature {
offsets: Vec<Option<Offset>>,
/// Token category (mask)
mask: Vec<Mask>,
/// Token type ids (mask)
token_type_ids: Vec<i64>,
/// per-token flag indicating if this feature carries the output label for this token
reference_feature: Vec<bool>,
/// Reference example index (long inputs may be broken into multiple input features)
@ -222,7 +228,7 @@ pub struct TokenClassificationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
@ -254,9 +260,9 @@ impl TokenClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -266,13 +272,12 @@ impl TokenClassificationConfig {
label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
TokenClassificationConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -292,7 +297,9 @@ impl Default for TokenClassificationConfig {
fn default() -> TokenClassificationConfig {
TokenClassificationConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
None,
@ -331,30 +338,38 @@ pub enum TokenClassificationOption {
Longformer(LongformerForTokenClassification),
/// FNet for Token Classification
FNet(FNetForTokenClassification),
/// ONNX model for Token Classification
#[cfg(feature = "onnx")]
ONNX(ONNXEncoder),
}
impl TokenClassificationOption {
/// Instantiate a new token sequence classification model of the supplied type.
/// Instantiate a new sequence classification model of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
/// * `TokenClassificationConfig` - Token classification pipeline configuration. The type of model created will be inferred from the
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
pub fn new(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
match config.model_resource {
ModelResource::Torch(_) => Self::new_torch(config),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Self::new_onnx(config),
}
}
fn new_torch(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
let device = config.device;
let weights_path = config.model_resource.get_torch_local_path()?;
let mut var_store = VarStore::new(device);
let model_config =
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
let model_type = config.model_type;
let model = match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(TokenClassificationOption::Bert(
BertForTokenClassification::new(p, config)?,
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Bert(
BertForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -363,9 +378,9 @@ impl TokenClassificationOption {
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(TokenClassificationOption::Deberta(
DebertaForTokenClassification::new(p, config)?,
if let ConfigOption::Deberta(config) = model_config {
Ok(Self::Deberta(
DebertaForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -374,9 +389,9 @@ impl TokenClassificationOption {
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = config {
Ok(TokenClassificationOption::DebertaV2(
DebertaV2ForTokenClassification::new(p, config)?,
if let ConfigOption::DebertaV2(config) = model_config {
Ok(Self::DebertaV2(
DebertaV2ForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -385,9 +400,9 @@ impl TokenClassificationOption {
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(TokenClassificationOption::DistilBert(
DistilBertForTokenClassification::new(p, config)?,
if let ConfigOption::DistilBert(config) = model_config {
Ok(Self::DistilBert(
DistilBertForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -396,9 +411,9 @@ impl TokenClassificationOption {
}
}
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(TokenClassificationOption::MobileBert(
MobileBertForTokenClassification::new(p, config)?,
if let ConfigOption::MobileBert(config) = model_config {
Ok(Self::MobileBert(
MobileBertForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -407,9 +422,9 @@ impl TokenClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = config {
Ok(TokenClassificationOption::Roberta(
RobertaForTokenClassification::new(p, config)?,
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::Roberta(
RobertaForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -418,9 +433,9 @@ impl TokenClassificationOption {
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Roberta(config) = config {
Ok(TokenClassificationOption::XLMRoberta(
RobertaForTokenClassification::new(p, config)?,
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::XLMRoberta(
RobertaForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -429,9 +444,9 @@ impl TokenClassificationOption {
}
}
ModelType::Electra => {
if let ConfigOption::Electra(config) = config {
Ok(TokenClassificationOption::Electra(
ElectraForTokenClassification::new(p, config)?,
if let ConfigOption::Electra(config) = model_config {
Ok(Self::Electra(
ElectraForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -440,9 +455,9 @@ impl TokenClassificationOption {
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(TokenClassificationOption::Albert(
AlbertForTokenClassification::new(p, config)?,
if let ConfigOption::Albert(config) = model_config {
Ok(Self::Albert(
AlbertForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -451,9 +466,9 @@ impl TokenClassificationOption {
}
}
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(TokenClassificationOption::XLNet(
XLNetForTokenClassification::new(p, config)?,
if let ConfigOption::XLNet(config) = model_config {
Ok(Self::XLNet(
XLNetForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -462,9 +477,9 @@ impl TokenClassificationOption {
}
}
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(TokenClassificationOption::Longformer(
LongformerForTokenClassification::new(p, config)?,
if let ConfigOption::Longformer(config) = model_config {
Ok(Self::Longformer(
LongformerForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -473,9 +488,9 @@ impl TokenClassificationOption {
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(TokenClassificationOption::FNet(
FNetForTokenClassification::new(p, config)?,
if let ConfigOption::FNet(config) = model_config {
Ok(Self::FNet(
FNetForTokenClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -483,10 +498,36 @@ impl TokenClassificationOption {
))
}
}
#[cfg(feature = "onnx")]
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Token classification not implemented for {model_type:?}!"
))),
}
}?;
var_store.load(weights_path)?;
Ok(model)
}
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
.encoder_path
.ok_or(RustBertError::InvalidConfigurationError(
"An encoder file must be provided for token classification ONNX models."
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
}
/// Returns the `ModelType` for this TokenClassificationOption
@ -504,6 +545,8 @@ impl TokenClassificationOption {
Self::XLNet(_) => ModelType::XLNet,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
@ -637,6 +680,12 @@ impl TokenClassificationOption {
.expect("Error in fnet forward_t")
.logits
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.forward(input_ids, mask, token_type_ids, position_ids, input_embeds)
.expect("Error in ONNX forward pass.")
.logits
.unwrap(),
}
}
}
@ -646,7 +695,7 @@ pub struct TokenClassificationModel {
tokenizer: TokenizerOption,
token_sequence_classifier: TokenClassificationOption,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
device: Device,
label_aggregation_function: LabelAggregationOption,
max_length: usize,
batch_size: usize,
@ -720,25 +769,23 @@ impl TokenClassificationModel {
tokenizer: TokenizerOption,
) -> Result<TokenClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let device = config.device;
let token_sequence_classifier = TokenClassificationOption::new(&config)?;
let label_aggregation_function = config.label_aggregation_function;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let max_length = model_config
.get_max_len()
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
let batch_size = config.batch_size;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
let device = get_device(config.model_resource, config.device);
Ok(TokenClassificationModel {
tokenizer,
token_sequence_classifier,
label_mapping,
var_store,
device,
label_aggregation_function,
max_length,
batch_size,
@ -815,6 +862,11 @@ impl TokenClassificationModel {
input_ids: encoded_span.token_ids,
offsets: encoded_span.token_offsets,
mask: encoded_span.mask,
token_type_ids: encoded_span
.segment_ids
.into_iter()
.map(|segment_id| segment_id as i64)
.collect(),
reference_feature,
example_index,
};
@ -923,11 +975,12 @@ impl TokenClassificationModel {
no_grad(|| {
let batch_features = &mut features[start..end];
let (input_ids, attention_masks) = self.pad_features(batch_features);
let (input_ids, attention_masks, token_type_ids) =
self.pad_features(batch_features);
let output = self.token_sequence_classifier.forward_t(
Some(&input_ids),
Some(&attention_masks),
None,
Some(&token_type_ids),
None,
None,
false,
@ -985,7 +1038,7 @@ impl TokenClassificationModel {
tokens
}
fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor) {
fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor, Tensor) {
let max_len = features
.iter()
.map(|feature| feature.input_ids.len())
@ -997,8 +1050,8 @@ impl TokenClassificationModel {
.map(|feature| &feature.input_ids)
.map(|input| {
let mut attention_mask = Vec::with_capacity(max_len);
attention_mask.resize(input.len(), 1);
attention_mask.resize(max_len, 0);
attention_mask.resize(input.len(), 1i64);
attention_mask.resize(max_len, 0i64);
attention_mask
})
.map(|input| Tensor::from_slice(&(input)))
@ -1011,6 +1064,9 @@ impl TokenClassificationModel {
for feature in features.iter_mut() {
feature.input_ids.resize(max_len, padding_index);
feature.offsets.resize(max_len, None);
feature
.token_type_ids
.resize(max_len, *feature.token_type_ids.last().unwrap_or(&0));
feature.reference_feature.resize(max_len, false);
}
@ -1019,9 +1075,15 @@ impl TokenClassificationModel {
.map(|input| Tensor::from_slice(input.input_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
(input_ids, attention_masks)
let padded_token_type_ids = features
.iter()
.map(|input| Tensor::from_slice(input.token_type_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.device);
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.device);
let token_type_ids = Tensor::stack(&padded_token_type_ids, 0).to(self.device);
(input_ids, attention_masks, token_type_ids)
}
fn decode_token(

View File

@ -25,6 +25,7 @@
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelResource;
//! let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
//! let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
//! let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
@ -35,7 +36,7 @@
//!
//! let translation_config = TranslationConfig::new(
//! ModelType::M2M100,
//! model_resource,
//! ModelResource::Torch(Box::new(model_resource)),
//! config_resource,
//! vocab_resource,
//! Some(merges_resource),

View File

@ -5,6 +5,7 @@ use tch::Device;
#[cfg(feature = "remote")]
use crate::{
pipelines::common::ModelResource,
pipelines::translation::{TranslationConfig, TranslationModel},
resources::ResourceProvider,
RustBertError,
@ -379,7 +380,7 @@ impl TranslationModelBuilder {
let translation_config = TranslationConfig::new(
translation_resources.model_type,
translation_resources.model_resource,
ModelResource::Torch(Box::new(translation_resources.model_resource)),
translation_resources.config_resource,
translation_resources.vocab_resource,
Some(translation_resources.merges_resource),

View File

@ -1,3 +1,5 @@
// Copyright 2018-2020 The HuggingFace Inc. team.
// Copyright 2020 Marian Team Authors
// Copyright 2019-2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -16,9 +18,10 @@ use crate::m2m_100::M2M100Generator;
use crate::marian::MarianGenerator;
use crate::mbart::MBartGenerator;
use crate::nllb::NLLBGenerator;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXConditionalGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use serde::{Deserialize, Serialize};
@ -934,7 +937,7 @@ pub struct TranslationConfig {
/// Model type used for translation
pub model_type: ModelType,
/// Model weights resource
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource
@ -993,12 +996,14 @@ impl TranslationConfig {
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
/// MarianTargetLanguages, MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::common::{ModelResource, ModelType};
/// use rust_bert::pipelines::translation::TranslationConfig;
/// use rust_bert::resources::RemoteResource;
/// use tch::Device;
///
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// )));
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
@ -1019,9 +1024,9 @@ impl TranslationConfig {
/// # Ok(())
/// # }
/// ```
pub fn new<RM, RC, RV, S, T>(
pub fn new<RC, RV, S, T>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -1030,7 +1035,6 @@ impl TranslationConfig {
device: impl Into<Option<Device>>,
) -> TranslationConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
@ -1040,7 +1044,7 @@ impl TranslationConfig {
TranslationConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -1068,6 +1072,7 @@ impl TranslationConfig {
impl From<TranslationConfig> for GenerateConfig {
fn from(config: TranslationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
@ -1102,49 +1107,31 @@ pub enum TranslationOption {
MBart(MBartGenerator),
/// Translator based on M2M100 model
M2M100(M2M100Generator),
/// Translator based on NLLB model
NLLB(NLLBGenerator),
/// Translator based on ONNX model
#[cfg(feature = "onnx")]
ONNX(ONNXConditionalGenerator),
}
impl TranslationOption {
pub fn new(config: TranslationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new(
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
ONNXConditionalGenerator::new(config.into(), None, None)?,
)),
(ModelType::Marian, _) => Ok(TranslationOption::Marian(MarianGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new(
(ModelType::T5, _) => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
(ModelType::MBart, _) => Ok(TranslationOption::MBart(MBartGenerator::new(
config.into(),
)?)),
ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new(
(ModelType::M2M100, _) => Ok(TranslationOption::M2M100(M2M100Generator::new(
config.into(),
)?)),
ModelType::NLLB => {
let config: GenerateConfig = config.into();
let tokenizer = TokenizerOption::from_file(
ModelType::NLLB,
config.vocab_resource.get_local_path()?.to_str().unwrap(),
Some(
config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"M2M100 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?
.to_str()
.unwrap(),
),
false,
None,
None,
)?;
Ok(TranslationOption::NLLB(NLLBGenerator::new_with_tokenizer(
config, tokenizer,
)?))
}
(ModelType::NLLB, _) => Ok(TranslationOption::NLLB(NLLBGenerator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Translation not implemented for {:?}!",
config.model_type
@ -1156,21 +1143,25 @@ impl TranslationOption {
config: TranslationConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Marian => Ok(TranslationOption::Marian(
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
)),
(ModelType::Marian, _) => Ok(TranslationOption::Marian(
MarianGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new_with_tokenizer(
(ModelType::T5, _) => Ok(TranslationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
ModelType::MBart => Ok(TranslationOption::MBart(
(ModelType::MBart, _) => Ok(TranslationOption::MBart(
MBartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::M2M100 => Ok(TranslationOption::M2M100(
(ModelType::M2M100, _) => Ok(TranslationOption::M2M100(
M2M100Generator::new_with_tokenizer(config.into(), tokenizer)?,
)),
ModelType::NLLB => Ok(TranslationOption::NLLB(NLLBGenerator::new_with_tokenizer(
(ModelType::NLLB, _) => Ok(TranslationOption::NLLB(NLLBGenerator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
@ -1189,190 +1180,36 @@ impl TranslationOption {
Self::MBart(_) => ModelType::MBart,
Self::M2M100(_) => ModelType::M2M100,
Self::NLLB(_) => ModelType::NLLB,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
/// Interface method to access tokenizer
/// Returns the `Tokenizer` for this TranslationOption
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::Marian(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
Self::MBart(model_ref) => model_ref._get_tokenizer(),
Self::M2M100(model_ref) => model_ref._get_tokenizer(),
Self::NLLB(model_ref) => model_ref._get_tokenizer(),
Self::Marian(ref generator) => generator.get_tokenizer(),
Self::T5(ref generator) => generator.get_tokenizer(),
Self::MBart(ref generator) => generator.get_tokenizer(),
Self::M2M100(ref generator) => generator.get_tokenizer(),
Self::NLLB(ref generator) => generator.get_tokenizer(),
#[cfg(feature = "onnx")]
Self::ONNX(ref generator) => generator.get_tokenizer(),
}
}
/// Interface method to access tokenizer
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::Marian(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
Self::MBart(model_ref) => model_ref._get_tokenizer_mut(),
Self::M2M100(model_ref) => model_ref._get_tokenizer_mut(),
Self::NLLB(model_ref) => model_ref._get_tokenizer_mut(),
Self::Marian(model_ref) => model_ref.get_tokenizer_mut(),
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
Self::MBart(model_ref) => model_ref.get_tokenizer_mut(),
Self::M2M100(model_ref) => model_ref.get_tokenizer_mut(),
Self::NLLB(model_ref) => model_ref.get_tokenizer_mut(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
}
}
fn validate_and_get_prefix_and_forced_bos_id(
&self,
source_language: Option<&Language>,
target_language: Option<&Language>,
supported_source_languages: &HashSet<Language>,
supported_target_languages: &HashSet<Language>,
) -> Result<(Option<String>, Option<i64>), RustBertError> {
if let Some(source_language) = source_language {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
"{source_language} not in list of supported languages: {supported_source_languages:?}",
)));
}
}
if let Some(target_language) = target_language {
if !supported_target_languages.contains(target_language) {
return Err(RustBertError::ValueError(format!(
"{target_language} not in list of supported languages: {supported_target_languages:?}"
)));
}
}
Ok(match *self {
Self::Marian(_) => {
if supported_target_languages.len() > 1 {
(
Some(format!(
">>{}<< ",
target_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
"Missing target language for Marian \
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)",
)))?
)),
None,
)
} else {
(None, None)
}
}
Self::T5(_) => (
Some(format!(
"translate {} to {}:",
source_language.ok_or_else(|| RustBertError::ValueError(
"Missing source language for T5".to_string(),
))?,
target_language.ok_or_else(|| RustBertError::ValueError(
"Missing target language for T5".to_string(),
))?,
)),
None,
),
Self::MBart(ref model) => {
(
Some(format!(
">>{}<< ",
source_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
"Missing source language for MBart\
(multiple languages supported by model: {supported_source_languages:?}, \
need to specify target language)"
)))?
)),
if let Some(target_language) = target_language {
Some(
model._get_tokenizer().convert_tokens_to_ids(&[format!(
">>{}<<",
target_language.get_iso_639_1_code().ok_or_else(|| {
RustBertError::ValueError(format!(
"This language has no ISO639-I code. Languages supported by model: {supported_source_languages:?}."
))
})?
)])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for MBart\
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)"
)));
},
)
}
Self::M2M100(ref model) => (
Some(match source_language {
Some(value) => {
let language_code = value.get_iso_639_1_code().ok_or_else(|| {
RustBertError::ValueError(format!(
"This language has no ISO639-I language code representation. \
languages supported by the model: {supported_target_languages:?}"
))
})?;
match language_code.len() {
2 => format!(">>{language_code}.<< "),
3 => format!(">>{language_code}<< "),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-I code".to_string(),
));
}
}
}
None => {
return Err(RustBertError::ValueError(format!(
"Missing source language for M2M100 \
(multiple languages supported by model: {supported_source_languages:?}, \
need to specify target language)"
)));
}
}),
if let Some(target_language) = target_language {
let language_code = target_language.get_iso_639_1_code().ok_or_else(|| {
RustBertError::ValueError(format!(
"This language has no ISO639-I language code representation. \
languages supported by the model: {supported_target_languages:?}"
))
})?;
Some(
model._get_tokenizer().convert_tokens_to_ids(&[
match language_code.len() {
2 => format!(">>{language_code}.<<"),
3 => format!(">>{language_code}<<"),
_ => {
return Err(RustBertError::ValueError(
"Invalid ISO 639-3 code".to_string(),
));
}
},
])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
"Missing target language for M2M100 \
(multiple languages supported by model: {supported_target_languages:?}, \
need to specify target language)",
)));
},
),
Self::NLLB(ref model) => {
let source_language = source_language
.and_then(Language::get_nllb_code)
.map(str::to_string)
.ok_or_else(|| RustBertError::ValueError(
format!("Missing source language for NLLB. Need to specify one from: {supported_source_languages:?}")
))?;
let target_language = target_language
.and_then(Language::get_nllb_code)
.map(str::to_string)
.map(|code| model._get_tokenizer().convert_tokens_to_ids(&[code])[0])
.ok_or_else(|| RustBertError::ValueError(
format!("Missing target language for NLLB. Need to specify one from: {supported_target_languages:?}")
))?;
(Some(source_language), Some(target_language))
}
})
}
/// Interface method to generate() of the particular models.
pub fn generate<S>(
&self,
@ -1415,6 +1252,19 @@ impl TranslationOption {
.map(|output| output.text)
.collect()
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => {
let generate_options =
forced_bos_token_id.map(|forced_bos_token_id| GenerateOptions {
forced_bos_token_id: Some(forced_bos_token_id),
..Default::default()
});
model
.generate(prompt_texts, generate_options)
.into_iter()
.map(|output| output.text)
.collect()
}
}
}
}
@ -1441,12 +1291,14 @@ impl TranslationModel {
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
/// MarianTargetLanguages, MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::common::{ModelResource, ModelType};
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::RemoteResource;
/// use tch::Device;
///
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// )));
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
@ -1496,12 +1348,14 @@ impl TranslationModel {
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
/// MarianTargetLanguages, MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
/// use rust_bert::pipelines::common::{ModelResource, ModelType, TokenizerOption};
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, ResourceProvider};
/// use tch::Device;
///
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
/// )));
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
@ -1575,12 +1429,14 @@ impl TranslationModel {
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
/// MarianTargetLanguages, MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::common::{ModelResource, ModelType};
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use rust_bert::resources::RemoteResource;
/// use tch::Device;
///
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
/// MarianModelResources::ENGLISH2ROMANCE,
/// )));
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
/// let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
@ -1616,12 +1472,13 @@ impl TranslationModel {
where
S: AsRef<str> + Sync,
{
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
source_language.into().as_ref(),
target_language.into().as_ref(),
&self.supported_source_languages,
&self.supported_target_languages,
)?;
let (prefix, forced_bos_token_id) =
self.model.get_tokenizer().get_prefix_and_forced_bos_id(
source_language.into().as_ref(),
target_language.into().as_ref(),
&self.supported_source_languages,
&self.supported_target_languages,
)?;
Ok(match prefix {
Some(value) => {
@ -1648,7 +1505,9 @@ mod test {
#[test]
#[ignore] // no need to run, compilation is enough to verify it is Send
fn test() {
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
MarianModelResources::ROMANCE2ENGLISH,
)));
let config_resource =
RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);

View File

@ -105,7 +105,7 @@ use crate::deberta::DebertaForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::common::{ConfigOption, ModelResource, ModelType, TokenizerOption};
use crate::pipelines::sequence_classification::Label;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
@ -113,17 +113,18 @@ use crate::xlnet::XLNetForSequenceClassification;
use crate::RustBertError;
use rust_tokenizers::tokenizer::TruncationStrategy;
use rust_tokenizers::TokenizedInput;
use std::borrow::Borrow;
use std::ops::Deref;
use tch::kind::Kind::{Bool, Float};
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Tensor};
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
use std::ops::Deref;
use tch::kind::Kind::{Bool, Float};
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
/// # Configuration for ZeroShotClassificationModel
/// Contains information regarding the model to load and device to place the model on.
@ -131,7 +132,7 @@ pub struct ZeroShotClassificationConfig {
/// Model type
pub model_type: ModelType,
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Box<dyn ResourceProvider + Send>,
pub model_resource: ModelResource,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Box<dyn ResourceProvider + Send>,
/// Vocab resource (default: pretrained BERT model on CoNLL)
@ -159,9 +160,9 @@ impl ZeroShotClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<RM, RC, RV>(
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: RM,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
@ -170,13 +171,12 @@ impl ZeroShotClassificationConfig {
add_prefix_space: impl Into<Option<bool>>,
) -> ZeroShotClassificationConfig
where
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
ZeroShotClassificationConfig {
model_type,
model_resource: Box::new(model_resource),
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
@ -190,13 +190,13 @@ impl ZeroShotClassificationConfig {
#[cfg(feature = "remote")]
impl Default for ZeroShotClassificationConfig {
/// Provides a defaultSST-2 sentiment analysis model (English)
/// Provides a default zero-shot classification model (English)
fn default() -> ZeroShotClassificationConfig {
ZeroShotClassificationConfig {
model_type: ModelType::Bart,
model_resource: Box::new(RemoteResource::from_pretrained(
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BartModelResources::BART_MNLI,
)),
))),
config_resource: Box::new(RemoteResource::from_pretrained(
BartConfigResources::BART_MNLI,
)),
@ -239,30 +239,38 @@ pub enum ZeroShotClassificationOption {
XLNet(XLNetForSequenceClassification),
/// Longformer for Sequence Classification
Longformer(LongformerForSequenceClassification),
/// ONNX model for Sequence Classification
#[cfg(feature = "onnx")]
ONNX(ONNXEncoder),
}
impl ZeroShotClassificationOption {
/// Instantiate a new zero shot classification model of the supplied type.
/// Instantiate a new zer-shot classification model of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
/// * `ZeroShotClassificationConfig` - Zero-shot classification pipeline configuration. The type of model created will be inferred from the
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
pub fn new(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
match config.model_resource {
ModelResource::Torch(_) => Self::new_torch(config),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Self::new_onnx(config),
}
}
fn new_torch(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
let device = config.device;
let weights_path = config.model_resource.get_torch_local_path()?;
let mut var_store = VarStore::new(device);
let model_config =
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
let model_type = config.model_type;
let model = match model_type {
ModelType::Bart => {
if let ConfigOption::Bart(config) = config {
Ok(ZeroShotClassificationOption::Bart(
BartForSequenceClassification::new(p, config)?,
if let ConfigOption::Bart(config) = model_config {
Ok(Self::Bart(
BartForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -271,9 +279,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = config {
Ok(ZeroShotClassificationOption::Deberta(
DebertaForSequenceClassification::new(p, config)?,
if let ConfigOption::Deberta(config) = model_config {
Ok(Self::Deberta(
DebertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -282,9 +290,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::Bert(
BertForSequenceClassification::new(p, config)?,
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Bert(
BertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -293,9 +301,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(ZeroShotClassificationOption::DistilBert(
DistilBertModelClassifier::new(p, config)?,
if let ConfigOption::DistilBert(config) = model_config {
Ok(Self::DistilBert(
DistilBertModelClassifier::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -304,9 +312,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(ZeroShotClassificationOption::MobileBert(
MobileBertForSequenceClassification::new(p, config)?,
if let ConfigOption::MobileBert(config) = model_config {
Ok(Self::MobileBert(
MobileBertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -315,9 +323,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::Roberta(
RobertaForSequenceClassification::new(p, config)?,
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Roberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -326,9 +334,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
Ok(ZeroShotClassificationOption::XLMRoberta(
RobertaForSequenceClassification::new(p, config)?,
if let ConfigOption::Bert(config) = model_config {
Ok(Self::XLMRoberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -337,9 +345,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(ZeroShotClassificationOption::Albert(
AlbertForSequenceClassification::new(p, config)?,
if let ConfigOption::Albert(config) = model_config {
Ok(Self::Albert(
AlbertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -348,9 +356,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(ZeroShotClassificationOption::XLNet(
XLNetForSequenceClassification::new(p, config)?,
if let ConfigOption::XLNet(config) = model_config {
Ok(Self::XLNet(
XLNetForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -359,9 +367,9 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(ZeroShotClassificationOption::Longformer(
LongformerForSequenceClassification::new(p, config)?,
if let ConfigOption::Longformer(config) = model_config {
Ok(Self::Longformer(
LongformerForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
@ -369,10 +377,36 @@ impl ZeroShotClassificationOption {
))
}
}
#[cfg(feature = "onnx")]
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Zero shot classification not implemented for {model_type:?}!",
))),
}
}?;
var_store.load(weights_path)?;
Ok(model)
}
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
.encoder_path
.ok_or(RustBertError::InvalidConfigurationError(
"An encoder file must be provided for zero-shot classification ONNX models."
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
}
/// Returns the `ModelType` for this SequenceClassificationOption
@ -388,6 +422,8 @@ impl ZeroShotClassificationOption {
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::Longformer(_) => ModelType::Longformer,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
@ -503,6 +539,18 @@ impl ZeroShotClassificationOption {
.expect("Error in Longformer forward pass.")
.logits
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.forward(
input_ids,
mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
token_type_ids,
position_ids,
input_embeds,
)
.expect("Error in ONNX forward pass.")
.logits
.unwrap(),
}
}
}
@ -530,7 +578,7 @@ pub type ZeroShotTemplate = Box<dyn Fn(&str) -> String>;
pub struct ZeroShotClassificationModel {
tokenizer: TokenizerOption,
zero_shot_classifier: ZeroShotClassificationOption,
var_store: VarStore,
device: Device,
}
impl ZeroShotClassificationModel {
@ -600,18 +648,13 @@ impl ZeroShotClassificationModel {
config: ZeroShotClassificationConfig,
tokenizer: TokenizerOption,
) -> Result<ZeroShotClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let device = config.device;
let zero_shot_classifier = ZeroShotClassificationOption::new(&config)?;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let zero_shot_classifier =
ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
Ok(ZeroShotClassificationModel {
tokenizer,
zero_shot_classifier,
var_store,
device,
})
}
@ -631,7 +674,7 @@ impl ZeroShotClassificationModel {
labels: T,
template: Option<ZeroShotTemplate>,
max_len: usize,
) -> Result<(Tensor, Tensor), RustBertError>
) -> Result<(Tensor, Tensor, Tensor), RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
@ -659,7 +702,7 @@ impl ZeroShotClassificationModel {
})
.collect::<Vec<(&str, &str)>>();
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
let mut tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
text_pair_list.as_ref(),
max_len,
&TruncationStrategy::LongestFirst,
@ -675,25 +718,35 @@ impl ZeroShotClassificationModel {
.tokenizer
.get_pad_id()
.expect("The Tokenizer used for sequence classification should contain a PAD id");
let tokenized_input_tensors = tokenized_input
.into_iter()
.map(|mut input| {
let input_ids = tokenized_input
.iter_mut()
.map(|input| {
input.token_ids.resize(max_len, pad_id);
Tensor::from_slice(&(input.token_ids))
})
.collect::<Vec<_>>();
let token_type_ids = tokenized_input
.iter_mut()
.map(|input| {
input
.segment_ids
.resize(max_len, *input.segment_ids.last().unwrap_or(&0));
Tensor::from_slice(&(input.segment_ids))
})
.collect::<Vec<_>>();
let tokenized_input_tensors =
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device());
let mask = tokenized_input_tensors
let input_ids = Tensor::stack(input_ids.as_slice(), 0).to(self.device);
let token_type_ids = Tensor::stack(token_type_ids.as_slice(), 0)
.to(self.device)
.to_kind(Kind::Int64);
let mask = input_ids
.ne(self
.tokenizer
.get_pad_id()
.expect("The Tokenizer used for zero shot classification should contain a PAD id"))
.to_kind(Bool);
Ok((tokenized_input_tensors, mask))
Ok((input_ids, mask, token_type_ids))
}
/// Zero shot classification with 1 (and exactly 1) true label.
@ -762,14 +815,14 @@ impl ZeroShotClassificationModel {
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask) =
let (input_tensor, mask, token_type_ids) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(&input_tensor),
Some(&mask),
None,
Some(&token_type_ids),
None,
None,
false,
@ -904,14 +957,14 @@ impl ZeroShotClassificationModel {
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask) =
let (input_tensor, mask, token_type_ids) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(&input_tensor),
Some(&mask),
None,
Some(&token_type_ids),
None,
None,
false,

View File

@ -24,6 +24,7 @@
//! use tch::Device;
//!
//! fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelResource;
//! let config_resource = Box::new(RemoteResource::from_pretrained(
//! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
//! ));
@ -36,7 +37,7 @@
//!
//! let summarization_config = SummarizationConfig {
//! model_type: ModelType::ProphetNet,
//! model_resource: weights_resource,
//! model_resource: ModelResource::Torch(weights_resource),
//! config_resource,
//! vocab_resource,
//! merges_resource: None,

View File

@ -13,9 +13,8 @@
use std::borrow::Borrow;
use std::collections::HashMap;
use rust_tokenizers::tokenizer::TruncationStrategy;
use serde::{Deserialize, Serialize};
use tch::{nn, Kind, Tensor};
use tch::{nn, Device, Kind, Tensor};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
@ -83,7 +82,7 @@ pub struct ProphetNetConfig {
pub activation_dropout: f64,
pub attention_dropout: f64,
pub decoder_ffn_dim: i64,
pub decoder_start_token_id: i64,
pub decoder_start_token_id: Option<i64>,
pub disable_ngram_loss: bool,
pub dropout: f64,
pub encoder_ffn_dim: i64,
@ -94,6 +93,8 @@ pub struct ProphetNetConfig {
pub max_position_embeddings: i64,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub ngram: i64,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
@ -120,7 +121,7 @@ impl Default for ProphetNetConfig {
activation_dropout: 0.1,
attention_dropout: 0.1,
decoder_ffn_dim: 4096,
decoder_start_token_id: 0,
decoder_start_token_id: Some(0),
disable_ngram_loss: false,
dropout: 0.1,
encoder_ffn_dim: 4096,
@ -131,6 +132,8 @@ impl Default for ProphetNetConfig {
max_position_embeddings: 512,
bos_token_id: 1,
eos_token_id: 2,
forced_bos_token_id: None,
forced_eos_token_id: None,
ngram: 2,
id2label: None,
label2id: None,
@ -381,7 +384,11 @@ impl ProphetNetForConditionalGeneration {
linear_config,
);
let decoder_start_token_id = config.decoder_start_token_id;
let decoder_start_token_id = config.decoder_start_token_id.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"`decoder_start_token_id` must be provided for ProphetNet models".to_string(),
)
})?;
let pad_token_id = config.pad_token_id;
let ngram = config.ngram;
@ -919,7 +926,7 @@ impl ProphetNetConditionalGenerator {
let pad_token_id = Some(config.pad_token_id);
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(config.decoder_start_token_id);
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(ProphetNetConditionalGenerator {
@ -945,11 +952,11 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -972,8 +979,8 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
@ -1060,48 +1067,6 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -79,6 +79,8 @@ pub struct ReformerConfig {
pub chunk_size_feed_forward: Option<i64>,
pub eos_token_id: i64,
pub pad_token_id: i64,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub feed_forward_size: i64,
pub hash_seed: Option<i64>,
pub hidden_act: Activation,
@ -106,6 +108,7 @@ pub struct ReformerConfig {
pub label2id: Option<HashMap<String, i64>>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub decoder_start_token_id: Option<i64>,
}
impl Config for ReformerConfig {}
@ -130,6 +133,8 @@ impl Default for ReformerConfig {
chunk_size_feed_forward: None,
eos_token_id: 2,
pad_token_id: 0,
forced_bos_token_id: None,
forced_eos_token_id: None,
feed_forward_size: 512,
hash_seed: None,
hidden_act: Activation::gelu,
@ -157,6 +162,7 @@ impl Default for ReformerConfig {
label2id: None,
output_attentions: None,
output_hidden_states: None,
decoder_start_token_id: None,
}
}
}
@ -1057,7 +1063,7 @@ impl ReformerGenerator {
let pad_token_id = Some(config.pad_token_id);
let vocab_size = config.vocab_size;
let is_encoder_decoder = false;
let decoder_start_id = None;
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(ReformerGenerator {
@ -1083,11 +1089,11 @@ impl PrivateLanguageGenerator for ReformerGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -1110,8 +1116,8 @@ impl PrivateLanguageGenerator for ReformerGenerator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(

View File

@ -42,7 +42,7 @@ impl RobertaModelResources {
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/model",
"https://cdn.huggingface.co/distilroberta-base-rust_model.ot",
"https://huggingface.co/distilroberta-base/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
pub const ROBERTA_QA: (&'static str, &'static str) = (

View File

@ -12,10 +12,9 @@
use std::borrow::Borrow;
use rust_tokenizers::tokenizer::TruncationStrategy;
use serde::{Deserialize, Serialize};
use tch::nn::{embedding, LinearConfig};
use tch::{nn, Tensor};
use tch::{nn, Device, Tensor};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
@ -132,6 +131,8 @@ pub struct T5Config {
pub decoder_start_token_id: Option<i64>,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub initializer_factor: f64,
pub is_encoder_decoder: Option<bool>,
pub layer_norm_epsilon: f64,
@ -210,6 +211,8 @@ impl Default for T5Config {
decoder_start_token_id: None,
bos_token_id: None,
eos_token_id: Some(1),
forced_bos_token_id: None,
forced_eos_token_id: None,
initializer_factor: 1.0,
is_encoder_decoder: None,
layer_norm_epsilon: 1e-6,
@ -770,7 +773,7 @@ impl T5Generator {
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(0);
let decoder_start_id = config.decoder_start_token_id;
// T5 do not have an embedding matrix for position IDs and relies on relative positions instead
let max_position_embeddings = i64::MAX;
@ -797,11 +800,11 @@ impl PrivateLanguageGenerator for T5Generator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -824,8 +827,8 @@ impl PrivateLanguageGenerator for T5Generator {
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
@ -906,48 +909,6 @@ impl PrivateLanguageGenerator for T5Generator {
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,

View File

@ -19,7 +19,7 @@
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::common::{ModelResource, ModelType};
//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
//! use rust_bert::resources::RemoteResource;
@ -35,7 +35,7 @@
//! ));
//! let generate_config = TextGenerationConfig {
//! model_type: ModelType::XLNet,
//! model_resource,
//! model_resource: ModelResource::Torch(model_resource),
//! config_resource,
//! vocab_resource,
//! merges_resource: None,

View File

@ -1594,11 +1594,11 @@ impl PrivateLanguageGenerator for XLNetGenerator {
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
@ -1622,8 +1622,8 @@ impl PrivateLanguageGenerator for XLNetGenerator {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(

View File

@ -2,6 +2,7 @@ use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::pipelines::common::ModelResource;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
@ -90,7 +91,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> {
BartModelResources::DISTILBART_CNN_6_6,
));
let summarization_config = SummarizationConfig {
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -151,7 +152,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> {
BartModelResources::DISTILBART_CNN_6_6,
));
let summarization_config = SummarizationConfig {
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),

View File

@ -6,7 +6,7 @@ use rust_bert::bert::{
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertModelResources, BertVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
@ -106,7 +106,9 @@ fn bert_masked_lm_pipeline() -> anyhow::Result<()> {
// Set-up model
let config = MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT,
))),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,
@ -452,7 +454,9 @@ fn bert_question_answering() -> anyhow::Result<()> {
// Set-up question answering model
let config = QuestionAnsweringConfig {
model_type: ModelType::Bert,
model_resource: Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BertModelResources::BERT_QA,
))),
config_resource: Box::new(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),

View File

@ -5,7 +5,7 @@ use rust_bert::fnet::{
FNetConfig, FNetConfigResources, FNetForMaskedLM, FNetForMultipleChoice,
FNetForQuestionAnswering, FNetForTokenClassification, FNetModelResources, FNetVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
@ -76,9 +76,8 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
assert_eq!("▁one", word_1);
assert_eq!("▁the", word_2);
let value = (f64::try_from(model_output.prediction_scores.get(0).get(4).max()).unwrap()
- 13.1721)
.abs();
let value =
(f64::try_from(model_output.prediction_scores.get(0).get(4).max())? - 13.1721).abs();
dbg!(value);
assert!(value < 1e-3);
Ok(())
@ -93,9 +92,9 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
));
)));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,

View File

@ -2,7 +2,7 @@ use rust_bert::gpt2::{
GPT2Generator, GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources,
Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::conversation::{
ConversationConfig, ConversationManager, ConversationModel,
};
@ -107,7 +107,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -139,7 +139,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -183,7 +183,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -240,7 +240,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -296,7 +296,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -369,7 +369,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: Some(56),
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -419,7 +419,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: Some(36),
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -485,7 +485,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: Some(36),
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -566,7 +566,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: Some(32),
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -616,7 +616,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: Some(16),
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -672,7 +672,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: Some(16),
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),

View File

@ -7,6 +7,7 @@ use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer};
use rust_tokenizers::vocab::Vocab;
use std::convert::TryFrom;
use tch::{nn, Device, Tensor};
/// Equivalent Python code:
@ -105,7 +106,7 @@ fn gpt_j_correctness() -> anyhow::Result<()> {
Tensor::from_slice(
&input
.iter()
.map(|&e| i64::from(e != pad_token))
.map(|&e| i64::try_from(e != pad_token).unwrap())
.collect::<Vec<_>>(),
)
.to(device)

View File

@ -2,7 +2,7 @@ use rust_bert::gpt_neo::{
GptNeoConfig, GptNeoConfigResources, GptNeoForCausalLM, GptNeoMergesResources,
GptNeoModelResources, GptNeoVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
@ -125,7 +125,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
// Set-up model
let generation_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),

View File

@ -7,7 +7,7 @@ use rust_bert::longformer::{
LongformerForTokenClassification, LongformerMergesResources, LongformerModelResources,
LongformerVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
@ -117,12 +117,12 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
.prediction_scores
.get(0)
.get(4)
.double_value(&[i64::try_from(&index_1).unwrap()]);
.double_value(&[i64::try_from(&index_1)?]);
let score_2 = model_output
.prediction_scores
.get(1)
.get(7)
.double_value(&[i64::try_from(&index_2).unwrap()]);
.double_value(&[i64::try_from(&index_2)?]);
assert_eq!("Ġeye", word_1); // Outputs "person" : "Looks like one [eye] is missing"
assert_eq!("Ġsunny", word_2); // Outputs "pear" : "It was a nice and [sunny] day"
@ -384,7 +384,9 @@ fn longformer_for_question_answering() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
))),
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
Some(RemoteResource::from_pretrained(

View File

@ -1,5 +1,5 @@
use rust_bert::longt5::{LongT5ConfigResources, LongT5ModelResources, LongT5VocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
@ -8,9 +8,9 @@ fn test_summarization_longt5() -> anyhow::Result<()> {
// Set-up translation model
let summarization_config = SummarizationConfig {
model_type: ModelType::LongT5,
model_resource: Box::new(RemoteResource::from_pretrained(
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
LongT5ModelResources::TGLOBAL_BASE_BOOK_SUMMARY,
)),
))),
config_resource: Box::new(RemoteResource::from_pretrained(
LongT5ConfigResources::TGLOBAL_BASE_BOOK_SUMMARY,
)),

View File

@ -2,7 +2,7 @@ use rust_bert::m2m_100::{
M2M100Config, M2M100ConfigResources, M2M100MergesResources, M2M100Model, M2M100ModelResources,
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::Config;
@ -78,7 +78,7 @@ fn m2m100_translation() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::M2M100,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),

View File

@ -2,7 +2,7 @@ use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
};
@ -23,7 +23,7 @@ fn test_translation() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::Marian,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),

View File

@ -74,12 +74,12 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
.logits
.get(0)
.get(4)
.double_value(&[i64::try_from(&index_1).unwrap()]);
.double_value(&[i64::try_from(&index_1)?]);
let score_2 = model_output
.logits
.get(1)
.get(7)
.double_value(&[i64::try_from(&index_2).unwrap()]);
.double_value(&[i64::try_from(&index_2)?]);
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("sunny", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"

View File

@ -1,14 +1,17 @@
use rust_bert::nllb::{
NLLBConfigResources, NLLBLanguages, NLLBMergeResources, NLLBResources, NLLBVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn nllb_translation() -> anyhow::Result<()> {
let model_resource = RemoteResource::from_pretrained(NLLBResources::NLLB_600M_DISTILLED);
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
NLLBResources::NLLB_600M_DISTILLED,
)));
let config_resource = RemoteResource::from_pretrained(NLLBConfigResources::NLLB_600M_DISTILLED);
let vocab_resource = RemoteResource::from_pretrained(NLLBVocabResources::NLLB_600M_DISTILLED);
let merges_resource = RemoteResource::from_pretrained(NLLBMergeResources::NLLB_600M_DISTILLED);

295
tests/onnx.rs Normal file
View File

@ -0,0 +1,295 @@
#[cfg(feature = "onnx")]
mod tests {
extern crate anyhow;
use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig,
};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
fn onnx_masked_lm() -> anyhow::Result<()> {
let masked_lm = MaskedLanguageModel::new(MaskedLanguageConfig::new(
ModelType::Bert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/model.onnx",
"onnx-bert-base-uncased-for-masked-lm",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/config.json",
"onnx-bert-base-uncased-for-masked-lm",
),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/vocab.txt",
"onnx-bert-base-uncased-for-masked-lm",
),
None,
false,
None,
None,
Some(String::from("<mask>")),
))?;
let input = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];
let output = masked_lm.predict(input)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0].len(), 1);
assert_eq!(output[0][0].text, "university");
assert_eq!(output[0][0].id, 2755);
assert!((output[0][0].score - 10.0135).abs() < 1e-4);
assert_eq!(output[1].len(), 2);
assert_eq!(output[1][0].text, "capital");
assert_eq!(output[1][0].id, 2364);
assert!((output[1][0].score - 19.4008).abs() < 1e-4);
assert_eq!(output[1][1].text, "located");
assert_eq!(output[1][1].id, 1388);
assert!((output[1][1].score - 10.8547).abs() < 1e-4);
Ok(())
}
#[test]
fn onnx_question_answering() -> anyhow::Result<()> {
let qa_model = QuestionAnsweringModel::new(QuestionAnsweringConfig::new(
ModelType::Roberta,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/model.onnx",
"onnx-roberta-base-squad2",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/config.json",
"onnx-roberta-base-squad2",
),
RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/vocab.json",
"onnx-roberta-base-squad2",
),
Some(RemoteResource::new(
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/merges.txt",
"onnx-roberta-base-squad2",
)),
false,
None,
None,
))?;
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let output = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(output.len(), 1);
assert_eq!(output[0].len(), 1);
assert_eq!(output[0][0].answer, " Amsterdam");
assert!((output[0][0].score - 0.9898).abs() < 1e-4);
Ok(())
}
#[test]
fn onnx_sequence_classification() -> anyhow::Result<()> {
let classification_model = SentimentModel::new(SequenceClassificationConfig::new(
ModelType::DistilBert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/model.onnx",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
),
RemoteResource::new(
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
),
None,
true,
None,
None,
))?;
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
let output = classification_model.predict(input);
assert_eq!(output.len(), 3);
assert_eq!(output[0].polarity, SentimentPolarity::Positive);
assert!((output[0].score - 0.9981).abs() < 1e-4);
assert_eq!(output[1].polarity, SentimentPolarity::Negative);
assert!((output[1].score - 0.9927).abs() < 1e-4);
assert_eq!(output[2].polarity, SentimentPolarity::Positive);
assert!((output[2].score - 0.9997).abs() < 1e-4);
Ok(())
}
#[test]
fn onnx_token_classification() -> anyhow::Result<()> {
let token_classification_model = NERModel::new(TokenClassificationConfig::new(
ModelType::Bert,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/model.onnx",
"onnx-bert-base-NER",
))),
..Default::default()
}),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/config.json",
"onnx-bert-base-NER",
),
RemoteResource::new(
"https://huggingface.co/optimum/bert-base-NER/resolve/main/vocab.txt",
"onnx-bert-base-NER",
),
None,
false,
None,
None,
LabelAggregationOption::First,
))?;
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
let output = token_classification_model.predict_full_entities(&input);
assert_eq!(output.len(), 2);
assert_eq!(output[0].len(), 2);
assert_eq!(output[0][0].word, "John Smith");
assert_eq!(output[0][0].label, "PER");
assert!((output[0][0].score - 0.9992).abs() < 1e-4);
assert_eq!(output[0][1].word, "Acme Corp");
assert_eq!(output[0][1].label, "ORG");
assert!((output[0][1].score - 0.0001).abs() < 1e-4);
assert_eq!(output[1].len(), 1);
assert_eq!(output[1][0].word, "New York");
assert_eq!(output[1][0].label, "LOC");
assert!((output[1][0].score - 0.9987).abs() < 1e-4);
Ok(())
}
#[test]
fn onnx_text_generation() -> anyhow::Result<()> {
let text_generation_model = TextGenerationModel::new(TextGenerationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::ONNX(ONNXModelResources {
encoder_resource: None,
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_model.onnx",
"onnx-gpt2",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_with_past_model.onnx",
"onnx-gpt2",
))),
}),
config_resource: Box::new(RemoteResource::new(
"https://huggingface.co/optimum/gpt2/resolve/main/config.json",
"onnx-gpt2",
)),
vocab_resource: Box::new(RemoteResource::new(
"https://huggingface.co/gpt2/resolve/main/vocab.json",
"onnx-gpt2",
)),
merges_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/gpt2/resolve/main/merges.txt",
"onnx-gpt2",
))),
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,
num_return_sequences: 1,
..Default::default()
})?;
let prompts = ["It was a very nice and sunny"];
let output = text_generation_model.generate(&prompts, None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "It was a very nice and sunny day. I was very happy with the weather. I was very happy with the weather. I was very happy with");
Ok(())
}
#[test]
fn onnx_translation() -> anyhow::Result<()> {
let translation_model = TranslationModel::new(TranslationConfig::new(
ModelType::M2M100,
ModelResource::ONNX(ONNXModelResources {
encoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
"onnx-m2m100_418M",
))),
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
"onnx-m2m100_418M",
))),
}),
RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
"onnx-m2m100_418M",
),
RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
"onnx-m2m100_418M",
),
Some(RemoteResource::new(
"https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
"onnx-m2m100_418M",
)),
M2M100SourceLanguages::M2M100_418M,
M2M100TargetLanguages::M2M100_418M,
Device::cuda_if_available(),
))?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::French,
)?);
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::Spanish,
)?);
outputs.extend(translation_model.translate(
&[source_sentence],
Language::English,
Language::Hindi,
)?);
assert_eq!(outputs.len(), 3);
assert_eq!(
outputs[0],
" Cette phrase sera traduite en plusieurs langues."
);
assert_eq!(outputs[1], " Esta frase se traducirá en varios idiomas.");
assert_eq!(outputs[2], " यह वाक्यांश कई भाषाओं में अनुवादित किया जाएगा।");
Ok(())
}
}

View File

@ -2,7 +2,7 @@ use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfig, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::generation_utils::Cache;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
@ -119,7 +119,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -161,7 +161,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -214,7 +214,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
@ -283,7 +283,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),

View File

@ -1,7 +1,7 @@
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::resources::RemoteResource;
use tch::Device;
@ -20,7 +20,7 @@ fn pegasus_summarization_greedy() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig {
model_type: ModelType::Pegasus,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -1,6 +1,6 @@
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
@ -22,7 +22,7 @@ fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
let summarization_config = SummarizationConfig {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
model_resource: ModelResource::Torch(weights_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -1,4 +1,4 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::reformer::{
ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering,
@ -45,7 +45,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
// Set-up translation model
let generation_config = TextGenerationConfig {
model_type: ModelType::Reformer,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,

View File

@ -1,4 +1,4 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
@ -319,7 +319,9 @@ fn roberta_question_answering() -> anyhow::Result<()> {
// Set-up question answering model
let config = QuestionAnsweringConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::ROBERTA_QA),
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
))),
RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA_QA),
RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA_QA),
Some(RemoteResource::from_pretrained(
@ -354,9 +356,9 @@ fn xlm_roberta_german_ner() -> anyhow::Result<()> {
// Set-up question answering model
let ner_config = TokenClassificationConfig {
model_type: ModelType::XLMRoberta,
model_resource: Box::new(RemoteResource::from_pretrained(
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE,
)),
))),
config_resource: Box::new(RemoteResource::from_pretrained(
RobertaConfigResources::XLM_ROBERTA_NER_DE,
)),

View File

@ -1,4 +1,4 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
@ -26,7 +26,7 @@ fn test_translation_t5() -> anyhow::Result<()> {
let translation_config = TranslationConfig::new(
ModelType::T5,
model_resource,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
None,
@ -65,7 +65,9 @@ fn test_summarization_t5() -> anyhow::Result<()> {
// Set-up translation model
let summarization_config = SummarizationConfig {
model_type: ModelType::T5,
model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
T5ModelResources::T5_SMALL,
))),
config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: None,

View File

@ -1,4 +1,4 @@
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{RemoteResource, ResourceProvider};
use rust_bert::xlnet::{
@ -208,7 +208,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
let generate_config = TextGenerationConfig {
model_type: ModelType::XLNet,
model_resource,
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: None,