mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-05 16:47:24 +03:00
ONNX Support (#346)
* Fixed Clippy warnings
* Revert "Shallow clone optimization (#243)"
This reverts commit ba584653bc
.
* updated dependencies
* tryouts
* GPT2 tryouts
* WIP GPT2
* input mapping
* Cache storage
* Initial GPT2 prototype
* Initial ONNX Config and decoder implementation
* ONNXDecoder first draft
* Use Decoders in example
* Automated tch-ort conversion, decoder implementation
* ONNXCausalDecoder implementation
* Refactored _get_var_store to be optional, added get_device to gen trait
* updated example
* Added decoder_start_token_id to ConfigOption
* Addition of ONNXModelConfig, make max_position_embeddigs optional
* Addition of forward pass function for ONNXModel
* Working ONNX causal decoder
* Simplify tensor conversion
* refactor translation to facilitate ONNX integration
* Implementation of ONNXEncoder
* Implementation of ONNXConditionalGenerator
* working ONNXCausalGenerator
* - Reworked model resources type for pipelines and generators
* Aligned ONNXConditionalGenerator with other generators to use GenerateConfig for creation
* Moved force_token_id_generation to common utils function, fixed tests, Translation implementation
* generalized forced_bos and forced_eos tokens generation
* Aligned the `encode_prompt_text` method across language models
* Fix prompt encoding for causal generation
* Fix prompt encoding for causal generation
* Support for ONNX models for SequenceClassification
* Support for ONNX models for TokenClassification
* Support for ONNX models for POS and NER pipelines
* Support for ONNX models for ZeroShotClassification pipeline
* Support for ONNX models for QuestionAnswering pipeline
* Support for ONNX models for MaskedLM pipeline
* Added token_type_ids , updated layer cache i/o parsing for ONNX pipelines
* Support for ONNX models for TextGenerationPipeline, updated examples for remote resources
* Remove ONNX zero-shot classification example (lack of correct pretrained model)
* Addition of tests for ONNX pipelines support
* Made onnx feature optional
* Fix device lookup with onnx feature enabled
* Updates from main branch
* Flexible tokenizer creation for M2M100 (NLLB support), make NLLB test optional du to their size
* Fixed Clippy warnings
* Addition of documentation for ONNX
* Added documentation for ONNX support
* upcoming tch 1.12 fixes
* Fix merge conflicts
* Fix merge conflicts (2)
* Add download libtorch feature to ONNX tests
* Add download-onnx feature
* attempt to enable onnx download
* add remote resources feature
* onnx download
* pin ort version
* Update ort version
This commit is contained in:
parent
81cde55b25
commit
540c9268e7
18
.github/workflows/continuous-integration.yml
vendored
18
.github/workflows/continuous-integration.yml
vendored
@ -140,6 +140,24 @@ jobs:
|
||||
--test nllb
|
||||
--features download-libtorch
|
||||
|
||||
test-onnx:
|
||||
name: Integration tests (ONNX models)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
override: true
|
||||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --package rust-bert
|
||||
--features onnx
|
||||
--test onnx
|
||||
--features download-libtorch
|
||||
|
||||
convert-model:
|
||||
name: Model conversion test
|
||||
runs-on: ubuntu-latest
|
||||
|
13
CHANGELOG.md
13
CHANGELOG.md
@ -7,11 +7,22 @@ All notable changes to this project will be documented in this file. The format
|
||||
- Addition of `add_tokens` and `add_extra_ids` interface methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
|
||||
- Addition of `get_tokenizer` and `get_tokenizer_mut` methods to all pipelines allowing to get a (mutable) reference to the pipeline tokenizer.
|
||||
- Addition of a `get_embedding_dim` method to get the dimension of the embeddings for sentence embeddings pipelines
|
||||
- `get_vocab_size`, `get_decoder_start_token_id` and `get_prefix_and_forced_bos_id` for the `TokenizerOption` in pipelines
|
||||
- Addition of the [GPT-J](https://www.eleuther.ai/artifacts/gpt-j) model architecture
|
||||
- Addition of the [NLLB](https://arxiv.org/abs/2207.04672) model architecture and pretrained weights
|
||||
- Addition of support for ONNX models (encoder, decoders, encoder-decoders) via the [ort](https://github.com/pykeio/ort) onnxruntime bindings
|
||||
- Integration of ONNX models to the sequence classification, token classification, question answering, zero-shot classification, text generation, summarization and translation pipelines
|
||||
|
||||
## Changed
|
||||
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
|
||||
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer
|
||||
- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator)
|
||||
- (BREAKING) Upgraded to `torch` 2.0 (via `tch` 0.13.0). The process to automatically download the dependencies have changed, it must now be enabled via the `download-libtorch` feature flag.
|
||||
- Read the `decoder_start_token_id` from the provided configuration rather than using a hard-coded default value
|
||||
- (BREAKING) Changed the return type of the `LanguageGenerator` and pipelines functions `float`, `half`, `set_device` to `Result<(), RustBertError>` as these become fallible for ONNX models
|
||||
- (BREAKING) Wrapped the model resources specification for the pipeline `Config` objects into an `Enum` to allow handling both torch-based and ONNX models.
|
||||
The `model_resources` field now needs to be wrapped in the corresponding enum variant, e.g. `model_resources: ModelResources::TORCH(model_resource)` for Torch-based models
|
||||
- (BREAKING) Added the `forced_bos_token_id` and `forced_eos_token_id` fields to text generation models.
|
||||
If these are not None, this will trigger a forced BOS/EOS token generation at the first of `max_length` positions (aligns with the Pytorch Transformers library)
|
||||
|
||||
## Fixed
|
||||
- MIN/MAX computation for float-like (was set to infinity instead of min/max)
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.20.1-alpha"
|
||||
version = "0.21.0-alpha"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
description = "Ready-to-use NLP pipelines and language models"
|
||||
@ -66,6 +66,7 @@ doc-only = ["tch/doc-only"]
|
||||
all-tests = []
|
||||
remote = ["cached-path", "dirs", "lazy_static"]
|
||||
download-libtorch = ["tch/download-libtorch"]
|
||||
onnx = ["ort", "ndarray"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc-only"]
|
||||
@ -84,6 +85,8 @@ regex = "1.6"
|
||||
cached-path = { version = "0.6", optional = true }
|
||||
dirs = { version = "4", optional = true }
|
||||
lazy_static = { version = "1", optional = true }
|
||||
ort = {version="1.14.8", optional = true, default-features = false, features = ["half"]}
|
||||
ndarray = {version="0.15", optional = true}
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
@ -93,3 +96,5 @@ tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
|
||||
torch-sys = "0.13.0"
|
||||
tempfile = "3"
|
||||
itertools = "0.10"
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
|
||||
ort = {version="1.14.8", features = ["load-dynamic"]}
|
31
README.md
31
README.md
@ -5,7 +5,7 @@
|
||||
[![Documentation](https://docs.rs/rust-bert/badge.svg)](https://docs.rs/rust-bert)
|
||||
![License](https://img.shields.io/crates/l/rust_bert.svg)
|
||||
|
||||
Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
|
||||
Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using [tch-rs](https://github.com/LaurentMazare/tch-rs) or [onnxruntime bindings](https://github.com/pykeio/ort) and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
|
||||
This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
|
||||
|
||||
Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
|
||||
@ -35,6 +35,7 @@ The tasks currently supported include:
|
||||
- Language Generation
|
||||
- Masked Language Model
|
||||
- Sentence Embeddings
|
||||
- Keywords extraction
|
||||
|
||||
<details>
|
||||
<summary> <b>Expand to display the supported models/tasks matrix </b> </summary>
|
||||
@ -51,10 +52,12 @@ RoBERTa|✅|✅|✅| | | |✅| ✅|
|
||||
GPT| | | |✅ | | | | |
|
||||
GPT2| | | |✅ | | | | |
|
||||
GPT-Neo| | | |✅ | | | | |
|
||||
GPT-J| | | |✅ | | | | |
|
||||
BART|✅| | |✅ |✅| | | |
|
||||
Marian| | | | | |✅| | |
|
||||
MBart|✅| | |✅ | | | | |
|
||||
M2M100| | | |✅ | | | | |
|
||||
NLLB| | | |✅ | | | | |
|
||||
Electra | |✅| | | | |✅| |
|
||||
ALBERT |✅|✅|✅| | | |✅| ✅ |
|
||||
T5 | | | |✅ |✅|✅| | ✅ |
|
||||
@ -116,6 +119,32 @@ cd rust-bert
|
||||
cargo run --example sentence_embeddings
|
||||
```
|
||||
|
||||
## ONNX Support (Optional)
|
||||
|
||||
ONNX support can be enabled via the optional `onnx` feature. This crate then leverages the [ort](https://github.com/pykeio/ort) crate with bindings to the onnxruntime C++ library. We refer the user to this page project for further installation instructions/support.
|
||||
1. Enable the optional `onnx` feature. The `rust-bert` crate does not include any optional dependencies for `ort`, the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library.
|
||||
2. The current recommended installation is to use dynamic linking by pointing to an existing library location. Use the `load-dynamic` cargo feature for `ort`.
|
||||
3. set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib` depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
|
||||
|
||||
Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library. A detailed guide on how to export a Transformer model to ONNX using optimum is available at https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model
|
||||
The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
|
||||
|
||||
| Architecture | Encoder file | Decoder without past file | Decoder with past file |
|
||||
|-----------------------------|---------------|---------------------------|-------------------------|
|
||||
| Encoder (e.g. BERT) | required | not used | not used |
|
||||
| Decoder (e.g. GPT2) | not used | required | optional |
|
||||
| Encoder-decoder (e.g. BART) | required | required | optional |
|
||||
|
||||
Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
|
||||
since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
|
||||
redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
|
||||
he base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
|
||||
|
||||
Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
|
||||
ost pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
|
||||
token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
|
||||
These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. Examples leveraging ONNX models are given in the `./examples` directory
|
||||
|
||||
## Ready-to-use pipelines
|
||||
|
||||
Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
|
||||
|
@ -5,7 +5,7 @@ use criterion::{black_box, Criterion};
|
||||
use rust_bert::gpt2::{
|
||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use std::time::{Duration, Instant};
|
||||
@ -14,7 +14,9 @@ use tch::Device;
|
||||
fn create_text_generation_model() -> TextGenerationModel {
|
||||
let config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::GPT2,
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
|
||||
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
|
||||
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
|
@ -3,7 +3,7 @@ extern crate criterion;
|
||||
|
||||
use criterion::{black_box, Criterion};
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
squad_processor, QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
};
|
||||
@ -17,7 +17,9 @@ static BATCH_SIZE: usize = 64;
|
||||
fn create_qa_model() -> QuestionAnsweringModel {
|
||||
let config = QuestionAnsweringConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_QA,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||
None, //merges resource only relevant with ModelType::Roberta
|
||||
@ -52,7 +54,9 @@ fn qa_load_model(iters: u64) -> Duration {
|
||||
let start = Instant::now();
|
||||
let config = QuestionAnsweringConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_QA,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||
None, //merges resource only relevant with ModelType::Roberta
|
||||
|
@ -17,6 +17,7 @@ use std::sync::{Arc, RwLock};
|
||||
use rust_bert::bart::{
|
||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelResource;
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider};
|
||||
use tch::Device;
|
||||
@ -80,7 +81,7 @@ fn config(device: Device, model_data: Arc<RwLock<Vec<u8>>>) -> SummarizationConf
|
||||
let merges_resource = Box::new(RemoteResource::from_pretrained(
|
||||
BartMergesResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
let model_resource = Box::new(BufferResource { data: model_data });
|
||||
let model_resource = ModelResource::Torch(Box::new(BufferResource { data: model_data }));
|
||||
|
||||
SummarizationConfig {
|
||||
model_resource,
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
|
||||
use rust_bert::pipelines::sequence_classification::{
|
||||
SequenceClassificationConfig, SequenceClassificationModel,
|
||||
@ -26,7 +26,9 @@ fn main() -> anyhow::Result<()> {
|
||||
// Language identification
|
||||
let sequence_classification_config = SequenceClassificationConfig::new(
|
||||
ModelType::Roberta,
|
||||
RemoteResource::from_pretrained(RobertaModelResources::CODEBERTA_LANGUAGE_ID),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::CODEBERTA_LANGUAGE_ID,
|
||||
))),
|
||||
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERTA_LANGUAGE_ID),
|
||||
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERTA_LANGUAGE_ID),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
@ -56,7 +58,9 @@ fn main() -> anyhow::Result<()> {
|
||||
// Masked language model
|
||||
let config = MaskedLanguageConfig::new(
|
||||
ModelType::Roberta,
|
||||
RemoteResource::from_pretrained(RobertaModelResources::CODEBERT_MLM),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::CODEBERT_MLM,
|
||||
))),
|
||||
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERT_MLM),
|
||||
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERT_MLM),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
|
@ -17,7 +17,7 @@ extern crate anyhow;
|
||||
use rust_bert::gpt_neo::{
|
||||
GptNeoConfigResources, GptNeoMergesResources, GptNeoModelResources, GptNeoVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPTNeo,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -53,7 +53,7 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let mut model = TextGenerationModel::new(generate_config)?;
|
||||
model.set_device(Device::cuda_if_available());
|
||||
model.set_device(Device::cuda_if_available())?;
|
||||
|
||||
let input_context_1 = "It was a very nice and sunny";
|
||||
let input_context_2 = "It was a gloom winter night, and";
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use rust_bert::gpt_j::{GptJConfigResources, GptJMergesResources, GptJVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::{LocalResource, RemoteResource};
|
||||
use tch::Device;
|
||||
@ -67,7 +67,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let generation_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPTJ,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::reformer::{
|
||||
ReformerConfigResources, ReformerModelResources, ReformerVocabResources,
|
||||
@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::Reformer,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
||||
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::XLNet,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -12,14 +12,16 @@
|
||||
|
||||
extern crate anyhow;
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let config = MaskedLanguageConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT),
|
||||
None,
|
||||
|
36
examples/onnx-masked-lm.rs
Normal file
36
examples/onnx-masked-lm.rs
Normal file
@ -0,0 +1,36 @@
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let masked_lm = MaskedLanguageModel::new(MaskedLanguageConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/model.onnx",
|
||||
"onnx-bert-base-uncased-for-masked-lm",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/config.json",
|
||||
"onnx-bert-base-uncased-for-masked-lm",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/vocab.txt",
|
||||
"onnx-bert-base-uncased-for-masked-lm",
|
||||
),
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
Some(String::from("<mask>")),
|
||||
))?;
|
||||
let input = [
|
||||
"Hello I am a <mask> student",
|
||||
"Paris is the <mask> of France. It is <mask> in Europe.",
|
||||
];
|
||||
let output = masked_lm.predict(input)?;
|
||||
println!("{:?}", output);
|
||||
Ok(())
|
||||
}
|
40
examples/onnx-question-answering.rs
Normal file
40
examples/onnx-question-answering.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let qa_model = QuestionAnsweringModel::new(QuestionAnsweringConfig::new(
|
||||
ModelType::Roberta,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/model.onnx",
|
||||
"onnx-roberta-base-squad2",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/config.json",
|
||||
"onnx-roberta-base-squad2",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/vocab.json",
|
||||
"onnx-roberta-base-squad2",
|
||||
),
|
||||
Some(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/merges.txt",
|
||||
"onnx-roberta-base-squad2",
|
||||
)),
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
))?;
|
||||
let question = String::from("Where does Amy live ?");
|
||||
let context = String::from("Amy lives in Amsterdam");
|
||||
let qa_input = QaInput { question, context };
|
||||
|
||||
let output = qa_model.predict(&[qa_input], 1, 32);
|
||||
println!("{:?}", output);
|
||||
Ok(())
|
||||
}
|
37
examples/onnx-sequence-classification.rs
Normal file
37
examples/onnx-sequence-classification.rs
Normal file
@ -0,0 +1,37 @@
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::sentiment::SentimentModel;
|
||||
use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let classification_model = SentimentModel::new(SequenceClassificationConfig::new(
|
||||
ModelType::DistilBert,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/model.onnx",
|
||||
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
|
||||
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
|
||||
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
|
||||
),
|
||||
None,
|
||||
true,
|
||||
None,
|
||||
None,
|
||||
))?;
|
||||
let input = [
|
||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
||||
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||
];
|
||||
let output = classification_model.predict(input);
|
||||
println!("{:?}", output);
|
||||
Ok(())
|
||||
}
|
42
examples/onnx-text-generation.rs
Normal file
42
examples/onnx-text-generation.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let text_generation_model = TextGenerationModel::new(TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource: ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: None,
|
||||
decoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_model.onnx",
|
||||
"onnx-gpt2",
|
||||
))),
|
||||
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_with_past_model.onnx",
|
||||
"onnx-gpt2",
|
||||
))),
|
||||
}),
|
||||
config_resource: Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/gpt2/resolve/main/config.json",
|
||||
"onnx-gpt2",
|
||||
)),
|
||||
vocab_resource: Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/gpt2/resolve/main/vocab.json",
|
||||
"onnx-gpt2",
|
||||
)),
|
||||
merges_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/gpt2/resolve/main/merges.txt",
|
||||
"onnx-gpt2",
|
||||
))),
|
||||
max_length: Some(30),
|
||||
do_sample: false,
|
||||
num_beams: 1,
|
||||
temperature: 1.0,
|
||||
num_return_sequences: 1,
|
||||
..Default::default()
|
||||
})?;
|
||||
let prompts = ["It was a very nice and sunny"];
|
||||
let output = text_generation_model.generate(&prompts, None);
|
||||
println!("{:?}", output);
|
||||
Ok(())
|
||||
}
|
36
examples/onnx-token-classification.rs
Normal file
36
examples/onnx-token-classification.rs
Normal file
@ -0,0 +1,36 @@
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::ner::NERModel;
|
||||
use rust_bert::pipelines::token_classification::{
|
||||
LabelAggregationOption, TokenClassificationConfig,
|
||||
};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let token_classification_model = NERModel::new(TokenClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-NER/resolve/main/model.onnx",
|
||||
"onnx-bert-base-NER",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-NER/resolve/main/config.json",
|
||||
"onnx-bert-base-NER",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-NER/resolve/main/vocab.txt",
|
||||
"onnx-bert-base-NER",
|
||||
),
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
LabelAggregationOption::First,
|
||||
))?;
|
||||
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
|
||||
let output = token_classification_model.predict_full_entities(&input);
|
||||
println!("{:?}", output);
|
||||
Ok(())
|
||||
}
|
63
examples/onnx-translation.rs
Normal file
63
examples/onnx-translation.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
|
||||
use tch::Device;
|
||||
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let translation_model = TranslationModel::new(TranslationConfig::new(
|
||||
ModelType::M2M100,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
decoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
|
||||
"onnx-m2m100_418M",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
|
||||
"onnx-m2m100_418M",
|
||||
),
|
||||
Some(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
|
||||
"onnx-m2m100_418M",
|
||||
)),
|
||||
M2M100SourceLanguages::M2M100_418M,
|
||||
M2M100TargetLanguages::M2M100_418M,
|
||||
Device::cuda_if_available(),
|
||||
))?;
|
||||
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(translation_model.translate(
|
||||
&[source_sentence],
|
||||
Language::English,
|
||||
Language::French,
|
||||
)?);
|
||||
outputs.extend(translation_model.translate(
|
||||
&[source_sentence],
|
||||
Language::English,
|
||||
Language::Spanish,
|
||||
)?);
|
||||
outputs.extend(translation_model.translate(
|
||||
&[source_sentence],
|
||||
Language::English,
|
||||
Language::Hindi,
|
||||
)?);
|
||||
|
||||
println!("{:?}", outputs);
|
||||
Ok(())
|
||||
}
|
@ -13,7 +13,7 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
};
|
||||
@ -23,7 +23,9 @@ fn main() -> anyhow::Result<()> {
|
||||
// Set-up Question Answering model
|
||||
let config = QuestionAnsweringConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT_QA),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_QA,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT_QA),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT_QA),
|
||||
None, //merges resource only relevant with ModelType::Roberta
|
||||
|
@ -16,7 +16,7 @@ use rust_bert::longformer::{
|
||||
LongformerConfigResources, LongformerMergesResources, LongformerModelResources,
|
||||
LongformerVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
};
|
||||
@ -26,7 +26,9 @@ fn main() -> anyhow::Result<()> {
|
||||
// Set-up Question Answering model
|
||||
let config = QuestionAnsweringConfig::new(
|
||||
ModelType::Longformer,
|
||||
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
||||
))),
|
||||
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
|
||||
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
|
@ -13,7 +13,7 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
@ -25,9 +25,9 @@ fn main() -> anyhow::Result<()> {
|
||||
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||
FNetVocabResources::BASE_SST2,
|
||||
));
|
||||
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
FNetModelResources::BASE_SST2,
|
||||
));
|
||||
)));
|
||||
|
||||
let sentiment_config = SentimentConfig {
|
||||
model_type: ModelType::FNet,
|
||||
|
@ -15,6 +15,7 @@ extern crate anyhow;
|
||||
use rust_bert::bart::{
|
||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelResource;
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
@ -34,7 +35,7 @@ fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
|
@ -13,7 +13,7 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
@ -31,7 +31,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_type: ModelType::Pegasus,
|
||||
model_resource: weights_resource,
|
||||
model_resource: ModelResource::Torch(weights_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::prophetnet::{
|
||||
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
||||
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_type: ModelType::ProphetNet,
|
||||
model_resource: weights_resource,
|
||||
model_resource: ModelResource::Torch(weights_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||
@ -24,7 +24,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let summarization_config = SummarizationConfig::new(
|
||||
ModelType::T5,
|
||||
weights_resource,
|
||||
ModelResource::Torch(Box::new(weights_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
None,
|
||||
|
@ -11,7 +11,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::ner::NERModel;
|
||||
use rust_bert::pipelines::token_classification::{
|
||||
LabelAggregationOption, TokenClassificationConfig,
|
||||
@ -22,7 +22,9 @@ fn main() -> anyhow::Result<()> {
|
||||
// Load a configuration
|
||||
let config = TokenClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_NER,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
|
||||
None, //merges resource only relevant with ModelType::Roberta
|
||||
|
@ -16,7 +16,7 @@ use rust_bert::m2m_100::{
|
||||
M2M100ConfigResources, M2M100MergesResources, M2M100ModelResources, M2M100SourceLanguages,
|
||||
M2M100TargetLanguages, M2M100VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::M2M100,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
Some(merges_resource),
|
||||
|
@ -17,7 +17,7 @@ use rust_bert::marian::{
|
||||
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
MarianTargetLanguages, MarianVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::Marian,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
Some(merges_resource),
|
||||
|
@ -16,7 +16,7 @@ use rust_bert::mbart::{
|
||||
MBartConfigResources, MBartModelResources, MBartSourceLanguages, MBartTargetLanguages,
|
||||
MBartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
@ -32,7 +32,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::MBart,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
None,
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||
@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::T5,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
None,
|
||||
|
@ -23,7 +23,6 @@ use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
};
|
||||
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
@ -100,12 +99,12 @@ impl BartConfigResources {
|
||||
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
|
||||
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
|
||||
"distilbart-cnn-6-6/config",
|
||||
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/config.json",
|
||||
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
|
||||
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
|
||||
"distilbart-cnn-12-6/config",
|
||||
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/config.json",
|
||||
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
@ -133,12 +132,12 @@ impl BartVocabResources {
|
||||
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
|
||||
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
|
||||
"distilbart-cnn-6-6/vocab",
|
||||
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/vocab.json",
|
||||
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/vocab.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
|
||||
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
|
||||
"distilbart-cnn-12-6/vocab",
|
||||
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/vocab.json",
|
||||
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
@ -166,12 +165,12 @@ impl BartMergesResources {
|
||||
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-6-6>. Modified with conversion to C-array format.
|
||||
pub const DISTILBART_CNN_6_6: (&'static str, &'static str) = (
|
||||
"distilbart-cnn-6-6/merges",
|
||||
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-6-6/merges.txt",
|
||||
"https://huggingface.co/sshleifer/distilbart-cnn-6-6/resolve/main/merges.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the Hugging Face team at <https://huggingface.co/sshleifer/distilbart-cnn-12-6>. Modified with conversion to C-array format.
|
||||
pub const DISTILBART_CNN_12_6: (&'static str, &'static str) = (
|
||||
"distilbart-cnn-12-6/merges",
|
||||
"https://cdn.huggingface.co/sshleifer/distilbart-cnn-12-6/merges.txt",
|
||||
"https://huggingface.co/sshleifer/distilbart-cnn-12-6/resolve/main/merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
@ -197,6 +196,8 @@ pub struct BartConfig {
|
||||
pub encoder_layers: i64,
|
||||
pub bos_token_id: Option<i64>,
|
||||
pub eos_token_id: Option<i64>,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub pad_token_id: Option<i64>,
|
||||
pub id2label: Option<HashMap<i64, String>>,
|
||||
pub label2id: Option<HashMap<String, i64>>,
|
||||
@ -240,6 +241,8 @@ impl Default for BartConfig {
|
||||
bos_token_id: Some(0),
|
||||
eos_token_id: Some(2),
|
||||
pad_token_id: Some(1),
|
||||
forced_bos_token_id: Some(0),
|
||||
forced_eos_token_id: Some(2),
|
||||
id2label: None,
|
||||
label2id: None,
|
||||
init_std: 0.02,
|
||||
@ -918,6 +921,8 @@ pub struct BartGenerator {
|
||||
generate_config: GenerateConfig,
|
||||
bos_token_id: Option<i64>,
|
||||
eos_token_ids: Option<Vec<i64>>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
forced_eos_token_id: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
is_encoder_decoder: bool,
|
||||
vocab_size: i64,
|
||||
@ -1006,10 +1011,12 @@ impl BartGenerator {
|
||||
Some(value) => vec![value],
|
||||
None => vec![2],
|
||||
});
|
||||
let forced_bos_token_id = config.forced_bos_token_id;
|
||||
let forced_eos_token_id = config.forced_eos_token_id;
|
||||
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(2);
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(BartGenerator {
|
||||
@ -1019,6 +1026,8 @@ impl BartGenerator {
|
||||
generate_config,
|
||||
bos_token_id,
|
||||
eos_token_ids,
|
||||
forced_bos_token_id,
|
||||
forced_eos_token_id,
|
||||
pad_token_id,
|
||||
is_encoder_decoder,
|
||||
vocab_size,
|
||||
@ -1026,14 +1035,6 @@ impl BartGenerator {
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
|
||||
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateLanguageGenerator for BartGenerator {
|
||||
@ -1043,11 +1044,11 @@ impl PrivateLanguageGenerator for BartGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -1058,6 +1059,12 @@ impl PrivateLanguageGenerator for BartGenerator {
|
||||
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
|
||||
self.eos_token_ids.as_ref()
|
||||
}
|
||||
fn get_forced_bos_token_id(&self) -> Option<i64> {
|
||||
self.forced_bos_token_id
|
||||
}
|
||||
fn get_forced_eos_token_id(&self) -> Option<i64> {
|
||||
self.forced_eos_token_id
|
||||
}
|
||||
fn get_pad_id(&self) -> Option<i64> {
|
||||
self.pad_token_id
|
||||
}
|
||||
@ -1070,8 +1077,8 @@ impl PrivateLanguageGenerator for BartGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
@ -1119,25 +1126,6 @@ impl PrivateLanguageGenerator for BartGenerator {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_scores_for_generation(
|
||||
&self,
|
||||
scores: &mut Tensor,
|
||||
current_length: i64,
|
||||
max_length: Option<i64>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) {
|
||||
if current_length == 1 {
|
||||
self.force_token_id_generation(
|
||||
scores,
|
||||
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
|
||||
);
|
||||
} else if let Some(max_length) = max_length {
|
||||
if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
|
||||
Some(self.model.encode(input_ids, attention_mask))
|
||||
}
|
||||
@ -1170,48 +1158,6 @@ impl PrivateLanguageGenerator for BartGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -1,3 +1,7 @@
|
||||
#[cfg(feature = "onnx")]
|
||||
use ndarray::ShapeError;
|
||||
#[cfg(feature = "onnx")]
|
||||
use ort::OrtError;
|
||||
use rust_tokenizers::error::TokenizerError;
|
||||
use tch::TchError;
|
||||
use thiserror::Error;
|
||||
@ -23,6 +27,14 @@ pub enum RustBertError {
|
||||
#[error("Value error: {0}")]
|
||||
ValueError(String),
|
||||
|
||||
#[error("Value error: {0}")]
|
||||
#[cfg(feature = "onnx")]
|
||||
OrtError(String),
|
||||
|
||||
#[error("Value error: {0}")]
|
||||
#[cfg(feature = "onnx")]
|
||||
NdArrayError(String),
|
||||
|
||||
#[error("Unsupported operation")]
|
||||
UnsupportedError,
|
||||
}
|
||||
@ -44,3 +56,16 @@ impl From<TchError> for RustBertError {
|
||||
RustBertError::TchError(error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
impl From<OrtError> for RustBertError {
|
||||
fn from(error: OrtError) -> Self {
|
||||
RustBertError::OrtError(error.to_string())
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
impl From<ShapeError> for RustBertError {
|
||||
fn from(error: ShapeError) -> Self {
|
||||
RustBertError::NdArrayError(error.to_string())
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ use std::path::PathBuf;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// # In-memory raw buffer resource
|
||||
#[derive(Debug)]
|
||||
pub struct BufferResource {
|
||||
/// The data representing the underlying resource
|
||||
pub data: Arc<RwLock<Vec<u8>>>,
|
||||
|
@ -3,7 +3,7 @@ use crate::resources::{Resource, ResourceProvider};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// # Local resource
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
#[derive(PartialEq, Eq, Debug, Clone)]
|
||||
pub struct LocalResource {
|
||||
/// Local path for the resource
|
||||
pub local_path: PathBuf,
|
||||
|
@ -25,6 +25,7 @@ mod local;
|
||||
use crate::common::error::RustBertError;
|
||||
pub use buffer::BufferResource;
|
||||
pub use local::LocalResource;
|
||||
use std::fmt::Debug;
|
||||
use std::ops::DerefMut;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::RwLockWriteGuard;
|
||||
@ -37,7 +38,7 @@ pub enum Resource<'a> {
|
||||
|
||||
/// # Resource Trait that can provide the location or data for the model, and location of
|
||||
/// configuration or vocabulary resources
|
||||
pub trait ResourceProvider: Send + Sync {
|
||||
pub trait ResourceProvider: Debug + Send + Sync {
|
||||
/// Provides the local path for a resource.
|
||||
///
|
||||
/// # Returns
|
||||
|
@ -6,7 +6,7 @@ use lazy_static::lazy_static;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// # Remote resource that will be downloaded and cached locally on demand
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub struct RemoteResource {
|
||||
/// Remote path/url for the resource
|
||||
pub url: String,
|
||||
|
@ -88,6 +88,7 @@ pub struct FNetConfig {
|
||||
pub pad_token_id: Option<i64>,
|
||||
pub bos_token_id: Option<i64>,
|
||||
pub eos_token_id: Option<i64>,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub id2label: Option<HashMap<i64, String>>,
|
||||
pub label2id: Option<HashMap<String, i64>>,
|
||||
pub output_attentions: Option<bool>,
|
||||
@ -112,6 +113,7 @@ impl Default for FNetConfig {
|
||||
pad_token_id: Some(3),
|
||||
bos_token_id: Some(1),
|
||||
eos_token_id: Some(2),
|
||||
decoder_start_token_id: None,
|
||||
id2label: None,
|
||||
label2id: None,
|
||||
output_attentions: None,
|
||||
|
@ -26,7 +26,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::nn::embedding;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
/// # GPT2 Pretrained model weight files
|
||||
pub struct Gpt2ModelResources;
|
||||
@ -194,6 +194,9 @@ pub struct Gpt2Config {
|
||||
pub output_hidden_states: Option<bool>,
|
||||
pub resid_pdrop: Option<f64>,
|
||||
pub vocab_size: i64,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl Config for Gpt2Config {}
|
||||
@ -218,6 +221,9 @@ impl Default for Gpt2Config {
|
||||
output_hidden_states: None,
|
||||
resid_pdrop: Some(0.1),
|
||||
vocab_size: 50257,
|
||||
decoder_start_token_id: None,
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -654,7 +660,7 @@ impl GPT2Generator {
|
||||
let max_position_embeddings = config.n_positions;
|
||||
let is_encoder_decoder = false;
|
||||
let vocab_size = config.vocab_size;
|
||||
let decoder_start_id = None;
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
|
||||
Ok(GPT2Generator {
|
||||
model,
|
||||
@ -679,11 +685,11 @@ impl PrivateLanguageGenerator for GPT2Generator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -706,8 +712,8 @@ impl PrivateLanguageGenerator for GPT2Generator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
|
@ -135,6 +135,9 @@ pub struct GptJConfig {
|
||||
pub use_float16: bool,
|
||||
#[serde(default = "default_preload_on_cpu")]
|
||||
pub preload_on_cpu: bool,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl Config for GptJConfig {}
|
||||
@ -163,6 +166,9 @@ impl Default for GptJConfig {
|
||||
scale_attn_weights: Some(true),
|
||||
use_float16: default_use_float16(),
|
||||
preload_on_cpu: default_preload_on_cpu(),
|
||||
decoder_start_token_id: None,
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -630,7 +636,7 @@ impl GptJGenerator {
|
||||
let max_position_embeddings = config.n_positions;
|
||||
let is_encoder_decoder = false;
|
||||
let vocab_size = config.vocab_size;
|
||||
let decoder_start_id = None;
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
|
||||
Ok(GptJGenerator {
|
||||
model,
|
||||
@ -655,11 +661,11 @@ impl PrivateLanguageGenerator for GptJGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -682,8 +688,8 @@ impl PrivateLanguageGenerator for GptJGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
|
@ -22,7 +22,7 @@ use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, L
|
||||
use crate::{Activation, Config, RustBertError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
/// # GPT-Neo Pretrained model weight files
|
||||
pub struct GptNeoModelResources;
|
||||
@ -127,6 +127,8 @@ pub struct GptNeoConfig {
|
||||
pub intermediate_size: Option<i64>,
|
||||
pub bos_token_id: i64,
|
||||
pub eos_token_id: i64,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub vocab_size: i64,
|
||||
pub num_layers: i64,
|
||||
pub num_heads: i64,
|
||||
@ -140,6 +142,7 @@ pub struct GptNeoConfig {
|
||||
pub output_attentions: Option<bool>,
|
||||
pub output_hidden_states: Option<bool>,
|
||||
pub resid_dropout: f64,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl Config for GptNeoConfig {}
|
||||
@ -162,6 +165,8 @@ impl Default for GptNeoConfig {
|
||||
intermediate_size: None,
|
||||
bos_token_id: 50256,
|
||||
eos_token_id: 50256,
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
vocab_size: 50257,
|
||||
num_layers: 24,
|
||||
num_heads: 16,
|
||||
@ -175,6 +180,7 @@ impl Default for GptNeoConfig {
|
||||
output_attentions: None,
|
||||
output_hidden_states: None,
|
||||
resid_dropout: 0.0,
|
||||
decoder_start_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -673,7 +679,7 @@ impl GptNeoGenerator {
|
||||
let pad_token_id = tokenizer.get_pad_id();
|
||||
let is_encoder_decoder = false;
|
||||
let vocab_size = config.vocab_size;
|
||||
let decoder_start_id = None;
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(GptNeoGenerator {
|
||||
@ -699,11 +705,11 @@ impl PrivateLanguageGenerator for GptNeoGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -727,8 +733,8 @@ impl PrivateLanguageGenerator for GptNeoGenerator {
|
||||
self.decoder_start_id
|
||||
}
|
||||
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
|
@ -26,6 +26,7 @@
|
||||
//! use tch::Device;
|
||||
//!
|
||||
//! fn main() -> anyhow::Result<()> {
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||
//! GptNeoConfigResources::GPT_NEO_1_3B,
|
||||
//! ));
|
||||
@ -41,7 +42,7 @@
|
||||
//!
|
||||
//! let text_generation_config = TextGenerationConfig {
|
||||
//! model_type: ModelType::GPTNeo,
|
||||
//! model_resource,
|
||||
//! model_resource: ModelResource::Torch(model_resource),
|
||||
//! config_resource,
|
||||
//! vocab_resource,
|
||||
//! merges_resource: Some(merges_resource),
|
||||
|
31
src/lib.rs
31
src/lib.rs
@ -1,6 +1,6 @@
|
||||
//! # Ready-to-use NLP pipelines and Transformer-based models
|
||||
//!
|
||||
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
|
||||
//! Rust-native state-of-the-art Natural Language Processing models and pipelines. Port of Hugging Face's [Transformers library](https://github.com/huggingface/transformers), using [tch-rs](https://github.com/LaurentMazare/tch-rs) or [onnxruntime bindings](https://github.com/pykeio/ort) and pre-processing from [rust-tokenizers](https://github.com/guillaume-be/rust-tokenizers). Supports multi-threaded tokenization and GPU inference.
|
||||
//! This repository exposes the model base architecture, task-specific heads (see below) and [ready-to-use pipelines](#ready-to-use-pipelines). [Benchmarks](#benchmarks) are available at the end of this document.
|
||||
//!
|
||||
//! Get started with tasks including question answering, named entity recognition, translation, summarization, text generation, conversational agents and more in just a few lines of code:
|
||||
@ -42,6 +42,7 @@
|
||||
//! - Language Generation
|
||||
//! - Sentence Embeddings
|
||||
//! - Masked Language Model
|
||||
//! - Keywords extraction
|
||||
//!
|
||||
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
|
||||
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
|
||||
@ -61,10 +62,12 @@
|
||||
//!GPT| | | |✅ | | | | |
|
||||
//!GPT2| | | |✅ | | | | |
|
||||
//!GPT-Neo| | | |✅ | | | | |
|
||||
//!GPT-J| | | |✅ | | | | |
|
||||
//!BART|✅| | |✅ |✅| | | |
|
||||
//!Marian| | | | | |✅| | |
|
||||
//!MBart|✅| | |✅ | | | | |
|
||||
//!M2M100| | | |✅ | | | | |
|
||||
//!NLLB| | | |✅ | | | | |
|
||||
//!Electra | |✅| | | | |✅| |
|
||||
//!ALBERT |✅|✅|✅| | | |✅| ✅ |
|
||||
//!T5 | | | |✅ |✅|✅| | ✅ |
|
||||
@ -109,6 +112,32 @@
|
||||
//! The CPU version of libtorch will be downloaded by default. To download a CUDA version, please set the environment variable `TORCH_CUDA_VERSION` to `cu118`.
|
||||
//! Note that the libtorch library is large (order of several GBs for the CUDA-enabled version) and the first build may therefore take several minutes to complete.
|
||||
//!
|
||||
//! ## ONNX Support (Optional)
|
||||
//!
|
||||
//! ONNX support can be enabled via the optional `onnx` feature. This crate then leverages the [ort](https://github.com/pykeio/ort) crate with bindings to the onnxruntime C++ library. We refer the user to this page project for further installation instructions/support.
|
||||
//! 1. Enable the optional `onnx` feature. The `rust-bert` crate does not include any optional dependencies for `ort`, the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library.
|
||||
//! 2. The current recommended installation is to use dynamic linking by pointing to an existing library location. Use the `load-dynamic` cargo feature for `ort`.
|
||||
//! 3. set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib` depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
|
||||
//!
|
||||
//! Most architectures (including encoders, decoders and encoder-decoders) are supported. the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library. A detailed guide on how to export a Transformer model to ONNX using optimum is available at <https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model>
|
||||
//! The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
|
||||
//!
|
||||
//! | Architecture | Encoder file | Decoder without past file | Decoder with past file |
|
||||
//! -----------------------------|---------------|---------------------------|-------------------------
|
||||
//! | Encoder (e.g. BERT) | required | not used | not used |
|
||||
//! | Decoder (e.g. GPT2) | not used | required | optional |
|
||||
//! | Encoder-decoder (e.g. BART) | required | required | optional |
|
||||
//!
|
||||
//! Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
|
||||
//! since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
|
||||
//! redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
|
||||
//! he base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
|
||||
//!
|
||||
//! Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
|
||||
//! ost pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
|
||||
//! token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
|
||||
//! These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. Examples leveraging ONNX models are given in the `./examples` directory. More information on these can be found in the [`onnx` module](./pipelines/onnx/index.html)
|
||||
//!
|
||||
//! # Ready-to-use pipelines
|
||||
//!
|
||||
//! Based on Hugging Face's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. More information on these can be found in the [`pipelines` module](./pipelines/index.html)
|
||||
|
@ -31,11 +31,12 @@
|
||||
//!
|
||||
//! fn main() -> anyhow::Result<()> {
|
||||
//! // Set-up Question Answering model
|
||||
//! let config = QuestionAnsweringConfig::new(
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let config = QuestionAnsweringConfig::new(
|
||||
//! ModelType::Longformer,
|
||||
//! RemoteResource::from_pretrained(
|
||||
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
//! LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
||||
//! ),
|
||||
//! ))),
|
||||
//! RemoteResource::from_pretrained(
|
||||
//! LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
|
||||
//! ),
|
||||
|
@ -19,11 +19,10 @@ use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
|
||||
use crate::t5::{FeedForwardProj, T5Config, T5ModelOutput, TaskSpecificParams};
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{embedding, LinearConfig};
|
||||
use tch::{nn, Tensor};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
/// # LongT5 Pretrained model weight files
|
||||
pub struct LongT5ModelResources;
|
||||
@ -79,6 +78,8 @@ pub struct LongT5Config {
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub bos_token_id: Option<i64>,
|
||||
pub eos_token_id: Option<i64>,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub initializer_factor: f64,
|
||||
pub is_encoder_decoder: Option<bool>,
|
||||
pub layer_norm_epsilon: f64,
|
||||
@ -112,6 +113,8 @@ impl Default for LongT5Config {
|
||||
decoder_start_token_id: None,
|
||||
bos_token_id: None,
|
||||
eos_token_id: Some(1),
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
initializer_factor: 1.0,
|
||||
is_encoder_decoder: None,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
@ -145,6 +148,8 @@ impl From<&LongT5Config> for T5Config {
|
||||
decoder_start_token_id: val.decoder_start_token_id,
|
||||
bos_token_id: None,
|
||||
eos_token_id: val.eos_token_id,
|
||||
forced_bos_token_id: val.forced_bos_token_id,
|
||||
forced_eos_token_id: val.forced_eos_token_id,
|
||||
initializer_factor: val.initializer_factor,
|
||||
is_encoder_decoder: val.is_encoder_decoder,
|
||||
layer_norm_epsilon: val.layer_norm_epsilon,
|
||||
@ -600,7 +605,7 @@ impl LongT5Generator {
|
||||
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = pad_token_id;
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
// longT5 do not have an embedding matrix for position IDs and relies on relative positions instead
|
||||
let max_position_embeddings = i64::MAX;
|
||||
|
||||
@ -627,11 +632,11 @@ impl PrivateLanguageGenerator for LongT5Generator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -654,8 +659,8 @@ impl PrivateLanguageGenerator for LongT5Generator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
@ -738,48 +743,6 @@ impl PrivateLanguageGenerator for LongT5Generator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -14,17 +14,16 @@ use crate::m2m_100::decoder::M2M100Decoder;
|
||||
use crate::m2m_100::encoder::M2M100Encoder;
|
||||
use crate::m2m_100::LayerState;
|
||||
use crate::mbart::{MBartConfig, MBartModelOutput};
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::TokenizerOption;
|
||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
PreparedInput, PrivateLanguageGenerator,
|
||||
};
|
||||
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
/// # M2M100 Pretrained model weight files
|
||||
pub struct M2M100ModelResources;
|
||||
@ -522,7 +521,7 @@ impl M2M100Generator {
|
||||
.get_local_path()?;
|
||||
|
||||
let tokenizer = TokenizerOption::from_file(
|
||||
ModelType::M2M100,
|
||||
generate_config.model_type,
|
||||
vocab_path.to_str().unwrap(),
|
||||
Some(merges_path.to_str().unwrap()),
|
||||
false,
|
||||
@ -555,7 +554,7 @@ impl M2M100Generator {
|
||||
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(2);
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(M2M100Generator {
|
||||
@ -572,14 +571,6 @@ impl M2M100Generator {
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
|
||||
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateLanguageGenerator for M2M100Generator {
|
||||
@ -589,11 +580,11 @@ impl PrivateLanguageGenerator for M2M100Generator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -616,8 +607,8 @@ impl PrivateLanguageGenerator for M2M100Generator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
@ -665,22 +656,6 @@ impl PrivateLanguageGenerator for M2M100Generator {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_scores_for_generation(
|
||||
&self,
|
||||
scores: &mut Tensor,
|
||||
current_length: i64,
|
||||
max_length: Option<i64>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) {
|
||||
if current_length == 1 {
|
||||
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
|
||||
} else if let Some(max_length) = max_length {
|
||||
if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
|
||||
Some(self.model.encode(input_ids, attention_mask))
|
||||
}
|
||||
@ -713,48 +688,6 @@ impl PrivateLanguageGenerator for M2M100Generator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -14,15 +14,14 @@
|
||||
use crate::bart::{BartConfig, BartModel, BartModelOutput, LayerState};
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
PreparedInput, PrivateLanguageGenerator,
|
||||
force_token_id_generation, PreparedInput, PrivateLanguageGenerator,
|
||||
};
|
||||
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::Init;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
/// # Marian Pretrained model weight files
|
||||
pub struct MarianModelResources;
|
||||
@ -773,10 +772,10 @@ impl MarianGenerator {
|
||||
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id =
|
||||
Some(tokenizer.get_pad_id().ok_or(RustBertError::TokenizerError(
|
||||
"The tokenizer must contain a pad token ID to be used as BOS".to_string(),
|
||||
))?);
|
||||
let decoder_start_id = match config.decoder_start_token_id {
|
||||
Some(start_token_id) => Some(start_token_id),
|
||||
None => pad_token_id,
|
||||
};
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(MarianGenerator {
|
||||
@ -793,14 +792,6 @@ impl MarianGenerator {
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
|
||||
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateLanguageGenerator for MarianGenerator {
|
||||
@ -810,11 +801,11 @@ impl PrivateLanguageGenerator for MarianGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -837,8 +828,8 @@ impl PrivateLanguageGenerator for MarianGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
@ -901,7 +892,11 @@ impl PrivateLanguageGenerator for MarianGenerator {
|
||||
);
|
||||
if let Some(max_length) = max_length {
|
||||
if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
force_token_id_generation(
|
||||
scores,
|
||||
self.get_eos_ids().as_ref().unwrap(),
|
||||
self.get_vocab_size(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -938,48 +933,6 @@ impl PrivateLanguageGenerator for MarianGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -22,13 +22,12 @@ use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::{Activation, Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::nn::{embedding, EmbeddingConfig, Init};
|
||||
use tch::{nn, Tensor};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
/// # MBART Pretrained model weight files
|
||||
pub struct MBartModelResources;
|
||||
@ -99,6 +98,7 @@ pub struct MBartConfig {
|
||||
pub bos_token_id: Option<i64>,
|
||||
pub eos_token_id: Option<i64>,
|
||||
pub pad_token_id: Option<i64>,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub id2label: Option<HashMap<i64, String>>,
|
||||
@ -138,6 +138,7 @@ impl Default for MBartConfig {
|
||||
bos_token_id: Some(0),
|
||||
eos_token_id: Some(2),
|
||||
pad_token_id: Some(1),
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: Some(2),
|
||||
decoder_start_token_id: None,
|
||||
id2label: None,
|
||||
@ -725,6 +726,7 @@ pub struct MBartGenerator {
|
||||
generate_config: GenerateConfig,
|
||||
bos_token_id: Option<i64>,
|
||||
eos_token_ids: Option<Vec<i64>>,
|
||||
forced_eos_token_id: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
is_encoder_decoder: bool,
|
||||
vocab_size: i64,
|
||||
@ -805,10 +807,11 @@ impl MBartGenerator {
|
||||
Some(value) => vec![value],
|
||||
None => vec![2],
|
||||
});
|
||||
let forced_eos_token_id = config.forced_eos_token_id;
|
||||
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(2);
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(MBartGenerator {
|
||||
@ -818,6 +821,7 @@ impl MBartGenerator {
|
||||
generate_config,
|
||||
bos_token_id,
|
||||
eos_token_ids,
|
||||
forced_eos_token_id,
|
||||
pad_token_id,
|
||||
is_encoder_decoder,
|
||||
vocab_size,
|
||||
@ -825,14 +829,6 @@ impl MBartGenerator {
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
|
||||
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateLanguageGenerator for MBartGenerator {
|
||||
@ -842,11 +838,11 @@ impl PrivateLanguageGenerator for MBartGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -857,6 +853,9 @@ impl PrivateLanguageGenerator for MBartGenerator {
|
||||
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
|
||||
self.eos_token_ids.as_ref()
|
||||
}
|
||||
fn get_forced_eos_token_id(&self) -> Option<i64> {
|
||||
self.forced_eos_token_id
|
||||
}
|
||||
fn get_pad_id(&self) -> Option<i64> {
|
||||
self.pad_token_id
|
||||
}
|
||||
@ -915,24 +914,8 @@ impl PrivateLanguageGenerator for MBartGenerator {
|
||||
})
|
||||
}
|
||||
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
}
|
||||
|
||||
fn prepare_scores_for_generation(
|
||||
&self,
|
||||
scores: &mut Tensor,
|
||||
current_length: i64,
|
||||
max_length: Option<i64>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) {
|
||||
if current_length == 1 {
|
||||
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
|
||||
} else if let Some(max_length) = max_length {
|
||||
if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
|
||||
@ -967,48 +950,6 @@ impl PrivateLanguageGenerator for MBartGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -24,7 +24,7 @@ use crate::{Config, RustBertError};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::nn::embedding;
|
||||
use tch::{nn, Tensor};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
/// # GPT Pretrained model weight files
|
||||
pub struct OpenAiGptModelResources;
|
||||
@ -505,7 +505,7 @@ impl OpenAIGenerator {
|
||||
let pad_token_id = tokenizer.get_pad_id();
|
||||
let is_encoder_decoder = false;
|
||||
let vocab_size = config.vocab_size;
|
||||
let decoder_start_id = None;
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.n_positions;
|
||||
|
||||
Ok(OpenAIGenerator {
|
||||
@ -531,11 +531,11 @@ impl PrivateLanguageGenerator for OpenAIGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -558,8 +558,8 @@ impl PrivateLanguageGenerator for OpenAIGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
|
@ -11,7 +11,6 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::BartModelOutput;
|
||||
use crate::common::kind::get_negative_infinity;
|
||||
use crate::mbart::MBartConfig;
|
||||
use crate::pegasus::decoder::PegasusDecoder;
|
||||
use crate::pegasus::encoder::PegasusEncoder;
|
||||
@ -22,10 +21,9 @@ use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
};
|
||||
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
|
||||
use crate::{Config, RustBertError};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{embedding, EmbeddingConfig, Init};
|
||||
use tch::{nn, Tensor};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
/// # Pegasus Pretrained model weight files
|
||||
pub struct PegasusModelResources;
|
||||
@ -510,14 +508,13 @@ impl PegasusConditionalGenerator {
|
||||
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
|
||||
|
||||
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
Some(value) => vec![value],
|
||||
None => vec![1],
|
||||
});
|
||||
let eos_token_ids = config
|
||||
.eos_token_id
|
||||
.map_or(Some(vec![1]), |value| Some(vec![value]));
|
||||
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(0);
|
||||
let decoder_start_id = config.decoder_start_token_id.or(Some(0));
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(PegasusConditionalGenerator {
|
||||
@ -534,18 +531,6 @@ impl PegasusConditionalGenerator {
|
||||
max_position_embeddings,
|
||||
})
|
||||
}
|
||||
|
||||
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
|
||||
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(
|
||||
1,
|
||||
&impossible_tokens,
|
||||
get_negative_infinity(scores.kind()).unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateLanguageGenerator for PegasusConditionalGenerator {
|
||||
@ -555,11 +540,11 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -582,8 +567,8 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
@ -630,20 +615,6 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_scores_for_generation(
|
||||
&self,
|
||||
scores: &mut Tensor,
|
||||
current_length: i64,
|
||||
max_length: Option<i64>,
|
||||
_forced_bos_token_id: Option<i64>,
|
||||
) {
|
||||
if let Some(max_length) = max_length {
|
||||
if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
|
||||
Some(self.model.encode(input_ids, attention_mask))
|
||||
}
|
||||
@ -676,51 +647,6 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
._get_tokenizer()
|
||||
.get_pad_id()
|
||||
.expect("A padding token must be provided to encode prompt texts."),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -36,8 +36,10 @@ use crate::mbart::MBartConfig;
|
||||
use crate::mobilebert::MobileBertConfig;
|
||||
use crate::openai_gpt::OpenAiGptConfig;
|
||||
use crate::pegasus::PegasusConfig;
|
||||
use crate::pipelines::translation::Language;
|
||||
use crate::prophetnet::ProphetNetConfig;
|
||||
use crate::reformer::ReformerConfig;
|
||||
use crate::resources::{Resource, ResourceProvider};
|
||||
use crate::roberta::RobertaConfig;
|
||||
use crate::t5::T5Config;
|
||||
use crate::xlnet::XLNetConfig;
|
||||
@ -52,9 +54,104 @@ use rust_tokenizers::tokenizer::{
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput, TokensWithOffsets};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::convert::TryFrom;
|
||||
use std::path::Path;
|
||||
use std::fmt::Debug;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tch::{Device, Kind, Tensor};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::ONNXModelConfig;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
/// Container for ONNX model resources, containing 3 optional resources (Encoder, Decoder and Decoder with past)
|
||||
pub struct ONNXModelResources {
|
||||
/// Model encoder resource
|
||||
pub encoder_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
/// Model encoder resource
|
||||
pub decoder_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
/// Model encoder resource
|
||||
pub decoder_with_past_resource: Option<Box<dyn ResourceProvider + Send>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Variants to store either a Torch model resource or ONNX resources
|
||||
pub enum ModelResource {
|
||||
Torch(Box<dyn ResourceProvider + Send>),
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXModelResources),
|
||||
}
|
||||
|
||||
impl ResourceProvider for ModelResource {
|
||||
fn get_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||
match self {
|
||||
ModelResource::Torch(ref resource) => resource.get_local_path(),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
|
||||
}
|
||||
}
|
||||
fn get_resource(&self) -> Result<Resource, RustBertError> {
|
||||
match self {
|
||||
ModelResource::Torch(ref resource) => resource.get_resource(),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Err(RustBertError::UnsupportedError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ONNXLocalPaths {
|
||||
pub encoder_path: Option<PathBuf>,
|
||||
pub decoder_path: Option<PathBuf>,
|
||||
pub decoder_with_past_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl ModelResource {
|
||||
/// Provides the torch resource local path.
|
||||
/// Returns an error if the variant is not a `ModelResources::TORCH`
|
||||
pub fn get_torch_local_path(&self) -> Result<PathBuf, RustBertError> {
|
||||
match self {
|
||||
ModelResource::Torch(torch_resource) => torch_resource.get_local_path(),
|
||||
#[cfg(feature = "onnx")]
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the Torch local path but other weights variants were given: {:?}", self)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn get_onnx_local_paths(&self) -> Result<ONNXLocalPaths, RustBertError> {
|
||||
let (encoder_path, decoder_path, decoder_with_past_path) = match self {
|
||||
ModelResource::ONNX(onnx_model_resources) => Ok((
|
||||
onnx_model_resources
|
||||
.encoder_resource.as_ref()
|
||||
.map(|r| r.get_local_path()),
|
||||
onnx_model_resources
|
||||
.decoder_resource.as_ref()
|
||||
.map(|r| r.get_local_path()),
|
||||
onnx_model_resources
|
||||
.decoder_with_past_resource.as_ref()
|
||||
.map(|r| r.get_local_path()),
|
||||
)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!("Attempting to get the ONNX local paths but other weights variants were given: {:?}", self)))
|
||||
}?;
|
||||
Ok(ONNXLocalPaths {
|
||||
encoder_path: encoder_path.transpose()?,
|
||||
decoder_path: decoder_path.transpose()?,
|
||||
decoder_with_past_path: decoder_with_past_path.transpose()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_device(_model_resource: ModelResource, device: Device) -> Device {
|
||||
#[cfg(feature = "onnx")]
|
||||
let device = if let ModelResource::ONNX(_) = _model_resource {
|
||||
Device::Cpu
|
||||
} else {
|
||||
device
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "onnx"))]
|
||||
let device = device;
|
||||
device
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
/// # Identifies the type of model
|
||||
@ -92,6 +189,8 @@ pub enum ModelType {
|
||||
#[serde(alias = "m2m100")]
|
||||
NLLB,
|
||||
FNet,
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX,
|
||||
}
|
||||
|
||||
/// # Abstraction that holds a model configuration, can be of any of the supported models
|
||||
@ -144,6 +243,9 @@ pub enum ConfigOption {
|
||||
M2M100(M2M100Config),
|
||||
/// FNet configuration
|
||||
FNet(FNetConfig),
|
||||
/// ONNX Model configuration
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXModelConfig),
|
||||
}
|
||||
|
||||
/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
|
||||
@ -220,6 +322,8 @@ impl ConfigOption {
|
||||
ConfigOption::M2M100(M2M100Config::from_file(path))
|
||||
}
|
||||
ModelType::FNet => ConfigOption::FNet(FNetConfig::from_file(path)),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelType::ONNX => ConfigOption::ONNX(ONNXModelConfig::from_file(path)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -293,6 +397,11 @@ impl ConfigOption {
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(config) => config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::T5(_) => panic!("T5 does not use a label mapping"),
|
||||
Self::LongT5(_) => panic!("LongT5 does not use a label mapping"),
|
||||
Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"),
|
||||
@ -329,6 +438,132 @@ impl ConfigOption {
|
||||
Self::M2M100(config) => Some(config.max_position_embeddings),
|
||||
Self::FNet(config) => Some(config.max_position_embeddings),
|
||||
Self::Roberta(config) => Some(config.max_position_embeddings),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(config) => config.max_position_embeddings,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_vocab_size(&self) -> i64 {
|
||||
match self {
|
||||
Self::Bart(config) => config.vocab_size,
|
||||
Self::Bert(config) => config.vocab_size,
|
||||
Self::Deberta(config) => config.vocab_size,
|
||||
Self::DebertaV2(config) => config.vocab_size,
|
||||
Self::DistilBert(config) => config.vocab_size,
|
||||
Self::Electra(config) => config.vocab_size,
|
||||
Self::Marian(config) => config.vocab_size,
|
||||
Self::MobileBert(config) => config.vocab_size,
|
||||
Self::T5(config) => config.vocab_size,
|
||||
Self::LongT5(config) => config.vocab_size,
|
||||
Self::Albert(config) => config.vocab_size,
|
||||
Self::XLNet(config) => config.vocab_size,
|
||||
Self::GPT2(config) => config.vocab_size,
|
||||
Self::GPTJ(config) => config.vocab_size,
|
||||
Self::Reformer(config) => config.vocab_size,
|
||||
Self::ProphetNet(config) => config.vocab_size,
|
||||
Self::Longformer(config) => config.vocab_size,
|
||||
Self::Pegasus(config) => config.vocab_size,
|
||||
Self::OpenAiGpt(config) => config.vocab_size,
|
||||
Self::GPTNeo(config) => config.vocab_size,
|
||||
Self::MBart(config) => config.vocab_size,
|
||||
Self::M2M100(config) => config.vocab_size,
|
||||
Self::FNet(config) => config.vocab_size,
|
||||
Self::Roberta(config) => config.vocab_size,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(config) => config.vocab_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_decoder_start_token_id(&self) -> Option<i64> {
|
||||
match self {
|
||||
Self::Bart(config) => config.decoder_start_token_id,
|
||||
Self::Bert(_) => None,
|
||||
Self::Deberta(_) => None,
|
||||
Self::DebertaV2(_) => None,
|
||||
Self::DistilBert(_) => None,
|
||||
Self::Electra(_) => None,
|
||||
Self::Marian(config) => config.decoder_start_token_id,
|
||||
Self::MobileBert(_) => None,
|
||||
Self::T5(config) => config.decoder_start_token_id,
|
||||
Self::LongT5(config) => config.decoder_start_token_id,
|
||||
Self::Albert(_) => None,
|
||||
Self::XLNet(_) => None,
|
||||
Self::GPT2(config) => config.decoder_start_token_id,
|
||||
Self::GPTJ(config) => config.decoder_start_token_id,
|
||||
Self::Reformer(config) => config.decoder_start_token_id,
|
||||
Self::ProphetNet(config) => config.decoder_start_token_id,
|
||||
Self::Longformer(_) => None,
|
||||
Self::Pegasus(config) => config.decoder_start_token_id,
|
||||
Self::OpenAiGpt(config) => config.decoder_start_token_id,
|
||||
Self::GPTNeo(config) => config.decoder_start_token_id,
|
||||
Self::MBart(config) => config.decoder_start_token_id,
|
||||
Self::M2M100(config) => config.decoder_start_token_id,
|
||||
Self::FNet(config) => config.decoder_start_token_id,
|
||||
Self::Roberta(_) => None,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(config) => config.decoder_start_token_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_forced_bos_token_id(&self) -> Option<i64> {
|
||||
match self {
|
||||
Self::Bart(config) => config.forced_bos_token_id,
|
||||
Self::Bert(_) => None,
|
||||
Self::Deberta(_) => None,
|
||||
Self::DebertaV2(_) => None,
|
||||
Self::DistilBert(_) => None,
|
||||
Self::Electra(_) => None,
|
||||
Self::Marian(config) => config.forced_bos_token_id,
|
||||
Self::MobileBert(_) => None,
|
||||
Self::T5(config) => config.forced_bos_token_id,
|
||||
Self::LongT5(config) => config.forced_bos_token_id,
|
||||
Self::Albert(_) => None,
|
||||
Self::XLNet(_) => None,
|
||||
Self::GPT2(config) => config.forced_bos_token_id,
|
||||
Self::GPTJ(config) => config.forced_bos_token_id,
|
||||
Self::Reformer(config) => config.forced_bos_token_id,
|
||||
Self::ProphetNet(config) => config.forced_bos_token_id,
|
||||
Self::Longformer(_) => None,
|
||||
Self::Pegasus(config) => config.forced_bos_token_id,
|
||||
Self::OpenAiGpt(config) => config.forced_bos_token_id,
|
||||
Self::GPTNeo(config) => config.forced_bos_token_id,
|
||||
Self::MBart(config) => config.forced_bos_token_id,
|
||||
Self::M2M100(config) => config.forced_bos_token_id,
|
||||
Self::FNet(_) => None,
|
||||
Self::Roberta(_) => None,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(config) => config.forced_bos_token_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_forced_eos_token_id(&self) -> Option<i64> {
|
||||
match self {
|
||||
Self::Bart(config) => config.forced_eos_token_id,
|
||||
Self::Bert(_) => None,
|
||||
Self::Deberta(_) => None,
|
||||
Self::DebertaV2(_) => None,
|
||||
Self::DistilBert(_) => None,
|
||||
Self::Electra(_) => None,
|
||||
Self::Marian(config) => config.forced_eos_token_id,
|
||||
Self::MobileBert(_) => None,
|
||||
Self::T5(config) => config.forced_eos_token_id,
|
||||
Self::LongT5(config) => config.forced_eos_token_id,
|
||||
Self::Albert(_) => None,
|
||||
Self::XLNet(_) => None,
|
||||
Self::GPT2(config) => config.forced_eos_token_id,
|
||||
Self::GPTJ(config) => config.forced_eos_token_id,
|
||||
Self::Reformer(config) => config.forced_eos_token_id,
|
||||
Self::ProphetNet(config) => config.forced_eos_token_id,
|
||||
Self::Longformer(_) => None,
|
||||
Self::Pegasus(config) => config.forced_eos_token_id,
|
||||
Self::OpenAiGpt(config) => config.forced_eos_token_id,
|
||||
Self::GPTNeo(config) => config.forced_eos_token_id,
|
||||
Self::MBart(config) => config.forced_eos_token_id,
|
||||
Self::M2M100(config) => config.forced_eos_token_id,
|
||||
Self::FNet(_) => None,
|
||||
Self::Roberta(_) => None,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(config) => config.forced_eos_token_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -670,6 +905,10 @@ impl TokenizerOption {
|
||||
lower_case,
|
||||
strip_accents.unwrap_or(false),
|
||||
)?),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
|
||||
"Default Tokenizer not defined for generic ONNX models.".to_string(),
|
||||
))?,
|
||||
};
|
||||
Ok(tokenizer)
|
||||
}
|
||||
@ -1311,6 +1550,167 @@ impl TokenizerOption {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to prepare the input for translation models
|
||||
pub fn get_prefix_and_forced_bos_id(
|
||||
&self,
|
||||
source_language: Option<&Language>,
|
||||
target_language: Option<&Language>,
|
||||
supported_source_languages: &HashSet<Language>,
|
||||
supported_target_languages: &HashSet<Language>,
|
||||
) -> Result<(Option<String>, Option<i64>), RustBertError> {
|
||||
if let Some(source_language) = source_language {
|
||||
if !supported_source_languages.contains(source_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{source_language} not in list of supported languages: {supported_source_languages:?}",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(target_language) = target_language {
|
||||
if !supported_target_languages.contains(target_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{target_language} not in list of supported languages: {supported_target_languages:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(match *self {
|
||||
Self::Marian(_) => {
|
||||
if supported_target_languages.len() > 1 {
|
||||
(
|
||||
Some(format!(
|
||||
">>{}<< ",
|
||||
target_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
|
||||
"Missing target language for Marian \
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)",
|
||||
)))?
|
||||
)),
|
||||
None,
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
}
|
||||
}
|
||||
Self::T5(_) => (
|
||||
Some(format!(
|
||||
"translate {} to {}:",
|
||||
source_language.ok_or_else(|| RustBertError::ValueError(
|
||||
"Missing source language for T5".to_string(),
|
||||
))?,
|
||||
target_language.ok_or_else(|| RustBertError::ValueError(
|
||||
"Missing target language for T5".to_string(),
|
||||
))?,
|
||||
)),
|
||||
None,
|
||||
),
|
||||
Self::MBart50(_) => {
|
||||
(
|
||||
Some(format!(
|
||||
">>{}<< ",
|
||||
source_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
|
||||
"Missing source language for MBart\
|
||||
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||
need to specify target language)"
|
||||
)))?
|
||||
)),
|
||||
if let Some(target_language) = target_language {
|
||||
Some(
|
||||
self.convert_tokens_to_ids(&[format!(
|
||||
">>{}<<",
|
||||
target_language.get_iso_639_1_code().ok_or_else(|| {
|
||||
RustBertError::ValueError(format!(
|
||||
"This language has no ISO639-I code. Languages supported by model: {supported_source_languages:?}."
|
||||
))
|
||||
})?
|
||||
)])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for MBart\
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
},
|
||||
)
|
||||
}
|
||||
Self::M2M100(_) => (
|
||||
Some(match source_language {
|
||||
Some(value) => {
|
||||
let language_code = value.get_iso_639_1_code().ok_or_else(|| {
|
||||
RustBertError::ValueError(format!(
|
||||
"This language has no ISO639-I language code representation. \
|
||||
languages supported by the model: {supported_target_languages:?}"
|
||||
))
|
||||
})?;
|
||||
match language_code.len() {
|
||||
2 => format!(">>{language_code}.<< "),
|
||||
3 => format!(">>{language_code}<< "),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-I code".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing source language for M2M100 \
|
||||
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
}
|
||||
}),
|
||||
if let Some(target_language) = target_language {
|
||||
let language_code = target_language.get_iso_639_1_code().ok_or_else(|| {
|
||||
RustBertError::ValueError(format!(
|
||||
"This language has no ISO639-I language code representation. \
|
||||
languages supported by the model: {supported_target_languages:?}"
|
||||
))
|
||||
})?;
|
||||
Some(
|
||||
self.convert_tokens_to_ids(&[
|
||||
match language_code.len() {
|
||||
2 => format!(">>{language_code}.<<"),
|
||||
3 => format!(">>{language_code}<<"),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
));
|
||||
}
|
||||
},
|
||||
])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for M2M100 \
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)",
|
||||
)));
|
||||
},
|
||||
),
|
||||
Self::NLLB(_) => {
|
||||
let source_language = source_language
|
||||
.and_then(Language::get_nllb_code)
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| RustBertError::ValueError(
|
||||
format!("Missing source language for NLLB. Need to specify one from: {supported_source_languages:?}")
|
||||
))?;
|
||||
|
||||
let target_language = target_language
|
||||
.and_then(Language::get_nllb_code)
|
||||
.map(str::to_string)
|
||||
.map(|code| self.convert_tokens_to_ids(&[code])[0])
|
||||
.ok_or_else(|| RustBertError::ValueError(
|
||||
format!("Missing target language for NLLB. Need to specify one from: {supported_target_languages:?}")
|
||||
))?;
|
||||
|
||||
(Some(source_language), Some(target_language))
|
||||
}
|
||||
_ => (None, None),
|
||||
})
|
||||
}
|
||||
|
||||
/// Interface method to convert tokens to ids
|
||||
pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
|
||||
where
|
||||
@ -1793,6 +2193,55 @@ impl TokenizerOption {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tokenize_and_pad<'a, S>(
|
||||
&self,
|
||||
input: S,
|
||||
max_length: usize,
|
||||
device: Device,
|
||||
) -> (Tensor, Tensor)
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let mut tokenized_input: Vec<TokenizedInput> = self.encode_list(
|
||||
input.as_ref(),
|
||||
max_length,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let pad_id = self
|
||||
.get_pad_id()
|
||||
.expect("The Tokenizer used for sequence classification should contain a PAD id");
|
||||
let tokenized_input_tensors: Vec<Tensor> = tokenized_input
|
||||
.iter_mut()
|
||||
.map(|input| {
|
||||
input.token_ids.resize(max_len, pad_id);
|
||||
Tensor::from_slice(&(input.token_ids))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let token_type_ids: Vec<Tensor> = tokenized_input
|
||||
.iter_mut()
|
||||
.map(|input| {
|
||||
input
|
||||
.segment_ids
|
||||
.resize(max_len, *input.segment_ids.last().unwrap_or(&0));
|
||||
Tensor::from_slice(&(input.segment_ids))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
(
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(device),
|
||||
Tensor::stack(token_type_ids.as_slice(), 0)
|
||||
.to(device)
|
||||
.to_kind(Kind::Int64),
|
||||
)
|
||||
}
|
||||
|
||||
/// Interface method
|
||||
pub fn add_extra_ids(&mut self, num_extra_ids: i64) {
|
||||
match *self {
|
||||
|
@ -55,7 +55,7 @@
|
||||
//! from the 3rd party utilization of the pretrained system.
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::gpt2::GPT2Generator;
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use crate::resources::ResourceProvider;
|
||||
@ -76,7 +76,7 @@ pub struct ConversationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: DialoGPT-medium)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: DialoGPT-medium)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: DialoGPT-medium)
|
||||
@ -122,9 +122,9 @@ impl Default for ConversationConfig {
|
||||
fn default() -> ConversationConfig {
|
||||
ConversationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DIALOGPT_MEDIUM,
|
||||
)),
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DIALOGPT_MEDIUM,
|
||||
)),
|
||||
@ -157,6 +157,7 @@ impl Default for ConversationConfig {
|
||||
impl From<ConversationConfig> for GenerateConfig {
|
||||
fn from(config: ConversationConfig) -> GenerateConfig {
|
||||
GenerateConfig {
|
||||
model_type: config.model_type,
|
||||
model_resource: config.model_resource,
|
||||
config_resource: config.config_resource,
|
||||
merges_resource: config.merges_resource,
|
||||
|
@ -82,20 +82,24 @@ use crate::t5::LayerState as T5LayerState;
|
||||
use crate::xlnet::LayerState as XLNetLayerState;
|
||||
|
||||
use self::ordered_float::OrderedFloat;
|
||||
use crate::pipelines::common::TokenizerOption;
|
||||
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
|
||||
|
||||
extern crate ordered_float;
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::ONNXLayerCache;
|
||||
use crate::RustBertError;
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
|
||||
resources::RemoteResource,
|
||||
};
|
||||
|
||||
extern crate ordered_float;
|
||||
|
||||
/// # Configuration for text generation
|
||||
pub struct GenerateConfig {
|
||||
/// Model type used for generation
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained GPT2 model)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained GPT2 model)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained GPT2 model)
|
||||
@ -138,7 +142,10 @@ pub struct GenerateConfig {
|
||||
impl Default for GenerateConfig {
|
||||
fn default() -> GenerateConfig {
|
||||
GenerateConfig {
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::GPT2,
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
|
||||
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
|
||||
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
|
||||
@ -223,17 +230,19 @@ pub enum Cache {
|
||||
ProphetNetCache(Option<Vec<(Option<ProphetNetLayerState>, Option<ProphetNetLayerState>)>>),
|
||||
GPTNeoCache(Option<Vec<Option<GPTNeoLayerState>>>),
|
||||
GPTJCache(Option<Vec<Option<GPTJLayerState>>>),
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNXCache(ONNXLayerCache),
|
||||
None,
|
||||
}
|
||||
|
||||
pub(crate) mod private_generation_utils {
|
||||
use rust_tokenizers::TokenIdsWithOffsets;
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::mem;
|
||||
|
||||
use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy};
|
||||
use rust_tokenizers::TokenIdsWithOffsets;
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
use crate::pipelines::common::TokenizerOption;
|
||||
@ -242,7 +251,7 @@ pub(crate) mod private_generation_utils {
|
||||
};
|
||||
|
||||
use super::ordered_float::OrderedFloat;
|
||||
use crate::common::kind::get_positive_infinity;
|
||||
use crate::common::kind::{get_negative_infinity, get_positive_infinity};
|
||||
use crate::RustBertError;
|
||||
|
||||
pub struct InternalGenerateOptions<'a> {
|
||||
@ -283,17 +292,23 @@ pub(crate) mod private_generation_utils {
|
||||
|
||||
pub trait PrivateLanguageGenerator {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption;
|
||||
fn get_device(&self) -> Device;
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError>;
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption;
|
||||
fn get_var_store(&self) -> &nn::VarStore;
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
|
||||
fn get_config(&self) -> &GenerateConfig;
|
||||
fn get_bos_id(&self) -> Option<i64>;
|
||||
fn get_eos_ids(&self) -> Option<&Vec<i64>>;
|
||||
fn get_forced_bos_token_id(&self) -> Option<i64> {
|
||||
None
|
||||
}
|
||||
fn get_forced_eos_token_id(&self) -> Option<i64> {
|
||||
None
|
||||
}
|
||||
fn get_pad_id(&self) -> Option<i64>;
|
||||
fn is_encoder_decoder(&self) -> bool;
|
||||
fn get_vocab_size(&self) -> i64;
|
||||
fn get_decoder_start_id(&self) -> Option<i64>;
|
||||
fn get_max_positions_embeddings(&self) -> i64;
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64>;
|
||||
|
||||
fn forward_t(
|
||||
&self,
|
||||
@ -310,11 +325,32 @@ pub(crate) mod private_generation_utils {
|
||||
|
||||
fn prepare_scores_for_generation(
|
||||
&self,
|
||||
_scores: &mut Tensor,
|
||||
_current_length: i64,
|
||||
_max_length: Option<i64>,
|
||||
_forced_bos_token_id: Option<i64>,
|
||||
scores: &mut Tensor,
|
||||
current_length: i64,
|
||||
max_length: Option<i64>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) {
|
||||
if current_length == 1 {
|
||||
if let Some(forced_bos_token_id) =
|
||||
forced_bos_token_id.or(self.get_forced_bos_token_id())
|
||||
{
|
||||
force_token_id_generation(
|
||||
scores,
|
||||
&[forced_bos_token_id],
|
||||
self.get_vocab_size(),
|
||||
);
|
||||
}
|
||||
} else if let Some(max_length) = max_length {
|
||||
if let Some(forced_eos_token_id) = self.get_forced_eos_token_id() {
|
||||
if current_length == max_length - 1 {
|
||||
force_token_id_generation(
|
||||
scores,
|
||||
&[forced_eos_token_id],
|
||||
self.get_vocab_size(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> {
|
||||
@ -347,48 +383,66 @@ pub(crate) mod private_generation_utils {
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let num_truncated_tokens = token_ids
|
||||
.iter()
|
||||
.map(|token_ids| {
|
||||
let token_ids = if self.is_encoder_decoder() {
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| {
|
||||
if token_ids.len() > max_len as usize {
|
||||
token_ids.len() - max_len as usize
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.collect::<Vec<usize>>();
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>()
|
||||
} else {
|
||||
// Special tokens (e.g. BOS) are not added at the end of the prompt for causal generation
|
||||
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|prompt_tokens| {
|
||||
self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.zip(num_truncated_tokens)
|
||||
.map(|(tokens, num_truncated_tokens)| {
|
||||
truncate_sequences(
|
||||
TokenIdsWithOffsets {
|
||||
ids: tokens,
|
||||
offsets: vec![],
|
||||
reference_offsets: vec![],
|
||||
masks: vec![],
|
||||
},
|
||||
None,
|
||||
num_truncated_tokens,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
.0
|
||||
.ids
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
let num_truncated_tokens = token_ids
|
||||
.iter()
|
||||
.map(|token_ids| {
|
||||
max_len
|
||||
.map(|max_len| {
|
||||
if token_ids.len() > max_len as usize {
|
||||
token_ids.len() - max_len as usize
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.collect::<Vec<usize>>();
|
||||
|
||||
token_ids
|
||||
.into_iter()
|
||||
.zip(num_truncated_tokens)
|
||||
.map(|(tokens, num_truncated_tokens)| {
|
||||
truncate_sequences(
|
||||
TokenIdsWithOffsets {
|
||||
ids: tokens,
|
||||
offsets: vec![],
|
||||
reference_offsets: vec![],
|
||||
masks: vec![],
|
||||
},
|
||||
None,
|
||||
num_truncated_tokens,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
.0
|
||||
.ids
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>()
|
||||
};
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
@ -399,13 +453,20 @@ pub(crate) mod private_generation_utils {
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
.map(|mut input| {
|
||||
let mut temp = vec![pad_token; max_len - input.len()];
|
||||
temp.extend(input);
|
||||
temp
|
||||
if self.is_encoder_decoder() {
|
||||
input.extend(temp);
|
||||
input
|
||||
} else {
|
||||
// Pad left for causal generation
|
||||
temp.extend(input);
|
||||
temp
|
||||
}
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
@ -767,9 +828,9 @@ pub(crate) mod private_generation_utils {
|
||||
output_scores: bool,
|
||||
) -> GeneratedOutputWithScores {
|
||||
let mut unfinished_sentences =
|
||||
Tensor::ones([batch_size], (Kind::Int64, self.get_var_store().device()));
|
||||
Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
|
||||
let mut sentence_lengths: Tensor =
|
||||
Tensor::ones([batch_size], (Kind::Int64, self.get_var_store().device()));
|
||||
Tensor::ones([batch_size], (Kind::Int64, self.get_device()));
|
||||
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
|
||||
self.split_bad_word_ids(gen_opt.bad_word_ids);
|
||||
let mut static_bad_words_mask: Option<Tensor> = None;
|
||||
@ -1024,7 +1085,7 @@ pub(crate) mod private_generation_utils {
|
||||
let vocab_size = self.get_vocab_size();
|
||||
let beam_scores = Tensor::ones(
|
||||
[batch_size, gen_opt.num_beams],
|
||||
(Kind::Float, self.get_var_store().device()),
|
||||
(Kind::Float, self.get_device()),
|
||||
) * -1e9;
|
||||
let _ = beam_scores
|
||||
.slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams)
|
||||
@ -1033,11 +1094,11 @@ pub(crate) mod private_generation_utils {
|
||||
let mut beam_scores = beam_scores.view_([-1]);
|
||||
let mut beam_tokens = Tensor::zeros(
|
||||
[batch_size * gen_opt.num_beams],
|
||||
(Kind::Int64, self.get_var_store().device()),
|
||||
(Kind::Int64, self.get_device()),
|
||||
);
|
||||
let mut beam_indices = Tensor::zeros(
|
||||
[batch_size * gen_opt.num_beams],
|
||||
(Kind::Int64, self.get_var_store().device()),
|
||||
(Kind::Int64, self.get_device()),
|
||||
);
|
||||
let mut saved_beam_scores: Option<Vec<Tensor>> =
|
||||
if output_scores { Some(vec![]) } else { None };
|
||||
@ -1524,6 +1585,18 @@ pub(crate) mod private_generation_utils {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn force_token_id_generation(scores: &mut Tensor, token_ids: &[i64], vocab_size: i64) {
|
||||
let impossible_tokens: Vec<i64> = (0..vocab_size)
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::from_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(
|
||||
1,
|
||||
&impossible_tokens,
|
||||
get_negative_infinity(scores.kind()).unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -1805,7 +1878,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
generate_options.max_length
|
||||
});
|
||||
let encoding_max_len = if self.is_encoder_decoder() {
|
||||
Some(self.get_max_positions_embeddings())
|
||||
self.get_max_positions_embeddings()
|
||||
} else {
|
||||
max_length
|
||||
};
|
||||
@ -1819,9 +1892,7 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
self.encode_prompt_text(prompts, encoding_max_len, pad_token_id)
|
||||
}
|
||||
None => match self.get_bos_id() {
|
||||
Some(bos_id) => {
|
||||
Tensor::ones([1, 1], (Int64, self.get_var_store().device())) * bos_id
|
||||
}
|
||||
Some(bos_id) => Tensor::ones([1, 1], (Int64, self.get_device())) * bos_id,
|
||||
None => panic!(
|
||||
"A model with a BOS token must be used to start generation with an empty input"
|
||||
),
|
||||
@ -2155,16 +2226,19 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
|
||||
self._get_tokenizer_mut()
|
||||
}
|
||||
|
||||
fn half(&mut self) {
|
||||
self.get_var_store_mut().half();
|
||||
fn half(&mut self) -> Result<(), RustBertError> {
|
||||
self.get_var_store_mut()?.half();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn float(&mut self) {
|
||||
self.get_var_store_mut().float();
|
||||
fn float(&mut self) -> Result<(), RustBertError> {
|
||||
self.get_var_store_mut()?.float();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_device(&mut self, device: Device) {
|
||||
self.get_var_store_mut().set_device(device);
|
||||
fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
|
||||
self.get_var_store_mut()?.set_device(device);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -22,9 +22,10 @@
|
||||
//!use rust_bert::resources::RemoteResource;
|
||||
//! fn main() -> anyhow::Result<()> {
|
||||
//!
|
||||
//! let config = MaskedLanguageConfig::new(
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let config = MaskedLanguageConfig::new(
|
||||
//! ModelType::Bert,
|
||||
//! RemoteResource::from_pretrained(BertModelResources::BERT),
|
||||
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT))),
|
||||
//! RemoteResource::from_pretrained(BertConfigResources::BERT),
|
||||
//! RemoteResource::from_pretrained(BertVocabResources::BERT),
|
||||
//! None,
|
||||
@ -50,20 +51,23 @@ use crate::common::error::RustBertError;
|
||||
use crate::deberta::DebertaForMaskedLM;
|
||||
use crate::deberta_v2::DebertaV2ForMaskedLM;
|
||||
use crate::fnet::FNetForMaskedLM;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForMaskedLM;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
bert::{BertConfigResources, BertModelResources, BertVocabResources},
|
||||
resources::RemoteResource,
|
||||
};
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use rust_tokenizers::TokenizedInput;
|
||||
use std::borrow::Borrow;
|
||||
use std::convert::TryFrom;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
use tch::{no_grad, Device, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Output container for masked language model pipeline.
|
||||
@ -82,7 +86,7 @@ pub struct MaskedLanguageConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||
@ -113,9 +117,9 @@ impl MaskedLanguageConfig {
|
||||
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
/// * mask_token - A token used for model to predict masking words..
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -125,13 +129,12 @@ impl MaskedLanguageConfig {
|
||||
mask_token: impl Into<Option<String>>,
|
||||
) -> MaskedLanguageConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
MaskedLanguageConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -149,7 +152,9 @@ impl Default for MaskedLanguageConfig {
|
||||
fn default() -> MaskedLanguageConfig {
|
||||
MaskedLanguageConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT),
|
||||
None,
|
||||
@ -176,28 +181,39 @@ pub enum MaskedLanguageOption {
|
||||
XLMRoberta(RobertaForMaskedLM),
|
||||
/// FNet for Masked Language
|
||||
FNet(FNetForMaskedLM),
|
||||
/// ONNX model for Masked Language
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXEncoder),
|
||||
}
|
||||
impl MaskedLanguageOption {
|
||||
/// Instantiate a new masked language model of the supplied type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
pub fn new<'p, P>(
|
||||
model_type: ModelType,
|
||||
p: P,
|
||||
config: &ConfigOption,
|
||||
) -> Result<Self, RustBertError>
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
match model_type {
|
||||
/// * `MaskedLanguageConfig` - Masked language model pipeline configuration. The type of model created will be inferred from the
|
||||
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
|
||||
pub fn new(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_resource {
|
||||
ModelResource::Torch(_) => Self::new_torch(config),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Self::new_onnx(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_torch(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
|
||||
let device = config.device;
|
||||
let weights_path = config.model_resource.get_torch_local_path()?;
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config =
|
||||
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
|
||||
let model_type = config.model_type;
|
||||
let model = match model_type {
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
Ok(MaskedLanguageOption::Bert(BertForMaskedLM::new(p, config)))
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(MaskedLanguageOption::Bert(BertForMaskedLM::new(
|
||||
var_store.root(),
|
||||
config,
|
||||
)))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a BertConfig for Bert!".to_string(),
|
||||
@ -205,9 +221,10 @@ impl MaskedLanguageOption {
|
||||
}
|
||||
}
|
||||
ModelType::Deberta => {
|
||||
if let ConfigOption::Deberta(config) = config {
|
||||
if let ConfigOption::Deberta(config) = model_config {
|
||||
Ok(MaskedLanguageOption::Deberta(DebertaForMaskedLM::new(
|
||||
p, config,
|
||||
var_store.root(),
|
||||
config,
|
||||
)))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -216,9 +233,10 @@ impl MaskedLanguageOption {
|
||||
}
|
||||
}
|
||||
ModelType::DebertaV2 => {
|
||||
if let ConfigOption::DebertaV2(config) = config {
|
||||
if let ConfigOption::DebertaV2(config) = model_config {
|
||||
Ok(MaskedLanguageOption::DebertaV2(DebertaV2ForMaskedLM::new(
|
||||
p, config,
|
||||
var_store.root(),
|
||||
config,
|
||||
)))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -227,9 +245,10 @@ impl MaskedLanguageOption {
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Roberta(config) = config {
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(MaskedLanguageOption::Roberta(RobertaForMaskedLM::new(
|
||||
p, config,
|
||||
var_store.root(),
|
||||
config,
|
||||
)))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -238,9 +257,10 @@ impl MaskedLanguageOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLMRoberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(MaskedLanguageOption::XLMRoberta(RobertaForMaskedLM::new(
|
||||
p, config,
|
||||
var_store.root(),
|
||||
config,
|
||||
)))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -249,8 +269,11 @@ impl MaskedLanguageOption {
|
||||
}
|
||||
}
|
||||
ModelType::FNet => {
|
||||
if let ConfigOption::FNet(config) = config {
|
||||
Ok(MaskedLanguageOption::FNet(FNetForMaskedLM::new(p, config)))
|
||||
if let ConfigOption::FNet(config) = model_config {
|
||||
Ok(MaskedLanguageOption::FNet(FNetForMaskedLM::new(
|
||||
var_store.root(),
|
||||
config,
|
||||
)))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
"You can only supply a FNetConfig for FNet!".to_string(),
|
||||
@ -260,9 +283,29 @@ impl MaskedLanguageOption {
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Masked Language is not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
.encoder_path
|
||||
.ok_or(RustBertError::InvalidConfigurationError(
|
||||
"An encoder file must be provided for masked language ONNX models.".to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
}
|
||||
/// Returns the `ModelType` for this MaskedLanguageOption
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
match *self {
|
||||
@ -272,6 +315,8 @@ impl MaskedLanguageOption {
|
||||
Self::Roberta(_) => ModelType::Roberta,
|
||||
Self::XLMRoberta(_) => ModelType::Roberta,
|
||||
Self::FNet(_) => ModelType::FNet,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
@ -350,6 +395,21 @@ impl MaskedLanguageOption {
|
||||
.expect("Error in FNet forward pass.")
|
||||
.prediction_scores
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => {
|
||||
let attention_mask = input_ids.unwrap().ones_like();
|
||||
model
|
||||
.forward(
|
||||
input_ids,
|
||||
Some(&attention_mask),
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
)
|
||||
.expect("Error in ONNX forward pass.")
|
||||
.logits
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -359,7 +419,7 @@ pub struct MaskedLanguageModel {
|
||||
tokenizer: TokenizerOption,
|
||||
language_encode: MaskedLanguageOption,
|
||||
mask_token: Option<String>,
|
||||
var_store: VarStore,
|
||||
device: Device,
|
||||
max_length: usize,
|
||||
}
|
||||
|
||||
@ -428,25 +488,21 @@ impl MaskedLanguageModel {
|
||||
config: MaskedLanguageConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<MaskedLanguageModel, RustBertError> {
|
||||
let language_encode = MaskedLanguageOption::new(&config)?;
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let max_length = model_config
|
||||
.get_max_len()
|
||||
.map(|v| v as usize)
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
let language_encode =
|
||||
MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
let mask_token = config.mask_token;
|
||||
let device = get_device(config.model_resource, config.device);
|
||||
Ok(MaskedLanguageModel {
|
||||
tokenizer,
|
||||
language_encode,
|
||||
mask_token,
|
||||
var_store,
|
||||
device,
|
||||
max_length,
|
||||
})
|
||||
}
|
||||
@ -481,33 +537,6 @@ impl MaskedLanguageModel {
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(
|
||||
input.as_ref(),
|
||||
self.max_length,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input_tensors = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::from_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
|
||||
}
|
||||
|
||||
/// Mask texts
|
||||
///
|
||||
/// # Arguments
|
||||
@ -544,30 +573,22 @@ impl MaskedLanguageModel {
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let input_tensor = if let Some(mask_token) = &self.mask_token {
|
||||
let (input_ids, token_type_ids) = if let Some(mask_token) = &self.mask_token {
|
||||
let input_with_replaced_mask = self.replace_mask_token(input.as_ref(), mask_token)?;
|
||||
self.prepare_for_model(
|
||||
self.tokenizer.tokenize_and_pad(
|
||||
input_with_replaced_mask
|
||||
.iter()
|
||||
.map(|w| w.as_str())
|
||||
.collect::<Vec<&str>>(),
|
||||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
self.max_length,
|
||||
self.device,
|
||||
)
|
||||
} else {
|
||||
self.prepare_for_model(input.as_ref())
|
||||
self.tokenizer
|
||||
.tokenize_and_pad(input.as_ref(), self.max_length, self.device)
|
||||
};
|
||||
|
||||
let output = no_grad(|| {
|
||||
self.language_encode.forward_t(
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
});
|
||||
// get the position of mask_token in input texts
|
||||
let mask_token_id =
|
||||
self.tokenizer
|
||||
@ -575,7 +596,21 @@ impl MaskedLanguageModel {
|
||||
.ok_or_else(|| RustBertError::InvalidConfigurationError(
|
||||
"Tokenizer does not have a mask token id, Please use a tokenizer/model with a mask token.".into(),
|
||||
))?;
|
||||
let mask_token_mask = input_tensor.eq(mask_token_id);
|
||||
let mask_token_mask = input_ids.eq(mask_token_id);
|
||||
|
||||
let output = no_grad(|| {
|
||||
self.language_encode.forward_t(
|
||||
Some(&input_ids),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
let mut output_tokens = Vec::with_capacity(input.as_ref().len());
|
||||
for input_id in 0..input.as_ref().len() as i64 {
|
||||
let mut sequence_tokens = vec![];
|
||||
|
@ -490,3 +490,6 @@ pub mod text_generation;
|
||||
pub mod token_classification;
|
||||
pub mod translation;
|
||||
pub mod zero_shot_classification;
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod onnx;
|
||||
|
@ -90,11 +90,12 @@
|
||||
//! use tch::Device;
|
||||
//!
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let ner_config = TokenClassificationConfig {
|
||||
//! model_type: ModelType::XLMRoberta,
|
||||
//! model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
//! model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
|
||||
//! )),
|
||||
//! ))),
|
||||
//! config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
||||
//! )),
|
||||
|
44
src/pipelines/onnx/common.rs
Normal file
44
src/pipelines/onnx/common.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use ort::Session;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InputOutputNameMapping {
|
||||
pub(crate) input_names: Vec<String>,
|
||||
pub(crate) output_names: HashMap<String, usize>,
|
||||
pub(crate) key_value_output_names: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
pub(crate) fn get_input_output_mapping(session: &Session) -> InputOutputNameMapping {
|
||||
let input_names = session
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| input.name.clone())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let output_names = session
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(pos, output)| (output.name.clone(), pos))
|
||||
.collect::<HashMap<String, usize>>();
|
||||
|
||||
let mut key_value_output_names = output_names
|
||||
.iter()
|
||||
.filter(|(name, _)| name.contains(".key") | name.contains(".value"))
|
||||
.map(|(name, pos)| (name.clone(), *pos))
|
||||
.collect::<HashMap<String, usize>>();
|
||||
|
||||
if key_value_output_names.is_empty() {
|
||||
key_value_output_names = output_names
|
||||
.iter()
|
||||
.filter(|(name, _)| name.contains("key_value"))
|
||||
.map(|(name, pos)| (name.clone(), *pos))
|
||||
.collect::<HashMap<String, usize>>();
|
||||
}
|
||||
|
||||
InputOutputNameMapping {
|
||||
input_names,
|
||||
output_names,
|
||||
key_value_output_names,
|
||||
}
|
||||
}
|
105
src/pipelines/onnx/config.rs
Normal file
105
src/pipelines/onnx/config.rs
Normal file
@ -0,0 +1,105 @@
|
||||
/// # Configuration for ONNX environment and sessions
|
||||
use crate::RustBertError;
|
||||
use ort::{
|
||||
AllocatorType, Environment, ExecutionProvider, GraphOptimizationLevel, MemType, SessionBuilder,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tch::Device;
|
||||
|
||||
pub(crate) static INPUT_IDS_NAME: &str = "input_ids";
|
||||
pub(crate) static ATTENTION_MASK_NAME: &str = "attention_mask";
|
||||
pub(crate) static ENCODER_HIDDEN_STATES_NAME: &str = "encoder_hidden_states";
|
||||
pub(crate) static ENCODER_ATTENTION_MASK_NAME: &str = "encoder_attention_mask";
|
||||
pub(crate) static TOKEN_TYPE_IDS: &str = "token_type_ids";
|
||||
pub(crate) static POSITION_IDS: &str = "position_ids";
|
||||
pub(crate) static INPUT_EMBEDS: &str = "input_embeds";
|
||||
pub(crate) static LAST_HIDDEN_STATE: &str = "last_hidden_state";
|
||||
pub(crate) static LOGITS: &str = "logits";
|
||||
pub(crate) static START_LOGITS: &str = "start_logits";
|
||||
pub(crate) static END_LOGITS: &str = "end_logits";
|
||||
|
||||
#[derive(Default)]
|
||||
/// # ONNX Environment configuration
|
||||
/// See <https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions>
|
||||
pub struct ONNXEnvironmentConfig {
|
||||
pub optimization_level: Option<GraphOptimizationLevel>,
|
||||
pub execution_providers: Option<Vec<ExecutionProvider>>,
|
||||
pub num_intra_threads: Option<i16>,
|
||||
pub num_inter_threads: Option<i16>,
|
||||
pub parallel_execution: Option<bool>,
|
||||
pub enable_memory_pattern: Option<bool>,
|
||||
pub allocator: Option<AllocatorType>,
|
||||
pub memory_type: Option<MemType>,
|
||||
}
|
||||
|
||||
impl ONNXEnvironmentConfig {
|
||||
/// Create a new `ONNXEnvironmentConfig` from a `tch::Device`.
|
||||
/// This helper function maps torch device to ONNXRuntime execution providers
|
||||
pub fn from_device(device: Device) -> Self {
|
||||
let mut execution_providers = Vec::new();
|
||||
if let Device::Cuda(_) = device {
|
||||
execution_providers.push(ExecutionProvider::cuda());
|
||||
};
|
||||
execution_providers.push(ExecutionProvider::cpu());
|
||||
ONNXEnvironmentConfig {
|
||||
execution_providers: Some(execution_providers),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
///Build a session builder from an `ONNXEnvironmentConfig`.
|
||||
pub fn get_session_builder(
|
||||
&self,
|
||||
environment: &Arc<Environment>,
|
||||
) -> Result<SessionBuilder, RustBertError> {
|
||||
let mut session_builder = SessionBuilder::new(environment)?;
|
||||
match &self.optimization_level {
|
||||
Some(GraphOptimizationLevel::Level3) | None => {}
|
||||
Some(GraphOptimizationLevel::Level2) => {
|
||||
session_builder =
|
||||
session_builder.with_optimization_level(GraphOptimizationLevel::Level2)?
|
||||
}
|
||||
Some(GraphOptimizationLevel::Level1) => {
|
||||
session_builder =
|
||||
session_builder.with_optimization_level(GraphOptimizationLevel::Level1)?
|
||||
}
|
||||
Some(GraphOptimizationLevel::Disable) => {
|
||||
session_builder =
|
||||
session_builder.with_optimization_level(GraphOptimizationLevel::Disable)?
|
||||
}
|
||||
}
|
||||
if let Some(num_intra_threads) = self.num_intra_threads {
|
||||
session_builder = session_builder.with_intra_threads(num_intra_threads)?;
|
||||
}
|
||||
if let Some(num_inter_threads) = self.num_inter_threads {
|
||||
session_builder = session_builder.with_inter_threads(num_inter_threads)?;
|
||||
}
|
||||
if let Some(parallel_execution) = self.parallel_execution {
|
||||
session_builder = session_builder.with_parallel_execution(parallel_execution)?;
|
||||
}
|
||||
if let Some(enable_memory_pattern) = self.enable_memory_pattern {
|
||||
session_builder = session_builder.with_memory_pattern(enable_memory_pattern)?;
|
||||
}
|
||||
if let Some(allocator) = &self.allocator {
|
||||
session_builder = session_builder.with_allocator(allocator.clone())?;
|
||||
}
|
||||
if let Some(memory_type) = &self.memory_type {
|
||||
session_builder = session_builder.with_memory_type(memory_type.clone())?;
|
||||
}
|
||||
Ok(session_builder)
|
||||
}
|
||||
|
||||
///Build an ONNXEnvironment from an `ONNXEnvironmentConfig`.
|
||||
pub fn get_environment(&self) -> Result<Arc<Environment>, RustBertError> {
|
||||
Ok(Arc::new(
|
||||
Environment::builder()
|
||||
.with_name("Default environment")
|
||||
.with_execution_providers(
|
||||
self.execution_providers
|
||||
.clone()
|
||||
.unwrap_or(vec![ExecutionProvider::cpu()]),
|
||||
)
|
||||
.build()?,
|
||||
))
|
||||
}
|
||||
}
|
57
src/pipelines/onnx/conversion.rs
Normal file
57
src/pipelines/onnx/conversion.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use crate::RustBertError;
|
||||
use ndarray::IxDyn;
|
||||
use ort::tensor::{DynOrtTensor, FromArray, InputTensor};
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use tch::{Kind, Tensor};
|
||||
|
||||
pub(crate) fn ort_tensor_to_tch(ort_tensor: &DynOrtTensor<IxDyn>) -> Result<Tensor, RustBertError> {
|
||||
let ort_tensor = ort_tensor.try_extract::<f32>()?.view().to_owned();
|
||||
Ok(Tensor::try_from(ort_tensor)?)
|
||||
}
|
||||
|
||||
pub(crate) fn tch_tensor_to_ort(tch_tensor: &Tensor) -> Result<InputTensor, RustBertError> {
|
||||
let kind = tch_tensor.kind();
|
||||
Ok(match kind{
|
||||
Kind::Int64 => {
|
||||
let array: ndarray::ArrayD<i64> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Float => {
|
||||
let array: ndarray::ArrayD<f32> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Int => {
|
||||
let array: ndarray::ArrayD<i32> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Double => {
|
||||
let array: ndarray::ArrayD<f64> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Half => {
|
||||
let array: ndarray::ArrayD<half::f16> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Int16 => {
|
||||
let array: ndarray::ArrayD<i16> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Int8 => {
|
||||
let array: ndarray::ArrayD<i8> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::Uint8 => {
|
||||
let array: ndarray::ArrayD<u8> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
Kind::BFloat16 => {
|
||||
let array: ndarray::ArrayD<half::bf16> = tch_tensor.try_into()?;
|
||||
InputTensor::from_array(array)
|
||||
}
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Type not supported: attempted to get convert torch tensor to ndarray infinity for {kind:?}",
|
||||
)))
|
||||
}
|
||||
})
|
||||
}
|
113
src/pipelines/onnx/decoder.rs
Normal file
113
src/pipelines/onnx/decoder.rs
Normal file
@ -0,0 +1,113 @@
|
||||
use crate::pipelines::generation_utils::{Cache, LMModelOutput};
|
||||
use crate::pipelines::onnx::common::{get_input_output_mapping, InputOutputNameMapping};
|
||||
use crate::pipelines::onnx::config::{
|
||||
ONNXEnvironmentConfig, ATTENTION_MASK_NAME, ENCODER_ATTENTION_MASK_NAME,
|
||||
ENCODER_HIDDEN_STATES_NAME, INPUT_IDS_NAME, POSITION_IDS,
|
||||
};
|
||||
use crate::pipelines::onnx::conversion::{ort_tensor_to_tch, tch_tensor_to_ort};
|
||||
use crate::pipelines::onnx::models::ONNXLayerCache;
|
||||
use crate::RustBertError;
|
||||
use ort::{Environment, Session};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tch::Tensor;
|
||||
|
||||
pub struct ONNXDecoder {
|
||||
session: Session,
|
||||
name_mapping: InputOutputNameMapping,
|
||||
use_cache: bool,
|
||||
}
|
||||
|
||||
impl ONNXDecoder {
|
||||
pub fn new(
|
||||
model_file: PathBuf,
|
||||
use_cache: bool,
|
||||
environment: &Arc<Environment>,
|
||||
onnx_config: &ONNXEnvironmentConfig,
|
||||
) -> Result<Self, RustBertError> {
|
||||
let session = onnx_config
|
||||
.get_session_builder(environment)?
|
||||
.with_model_from_file(model_file)?;
|
||||
let name_mapping = get_input_output_mapping(&session);
|
||||
Ok(Self {
|
||||
session,
|
||||
name_mapping,
|
||||
use_cache,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
layer_states: Option<&ONNXLayerCache>,
|
||||
) -> Result<LMModelOutput, RustBertError> {
|
||||
let mut input_dict = HashMap::new();
|
||||
if let Some(input_ids) = input_ids {
|
||||
input_dict.insert(INPUT_IDS_NAME, input_ids);
|
||||
}
|
||||
if let Some(attention_mask) = attention_mask {
|
||||
input_dict.insert(ATTENTION_MASK_NAME, attention_mask);
|
||||
}
|
||||
if let Some(encoder_hidden_states) = encoder_hidden_states {
|
||||
input_dict.insert(ENCODER_HIDDEN_STATES_NAME, encoder_hidden_states);
|
||||
}
|
||||
if let Some(encoder_attention_mask) = encoder_attention_mask {
|
||||
input_dict.insert(ENCODER_ATTENTION_MASK_NAME, encoder_attention_mask);
|
||||
}
|
||||
if let Some(position_ids) = position_ids {
|
||||
input_dict.insert(POSITION_IDS, position_ids);
|
||||
}
|
||||
|
||||
let inputs = self
|
||||
.name_mapping
|
||||
.input_names
|
||||
.iter()
|
||||
.map(|input_name| {
|
||||
if let Some(tensor) = input_dict.remove(input_name.as_str()) {
|
||||
Ok(tch_tensor_to_ort(tensor)?)
|
||||
} else {
|
||||
let layer_states = layer_states.ok_or_else(|| {
|
||||
RustBertError::OrtError(format!(
|
||||
"{input_name} not found and cache was not provided."
|
||||
))
|
||||
})?;
|
||||
let input_pos = layer_states
|
||||
.values
|
||||
.get(&input_name.replace("past", "present"))
|
||||
.or_else(|| {
|
||||
layer_states
|
||||
.values
|
||||
.get(&input_name.replace("past_key_values", "present"))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
let found_keys = layer_states.values.keys().collect::<Vec<&String>>();
|
||||
RustBertError::OrtError(format!(
|
||||
"{input_name} not found in cache ({found_keys:?})."
|
||||
))
|
||||
})?;
|
||||
tch_tensor_to_ort(input_pos)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, RustBertError>>()?;
|
||||
|
||||
let outputs = self.session.run(inputs)?;
|
||||
|
||||
let lm_logits =
|
||||
ort_tensor_to_tch(&outputs[*self.name_mapping.output_names.get("logits").unwrap()])?;
|
||||
let cache = if self.use_cache {
|
||||
Cache::ONNXCache(ONNXLayerCache::from_ort_output(
|
||||
&outputs,
|
||||
&self.name_mapping.key_value_output_names,
|
||||
)?)
|
||||
} else {
|
||||
Cache::None
|
||||
};
|
||||
|
||||
Ok(LMModelOutput { lm_logits, cache })
|
||||
}
|
||||
}
|
228
src/pipelines/onnx/encoder.rs
Normal file
228
src/pipelines/onnx/encoder.rs
Normal file
@ -0,0 +1,228 @@
|
||||
use crate::pipelines::onnx::common::{get_input_output_mapping, InputOutputNameMapping};
|
||||
use crate::pipelines::onnx::config::{
|
||||
ONNXEnvironmentConfig, ATTENTION_MASK_NAME, END_LOGITS, INPUT_EMBEDS, INPUT_IDS_NAME,
|
||||
LAST_HIDDEN_STATE, LOGITS, POSITION_IDS, START_LOGITS, TOKEN_TYPE_IDS,
|
||||
};
|
||||
use crate::pipelines::onnx::conversion::{ort_tensor_to_tch, tch_tensor_to_ort};
|
||||
use crate::RustBertError;
|
||||
use ort::{Environment, Session};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tch::Tensor;
|
||||
|
||||
/// # ONNX Encoder model
|
||||
/// Container for an ONNX encoder model and the corresponding session. Can be used individually for
|
||||
/// pure-encoder models (e.g. BERT) or as part of encoder/decoder architectures.
|
||||
pub struct ONNXEncoder {
|
||||
session: Session,
|
||||
name_mapping: InputOutputNameMapping,
|
||||
}
|
||||
|
||||
impl ONNXEncoder {
|
||||
/// Create a new `ONNXEncoder`. Requires a pointer to the model file for
|
||||
/// the encoder, a reference to an environment and an ONNX environment configuration.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use ort::Environment;
|
||||
/// use rust_bert::pipelines::onnx::config::ONNXEnvironmentConfig;
|
||||
/// use rust_bert::pipelines::onnx::ONNXEncoder;
|
||||
/// use std::path::PathBuf;
|
||||
/// use std::sync::Arc;
|
||||
/// let environment = Arc::new(Environment::default());
|
||||
/// let onnx_config = ONNXEnvironmentConfig::default();
|
||||
/// let model_file = PathBuf::from("path/to/model.onnx");
|
||||
///
|
||||
/// let encoder = ONNXEncoder::new(model_file, &environment, &onnx_config).unwrap();
|
||||
/// ```
|
||||
pub fn new(
|
||||
model_file: PathBuf,
|
||||
environment: &Arc<Environment>,
|
||||
onnx_config: &ONNXEnvironmentConfig,
|
||||
) -> Result<Self, RustBertError> {
|
||||
let session = onnx_config
|
||||
.get_session_builder(environment)?
|
||||
.with_model_from_file(model_file)?;
|
||||
let name_mapping = get_input_output_mapping(&session);
|
||||
Ok(Self {
|
||||
session,
|
||||
name_mapping,
|
||||
})
|
||||
}
|
||||
|
||||
/// Forward pass through the model.
|
||||
///
|
||||
/// The outputs provided by the model depend on the underlying ONNX model and are all marked as optional to support a broad range of
|
||||
/// encoder stacks for multiple stacks. The end-user should extract the required output that is provided by the model exported.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `token_type_ids` - Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *SEP*) and 1 for the second sentence. If None set to 0.
|
||||
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `ONNXEncoderModelOutput` containing:
|
||||
/// - `last_hidden_state` - Optional `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// - `logits` - Optional `Tensor` of shape (*batch size*, *num_labels*)
|
||||
/// - `start_logits` - Optional `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
|
||||
/// - `end_logits` - Optional `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
|
||||
/// - `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// - `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
|
||||
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = BertConfig::from_file(config_path);
|
||||
/// # let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let model_output = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&mask),
|
||||
/// Some(&token_type_ids),
|
||||
/// Some(&position_ids),
|
||||
/// None,
|
||||
/// None,
|
||||
/// None,
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
/// ```
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
) -> Result<ONNXEncoderModelOutput, RustBertError> {
|
||||
let mut input_dict = HashMap::new();
|
||||
if let Some(input_ids) = input_ids {
|
||||
input_dict.insert(INPUT_IDS_NAME, input_ids);
|
||||
}
|
||||
if let Some(attention_mask) = attention_mask {
|
||||
input_dict.insert(ATTENTION_MASK_NAME, attention_mask);
|
||||
}
|
||||
if let Some(token_type_ids) = token_type_ids {
|
||||
input_dict.insert(TOKEN_TYPE_IDS, token_type_ids);
|
||||
}
|
||||
if let Some(position_ids) = position_ids {
|
||||
input_dict.insert(POSITION_IDS, position_ids);
|
||||
}
|
||||
if let Some(input_embeds) = input_embeds {
|
||||
input_dict.insert(INPUT_EMBEDS, input_embeds);
|
||||
}
|
||||
|
||||
let inputs = self
|
||||
.name_mapping
|
||||
.input_names
|
||||
.iter()
|
||||
.map(|input_name| {
|
||||
if let Some(tensor) = input_dict.remove(input_name.as_str()) {
|
||||
tch_tensor_to_ort(tensor)
|
||||
} else {
|
||||
Err(RustBertError::OrtError(format!(
|
||||
"{input_name} not found but expected by model."
|
||||
)))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, RustBertError>>()?;
|
||||
|
||||
let outputs = self.session.run(inputs)?;
|
||||
|
||||
let last_hidden_state = self
|
||||
.name_mapping
|
||||
.output_names
|
||||
.get(LAST_HIDDEN_STATE)
|
||||
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
|
||||
.transpose()?;
|
||||
let logits = self
|
||||
.name_mapping
|
||||
.output_names
|
||||
.get(LOGITS)
|
||||
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
|
||||
.transpose()?;
|
||||
let start_logits = self
|
||||
.name_mapping
|
||||
.output_names
|
||||
.get(START_LOGITS)
|
||||
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
|
||||
.transpose()?;
|
||||
let end_logits = self
|
||||
.name_mapping
|
||||
.output_names
|
||||
.get(END_LOGITS)
|
||||
.map(|pos| ort_tensor_to_tch(&outputs[*pos]))
|
||||
.transpose()?;
|
||||
|
||||
let (hidden_states, attentions) = if self.name_mapping.output_names.len() > 1 {
|
||||
let hidden_states = self
|
||||
.name_mapping
|
||||
.output_names
|
||||
.iter()
|
||||
.filter(|(name, _)| name.contains("hidden_states"))
|
||||
.map(|(_, position)| outputs.get(*position))
|
||||
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
|
||||
.collect::<Option<Vec<_>>>();
|
||||
|
||||
let attentions = self
|
||||
.name_mapping
|
||||
.output_names
|
||||
.iter()
|
||||
.filter(|(name, _)| name.contains("attentions"))
|
||||
.map(|(_, position)| outputs.get(*position))
|
||||
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
|
||||
.collect::<Option<Vec<_>>>();
|
||||
(hidden_states, attentions)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
Ok(ONNXEncoderModelOutput {
|
||||
last_hidden_state,
|
||||
logits,
|
||||
start_logits,
|
||||
end_logits,
|
||||
hidden_states,
|
||||
attentions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// # ONNX encoder model output.
|
||||
/// The outputs provided by the model depend on the underlying ONNX model and are all marked as optional to support a broad range of
|
||||
/// encoder stacks for multiple stacks. The end-user should extract the required output that is provided by the model exported.
|
||||
pub struct ONNXEncoderModelOutput {
|
||||
/// Last hidden states, typically used by masked language model encoder models
|
||||
pub last_hidden_state: Option<Tensor>,
|
||||
/// logits, typically used by models with a sequence of classification head
|
||||
pub logits: Option<Tensor>,
|
||||
/// logits marking the start location of a span (e.g. for extractive question answering tasks)
|
||||
pub start_logits: Option<Tensor>,
|
||||
/// logits marking the end location of a span (e.g. for extractive question answering tasks)
|
||||
pub end_logits: Option<Tensor>,
|
||||
/// Hidden states for intermediate layers of the model
|
||||
pub hidden_states: Option<Vec<Tensor>>,
|
||||
/// Attention weights for intermediate layers of the model
|
||||
pub attentions: Option<Vec<Tensor>>,
|
||||
}
|
115
src/pipelines/onnx/mod.rs
Normal file
115
src/pipelines/onnx/mod.rs
Normal file
@ -0,0 +1,115 @@
|
||||
//! # ONNX model support
|
||||
//!
|
||||
//! This crate allows running inference on models that were exported to ONNX via [onnxruntime](https://onnxruntime.ai/about.html)
|
||||
//! [bindings](https://github.com/pykeio/ort). In order to use ONNX model the corresponding optional feature (`onnx`) should be turned on.
|
||||
//! This will include the optional `ort` and `ndarray` dependencies. The `rust-bert` crate does not include any optional dependencies for `ort`,
|
||||
//! the end user should select the set of features that would be adequate for pulling the required `onnxruntime` C++ library. The current recommended
|
||||
//! installation is to use dynamic linking by pointing to an existing library location:
|
||||
//! - Use the `load-dynamic` cargo feature for `ort`
|
||||
//! - set the `ORT_DYLIB_PATH` to point to the location of downloaded onnxruntime library (`onnxruntime.dll`/`libonnxruntime.so`/`libonnxruntime.dylib`
|
||||
//! depending on the operating system). These can be downloaded from the [release page](https://github.com/microsoft/onnxruntime/releases) of the onnxruntime project
|
||||
//!
|
||||
//! For troubleshooting issues when using an ONNX model, it is recommended to add the `tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }`
|
||||
//! dependency, and use the `tracing_subscriber::fmt::init();` instruction in the `main` binary.
|
||||
//!
|
||||
//! Most architectures (including encoders, decoders and encoder-decoders) are supported.
|
||||
//! the library aims at keeping compatibility with models exported using the [optimum](https://github.com/huggingface/optimum) library.
|
||||
//! A detailed guide on how to export a Transformer model to ONNX using optimum is available at https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model
|
||||
//!
|
||||
//! The resources used to create ONNX models are similar to those based on Pytorch, replacing the pytorch by the ONNX model. Since ONNX models
|
||||
//! are less flexible than their Pytorch counterparts in the handling of optional arguments, exporting a decoder or encoder-decoder model to ONNX will usually
|
||||
//! result in multiple files. These files are expected (but not all are necessary) for use in this library as per the table below:
|
||||
//!
|
||||
//! | Architecture | Encoder file | Decoder without past file | Decoder with past file |
|
||||
//! |----------------------|---------------|----------------------------|-------------------------|
|
||||
//! | Encoder (e.g. BERT) | required | not used | not used |
|
||||
//! | Decoder (e.g. GPT2) | not used | required | optional |
|
||||
//! | Encoder-decoder (e.g. BART) | required | required | optional |
|
||||
//!
|
||||
//! Note that the computational efficiency will drop when the `decoder with past` file is optional but not provided
|
||||
//! since the model will not used cached past keys and values for the attention mechanism, leading to a high number of
|
||||
//! redundant computations. The Optimum library offers export options to ensure such a `decoder with past` model file is created.
|
||||
//!
|
||||
//! The base encoder and decoder model architecture are available (and exposed for convenience) in the `encoder` and `decoder` modules, respectively.
|
||||
//! Generation models (pure decoder or encoder/decoder architectures) are available in the `models` module.
|
||||
//!
|
||||
//! Most pipelines are available for ONNX model checkpoints, including sequence classification, zero-shot classification,
|
||||
//! token classification (including named entity recognition and part-of-speech tagging), question answering, text generation, summarization and translation.
|
||||
//!
|
||||
//! These models use the same configuration and tokenizer files as their Pytorch counterparts when used in a pipeline. The following is
|
||||
//! an example of a translation model based on a ONNX export of M2M100:
|
||||
//! ```no_run
|
||||
//! use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
|
||||
//! use tch::Device;
|
||||
//!
|
||||
//! use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
//! use rust_bert::resources::RemoteResource;
|
||||
//!
|
||||
//! fn main() -> anyhow::Result<()> {
|
||||
//! let translation_model = TranslationModel::new(TranslationConfig::new(
|
||||
//! ModelType::M2M100,
|
||||
//! ModelResource::ONNX(ONNXModelResources {
|
||||
//! encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
|
||||
//! "onnx-m2m100_418M",
|
||||
//! ))),
|
||||
//! decoder_resource: Some(Box::new(RemoteResource::new(
|
||||
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
|
||||
//! "onnx-m2m100_418M",
|
||||
//! ))),
|
||||
//! decoder_with_past_resource: Some(Box::new(RemoteResource::new(
|
||||
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
|
||||
//! "onnx-m2m100_418M",
|
||||
//! ))),
|
||||
//! }),
|
||||
//! RemoteResource::new(
|
||||
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
|
||||
//! "onnx-m2m100_418M",
|
||||
//! ),
|
||||
//! RemoteResource::new(
|
||||
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
|
||||
//! "onnx-m2m100_418M",
|
||||
//! ),
|
||||
//! Some(RemoteResource::new(
|
||||
//! "https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
|
||||
//! "onnx-m2m100_418M",
|
||||
//! )),
|
||||
//! M2M100SourceLanguages::M2M100_418M,
|
||||
//! M2M100TargetLanguages::M2M100_418M,
|
||||
//! Device::cuda_if_available(),
|
||||
//! ))?;
|
||||
//!
|
||||
//! let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
//!
|
||||
//! let mut outputs = Vec::new();
|
||||
//! outputs.extend(translation_model.translate(
|
||||
//! &[source_sentence],
|
||||
//! Language::English,
|
||||
//! Language::French,
|
||||
//! )?);
|
||||
//! outputs.extend(translation_model.translate(
|
||||
//! &[source_sentence],
|
||||
//! Language::English,
|
||||
//! Language::Spanish,
|
||||
//! )?);
|
||||
//! outputs.extend(translation_model.translate(
|
||||
//! &[source_sentence],
|
||||
//! Language::English,
|
||||
//! Language::Hindi,
|
||||
//! )?);
|
||||
//!
|
||||
//! println!("{:?}", outputs);
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
mod common;
|
||||
pub mod config;
|
||||
mod conversion;
|
||||
mod decoder;
|
||||
mod encoder;
|
||||
mod models;
|
||||
|
||||
pub use encoder::{ONNXEncoder, ONNXEncoderModelOutput};
|
||||
pub use models::{ONNXCausalGenerator, ONNXConditionalGenerator, ONNXLayerCache, ONNXModelConfig};
|
1130
src/pipelines/onnx/models.rs
Normal file
1130
src/pipelines/onnx/models.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -92,7 +92,10 @@ use {
|
||||
mobilebert::{
|
||||
MobileBertConfigResources, MobileBertModelResources, MobileBertVocabResources,
|
||||
},
|
||||
pipelines::{common::ModelType, token_classification::LabelAggregationOption},
|
||||
pipelines::{
|
||||
common::{ModelResource, ModelType},
|
||||
token_classification::LabelAggregationOption,
|
||||
},
|
||||
resources::RemoteResource,
|
||||
},
|
||||
tch::Device,
|
||||
@ -121,9 +124,9 @@ impl Default for POSConfig {
|
||||
POSConfig {
|
||||
token_classification_config: TokenClassificationConfig {
|
||||
model_type: ModelType::MobileBert,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
MobileBertModelResources::MOBILEBERT_ENGLISH_POS,
|
||||
)),
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
MobileBertConfigResources::MOBILEBERT_ENGLISH_POS,
|
||||
)),
|
||||
@ -142,6 +145,14 @@ impl Default for POSConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TokenClassificationConfig> for POSConfig {
|
||||
fn from(token_classification_config: TokenClassificationConfig) -> Self {
|
||||
POSConfig {
|
||||
token_classification_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<POSConfig> for TokenClassificationConfig {
|
||||
fn from(pos_config: POSConfig) -> Self {
|
||||
pos_config.token_classification_config
|
||||
|
@ -51,23 +51,27 @@ use crate::distilbert::DistilBertForQuestionAnswering;
|
||||
use crate::fnet::FNetForQuestionAnswering;
|
||||
use crate::longformer::LongformerForQuestionAnswering;
|
||||
use crate::mobilebert::MobileBertForQuestionAnswering;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::reformer::ReformerForQuestionAnswering;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForQuestionAnswering;
|
||||
use crate::xlnet::XLNetForQuestionAnswering;
|
||||
use rust_tokenizers::{Offset, TokenIdsWithOffsets, TokenizedInput};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
use crate::deberta_v2::DebertaV2ForQuestionAnswering;
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
|
||||
@ -88,6 +92,7 @@ pub struct QaInput {
|
||||
struct QaFeature {
|
||||
pub input_ids: Vec<i64>,
|
||||
pub offsets: Vec<Option<Offset>>,
|
||||
pub token_type_ids: Vec<i8>,
|
||||
pub p_mask: Vec<i8>,
|
||||
pub example_index: i64,
|
||||
}
|
||||
@ -128,7 +133,7 @@ fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
|
||||
/// Contains information regarding the model to load and device to place the model on.
|
||||
pub struct QuestionAnsweringConfig {
|
||||
/// Model weights resource (default: pretrained DistilBERT model on SQuAD)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained DistilBERT model on SQuAD)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained DistilBERT model on SQuAD)
|
||||
@ -166,9 +171,9 @@ impl QuestionAnsweringConfig {
|
||||
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -177,13 +182,12 @@ impl QuestionAnsweringConfig {
|
||||
add_prefix_space: impl Into<Option<bool>>,
|
||||
) -> QuestionAnsweringConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
QuestionAnsweringConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -212,9 +216,9 @@ impl QuestionAnsweringConfig {
|
||||
/// * max_query_length - Optional maximum question token length. Defaults to 64.
|
||||
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
|
||||
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
|
||||
pub fn custom_new<RM, RC, RV>(
|
||||
pub fn custom_new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -227,13 +231,12 @@ impl QuestionAnsweringConfig {
|
||||
max_answer_length: impl Into<Option<usize>>,
|
||||
) -> QuestionAnsweringConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
QuestionAnsweringConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -253,9 +256,9 @@ impl QuestionAnsweringConfig {
|
||||
impl Default for QuestionAnsweringConfig {
|
||||
fn default() -> QuestionAnsweringConfig {
|
||||
QuestionAnsweringConfig {
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||
)),
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||
)),
|
||||
@ -303,6 +306,9 @@ pub enum QuestionAnsweringOption {
|
||||
Longformer(LongformerForQuestionAnswering),
|
||||
/// FNet for Question Answering
|
||||
FNet(FNetForQuestionAnswering),
|
||||
/// ONNX model for Question Answering
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXEncoder),
|
||||
}
|
||||
|
||||
impl QuestionAnsweringOption {
|
||||
@ -310,23 +316,30 @@ impl QuestionAnsweringOption {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
pub fn new<'p, P>(
|
||||
model_type: ModelType,
|
||||
p: P,
|
||||
config: &ConfigOption,
|
||||
) -> Result<Self, RustBertError>
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
match model_type {
|
||||
/// * `QuestionAnsweringConfig` - Question answering pipeline configuration. The type of model created will be inferred from the
|
||||
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
|
||||
pub fn new(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_resource {
|
||||
ModelResource::Torch(_) => Self::new_torch(config),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Self::new_onnx(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_torch(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
|
||||
let device = config.device;
|
||||
let weights_path = config.model_resource.get_torch_local_path()?;
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = &mut ConfigOption::from_file(
|
||||
config.model_type,
|
||||
config.config_resource.get_local_path()?,
|
||||
);
|
||||
let model_type = config.model_type;
|
||||
let model = match model_type {
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::Bert(
|
||||
BertForQuestionAnswering::new(p, config),
|
||||
BertForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -335,9 +348,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::Deberta => {
|
||||
if let ConfigOption::Deberta(config) = config {
|
||||
if let ConfigOption::Deberta(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::Deberta(
|
||||
DebertaForQuestionAnswering::new(p, config),
|
||||
DebertaForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -346,9 +359,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::DebertaV2 => {
|
||||
if let ConfigOption::DebertaV2(config) = config {
|
||||
if let ConfigOption::DebertaV2(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::DebertaV2(
|
||||
DebertaV2ForQuestionAnswering::new(p, config),
|
||||
DebertaV2ForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -357,9 +370,10 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
if let ConfigOption::DistilBert(ref mut config) = model_config {
|
||||
config.sinusoidal_pos_embds = false;
|
||||
Ok(QuestionAnsweringOption::DistilBert(
|
||||
DistilBertForQuestionAnswering::new(p, config),
|
||||
DistilBertForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -368,9 +382,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::MobileBert => {
|
||||
if let ConfigOption::MobileBert(config) = config {
|
||||
if let ConfigOption::MobileBert(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::MobileBert(
|
||||
MobileBertForQuestionAnswering::new(p, config),
|
||||
MobileBertForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -379,9 +393,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Roberta(config) = config {
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::Roberta(
|
||||
RobertaForQuestionAnswering::new(p, config),
|
||||
RobertaForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -390,9 +404,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLMRoberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::XLMRoberta(
|
||||
RobertaForQuestionAnswering::new(p, config),
|
||||
RobertaForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -401,9 +415,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::Albert => {
|
||||
if let ConfigOption::Albert(config) = config {
|
||||
if let ConfigOption::Albert(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::Albert(
|
||||
AlbertForQuestionAnswering::new(p, config),
|
||||
AlbertForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -412,9 +426,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLNet => {
|
||||
if let ConfigOption::XLNet(config) = config {
|
||||
if let ConfigOption::XLNet(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::XLNet(
|
||||
XLNetForQuestionAnswering::new(p, config)?,
|
||||
XLNetForQuestionAnswering::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -423,9 +437,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::Reformer => {
|
||||
if let ConfigOption::Reformer(config) = config {
|
||||
if let ConfigOption::Reformer(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::Reformer(
|
||||
ReformerForQuestionAnswering::new(p, config)?,
|
||||
ReformerForQuestionAnswering::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -434,9 +448,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::Longformer => {
|
||||
if let ConfigOption::Longformer(config) = config {
|
||||
if let ConfigOption::Longformer(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::Longformer(
|
||||
LongformerForQuestionAnswering::new(p, config),
|
||||
LongformerForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -445,9 +459,9 @@ impl QuestionAnsweringOption {
|
||||
}
|
||||
}
|
||||
ModelType::FNet => {
|
||||
if let ConfigOption::FNet(config) = config {
|
||||
if let ConfigOption::FNet(config) = model_config {
|
||||
Ok(QuestionAnsweringOption::FNet(
|
||||
FNetForQuestionAnswering::new(p, config),
|
||||
FNetForQuestionAnswering::new(var_store.root(), config),
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -458,7 +472,28 @@ impl QuestionAnsweringOption {
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"QuestionAnswering not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
.encoder_path
|
||||
.ok_or(RustBertError::InvalidConfigurationError(
|
||||
"An encoder file must be provided for question answering ONNX models.".to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
@ -476,6 +511,8 @@ impl QuestionAnsweringOption {
|
||||
Self::Reformer(_) => ModelType::Reformer,
|
||||
Self::Longformer(_) => ModelType::Longformer,
|
||||
Self::FNet(_) => ModelType::FNet,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
@ -485,6 +522,7 @@ impl QuestionAnsweringOption {
|
||||
input_ids: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
_token_type_ids: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor) {
|
||||
match *self {
|
||||
@ -547,6 +585,19 @@ impl QuestionAnsweringOption {
|
||||
.expect("Error in fnet forward pass");
|
||||
(outputs.start_logits, outputs.end_logits)
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => {
|
||||
let outputs = model
|
||||
.forward(
|
||||
input_ids,
|
||||
mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
|
||||
_token_type_ids,
|
||||
None,
|
||||
input_embeds,
|
||||
)
|
||||
.expect("Error in ONNX forward pass.");
|
||||
(outputs.start_logits.unwrap(), outputs.end_logits.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -561,7 +612,7 @@ pub struct QuestionAnsweringModel {
|
||||
max_query_length: usize,
|
||||
max_answer_len: usize,
|
||||
qa_model: QuestionAnsweringOption,
|
||||
var_store: VarStore,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl QuestionAnsweringModel {
|
||||
@ -631,8 +682,7 @@ impl QuestionAnsweringModel {
|
||||
question_answering_config: QuestionAnsweringConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<QuestionAnsweringModel, RustBertError> {
|
||||
let config_path = question_answering_config.config_resource.get_local_path()?;
|
||||
let device = question_answering_config.device;
|
||||
let qa_model = QuestionAnsweringOption::new(&question_answering_config)?;
|
||||
|
||||
let pad_idx = tokenizer
|
||||
.get_pad_id()
|
||||
@ -640,19 +690,6 @@ impl QuestionAnsweringModel {
|
||||
let sep_idx = tokenizer
|
||||
.get_sep_id()
|
||||
.expect("The Tokenizer used for Question Answering should contain a SEP id");
|
||||
let mut var_store = VarStore::new(device);
|
||||
let mut model_config =
|
||||
ConfigOption::from_file(question_answering_config.model_type, config_path);
|
||||
|
||||
if let ConfigOption::DistilBert(ref mut config) = model_config {
|
||||
config.sinusoidal_pos_embds = false;
|
||||
};
|
||||
|
||||
let qa_model = QuestionAnsweringOption::new(
|
||||
question_answering_config.model_type,
|
||||
var_store.root(),
|
||||
&model_config,
|
||||
)?;
|
||||
|
||||
if question_answering_config.max_seq_length
|
||||
< (question_answering_config.max_query_length
|
||||
@ -668,8 +705,10 @@ impl QuestionAnsweringModel {
|
||||
question_answering_config.doc_stride
|
||||
)));
|
||||
}
|
||||
|
||||
crate::resources::load_weights(&question_answering_config.model_resource, &mut var_store)?;
|
||||
let device = get_device(
|
||||
question_answering_config.model_resource,
|
||||
question_answering_config.device,
|
||||
);
|
||||
Ok(QuestionAnsweringModel {
|
||||
tokenizer,
|
||||
pad_idx,
|
||||
@ -679,7 +718,7 @@ impl QuestionAnsweringModel {
|
||||
max_query_length: question_answering_config.max_query_length,
|
||||
max_answer_len: question_answering_config.max_answer_length,
|
||||
qa_model,
|
||||
var_store,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
@ -758,11 +797,16 @@ impl QuestionAnsweringModel {
|
||||
let end = start + min(len_features - start, batch_size);
|
||||
let batch_features = &mut features[start..end];
|
||||
no_grad(|| {
|
||||
let (input_ids, attention_masks) = self.pad_features(batch_features);
|
||||
let (input_ids, attention_masks, token_type_ids) =
|
||||
self.pad_features(batch_features);
|
||||
|
||||
let (start_logits, end_logits) =
|
||||
self.qa_model
|
||||
.forward_t(Some(&input_ids), Some(&attention_masks), None, false);
|
||||
let (start_logits, end_logits) = self.qa_model.forward_t(
|
||||
Some(&input_ids),
|
||||
Some(&attention_masks),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
false,
|
||||
);
|
||||
|
||||
let start_logits = start_logits.detach();
|
||||
let end_logits = end_logits.detach();
|
||||
@ -935,6 +979,7 @@ impl QuestionAnsweringModel {
|
||||
let qa_feature = QaFeature {
|
||||
input_ids: encoded_span.token_ids,
|
||||
offsets: encoded_span.token_offsets,
|
||||
token_type_ids: encoded_span.segment_ids,
|
||||
p_mask,
|
||||
example_index,
|
||||
};
|
||||
@ -947,7 +992,7 @@ impl QuestionAnsweringModel {
|
||||
spans
|
||||
}
|
||||
|
||||
fn pad_features(&self, features: &mut [QaFeature]) -> (Tensor, Tensor) {
|
||||
fn pad_features(&self, features: &mut [QaFeature]) -> (Tensor, Tensor, Tensor) {
|
||||
let max_len = features
|
||||
.iter()
|
||||
.map(|feature| feature.input_ids.len())
|
||||
@ -970,6 +1015,9 @@ impl QuestionAnsweringModel {
|
||||
feature.offsets.resize(max_len, None);
|
||||
feature.p_mask.resize(max_len, 1);
|
||||
feature.input_ids.resize(max_len, self.pad_idx);
|
||||
feature
|
||||
.token_type_ids
|
||||
.resize(max_len, *feature.token_type_ids.last().unwrap_or(&0));
|
||||
}
|
||||
|
||||
let padded_input_ids = features
|
||||
@ -977,9 +1025,17 @@ impl QuestionAnsweringModel {
|
||||
.map(|input| Tensor::from_slice(input.input_ids.as_slice()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
||||
(input_ids, attention_masks)
|
||||
let padded_token_type_ids = features
|
||||
.iter_mut()
|
||||
.map(|input| Tensor::from_slice(input.token_type_ids.as_slice()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.device);
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.device);
|
||||
let token_type_ids = Tensor::stack(&padded_token_type_ids, 0)
|
||||
.to(self.device)
|
||||
.to_kind(Kind::Int64);
|
||||
(input_ids, attention_masks, token_type_ids)
|
||||
}
|
||||
|
||||
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
|
||||
|
@ -22,8 +22,9 @@
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//!
|
||||
//! //Load a configuration
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
|
||||
//! RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
|
||||
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2))),
|
||||
//! RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
|
||||
//! RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
|
||||
//! None, // Merge resources
|
||||
@ -66,20 +67,21 @@ use crate::distilbert::DistilBertModelClassifier;
|
||||
use crate::fnet::FNetForSequenceClassification;
|
||||
use crate::longformer::LongformerForSequenceClassification;
|
||||
use crate::mobilebert::MobileBertForSequenceClassification;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::reformer::ReformerForSequenceClassification;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForSequenceClassification;
|
||||
use crate::xlnet::XLNetForSequenceClassification;
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use rust_tokenizers::TokenizedInput;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
use crate::deberta_v2::DebertaV2ForSequenceClassification;
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
distilbert::{DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources},
|
||||
@ -106,7 +108,7 @@ pub struct SequenceClassificationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||
@ -134,9 +136,9 @@ impl SequenceClassificationConfig {
|
||||
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -145,13 +147,12 @@ impl SequenceClassificationConfig {
|
||||
add_prefix_space: impl Into<Option<bool>>,
|
||||
) -> SequenceClassificationConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
SequenceClassificationConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -169,7 +170,9 @@ impl Default for SequenceClassificationConfig {
|
||||
fn default() -> SequenceClassificationConfig {
|
||||
SequenceClassificationConfig::new(
|
||||
ModelType::DistilBert,
|
||||
RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SST2,
|
||||
))),
|
||||
RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2),
|
||||
RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2),
|
||||
None,
|
||||
@ -209,6 +212,9 @@ pub enum SequenceClassificationOption {
|
||||
Longformer(LongformerForSequenceClassification),
|
||||
/// FNet for Sequence Classification
|
||||
FNet(FNetForSequenceClassification),
|
||||
/// ONNX Model for Sequence Classification
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXEncoder),
|
||||
}
|
||||
|
||||
impl SequenceClassificationOption {
|
||||
@ -216,23 +222,28 @@ impl SequenceClassificationOption {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
pub fn new<'p, P>(
|
||||
model_type: ModelType,
|
||||
p: P,
|
||||
config: &ConfigOption,
|
||||
) -> Result<Self, RustBertError>
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
match model_type {
|
||||
/// * `SequenceClassificationConfig` - Sequence classification pipeline configuration. The type of model created will be inferred from the
|
||||
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
|
||||
pub fn new(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_resource {
|
||||
ModelResource::Torch(_) => Self::new_torch(config),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Self::new_onnx(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_torch(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let device = config.device;
|
||||
let weights_path = config.model_resource.get_torch_local_path()?;
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config =
|
||||
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
|
||||
let model_type = config.model_type;
|
||||
let model = match model_type {
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
Ok(SequenceClassificationOption::Bert(
|
||||
BertForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(Self::Bert(
|
||||
BertForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -241,9 +252,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Deberta => {
|
||||
if let ConfigOption::Deberta(config) = config {
|
||||
Ok(SequenceClassificationOption::Deberta(
|
||||
DebertaForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Deberta(config) = model_config {
|
||||
Ok(Self::Deberta(
|
||||
DebertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -252,9 +263,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::DebertaV2 => {
|
||||
if let ConfigOption::DebertaV2(config) = config {
|
||||
Ok(SequenceClassificationOption::DebertaV2(
|
||||
DebertaV2ForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::DebertaV2(config) = model_config {
|
||||
Ok(Self::DebertaV2(
|
||||
DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -263,9 +274,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
Ok(SequenceClassificationOption::DistilBert(
|
||||
DistilBertModelClassifier::new(p, config)?,
|
||||
if let ConfigOption::DistilBert(config) = model_config {
|
||||
Ok(Self::DistilBert(
|
||||
DistilBertModelClassifier::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -274,9 +285,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::MobileBert => {
|
||||
if let ConfigOption::MobileBert(config) = config {
|
||||
Ok(SequenceClassificationOption::MobileBert(
|
||||
MobileBertForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::MobileBert(config) = model_config {
|
||||
Ok(Self::MobileBert(
|
||||
MobileBertForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -285,9 +296,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Roberta(config) = config {
|
||||
Ok(SequenceClassificationOption::Roberta(
|
||||
RobertaForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(Self::Roberta(
|
||||
RobertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -296,9 +307,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLMRoberta => {
|
||||
if let ConfigOption::Roberta(config) = config {
|
||||
Ok(SequenceClassificationOption::XLMRoberta(
|
||||
RobertaForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(Self::XLMRoberta(
|
||||
RobertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -307,9 +318,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Albert => {
|
||||
if let ConfigOption::Albert(config) = config {
|
||||
Ok(SequenceClassificationOption::Albert(
|
||||
AlbertForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Albert(config) = model_config {
|
||||
Ok(Self::Albert(
|
||||
AlbertForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -318,9 +329,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLNet => {
|
||||
if let ConfigOption::XLNet(config) = config {
|
||||
Ok(SequenceClassificationOption::XLNet(
|
||||
XLNetForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::XLNet(config) = model_config {
|
||||
Ok(Self::XLNet(
|
||||
XLNetForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -329,9 +340,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Bart => {
|
||||
if let ConfigOption::Bart(config) = config {
|
||||
Ok(SequenceClassificationOption::Bart(
|
||||
BartForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Bart(config) = model_config {
|
||||
Ok(Self::Bart(
|
||||
BartForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -340,9 +351,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Reformer => {
|
||||
if let ConfigOption::Reformer(config) = config {
|
||||
Ok(SequenceClassificationOption::Reformer(
|
||||
ReformerForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Reformer(config) = model_config {
|
||||
Ok(Self::Reformer(
|
||||
ReformerForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -351,9 +362,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Longformer => {
|
||||
if let ConfigOption::Longformer(config) = config {
|
||||
Ok(SequenceClassificationOption::Longformer(
|
||||
LongformerForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Longformer(config) = model_config {
|
||||
Ok(Self::Longformer(
|
||||
LongformerForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -362,9 +373,9 @@ impl SequenceClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::FNet => {
|
||||
if let ConfigOption::FNet(config) = config {
|
||||
Ok(SequenceClassificationOption::FNet(
|
||||
FNetForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::FNet(config) = model_config {
|
||||
Ok(Self::FNet(
|
||||
FNetForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -372,10 +383,36 @@ impl SequenceClassificationOption {
|
||||
))
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
|
||||
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
|
||||
)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Sequence Classification not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
.encoder_path
|
||||
.ok_or(RustBertError::InvalidConfigurationError(
|
||||
"An encoder file must be provided for sequence classification ONNX models."
|
||||
.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
@ -394,6 +431,8 @@ impl SequenceClassificationOption {
|
||||
Self::Reformer(_) => ModelType::Reformer,
|
||||
Self::Longformer(_) => ModelType::Longformer,
|
||||
Self::FNet(_) => ModelType::FNet,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
@ -534,6 +573,21 @@ impl SequenceClassificationOption {
|
||||
.expect("Error in FNet forward pass.")
|
||||
.logits
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => {
|
||||
let attention_mask = input_ids.unwrap().ones_like();
|
||||
model
|
||||
.forward(
|
||||
input_ids,
|
||||
Some(&attention_mask),
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
)
|
||||
.expect("Error in ONNX forward pass.")
|
||||
.logits
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -543,7 +597,7 @@ pub struct SequenceClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
sequence_classifier: SequenceClassificationOption,
|
||||
label_mapping: HashMap<i64, String>,
|
||||
var_store: VarStore,
|
||||
device: Device,
|
||||
max_length: usize,
|
||||
}
|
||||
|
||||
@ -615,23 +669,20 @@ impl SequenceClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<SequenceClassificationModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
let sequence_classifier = SequenceClassificationOption::new(&config)?;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let max_length = model_config
|
||||
.get_max_len()
|
||||
.map(|v| v as usize)
|
||||
.unwrap_or(usize::MAX);
|
||||
let sequence_classifier =
|
||||
SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
let label_mapping = model_config.get_label_mapping().clone();
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
let device = get_device(config.model_resource, config.device);
|
||||
Ok(SequenceClassificationModel {
|
||||
tokenizer,
|
||||
sequence_classifier,
|
||||
label_mapping,
|
||||
var_store,
|
||||
device,
|
||||
max_length,
|
||||
})
|
||||
}
|
||||
@ -645,36 +696,6 @@ impl SequenceClassificationModel {
|
||||
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
|
||||
fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(
|
||||
input.as_ref(),
|
||||
self.max_length,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let pad_id = self
|
||||
.tokenizer
|
||||
.get_pad_id()
|
||||
.expect("The Tokenizer used for sequence classification should contain a PAD id");
|
||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
input.token_ids.resize(max_len, pad_id);
|
||||
Tensor::from_slice(&(input.token_ids))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
|
||||
}
|
||||
|
||||
/// Classify texts
|
||||
///
|
||||
/// # Arguments
|
||||
@ -705,12 +726,14 @@ impl SequenceClassificationModel {
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let input_tensor = self.prepare_for_model(input.as_ref());
|
||||
let (input_ids, token_type_ids) =
|
||||
self.tokenizer
|
||||
.tokenize_and_pad(input.as_ref(), self.max_length, self.device);
|
||||
let output = no_grad(|| {
|
||||
let output = self.sequence_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
Some(&input_ids),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
@ -774,12 +797,14 @@ impl SequenceClassificationModel {
|
||||
input: &[&str],
|
||||
threshold: f64,
|
||||
) -> Result<Vec<Vec<Label>>, RustBertError> {
|
||||
let input_tensor = self.prepare_for_model(input);
|
||||
let (input_ids, token_type_ids) =
|
||||
self.tokenizer
|
||||
.tokenize_and_pad(input.as_ref(), self.max_length, self.device);
|
||||
let output = no_grad(|| {
|
||||
let output = self.sequence_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
Some(&input_ids),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
|
@ -67,14 +67,15 @@ use tch::Device;
|
||||
use crate::bart::BartGenerator;
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::pegasus::PegasusConditionalGenerator;
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use crate::prophetnet::ProphetNetConditionalGenerator;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::t5::T5Generator;
|
||||
|
||||
use crate::longt5::LongT5Generator;
|
||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::ONNXConditionalGenerator;
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
|
||||
@ -88,7 +89,7 @@ pub struct SummarizationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BART model on CNN-DM)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained BART model on CNN-DM)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained BART model on CNN-DM)
|
||||
@ -133,25 +134,24 @@ impl SummarizationConfig {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||
/// * model_resource - The `ModelResources` pointing to the model to load (e.g. model.ot)
|
||||
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
) -> SummarizationConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
SummarizationConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -179,7 +179,9 @@ impl Default for SummarizationConfig {
|
||||
fn default() -> SummarizationConfig {
|
||||
SummarizationConfig::new(
|
||||
ModelType::Bart,
|
||||
RemoteResource::from_pretrained(BartModelResources::BART_CNN),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BartModelResources::BART_CNN,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
|
||||
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
@ -192,6 +194,7 @@ impl Default for SummarizationConfig {
|
||||
impl From<SummarizationConfig> for GenerateConfig {
|
||||
fn from(config: SummarizationConfig) -> GenerateConfig {
|
||||
GenerateConfig {
|
||||
model_type: config.model_type,
|
||||
model_resource: config.model_resource,
|
||||
config_resource: config.config_resource,
|
||||
merges_resource: config.merges_resource,
|
||||
@ -227,22 +230,29 @@ pub enum SummarizationOption {
|
||||
ProphetNet(ProphetNetConditionalGenerator),
|
||||
/// Summarizer based on Pegasus model
|
||||
Pegasus(PegasusConditionalGenerator),
|
||||
/// Summarizer based on ONNX model
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXConditionalGenerator),
|
||||
}
|
||||
|
||||
impl SummarizationOption {
|
||||
pub fn new(config: SummarizationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::Bart => Ok(SummarizationOption::Bart(BartGenerator::new(
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
|
||||
ONNXConditionalGenerator::new(config.into(), None, None)?,
|
||||
)),
|
||||
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
|
||||
ModelType::LongT5 => Ok(SummarizationOption::LongT5(LongT5Generator::new(
|
||||
(ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
|
||||
(ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(LongT5Generator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
|
||||
(ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
|
||||
ProphetNetConditionalGenerator::new(config.into())?,
|
||||
)),
|
||||
ModelType::Pegasus => Ok(SummarizationOption::Pegasus(
|
||||
(ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
|
||||
PegasusConditionalGenerator::new(config.into())?,
|
||||
)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
@ -256,21 +266,25 @@ impl SummarizationOption {
|
||||
config: SummarizationConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::Bart => Ok(SummarizationOption::Bart(
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
|
||||
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
|
||||
)),
|
||||
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(
|
||||
BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
|
||||
(ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
|
||||
config.into(),
|
||||
tokenizer,
|
||||
)?)),
|
||||
ModelType::LongT5 => Ok(SummarizationOption::LongT5(
|
||||
(ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(
|
||||
LongT5Generator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
|
||||
(ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
|
||||
ProphetNetConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::Pegasus => Ok(SummarizationOption::Pegasus(
|
||||
(ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
|
||||
PegasusConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
@ -288,28 +302,34 @@ impl SummarizationOption {
|
||||
Self::LongT5(_) => ModelType::LongT5,
|
||||
Self::ProphetNet(_) => ModelType::ProphetNet,
|
||||
Self::Pegasus(_) => ModelType::Pegasus,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to access tokenizer
|
||||
pub fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
match self {
|
||||
Self::Bart(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::T5(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::LongT5(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::ProphetNet(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::Pegasus(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::Bart(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::T5(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::LongT5(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::ProphetNet(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::Pegasus(model_ref) => model_ref.get_tokenizer(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(model_ref) => model_ref.get_tokenizer(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to access tokenizer
|
||||
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
match self {
|
||||
Self::Bart(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::LongT5(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::ProphetNet(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::Pegasus(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::Bart(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::LongT5(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::ProphetNet(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::Pegasus(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -344,6 +364,12 @@ impl SummarizationOption {
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => model
|
||||
.generate(prompt_texts, None)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -38,14 +38,15 @@ use crate::gpt2::GPT2Generator;
|
||||
use crate::gpt_j::GptJGenerator;
|
||||
use crate::gpt_neo::GptNeoGenerator;
|
||||
use crate::openai_gpt::OpenAIGenerator;
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
|
||||
use crate::reformer::ReformerGenerator;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::t5::T5Generator;
|
||||
use crate::xlnet::XLNetGenerator;
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::ONNXCausalGenerator;
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
|
||||
@ -59,7 +60,7 @@ pub struct TextGenerationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BART model on CNN-DM)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained BART model on CNN-DM)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained BART model on CNN-DM)
|
||||
@ -104,25 +105,24 @@ impl TextGenerationConfig {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||
/// * model_resource - The `ResourceProvider` pointing to the model to load (e.g. model.ot)
|
||||
/// * model_resource - The `ModelResources` pointing to the model to load (e.g. model.ot)
|
||||
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
|
||||
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
) -> TextGenerationConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
TextGenerationConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -150,7 +150,9 @@ impl Default for TextGenerationConfig {
|
||||
fn default() -> TextGenerationConfig {
|
||||
TextGenerationConfig::new(
|
||||
ModelType::GPT2,
|
||||
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::GPT2_MEDIUM,
|
||||
))),
|
||||
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
|
||||
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
@ -163,6 +165,7 @@ impl Default for TextGenerationConfig {
|
||||
impl From<TextGenerationConfig> for GenerateConfig {
|
||||
fn from(config: TextGenerationConfig) -> GenerateConfig {
|
||||
GenerateConfig {
|
||||
model_type: config.model_type,
|
||||
model_resource: config.model_resource,
|
||||
config_resource: config.config_resource,
|
||||
merges_resource: config.merges_resource,
|
||||
@ -202,30 +205,37 @@ pub enum TextGenerationOption {
|
||||
Reformer(ReformerGenerator),
|
||||
/// Text Generator based on T5 model
|
||||
T5(T5Generator),
|
||||
/// ONNX model for text generation
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXCausalGenerator),
|
||||
}
|
||||
|
||||
impl TextGenerationOption {
|
||||
pub fn new(config: TextGenerationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::GPT2 => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
|
||||
ONNXCausalGenerator::new(config.into(), None, None)?,
|
||||
)),
|
||||
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::OpenAiGpt => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
|
||||
(ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::XLNet => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
|
||||
(ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::Reformer => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
|
||||
(ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::GPTNeo => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
|
||||
(ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
|
||||
(ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::T5 => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
|
||||
(ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Text generation not implemented for {:?}!",
|
||||
config.model_type
|
||||
@ -237,26 +247,30 @@ impl TextGenerationOption {
|
||||
config: TextGenerationConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::GPT2 => Ok(TextGenerationOption::GPT2(
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
|
||||
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
|
||||
)),
|
||||
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
|
||||
GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::OpenAiGpt => Ok(TextGenerationOption::GPT(
|
||||
(ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(
|
||||
OpenAIGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::XLNet => Ok(TextGenerationOption::XLNet(
|
||||
(ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(
|
||||
XLNetGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::Reformer => Ok(TextGenerationOption::Reformer(
|
||||
(ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(
|
||||
ReformerGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::GPTNeo => Ok(TextGenerationOption::GPTNeo(
|
||||
(ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(
|
||||
GptNeoGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(
|
||||
(ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(
|
||||
GptJGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::T5 => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
|
||||
(ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
|
||||
config.into(),
|
||||
tokenizer,
|
||||
)?)),
|
||||
@ -277,32 +291,37 @@ impl TextGenerationOption {
|
||||
Self::XLNet(_) => ModelType::XLNet,
|
||||
Self::Reformer(_) => ModelType::Reformer,
|
||||
Self::T5(_) => ModelType::T5,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to access tokenizer
|
||||
pub fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
match self {
|
||||
Self::GPT(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::GPTNeo(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::GPTJ(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::XLNet(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::Reformer(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::T5(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::GPT(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPTJ(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::XLNet(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::Reformer(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::T5(model_ref) => model_ref.get_tokenizer(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(model_ref) => model_ref.get_tokenizer(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to access tokenizer
|
||||
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
match self {
|
||||
Self::GPT(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::GPTNeo(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::GPTJ(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::XLNet(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::Reformer(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::GPT(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::GPT2(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::GPTNeo(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::GPTJ(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::XLNet(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::Reformer(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -357,10 +376,16 @@ impl TextGenerationOption {
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => model
|
||||
.generate_indices(prompt_texts, generate_options)
|
||||
.into_iter()
|
||||
.map(|output| output.indices)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn half(&mut self) {
|
||||
pub fn half(&mut self) -> Result<(), RustBertError> {
|
||||
match self {
|
||||
Self::GPT(model_ref) => model_ref.half(),
|
||||
Self::GPT2(model_ref) => model_ref.half(),
|
||||
@ -369,10 +394,14 @@ impl TextGenerationOption {
|
||||
Self::XLNet(model_ref) => model_ref.half(),
|
||||
Self::Reformer(model_ref) => model_ref.half(),
|
||||
Self::T5(model_ref) => model_ref.half(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => Err(RustBertError::OrtError(
|
||||
"Type casting not supported for ONNX models.".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn float(&mut self) {
|
||||
pub fn float(&mut self) -> Result<(), RustBertError> {
|
||||
match self {
|
||||
Self::GPT(model_ref) => model_ref.float(),
|
||||
Self::GPT2(model_ref) => model_ref.float(),
|
||||
@ -381,10 +410,14 @@ impl TextGenerationOption {
|
||||
Self::XLNet(model_ref) => model_ref.float(),
|
||||
Self::Reformer(model_ref) => model_ref.float(),
|
||||
Self::T5(model_ref) => model_ref.float(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => Err(RustBertError::OrtError(
|
||||
"Type casting not supported for ONNX models.".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_device(&mut self, device: Device) {
|
||||
pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
|
||||
match self {
|
||||
Self::GPT(model_ref) => model_ref.set_device(device),
|
||||
Self::GPT2(model_ref) => model_ref.set_device(device),
|
||||
@ -393,6 +426,10 @@ impl TextGenerationOption {
|
||||
Self::XLNet(model_ref) => model_ref.set_device(device),
|
||||
Self::Reformer(model_ref) => model_ref.set_device(device),
|
||||
Self::T5(model_ref) => model_ref.set_device(device),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => Err(RustBertError::OrtError(
|
||||
"Device assignment not supported for ONNX models.".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -520,16 +557,16 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
self.model.get_tokenizer_mut()
|
||||
}
|
||||
|
||||
pub fn half(&mut self) {
|
||||
self.model.half();
|
||||
pub fn half(&mut self) -> Result<(), RustBertError> {
|
||||
self.model.half()
|
||||
}
|
||||
|
||||
pub fn float(&mut self) {
|
||||
self.model.float();
|
||||
pub fn float(&mut self) -> Result<(), RustBertError> {
|
||||
self.model.float()
|
||||
}
|
||||
|
||||
pub fn set_device(&mut self, device: Device) {
|
||||
self.model.set_device(device);
|
||||
pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
|
||||
self.model.set_device(device)
|
||||
}
|
||||
|
||||
/// Generate texts from provided prompts
|
||||
|
@ -21,11 +21,12 @@
|
||||
//! use rust_bert::pipelines::common::ModelType;
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//!
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! //Load a configuration
|
||||
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
|
||||
//! let config = TokenClassificationConfig::new(
|
||||
//! ModelType::Bert,
|
||||
//! RemoteResource::from_pretrained(BertModelResources::BERT_NER),
|
||||
//! ModelResource::Torch(Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_NER))),
|
||||
//! RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
|
||||
//! RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
|
||||
//! None, //merges resource only relevant with ModelType::Roberta
|
||||
@ -120,7 +121,9 @@ use crate::electra::ElectraForTokenClassification;
|
||||
use crate::fnet::FNetForTokenClassification;
|
||||
use crate::longformer::LongformerForTokenClassification;
|
||||
use crate::mobilebert::MobileBertForTokenClassification;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{
|
||||
get_device, ConfigOption, ModelResource, ModelType, TokenizerOption,
|
||||
};
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForTokenClassification;
|
||||
use crate::xlnet::XLNetForTokenClassification;
|
||||
@ -131,13 +134,14 @@ use rust_tokenizers::{
|
||||
TokenizedInput,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
use crate::deberta_v2::DebertaV2ForTokenClassification;
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
bert::{BertConfigResources, BertModelResources, BertVocabResources},
|
||||
@ -195,6 +199,8 @@ struct InputFeature {
|
||||
offsets: Vec<Option<Offset>>,
|
||||
/// Token category (mask)
|
||||
mask: Vec<Mask>,
|
||||
/// Token type ids (mask)
|
||||
token_type_ids: Vec<i64>,
|
||||
/// per-token flag indicating if this feature carries the output label for this token
|
||||
reference_feature: Vec<bool>,
|
||||
/// Reference example index (long inputs may be broken into multiple input features)
|
||||
@ -222,7 +228,7 @@ pub struct TokenClassificationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||
@ -254,9 +260,9 @@ impl TokenClassificationConfig {
|
||||
/// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -266,13 +272,12 @@ impl TokenClassificationConfig {
|
||||
label_aggregation_function: LabelAggregationOption,
|
||||
) -> TokenClassificationConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
TokenClassificationConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -292,7 +297,9 @@ impl Default for TokenClassificationConfig {
|
||||
fn default() -> TokenClassificationConfig {
|
||||
TokenClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT_NER),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_NER,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT_NER),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT_NER),
|
||||
None,
|
||||
@ -331,30 +338,38 @@ pub enum TokenClassificationOption {
|
||||
Longformer(LongformerForTokenClassification),
|
||||
/// FNet for Token Classification
|
||||
FNet(FNetForTokenClassification),
|
||||
/// ONNX model for Token Classification
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXEncoder),
|
||||
}
|
||||
|
||||
impl TokenClassificationOption {
|
||||
/// Instantiate a new token sequence classification model of the supplied type.
|
||||
/// Instantiate a new sequence classification model of the supplied type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
pub fn new<'p, P>(
|
||||
model_type: ModelType,
|
||||
p: P,
|
||||
config: &ConfigOption,
|
||||
) -> Result<Self, RustBertError>
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
match model_type {
|
||||
/// * `TokenClassificationConfig` - Token classification pipeline configuration. The type of model created will be inferred from the
|
||||
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
|
||||
pub fn new(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_resource {
|
||||
ModelResource::Torch(_) => Self::new_torch(config),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Self::new_onnx(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_torch(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let device = config.device;
|
||||
let weights_path = config.model_resource.get_torch_local_path()?;
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config =
|
||||
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
|
||||
let model_type = config.model_type;
|
||||
let model = match model_type {
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
Ok(TokenClassificationOption::Bert(
|
||||
BertForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(Self::Bert(
|
||||
BertForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -363,9 +378,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Deberta => {
|
||||
if let ConfigOption::Deberta(config) = config {
|
||||
Ok(TokenClassificationOption::Deberta(
|
||||
DebertaForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Deberta(config) = model_config {
|
||||
Ok(Self::Deberta(
|
||||
DebertaForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -374,9 +389,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::DebertaV2 => {
|
||||
if let ConfigOption::DebertaV2(config) = config {
|
||||
Ok(TokenClassificationOption::DebertaV2(
|
||||
DebertaV2ForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::DebertaV2(config) = model_config {
|
||||
Ok(Self::DebertaV2(
|
||||
DebertaV2ForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -385,9 +400,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
Ok(TokenClassificationOption::DistilBert(
|
||||
DistilBertForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::DistilBert(config) = model_config {
|
||||
Ok(Self::DistilBert(
|
||||
DistilBertForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -396,9 +411,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::MobileBert => {
|
||||
if let ConfigOption::MobileBert(config) = config {
|
||||
Ok(TokenClassificationOption::MobileBert(
|
||||
MobileBertForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::MobileBert(config) = model_config {
|
||||
Ok(Self::MobileBert(
|
||||
MobileBertForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -407,9 +422,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Roberta(config) = config {
|
||||
Ok(TokenClassificationOption::Roberta(
|
||||
RobertaForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(Self::Roberta(
|
||||
RobertaForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -418,9 +433,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLMRoberta => {
|
||||
if let ConfigOption::Roberta(config) = config {
|
||||
Ok(TokenClassificationOption::XLMRoberta(
|
||||
RobertaForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Roberta(config) = model_config {
|
||||
Ok(Self::XLMRoberta(
|
||||
RobertaForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -429,9 +444,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Electra => {
|
||||
if let ConfigOption::Electra(config) = config {
|
||||
Ok(TokenClassificationOption::Electra(
|
||||
ElectraForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Electra(config) = model_config {
|
||||
Ok(Self::Electra(
|
||||
ElectraForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -440,9 +455,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Albert => {
|
||||
if let ConfigOption::Albert(config) = config {
|
||||
Ok(TokenClassificationOption::Albert(
|
||||
AlbertForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Albert(config) = model_config {
|
||||
Ok(Self::Albert(
|
||||
AlbertForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -451,9 +466,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLNet => {
|
||||
if let ConfigOption::XLNet(config) = config {
|
||||
Ok(TokenClassificationOption::XLNet(
|
||||
XLNetForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::XLNet(config) = model_config {
|
||||
Ok(Self::XLNet(
|
||||
XLNetForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -462,9 +477,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Longformer => {
|
||||
if let ConfigOption::Longformer(config) = config {
|
||||
Ok(TokenClassificationOption::Longformer(
|
||||
LongformerForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::Longformer(config) = model_config {
|
||||
Ok(Self::Longformer(
|
||||
LongformerForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -473,9 +488,9 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::FNet => {
|
||||
if let ConfigOption::FNet(config) = config {
|
||||
Ok(TokenClassificationOption::FNet(
|
||||
FNetForTokenClassification::new(p, config)?,
|
||||
if let ConfigOption::FNet(config) = model_config {
|
||||
Ok(Self::FNet(
|
||||
FNetForTokenClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -483,10 +498,36 @@ impl TokenClassificationOption {
|
||||
))
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
|
||||
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
|
||||
)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Token classification not implemented for {model_type:?}!"
|
||||
))),
|
||||
}
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
.encoder_path
|
||||
.ok_or(RustBertError::InvalidConfigurationError(
|
||||
"An encoder file must be provided for token classification ONNX models."
|
||||
.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this TokenClassificationOption
|
||||
@ -504,6 +545,8 @@ impl TokenClassificationOption {
|
||||
Self::XLNet(_) => ModelType::XLNet,
|
||||
Self::Longformer(_) => ModelType::Longformer,
|
||||
Self::FNet(_) => ModelType::FNet,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
@ -637,6 +680,12 @@ impl TokenClassificationOption {
|
||||
.expect("Error in fnet forward_t")
|
||||
.logits
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => model
|
||||
.forward(input_ids, mask, token_type_ids, position_ids, input_embeds)
|
||||
.expect("Error in ONNX forward pass.")
|
||||
.logits
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -646,7 +695,7 @@ pub struct TokenClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
token_sequence_classifier: TokenClassificationOption,
|
||||
label_mapping: HashMap<i64, String>,
|
||||
var_store: VarStore,
|
||||
device: Device,
|
||||
label_aggregation_function: LabelAggregationOption,
|
||||
max_length: usize,
|
||||
batch_size: usize,
|
||||
@ -720,25 +769,23 @@ impl TokenClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<TokenClassificationModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
let token_sequence_classifier = TokenClassificationOption::new(&config)?;
|
||||
|
||||
let label_aggregation_function = config.label_aggregation_function;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let max_length = model_config
|
||||
.get_max_len()
|
||||
.map(|v| v as usize)
|
||||
.unwrap_or(usize::MAX);
|
||||
let token_sequence_classifier =
|
||||
TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
let label_mapping = model_config.get_label_mapping().clone();
|
||||
let batch_size = config.batch_size;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
let device = get_device(config.model_resource, config.device);
|
||||
Ok(TokenClassificationModel {
|
||||
tokenizer,
|
||||
token_sequence_classifier,
|
||||
label_mapping,
|
||||
var_store,
|
||||
device,
|
||||
label_aggregation_function,
|
||||
max_length,
|
||||
batch_size,
|
||||
@ -815,6 +862,11 @@ impl TokenClassificationModel {
|
||||
input_ids: encoded_span.token_ids,
|
||||
offsets: encoded_span.token_offsets,
|
||||
mask: encoded_span.mask,
|
||||
token_type_ids: encoded_span
|
||||
.segment_ids
|
||||
.into_iter()
|
||||
.map(|segment_id| segment_id as i64)
|
||||
.collect(),
|
||||
reference_feature,
|
||||
example_index,
|
||||
};
|
||||
@ -923,11 +975,12 @@ impl TokenClassificationModel {
|
||||
|
||||
no_grad(|| {
|
||||
let batch_features = &mut features[start..end];
|
||||
let (input_ids, attention_masks) = self.pad_features(batch_features);
|
||||
let (input_ids, attention_masks, token_type_ids) =
|
||||
self.pad_features(batch_features);
|
||||
let output = self.token_sequence_classifier.forward_t(
|
||||
Some(&input_ids),
|
||||
Some(&attention_masks),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
@ -985,7 +1038,7 @@ impl TokenClassificationModel {
|
||||
tokens
|
||||
}
|
||||
|
||||
fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor) {
|
||||
fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor, Tensor) {
|
||||
let max_len = features
|
||||
.iter()
|
||||
.map(|feature| feature.input_ids.len())
|
||||
@ -997,8 +1050,8 @@ impl TokenClassificationModel {
|
||||
.map(|feature| &feature.input_ids)
|
||||
.map(|input| {
|
||||
let mut attention_mask = Vec::with_capacity(max_len);
|
||||
attention_mask.resize(input.len(), 1);
|
||||
attention_mask.resize(max_len, 0);
|
||||
attention_mask.resize(input.len(), 1i64);
|
||||
attention_mask.resize(max_len, 0i64);
|
||||
attention_mask
|
||||
})
|
||||
.map(|input| Tensor::from_slice(&(input)))
|
||||
@ -1011,6 +1064,9 @@ impl TokenClassificationModel {
|
||||
for feature in features.iter_mut() {
|
||||
feature.input_ids.resize(max_len, padding_index);
|
||||
feature.offsets.resize(max_len, None);
|
||||
feature
|
||||
.token_type_ids
|
||||
.resize(max_len, *feature.token_type_ids.last().unwrap_or(&0));
|
||||
feature.reference_feature.resize(max_len, false);
|
||||
}
|
||||
|
||||
@ -1019,9 +1075,15 @@ impl TokenClassificationModel {
|
||||
.map(|input| Tensor::from_slice(input.input_ids.as_slice()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
||||
(input_ids, attention_masks)
|
||||
let padded_token_type_ids = features
|
||||
.iter()
|
||||
.map(|input| Tensor::from_slice(input.token_type_ids.as_slice()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.device);
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.device);
|
||||
let token_type_ids = Tensor::stack(&padded_token_type_ids, 0).to(self.device);
|
||||
(input_ids, attention_masks, token_type_ids)
|
||||
}
|
||||
|
||||
fn decode_token(
|
||||
|
@ -25,6 +25,7 @@
|
||||
//! use tch::Device;
|
||||
//!
|
||||
//! fn main() -> anyhow::Result<()> {
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let model_resource = RemoteResource::from_pretrained(M2M100ModelResources::M2M100_418M);
|
||||
//! let config_resource = RemoteResource::from_pretrained(M2M100ConfigResources::M2M100_418M);
|
||||
//! let vocab_resource = RemoteResource::from_pretrained(M2M100VocabResources::M2M100_418M);
|
||||
@ -35,7 +36,7 @@
|
||||
//!
|
||||
//! let translation_config = TranslationConfig::new(
|
||||
//! ModelType::M2M100,
|
||||
//! model_resource,
|
||||
//! ModelResource::Torch(Box::new(model_resource)),
|
||||
//! config_resource,
|
||||
//! vocab_resource,
|
||||
//! Some(merges_resource),
|
||||
|
@ -5,6 +5,7 @@ use tch::Device;
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
pipelines::common::ModelResource,
|
||||
pipelines::translation::{TranslationConfig, TranslationModel},
|
||||
resources::ResourceProvider,
|
||||
RustBertError,
|
||||
@ -379,7 +380,7 @@ impl TranslationModelBuilder {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
translation_resources.model_type,
|
||||
translation_resources.model_resource,
|
||||
ModelResource::Torch(Box::new(translation_resources.model_resource)),
|
||||
translation_resources.config_resource,
|
||||
translation_resources.vocab_resource,
|
||||
Some(translation_resources.merges_resource),
|
||||
|
@ -1,3 +1,5 @@
|
||||
// Copyright 2018-2020 The HuggingFace Inc. team.
|
||||
// Copyright 2020 Marian Team Authors
|
||||
// Copyright 2019-2020 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -16,9 +18,10 @@ use crate::m2m_100::M2M100Generator;
|
||||
use crate::marian::MarianGenerator;
|
||||
use crate::mbart::MBartGenerator;
|
||||
use crate::nllb::NLLBGenerator;
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
|
||||
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::ONNXConditionalGenerator;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::t5::T5Generator;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -934,7 +937,7 @@ pub struct TranslationConfig {
|
||||
/// Model type used for translation
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource
|
||||
@ -993,12 +996,14 @@ impl TranslationConfig {
|
||||
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
/// MarianTargetLanguages, MarianVocabResources,
|
||||
/// };
|
||||
/// use rust_bert::pipelines::common::ModelType;
|
||||
/// use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
/// use rust_bert::pipelines::translation::TranslationConfig;
|
||||
/// use rust_bert::resources::RemoteResource;
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
|
||||
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||
/// )));
|
||||
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
|
||||
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
|
||||
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
|
||||
@ -1019,9 +1024,9 @@ impl TranslationConfig {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new<RM, RC, RV, S, T>(
|
||||
pub fn new<RC, RV, S, T>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -1030,7 +1035,6 @@ impl TranslationConfig {
|
||||
device: impl Into<Option<Device>>,
|
||||
) -> TranslationConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
S: AsRef<[Language]>,
|
||||
@ -1040,7 +1044,7 @@ impl TranslationConfig {
|
||||
|
||||
TranslationConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -1068,6 +1072,7 @@ impl TranslationConfig {
|
||||
impl From<TranslationConfig> for GenerateConfig {
|
||||
fn from(config: TranslationConfig) -> GenerateConfig {
|
||||
GenerateConfig {
|
||||
model_type: config.model_type,
|
||||
model_resource: config.model_resource,
|
||||
config_resource: config.config_resource,
|
||||
merges_resource: config.merges_resource,
|
||||
@ -1102,49 +1107,31 @@ pub enum TranslationOption {
|
||||
MBart(MBartGenerator),
|
||||
/// Translator based on M2M100 model
|
||||
M2M100(M2M100Generator),
|
||||
/// Translator based on NLLB model
|
||||
NLLB(NLLBGenerator),
|
||||
/// Translator based on ONNX model
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXConditionalGenerator),
|
||||
}
|
||||
|
||||
impl TranslationOption {
|
||||
pub fn new(config: TranslationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new(
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
|
||||
ONNXConditionalGenerator::new(config.into(), None, None)?,
|
||||
)),
|
||||
(ModelType::Marian, _) => Ok(TranslationOption::Marian(MarianGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
|
||||
ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new(
|
||||
(ModelType::T5, _) => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
|
||||
(ModelType::MBart, _) => Ok(TranslationOption::MBart(MBartGenerator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new(
|
||||
(ModelType::M2M100, _) => Ok(TranslationOption::M2M100(M2M100Generator::new(
|
||||
config.into(),
|
||||
)?)),
|
||||
ModelType::NLLB => {
|
||||
let config: GenerateConfig = config.into();
|
||||
let tokenizer = TokenizerOption::from_file(
|
||||
ModelType::NLLB,
|
||||
config.vocab_resource.get_local_path()?.to_str().unwrap(),
|
||||
Some(
|
||||
config
|
||||
.merges_resource
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
RustBertError::InvalidConfigurationError(
|
||||
"M2M100 expects a merges resources to be provided".to_string(),
|
||||
)
|
||||
})?
|
||||
.get_local_path()?
|
||||
.to_str()
|
||||
.unwrap(),
|
||||
),
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(TranslationOption::NLLB(NLLBGenerator::new_with_tokenizer(
|
||||
config, tokenizer,
|
||||
)?))
|
||||
}
|
||||
(ModelType::NLLB, _) => Ok(TranslationOption::NLLB(NLLBGenerator::new(config.into())?)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Translation not implemented for {:?}!",
|
||||
config.model_type
|
||||
@ -1156,21 +1143,25 @@ impl TranslationOption {
|
||||
config: TranslationConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<Self, RustBertError> {
|
||||
match config.model_type {
|
||||
ModelType::Marian => Ok(TranslationOption::Marian(
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
|
||||
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
|
||||
)),
|
||||
(ModelType::Marian, _) => Ok(TranslationOption::Marian(
|
||||
MarianGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new_with_tokenizer(
|
||||
(ModelType::T5, _) => Ok(TranslationOption::T5(T5Generator::new_with_tokenizer(
|
||||
config.into(),
|
||||
tokenizer,
|
||||
)?)),
|
||||
ModelType::MBart => Ok(TranslationOption::MBart(
|
||||
(ModelType::MBart, _) => Ok(TranslationOption::MBart(
|
||||
MBartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::M2M100 => Ok(TranslationOption::M2M100(
|
||||
(ModelType::M2M100, _) => Ok(TranslationOption::M2M100(
|
||||
M2M100Generator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
)),
|
||||
ModelType::NLLB => Ok(TranslationOption::NLLB(NLLBGenerator::new_with_tokenizer(
|
||||
(ModelType::NLLB, _) => Ok(TranslationOption::NLLB(NLLBGenerator::new_with_tokenizer(
|
||||
config.into(),
|
||||
tokenizer,
|
||||
)?)),
|
||||
@ -1189,190 +1180,36 @@ impl TranslationOption {
|
||||
Self::MBart(_) => ModelType::MBart,
|
||||
Self::M2M100(_) => ModelType::M2M100,
|
||||
Self::NLLB(_) => ModelType::NLLB,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to access tokenizer
|
||||
/// Returns the `Tokenizer` for this TranslationOption
|
||||
pub fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
match self {
|
||||
Self::Marian(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::T5(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::MBart(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::M2M100(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::NLLB(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::Marian(ref generator) => generator.get_tokenizer(),
|
||||
Self::T5(ref generator) => generator.get_tokenizer(),
|
||||
Self::MBart(ref generator) => generator.get_tokenizer(),
|
||||
Self::M2M100(ref generator) => generator.get_tokenizer(),
|
||||
Self::NLLB(ref generator) => generator.get_tokenizer(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref generator) => generator.get_tokenizer(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to access tokenizer
|
||||
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
match self {
|
||||
Self::Marian(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::MBart(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::M2M100(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::NLLB(model_ref) => model_ref._get_tokenizer_mut(),
|
||||
Self::Marian(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::MBart(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::M2M100(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
Self::NLLB(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_and_get_prefix_and_forced_bos_id(
|
||||
&self,
|
||||
source_language: Option<&Language>,
|
||||
target_language: Option<&Language>,
|
||||
supported_source_languages: &HashSet<Language>,
|
||||
supported_target_languages: &HashSet<Language>,
|
||||
) -> Result<(Option<String>, Option<i64>), RustBertError> {
|
||||
if let Some(source_language) = source_language {
|
||||
if !supported_source_languages.contains(source_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{source_language} not in list of supported languages: {supported_source_languages:?}",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(target_language) = target_language {
|
||||
if !supported_target_languages.contains(target_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{target_language} not in list of supported languages: {supported_target_languages:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(match *self {
|
||||
Self::Marian(_) => {
|
||||
if supported_target_languages.len() > 1 {
|
||||
(
|
||||
Some(format!(
|
||||
">>{}<< ",
|
||||
target_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
|
||||
"Missing target language for Marian \
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)",
|
||||
)))?
|
||||
)),
|
||||
None,
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
}
|
||||
}
|
||||
Self::T5(_) => (
|
||||
Some(format!(
|
||||
"translate {} to {}:",
|
||||
source_language.ok_or_else(|| RustBertError::ValueError(
|
||||
"Missing source language for T5".to_string(),
|
||||
))?,
|
||||
target_language.ok_or_else(|| RustBertError::ValueError(
|
||||
"Missing target language for T5".to_string(),
|
||||
))?,
|
||||
)),
|
||||
None,
|
||||
),
|
||||
Self::MBart(ref model) => {
|
||||
(
|
||||
Some(format!(
|
||||
">>{}<< ",
|
||||
source_language.and_then(|l| l.get_iso_639_1_code()).ok_or_else(|| RustBertError::ValueError(format!(
|
||||
"Missing source language for MBart\
|
||||
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||
need to specify target language)"
|
||||
)))?
|
||||
)),
|
||||
if let Some(target_language) = target_language {
|
||||
Some(
|
||||
model._get_tokenizer().convert_tokens_to_ids(&[format!(
|
||||
">>{}<<",
|
||||
target_language.get_iso_639_1_code().ok_or_else(|| {
|
||||
RustBertError::ValueError(format!(
|
||||
"This language has no ISO639-I code. Languages supported by model: {supported_source_languages:?}."
|
||||
))
|
||||
})?
|
||||
)])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for MBart\
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
},
|
||||
)
|
||||
}
|
||||
Self::M2M100(ref model) => (
|
||||
Some(match source_language {
|
||||
Some(value) => {
|
||||
let language_code = value.get_iso_639_1_code().ok_or_else(|| {
|
||||
RustBertError::ValueError(format!(
|
||||
"This language has no ISO639-I language code representation. \
|
||||
languages supported by the model: {supported_target_languages:?}"
|
||||
))
|
||||
})?;
|
||||
match language_code.len() {
|
||||
2 => format!(">>{language_code}.<< "),
|
||||
3 => format!(">>{language_code}<< "),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-I code".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing source language for M2M100 \
|
||||
(multiple languages supported by model: {supported_source_languages:?}, \
|
||||
need to specify target language)"
|
||||
)));
|
||||
}
|
||||
}),
|
||||
if let Some(target_language) = target_language {
|
||||
let language_code = target_language.get_iso_639_1_code().ok_or_else(|| {
|
||||
RustBertError::ValueError(format!(
|
||||
"This language has no ISO639-I language code representation. \
|
||||
languages supported by the model: {supported_target_languages:?}"
|
||||
))
|
||||
})?;
|
||||
Some(
|
||||
model._get_tokenizer().convert_tokens_to_ids(&[
|
||||
match language_code.len() {
|
||||
2 => format!(">>{language_code}.<<"),
|
||||
3 => format!(">>{language_code}<<"),
|
||||
_ => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
));
|
||||
}
|
||||
},
|
||||
])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"Missing target language for M2M100 \
|
||||
(multiple languages supported by model: {supported_target_languages:?}, \
|
||||
need to specify target language)",
|
||||
)));
|
||||
},
|
||||
),
|
||||
Self::NLLB(ref model) => {
|
||||
let source_language = source_language
|
||||
.and_then(Language::get_nllb_code)
|
||||
.map(str::to_string)
|
||||
.ok_or_else(|| RustBertError::ValueError(
|
||||
format!("Missing source language for NLLB. Need to specify one from: {supported_source_languages:?}")
|
||||
))?;
|
||||
|
||||
let target_language = target_language
|
||||
.and_then(Language::get_nllb_code)
|
||||
.map(str::to_string)
|
||||
.map(|code| model._get_tokenizer().convert_tokens_to_ids(&[code])[0])
|
||||
.ok_or_else(|| RustBertError::ValueError(
|
||||
format!("Missing target language for NLLB. Need to specify one from: {supported_target_languages:?}")
|
||||
))?;
|
||||
|
||||
(Some(source_language), Some(target_language))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Interface method to generate() of the particular models.
|
||||
pub fn generate<S>(
|
||||
&self,
|
||||
@ -1415,6 +1252,19 @@ impl TranslationOption {
|
||||
.map(|output| output.text)
|
||||
.collect()
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => {
|
||||
let generate_options =
|
||||
forced_bos_token_id.map(|forced_bos_token_id| GenerateOptions {
|
||||
forced_bos_token_id: Some(forced_bos_token_id),
|
||||
..Default::default()
|
||||
});
|
||||
model
|
||||
.generate(prompt_texts, generate_options)
|
||||
.into_iter()
|
||||
.map(|output| output.text)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1441,12 +1291,14 @@ impl TranslationModel {
|
||||
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
/// MarianTargetLanguages, MarianVocabResources,
|
||||
/// };
|
||||
/// use rust_bert::pipelines::common::ModelType;
|
||||
/// use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
||||
/// use rust_bert::resources::RemoteResource;
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
|
||||
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||
/// )));
|
||||
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
|
||||
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
|
||||
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
|
||||
@ -1496,12 +1348,14 @@ impl TranslationModel {
|
||||
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
/// MarianTargetLanguages, MarianVocabResources,
|
||||
/// };
|
||||
/// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
|
||||
/// use rust_bert::pipelines::common::{ModelResource, ModelType, TokenizerOption};
|
||||
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
|
||||
/// use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
|
||||
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ROMANCE2ENGLISH,
|
||||
/// )));
|
||||
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
|
||||
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
|
||||
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
|
||||
@ -1575,12 +1429,14 @@ impl TranslationModel {
|
||||
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
/// MarianTargetLanguages, MarianVocabResources,
|
||||
/// };
|
||||
/// use rust_bert::pipelines::common::ModelType;
|
||||
/// use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
/// use rust_bert::resources::RemoteResource;
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
|
||||
/// let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
/// MarianModelResources::ENGLISH2ROMANCE,
|
||||
/// )));
|
||||
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
|
||||
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
|
||||
/// let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
|
||||
@ -1616,12 +1472,13 @@ impl TranslationModel {
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
|
||||
source_language.into().as_ref(),
|
||||
target_language.into().as_ref(),
|
||||
&self.supported_source_languages,
|
||||
&self.supported_target_languages,
|
||||
)?;
|
||||
let (prefix, forced_bos_token_id) =
|
||||
self.model.get_tokenizer().get_prefix_and_forced_bos_id(
|
||||
source_language.into().as_ref(),
|
||||
target_language.into().as_ref(),
|
||||
&self.supported_source_languages,
|
||||
&self.supported_target_languages,
|
||||
)?;
|
||||
|
||||
Ok(match prefix {
|
||||
Some(value) => {
|
||||
@ -1648,7 +1505,9 @@ mod test {
|
||||
#[test]
|
||||
#[ignore] // no need to run, compilation is enough to verify it is Send
|
||||
fn test() {
|
||||
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
|
||||
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
)));
|
||||
let config_resource =
|
||||
RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
|
||||
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
|
||||
|
@ -105,7 +105,7 @@ use crate::deberta::DebertaForSequenceClassification;
|
||||
use crate::distilbert::DistilBertModelClassifier;
|
||||
use crate::longformer::LongformerForSequenceClassification;
|
||||
use crate::mobilebert::MobileBertForSequenceClassification;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::pipelines::common::{ConfigOption, ModelResource, ModelType, TokenizerOption};
|
||||
use crate::pipelines::sequence_classification::Label;
|
||||
use crate::resources::ResourceProvider;
|
||||
use crate::roberta::RobertaForSequenceClassification;
|
||||
@ -113,17 +113,18 @@ use crate::xlnet::XLNetForSequenceClassification;
|
||||
use crate::RustBertError;
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use rust_tokenizers::TokenizedInput;
|
||||
use std::borrow::Borrow;
|
||||
use std::ops::Deref;
|
||||
use tch::kind::Kind::{Bool, Float};
|
||||
use tch::nn::VarStore;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::{
|
||||
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
|
||||
resources::RemoteResource,
|
||||
};
|
||||
use std::ops::Deref;
|
||||
use tch::kind::Kind::{Bool, Float};
|
||||
use tch::nn::VarStore;
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
/// # Configuration for ZeroShotClassificationModel
|
||||
/// Contains information regarding the model to load and device to place the model on.
|
||||
@ -131,7 +132,7 @@ pub struct ZeroShotClassificationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||
pub model_resource: Box<dyn ResourceProvider + Send>,
|
||||
pub model_resource: ModelResource,
|
||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||
pub config_resource: Box<dyn ResourceProvider + Send>,
|
||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||
@ -159,9 +160,9 @@ impl ZeroShotClassificationConfig {
|
||||
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
pub fn new<RM, RC, RV>(
|
||||
pub fn new<RC, RV>(
|
||||
model_type: ModelType,
|
||||
model_resource: RM,
|
||||
model_resource: ModelResource,
|
||||
config_resource: RC,
|
||||
vocab_resource: RV,
|
||||
merges_resource: Option<RV>,
|
||||
@ -170,13 +171,12 @@ impl ZeroShotClassificationConfig {
|
||||
add_prefix_space: impl Into<Option<bool>>,
|
||||
) -> ZeroShotClassificationConfig
|
||||
where
|
||||
RM: ResourceProvider + Send + 'static,
|
||||
RC: ResourceProvider + Send + 'static,
|
||||
RV: ResourceProvider + Send + 'static,
|
||||
{
|
||||
ZeroShotClassificationConfig {
|
||||
model_type,
|
||||
model_resource: Box::new(model_resource),
|
||||
model_resource,
|
||||
config_resource: Box::new(config_resource),
|
||||
vocab_resource: Box::new(vocab_resource),
|
||||
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
|
||||
@ -190,13 +190,13 @@ impl ZeroShotClassificationConfig {
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl Default for ZeroShotClassificationConfig {
|
||||
/// Provides a defaultSST-2 sentiment analysis model (English)
|
||||
/// Provides a default zero-shot classification model (English)
|
||||
fn default() -> ZeroShotClassificationConfig {
|
||||
ZeroShotClassificationConfig {
|
||||
model_type: ModelType::Bart,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BartModelResources::BART_MNLI,
|
||||
)),
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BartConfigResources::BART_MNLI,
|
||||
)),
|
||||
@ -239,30 +239,38 @@ pub enum ZeroShotClassificationOption {
|
||||
XLNet(XLNetForSequenceClassification),
|
||||
/// Longformer for Sequence Classification
|
||||
Longformer(LongformerForSequenceClassification),
|
||||
/// ONNX model for Sequence Classification
|
||||
#[cfg(feature = "onnx")]
|
||||
ONNX(ONNXEncoder),
|
||||
}
|
||||
|
||||
impl ZeroShotClassificationOption {
|
||||
/// Instantiate a new zero shot classification model of the supplied type.
|
||||
/// Instantiate a new zer-shot classification model of the supplied type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
pub fn new<'p, P>(
|
||||
model_type: ModelType,
|
||||
p: P,
|
||||
config: &ConfigOption,
|
||||
) -> Result<Self, RustBertError>
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
match model_type {
|
||||
/// * `ZeroShotClassificationConfig` - Zero-shot classification pipeline configuration. The type of model created will be inferred from the
|
||||
/// `ModelResources` (Torch or ONNX) and `ModelType` (Architecture for Torch models) variants provided and
|
||||
pub fn new(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
|
||||
match config.model_resource {
|
||||
ModelResource::Torch(_) => Self::new_torch(config),
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelResource::ONNX(_) => Self::new_onnx(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_torch(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let device = config.device;
|
||||
let weights_path = config.model_resource.get_torch_local_path()?;
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config =
|
||||
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
|
||||
let model_type = config.model_type;
|
||||
let model = match model_type {
|
||||
ModelType::Bart => {
|
||||
if let ConfigOption::Bart(config) = config {
|
||||
Ok(ZeroShotClassificationOption::Bart(
|
||||
BartForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Bart(config) = model_config {
|
||||
Ok(Self::Bart(
|
||||
BartForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -271,9 +279,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Deberta => {
|
||||
if let ConfigOption::Deberta(config) = config {
|
||||
Ok(ZeroShotClassificationOption::Deberta(
|
||||
DebertaForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Deberta(config) = model_config {
|
||||
Ok(Self::Deberta(
|
||||
DebertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -282,9 +290,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
Ok(ZeroShotClassificationOption::Bert(
|
||||
BertForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(Self::Bert(
|
||||
BertForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -293,9 +301,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
Ok(ZeroShotClassificationOption::DistilBert(
|
||||
DistilBertModelClassifier::new(p, config)?,
|
||||
if let ConfigOption::DistilBert(config) = model_config {
|
||||
Ok(Self::DistilBert(
|
||||
DistilBertModelClassifier::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -304,9 +312,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::MobileBert => {
|
||||
if let ConfigOption::MobileBert(config) = config {
|
||||
Ok(ZeroShotClassificationOption::MobileBert(
|
||||
MobileBertForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::MobileBert(config) = model_config {
|
||||
Ok(Self::MobileBert(
|
||||
MobileBertForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -315,9 +323,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
Ok(ZeroShotClassificationOption::Roberta(
|
||||
RobertaForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(Self::Roberta(
|
||||
RobertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -326,9 +334,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLMRoberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
Ok(ZeroShotClassificationOption::XLMRoberta(
|
||||
RobertaForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Bert(config) = model_config {
|
||||
Ok(Self::XLMRoberta(
|
||||
RobertaForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -337,9 +345,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Albert => {
|
||||
if let ConfigOption::Albert(config) = config {
|
||||
Ok(ZeroShotClassificationOption::Albert(
|
||||
AlbertForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Albert(config) = model_config {
|
||||
Ok(Self::Albert(
|
||||
AlbertForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -348,9 +356,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::XLNet => {
|
||||
if let ConfigOption::XLNet(config) = config {
|
||||
Ok(ZeroShotClassificationOption::XLNet(
|
||||
XLNetForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::XLNet(config) = model_config {
|
||||
Ok(Self::XLNet(
|
||||
XLNetForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -359,9 +367,9 @@ impl ZeroShotClassificationOption {
|
||||
}
|
||||
}
|
||||
ModelType::Longformer => {
|
||||
if let ConfigOption::Longformer(config) = config {
|
||||
Ok(ZeroShotClassificationOption::Longformer(
|
||||
LongformerForSequenceClassification::new(p, config)?,
|
||||
if let ConfigOption::Longformer(config) = model_config {
|
||||
Ok(Self::Longformer(
|
||||
LongformerForSequenceClassification::new(var_store.root(), config)?,
|
||||
))
|
||||
} else {
|
||||
Err(RustBertError::InvalidConfigurationError(
|
||||
@ -369,10 +377,36 @@ impl ZeroShotClassificationOption {
|
||||
))
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
|
||||
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
|
||||
)),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Zero shot classification not implemented for {model_type:?}!",
|
||||
))),
|
||||
}
|
||||
}?;
|
||||
var_store.load(weights_path)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
.encoder_path
|
||||
.ok_or(RustBertError::InvalidConfigurationError(
|
||||
"An encoder file must be provided for zero-shot classification ONNX models."
|
||||
.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
@ -388,6 +422,8 @@ impl ZeroShotClassificationOption {
|
||||
Self::Albert(_) => ModelType::Albert,
|
||||
Self::XLNet(_) => ModelType::XLNet,
|
||||
Self::Longformer(_) => ModelType::Longformer,
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(_) => ModelType::ONNX,
|
||||
}
|
||||
}
|
||||
|
||||
@ -503,6 +539,18 @@ impl ZeroShotClassificationOption {
|
||||
.expect("Error in Longformer forward pass.")
|
||||
.logits
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
Self::ONNX(ref model) => model
|
||||
.forward(
|
||||
input_ids,
|
||||
mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
)
|
||||
.expect("Error in ONNX forward pass.")
|
||||
.logits
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -530,7 +578,7 @@ pub type ZeroShotTemplate = Box<dyn Fn(&str) -> String>;
|
||||
pub struct ZeroShotClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
zero_shot_classifier: ZeroShotClassificationOption,
|
||||
var_store: VarStore,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl ZeroShotClassificationModel {
|
||||
@ -600,18 +648,13 @@ impl ZeroShotClassificationModel {
|
||||
config: ZeroShotClassificationConfig,
|
||||
tokenizer: TokenizerOption,
|
||||
) -> Result<ZeroShotClassificationModel, RustBertError> {
|
||||
let config_path = config.config_resource.get_local_path()?;
|
||||
let device = config.device;
|
||||
let zero_shot_classifier = ZeroShotClassificationOption::new(&config)?;
|
||||
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let zero_shot_classifier =
|
||||
ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
|
||||
crate::resources::load_weights(&config.model_resource, &mut var_store)?;
|
||||
Ok(ZeroShotClassificationModel {
|
||||
tokenizer,
|
||||
zero_shot_classifier,
|
||||
var_store,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
@ -631,7 +674,7 @@ impl ZeroShotClassificationModel {
|
||||
labels: T,
|
||||
template: Option<ZeroShotTemplate>,
|
||||
max_len: usize,
|
||||
) -> Result<(Tensor, Tensor), RustBertError>
|
||||
) -> Result<(Tensor, Tensor, Tensor), RustBertError>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
T: AsRef<[&'a str]>,
|
||||
@ -659,7 +702,7 @@ impl ZeroShotClassificationModel {
|
||||
})
|
||||
.collect::<Vec<(&str, &str)>>();
|
||||
|
||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
|
||||
let mut tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
|
||||
text_pair_list.as_ref(),
|
||||
max_len,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -675,25 +718,35 @@ impl ZeroShotClassificationModel {
|
||||
.tokenizer
|
||||
.get_pad_id()
|
||||
.expect("The Tokenizer used for sequence classification should contain a PAD id");
|
||||
let tokenized_input_tensors = tokenized_input
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let input_ids = tokenized_input
|
||||
.iter_mut()
|
||||
.map(|input| {
|
||||
input.token_ids.resize(max_len, pad_id);
|
||||
Tensor::from_slice(&(input.token_ids))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let token_type_ids = tokenized_input
|
||||
.iter_mut()
|
||||
.map(|input| {
|
||||
input
|
||||
.segment_ids
|
||||
.resize(max_len, *input.segment_ids.last().unwrap_or(&0));
|
||||
Tensor::from_slice(&(input.segment_ids))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let tokenized_input_tensors =
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device());
|
||||
|
||||
let mask = tokenized_input_tensors
|
||||
let input_ids = Tensor::stack(input_ids.as_slice(), 0).to(self.device);
|
||||
let token_type_ids = Tensor::stack(token_type_ids.as_slice(), 0)
|
||||
.to(self.device)
|
||||
.to_kind(Kind::Int64);
|
||||
let mask = input_ids
|
||||
.ne(self
|
||||
.tokenizer
|
||||
.get_pad_id()
|
||||
.expect("The Tokenizer used for zero shot classification should contain a PAD id"))
|
||||
.to_kind(Bool);
|
||||
|
||||
Ok((tokenized_input_tensors, mask))
|
||||
Ok((input_ids, mask, token_type_ids))
|
||||
}
|
||||
|
||||
/// Zero shot classification with 1 (and exactly 1) true label.
|
||||
@ -762,14 +815,14 @@ impl ZeroShotClassificationModel {
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let num_inputs = inputs.as_ref().len();
|
||||
let (input_tensor, mask) =
|
||||
let (input_tensor, mask, token_type_ids) =
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
|
||||
|
||||
let output = no_grad(|| {
|
||||
let output = self.zero_shot_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
Some(&mask),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
@ -904,14 +957,14 @@ impl ZeroShotClassificationModel {
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let num_inputs = inputs.as_ref().len();
|
||||
let (input_tensor, mask) =
|
||||
let (input_tensor, mask, token_type_ids) =
|
||||
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
|
||||
|
||||
let output = no_grad(|| {
|
||||
let output = self.zero_shot_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
Some(&mask),
|
||||
None,
|
||||
Some(&token_type_ids),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
|
@ -24,6 +24,7 @@
|
||||
//! use tch::Device;
|
||||
//!
|
||||
//! fn main() -> anyhow::Result<()> {
|
||||
//! use rust_bert::pipelines::common::ModelResource;
|
||||
//! let config_resource = Box::new(RemoteResource::from_pretrained(
|
||||
//! ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
||||
//! ));
|
||||
@ -36,7 +37,7 @@
|
||||
//!
|
||||
//! let summarization_config = SummarizationConfig {
|
||||
//! model_type: ModelType::ProphetNet,
|
||||
//! model_resource: weights_resource,
|
||||
//! model_resource: ModelResource::Torch(weights_resource),
|
||||
//! config_resource,
|
||||
//! vocab_resource,
|
||||
//! merges_resource: None,
|
||||
|
@ -13,9 +13,8 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
@ -83,7 +82,7 @@ pub struct ProphetNetConfig {
|
||||
pub activation_dropout: f64,
|
||||
pub attention_dropout: f64,
|
||||
pub decoder_ffn_dim: i64,
|
||||
pub decoder_start_token_id: i64,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub disable_ngram_loss: bool,
|
||||
pub dropout: f64,
|
||||
pub encoder_ffn_dim: i64,
|
||||
@ -94,6 +93,8 @@ pub struct ProphetNetConfig {
|
||||
pub max_position_embeddings: i64,
|
||||
pub bos_token_id: i64,
|
||||
pub eos_token_id: i64,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub ngram: i64,
|
||||
pub id2label: Option<HashMap<i64, String>>,
|
||||
pub label2id: Option<HashMap<String, i64>>,
|
||||
@ -120,7 +121,7 @@ impl Default for ProphetNetConfig {
|
||||
activation_dropout: 0.1,
|
||||
attention_dropout: 0.1,
|
||||
decoder_ffn_dim: 4096,
|
||||
decoder_start_token_id: 0,
|
||||
decoder_start_token_id: Some(0),
|
||||
disable_ngram_loss: false,
|
||||
dropout: 0.1,
|
||||
encoder_ffn_dim: 4096,
|
||||
@ -131,6 +132,8 @@ impl Default for ProphetNetConfig {
|
||||
max_position_embeddings: 512,
|
||||
bos_token_id: 1,
|
||||
eos_token_id: 2,
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
ngram: 2,
|
||||
id2label: None,
|
||||
label2id: None,
|
||||
@ -381,7 +384,11 @@ impl ProphetNetForConditionalGeneration {
|
||||
linear_config,
|
||||
);
|
||||
|
||||
let decoder_start_token_id = config.decoder_start_token_id;
|
||||
let decoder_start_token_id = config.decoder_start_token_id.ok_or_else(|| {
|
||||
RustBertError::InvalidConfigurationError(
|
||||
"`decoder_start_token_id` must be provided for ProphetNet models".to_string(),
|
||||
)
|
||||
})?;
|
||||
let pad_token_id = config.pad_token_id;
|
||||
let ngram = config.ngram;
|
||||
|
||||
@ -919,7 +926,7 @@ impl ProphetNetConditionalGenerator {
|
||||
let pad_token_id = Some(config.pad_token_id);
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(config.decoder_start_token_id);
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(ProphetNetConditionalGenerator {
|
||||
@ -945,11 +952,11 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -972,8 +979,8 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
@ -1060,48 +1067,6 @@ impl PrivateLanguageGenerator for ProphetNetConditionalGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -79,6 +79,8 @@ pub struct ReformerConfig {
|
||||
pub chunk_size_feed_forward: Option<i64>,
|
||||
pub eos_token_id: i64,
|
||||
pub pad_token_id: i64,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub feed_forward_size: i64,
|
||||
pub hash_seed: Option<i64>,
|
||||
pub hidden_act: Activation,
|
||||
@ -106,6 +108,7 @@ pub struct ReformerConfig {
|
||||
pub label2id: Option<HashMap<String, i64>>,
|
||||
pub output_attentions: Option<bool>,
|
||||
pub output_hidden_states: Option<bool>,
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl Config for ReformerConfig {}
|
||||
@ -130,6 +133,8 @@ impl Default for ReformerConfig {
|
||||
chunk_size_feed_forward: None,
|
||||
eos_token_id: 2,
|
||||
pad_token_id: 0,
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
feed_forward_size: 512,
|
||||
hash_seed: None,
|
||||
hidden_act: Activation::gelu,
|
||||
@ -157,6 +162,7 @@ impl Default for ReformerConfig {
|
||||
label2id: None,
|
||||
output_attentions: None,
|
||||
output_hidden_states: None,
|
||||
decoder_start_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1057,7 +1063,7 @@ impl ReformerGenerator {
|
||||
let pad_token_id = Some(config.pad_token_id);
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = false;
|
||||
let decoder_start_id = None;
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
let max_position_embeddings = config.max_position_embeddings;
|
||||
|
||||
Ok(ReformerGenerator {
|
||||
@ -1083,11 +1089,11 @@ impl PrivateLanguageGenerator for ReformerGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -1110,8 +1116,8 @@ impl PrivateLanguageGenerator for ReformerGenerator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
|
@ -42,7 +42,7 @@ impl RobertaModelResources {
|
||||
/// Shared under Apache 2.0 license by the Hugging Face Inc. team at <https://huggingface.co/distilroberta-base>. Modified with conversion to C-array format.
|
||||
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
|
||||
"distilroberta-base/model",
|
||||
"https://cdn.huggingface.co/distilroberta-base-rust_model.ot",
|
||||
"https://huggingface.co/distilroberta-base/resolve/main/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at <https://huggingface.co/deepset/roberta-base-squad2>. Modified with conversion to C-array format.
|
||||
pub const ROBERTA_QA: (&'static str, &'static str) = (
|
||||
|
@ -12,10 +12,9 @@
|
||||
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use rust_tokenizers::tokenizer::TruncationStrategy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tch::nn::{embedding, LinearConfig};
|
||||
use tch::{nn, Tensor};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
use crate::pipelines::common::{ModelType, TokenizerOption};
|
||||
use crate::pipelines::generation_utils::private_generation_utils::{
|
||||
@ -132,6 +131,8 @@ pub struct T5Config {
|
||||
pub decoder_start_token_id: Option<i64>,
|
||||
pub bos_token_id: Option<i64>,
|
||||
pub eos_token_id: Option<i64>,
|
||||
pub forced_bos_token_id: Option<i64>,
|
||||
pub forced_eos_token_id: Option<i64>,
|
||||
pub initializer_factor: f64,
|
||||
pub is_encoder_decoder: Option<bool>,
|
||||
pub layer_norm_epsilon: f64,
|
||||
@ -210,6 +211,8 @@ impl Default for T5Config {
|
||||
decoder_start_token_id: None,
|
||||
bos_token_id: None,
|
||||
eos_token_id: Some(1),
|
||||
forced_bos_token_id: None,
|
||||
forced_eos_token_id: None,
|
||||
initializer_factor: 1.0,
|
||||
is_encoder_decoder: None,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
@ -770,7 +773,7 @@ impl T5Generator {
|
||||
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(0);
|
||||
let decoder_start_id = config.decoder_start_token_id;
|
||||
// T5 do not have an embedding matrix for position IDs and relies on relative positions instead
|
||||
let max_position_embeddings = i64::MAX;
|
||||
|
||||
@ -797,11 +800,11 @@ impl PrivateLanguageGenerator for T5Generator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -824,8 +827,8 @@ impl PrivateLanguageGenerator for T5Generator {
|
||||
fn get_decoder_start_id(&self) -> Option<i64> {
|
||||
self.decoder_start_id
|
||||
}
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
fn forward_t(
|
||||
&self,
|
||||
@ -906,48 +909,6 @@ impl PrivateLanguageGenerator for T5Generator {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: &[S],
|
||||
max_len: Option<i64>,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text,
|
||||
max_len
|
||||
.map(|max_len| max_len as usize)
|
||||
.unwrap_or(usize::MAX),
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|tokenized_input| tokenized_input.token_ids)
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
.into_iter()
|
||||
.map(|mut input| {
|
||||
let temp = vec![pad_token; max_len - input.len()];
|
||||
input.extend(temp);
|
||||
input
|
||||
})
|
||||
.map(|tokens| Tensor::from_slice(&tokens).to(self.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(
|
||||
&self,
|
||||
past: &mut Cache,
|
||||
|
@ -19,7 +19,7 @@
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//! use rust_bert::pipelines::common::ModelType;
|
||||
//! use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
//! use rust_bert::pipelines::generation_utils::LanguageGenerator;
|
||||
//! use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
//! use rust_bert::resources::RemoteResource;
|
||||
@ -35,7 +35,7 @@
|
||||
//! ));
|
||||
//! let generate_config = TextGenerationConfig {
|
||||
//! model_type: ModelType::XLNet,
|
||||
//! model_resource,
|
||||
//! model_resource: ModelResource::Torch(model_resource),
|
||||
//! config_resource,
|
||||
//! vocab_resource,
|
||||
//! merges_resource: None,
|
||||
|
@ -1594,11 +1594,11 @@ impl PrivateLanguageGenerator for XLNetGenerator {
|
||||
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
|
||||
&mut self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
&self.var_store
|
||||
fn get_device(&self) -> Device {
|
||||
self.var_store.device()
|
||||
}
|
||||
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
|
||||
&mut self.var_store
|
||||
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
|
||||
Ok(&mut self.var_store)
|
||||
}
|
||||
fn get_config(&self) -> &GenerateConfig {
|
||||
&self.generate_config
|
||||
@ -1622,8 +1622,8 @@ impl PrivateLanguageGenerator for XLNetGenerator {
|
||||
self.decoder_start_id
|
||||
}
|
||||
|
||||
fn get_max_positions_embeddings(&self) -> i64 {
|
||||
self.max_position_embeddings
|
||||
fn get_max_positions_embeddings(&self) -> Option<i64> {
|
||||
Some(self.max_position_embeddings)
|
||||
}
|
||||
|
||||
fn forward_t(
|
||||
|
@ -2,6 +2,7 @@ use rust_bert::bart::{
|
||||
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelResource;
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::pipelines::zero_shot_classification::{
|
||||
ZeroShotClassificationConfig, ZeroShotClassificationModel,
|
||||
@ -90,7 +91,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> {
|
||||
BartModelResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -151,7 +152,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> {
|
||||
BartModelResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
|
@ -6,7 +6,7 @@ use rust_bert::bert::{
|
||||
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
|
||||
BertModelResources, BertVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
|
||||
use rust_bert::pipelines::ner::NERModel;
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
@ -106,7 +106,9 @@ fn bert_masked_lm_pipeline() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let config = MaskedLanguageConfig::new(
|
||||
ModelType::Bert,
|
||||
RemoteResource::from_pretrained(BertModelResources::BERT),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT,
|
||||
))),
|
||||
RemoteResource::from_pretrained(BertConfigResources::BERT),
|
||||
RemoteResource::from_pretrained(BertVocabResources::BERT),
|
||||
None,
|
||||
@ -452,7 +454,9 @@ fn bert_question_answering() -> anyhow::Result<()> {
|
||||
// Set-up question answering model
|
||||
let config = QuestionAnsweringConfig {
|
||||
model_type: ModelType::Bert,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_QA,
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_QA,
|
||||
)),
|
||||
|
@ -5,7 +5,7 @@ use rust_bert::fnet::{
|
||||
FNetConfig, FNetConfigResources, FNetForMaskedLM, FNetForMultipleChoice,
|
||||
FNetForQuestionAnswering, FNetForTokenClassification, FNetModelResources, FNetVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel, SentimentPolarity};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
@ -76,9 +76,8 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
|
||||
|
||||
assert_eq!("▁one", word_1);
|
||||
assert_eq!("▁the", word_2);
|
||||
let value = (f64::try_from(model_output.prediction_scores.get(0).get(4).max()).unwrap()
|
||||
- 13.1721)
|
||||
.abs();
|
||||
let value =
|
||||
(f64::try_from(model_output.prediction_scores.get(0).get(4).max())? - 13.1721).abs();
|
||||
dbg!(value);
|
||||
assert!(value < 1e-3);
|
||||
Ok(())
|
||||
@ -93,9 +92,9 @@ fn fnet_for_sequence_classification() -> anyhow::Result<()> {
|
||||
let vocab_resource = Box::new(RemoteResource::from_pretrained(
|
||||
FNetVocabResources::BASE_SST2,
|
||||
));
|
||||
let model_resource = Box::new(RemoteResource::from_pretrained(
|
||||
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
FNetModelResources::BASE_SST2,
|
||||
));
|
||||
)));
|
||||
|
||||
let sentiment_config = SentimentConfig {
|
||||
model_type: ModelType::FNet,
|
||||
|
@ -2,7 +2,7 @@ use rust_bert::gpt2::{
|
||||
GPT2Generator, GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources,
|
||||
Gpt2ModelResources, Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::conversation::{
|
||||
ConversationConfig, ConversationManager, ConversationModel,
|
||||
};
|
||||
@ -107,7 +107,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -139,7 +139,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -183,7 +183,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -240,7 +240,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -296,7 +296,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -369,7 +369,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: Some(56),
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -419,7 +419,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: Some(36),
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -485,7 +485,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: Some(36),
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -566,7 +566,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: Some(32),
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -616,7 +616,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: Some(16),
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -672,7 +672,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: Some(16),
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
|
@ -7,6 +7,7 @@ use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use std::convert::TryFrom;
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
/// Equivalent Python code:
|
||||
@ -105,7 +106,7 @@ fn gpt_j_correctness() -> anyhow::Result<()> {
|
||||
Tensor::from_slice(
|
||||
&input
|
||||
.iter()
|
||||
.map(|&e| i64::from(e != pad_token))
|
||||
.map(|&e| i64::try_from(e != pad_token).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.to(device)
|
||||
|
@ -2,7 +2,7 @@ use rust_bert::gpt_neo::{
|
||||
GptNeoConfig, GptNeoConfigResources, GptNeoForCausalLM, GptNeoMergesResources,
|
||||
GptNeoModelResources, GptNeoVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
@ -125,7 +125,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let generation_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPTNeo,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
|
@ -7,7 +7,7 @@ use rust_bert::longformer::{
|
||||
LongformerForTokenClassification, LongformerMergesResources, LongformerModelResources,
|
||||
LongformerVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
};
|
||||
@ -117,12 +117,12 @@ fn longformer_masked_lm() -> anyhow::Result<()> {
|
||||
.prediction_scores
|
||||
.get(0)
|
||||
.get(4)
|
||||
.double_value(&[i64::try_from(&index_1).unwrap()]);
|
||||
.double_value(&[i64::try_from(&index_1)?]);
|
||||
let score_2 = model_output
|
||||
.prediction_scores
|
||||
.get(1)
|
||||
.get(7)
|
||||
.double_value(&[i64::try_from(&index_2).unwrap()]);
|
||||
.double_value(&[i64::try_from(&index_2)?]);
|
||||
|
||||
assert_eq!("Ġeye", word_1); // Outputs "person" : "Looks like one [eye] is missing"
|
||||
assert_eq!("Ġsunny", word_2); // Outputs "pear" : "It was a nice and [sunny] day"
|
||||
@ -384,7 +384,9 @@ fn longformer_for_question_answering() -> anyhow::Result<()> {
|
||||
// Set-up Question Answering model
|
||||
let config = QuestionAnsweringConfig::new(
|
||||
ModelType::Longformer,
|
||||
RemoteResource::from_pretrained(LongformerModelResources::LONGFORMER_BASE_SQUAD1),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
|
||||
))),
|
||||
RemoteResource::from_pretrained(LongformerConfigResources::LONGFORMER_BASE_SQUAD1),
|
||||
RemoteResource::from_pretrained(LongformerVocabResources::LONGFORMER_BASE_SQUAD1),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
|
@ -1,5 +1,5 @@
|
||||
use rust_bert::longt5::{LongT5ConfigResources, LongT5ModelResources, LongT5VocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
|
||||
@ -8,9 +8,9 @@ fn test_summarization_longt5() -> anyhow::Result<()> {
|
||||
// Set-up translation model
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_type: ModelType::LongT5,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
LongT5ModelResources::TGLOBAL_BASE_BOOK_SUMMARY,
|
||||
)),
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
LongT5ConfigResources::TGLOBAL_BASE_BOOK_SUMMARY,
|
||||
)),
|
||||
|
@ -2,7 +2,7 @@ use rust_bert::m2m_100::{
|
||||
M2M100Config, M2M100ConfigResources, M2M100MergesResources, M2M100Model, M2M100ModelResources,
|
||||
M2M100SourceLanguages, M2M100TargetLanguages, M2M100VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::Config;
|
||||
@ -78,7 +78,7 @@ fn m2m100_translation() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::M2M100,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
Some(merges_resource),
|
||||
|
@ -2,7 +2,7 @@ use rust_bert::marian::{
|
||||
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
|
||||
MarianTargetLanguages, MarianVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{
|
||||
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
|
||||
};
|
||||
@ -23,7 +23,7 @@ fn test_translation() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::Marian,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
Some(merges_resource),
|
||||
|
@ -74,12 +74,12 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
|
||||
.logits
|
||||
.get(0)
|
||||
.get(4)
|
||||
.double_value(&[i64::try_from(&index_1).unwrap()]);
|
||||
.double_value(&[i64::try_from(&index_1)?]);
|
||||
let score_2 = model_output
|
||||
.logits
|
||||
.get(1)
|
||||
.get(7)
|
||||
.double_value(&[i64::try_from(&index_2).unwrap()]);
|
||||
.double_value(&[i64::try_from(&index_2)?]);
|
||||
|
||||
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||
assert_eq!("sunny", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
|
||||
|
@ -1,14 +1,17 @@
|
||||
use rust_bert::nllb::{
|
||||
NLLBConfigResources, NLLBLanguages, NLLBMergeResources, NLLBResources, NLLBVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
|
||||
#[test]
|
||||
// #[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn nllb_translation() -> anyhow::Result<()> {
|
||||
let model_resource = RemoteResource::from_pretrained(NLLBResources::NLLB_600M_DISTILLED);
|
||||
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
NLLBResources::NLLB_600M_DISTILLED,
|
||||
)));
|
||||
let config_resource = RemoteResource::from_pretrained(NLLBConfigResources::NLLB_600M_DISTILLED);
|
||||
let vocab_resource = RemoteResource::from_pretrained(NLLBVocabResources::NLLB_600M_DISTILLED);
|
||||
let merges_resource = RemoteResource::from_pretrained(NLLBMergeResources::NLLB_600M_DISTILLED);
|
||||
|
295
tests/onnx.rs
Normal file
295
tests/onnx.rs
Normal file
@ -0,0 +1,295 @@
|
||||
#[cfg(feature = "onnx")]
|
||||
mod tests {
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::m2m_100::{M2M100SourceLanguages, M2M100TargetLanguages};
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType, ONNXModelResources};
|
||||
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
|
||||
use rust_bert::pipelines::ner::NERModel;
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
};
|
||||
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
|
||||
use rust_bert::pipelines::sequence_classification::SequenceClassificationConfig;
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::pipelines::token_classification::{
|
||||
LabelAggregationOption, TokenClassificationConfig,
|
||||
};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
|
||||
#[test]
|
||||
fn onnx_masked_lm() -> anyhow::Result<()> {
|
||||
let masked_lm = MaskedLanguageModel::new(MaskedLanguageConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/model.onnx",
|
||||
"onnx-bert-base-uncased-for-masked-lm",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/config.json",
|
||||
"onnx-bert-base-uncased-for-masked-lm",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-uncased-for-masked-lm/resolve/main/vocab.txt",
|
||||
"onnx-bert-base-uncased-for-masked-lm",
|
||||
),
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
Some(String::from("<mask>")),
|
||||
))?;
|
||||
let input = [
|
||||
"Hello I am a <mask> student",
|
||||
"Paris is the <mask> of France. It is <mask> in Europe.",
|
||||
];
|
||||
let output = masked_lm.predict(input)?;
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0].len(), 1);
|
||||
assert_eq!(output[0][0].text, "university");
|
||||
assert_eq!(output[0][0].id, 2755);
|
||||
assert!((output[0][0].score - 10.0135).abs() < 1e-4);
|
||||
assert_eq!(output[1].len(), 2);
|
||||
assert_eq!(output[1][0].text, "capital");
|
||||
assert_eq!(output[1][0].id, 2364);
|
||||
assert!((output[1][0].score - 19.4008).abs() < 1e-4);
|
||||
assert_eq!(output[1][1].text, "located");
|
||||
assert_eq!(output[1][1].id, 1388);
|
||||
assert!((output[1][1].score - 10.8547).abs() < 1e-4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onnx_question_answering() -> anyhow::Result<()> {
|
||||
let qa_model = QuestionAnsweringModel::new(QuestionAnsweringConfig::new(
|
||||
ModelType::Roberta,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/model.onnx",
|
||||
"onnx-roberta-base-squad2",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/config.json",
|
||||
"onnx-roberta-base-squad2",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/vocab.json",
|
||||
"onnx-roberta-base-squad2",
|
||||
),
|
||||
Some(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/roberta-base-squad2/resolve/main/merges.txt",
|
||||
"onnx-roberta-base-squad2",
|
||||
)),
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
))?;
|
||||
let question = String::from("Where does Amy live ?");
|
||||
let context = String::from("Amy lives in Amsterdam");
|
||||
let qa_input = QaInput { question, context };
|
||||
|
||||
let output = qa_model.predict(&[qa_input], 1, 32);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0].len(), 1);
|
||||
assert_eq!(output[0][0].answer, " Amsterdam");
|
||||
assert!((output[0][0].score - 0.9898).abs() < 1e-4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onnx_sequence_classification() -> anyhow::Result<()> {
|
||||
let classification_model = SentimentModel::new(SequenceClassificationConfig::new(
|
||||
ModelType::DistilBert,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/model.onnx",
|
||||
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
|
||||
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
|
||||
"onnx-distilbert-base-uncased-finetuned-sst-2-english",
|
||||
),
|
||||
None,
|
||||
true,
|
||||
None,
|
||||
None,
|
||||
))?;
|
||||
let input = [
|
||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
||||
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||
];
|
||||
let output = classification_model.predict(input);
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(output[0].polarity, SentimentPolarity::Positive);
|
||||
assert!((output[0].score - 0.9981).abs() < 1e-4);
|
||||
assert_eq!(output[1].polarity, SentimentPolarity::Negative);
|
||||
assert!((output[1].score - 0.9927).abs() < 1e-4);
|
||||
assert_eq!(output[2].polarity, SentimentPolarity::Positive);
|
||||
assert!((output[2].score - 0.9997).abs() < 1e-4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onnx_token_classification() -> anyhow::Result<()> {
|
||||
let token_classification_model = NERModel::new(TokenClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-NER/resolve/main/model.onnx",
|
||||
"onnx-bert-base-NER",
|
||||
))),
|
||||
..Default::default()
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-NER/resolve/main/config.json",
|
||||
"onnx-bert-base-NER",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/bert-base-NER/resolve/main/vocab.txt",
|
||||
"onnx-bert-base-NER",
|
||||
),
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
LabelAggregationOption::First,
|
||||
))?;
|
||||
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
|
||||
let output = token_classification_model.predict_full_entities(&input);
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0].len(), 2);
|
||||
assert_eq!(output[0][0].word, "John Smith");
|
||||
assert_eq!(output[0][0].label, "PER");
|
||||
assert!((output[0][0].score - 0.9992).abs() < 1e-4);
|
||||
assert_eq!(output[0][1].word, "Acme Corp");
|
||||
assert_eq!(output[0][1].label, "ORG");
|
||||
assert!((output[0][1].score - 0.0001).abs() < 1e-4);
|
||||
assert_eq!(output[1].len(), 1);
|
||||
assert_eq!(output[1][0].word, "New York");
|
||||
assert_eq!(output[1][0].label, "LOC");
|
||||
assert!((output[1][0].score - 0.9987).abs() < 1e-4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onnx_text_generation() -> anyhow::Result<()> {
|
||||
let text_generation_model = TextGenerationModel::new(TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
model_resource: ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: None,
|
||||
decoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_model.onnx",
|
||||
"onnx-gpt2",
|
||||
))),
|
||||
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/gpt2/resolve/main/decoder_with_past_model.onnx",
|
||||
"onnx-gpt2",
|
||||
))),
|
||||
}),
|
||||
config_resource: Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/gpt2/resolve/main/config.json",
|
||||
"onnx-gpt2",
|
||||
)),
|
||||
vocab_resource: Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/gpt2/resolve/main/vocab.json",
|
||||
"onnx-gpt2",
|
||||
)),
|
||||
merges_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/gpt2/resolve/main/merges.txt",
|
||||
"onnx-gpt2",
|
||||
))),
|
||||
max_length: Some(30),
|
||||
do_sample: false,
|
||||
num_beams: 1,
|
||||
temperature: 1.0,
|
||||
num_return_sequences: 1,
|
||||
..Default::default()
|
||||
})?;
|
||||
let prompts = ["It was a very nice and sunny"];
|
||||
let output = text_generation_model.generate(&prompts, None);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "It was a very nice and sunny day. I was very happy with the weather. I was very happy with the weather. I was very happy with");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn onnx_translation() -> anyhow::Result<()> {
|
||||
let translation_model = TranslationModel::new(TranslationConfig::new(
|
||||
ModelType::M2M100,
|
||||
ModelResource::ONNX(ONNXModelResources {
|
||||
encoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/encoder_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
decoder_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
decoder_with_past_resource: Some(Box::new(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/decoder_with_past_model.onnx",
|
||||
"onnx-m2m100_418M",
|
||||
))),
|
||||
}),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/config.json",
|
||||
"onnx-m2m100_418M",
|
||||
),
|
||||
RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/vocab.json",
|
||||
"onnx-m2m100_418M",
|
||||
),
|
||||
Some(RemoteResource::new(
|
||||
"https://huggingface.co/optimum/m2m100_418M/resolve/main/sentencepiece.bpe.model",
|
||||
"onnx-m2m100_418M",
|
||||
)),
|
||||
M2M100SourceLanguages::M2M100_418M,
|
||||
M2M100TargetLanguages::M2M100_418M,
|
||||
Device::cuda_if_available(),
|
||||
))?;
|
||||
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(translation_model.translate(
|
||||
&[source_sentence],
|
||||
Language::English,
|
||||
Language::French,
|
||||
)?);
|
||||
outputs.extend(translation_model.translate(
|
||||
&[source_sentence],
|
||||
Language::English,
|
||||
Language::Spanish,
|
||||
)?);
|
||||
outputs.extend(translation_model.translate(
|
||||
&[source_sentence],
|
||||
Language::English,
|
||||
Language::Hindi,
|
||||
)?);
|
||||
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
outputs[0],
|
||||
" Cette phrase sera traduite en plusieurs langues."
|
||||
);
|
||||
assert_eq!(outputs[1], " Esta frase se traducirá en varios idiomas.");
|
||||
assert_eq!(outputs[2], " यह वाक्यांश कई भाषाओं में अनुवादित किया जाएगा।");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -2,7 +2,7 @@ use rust_bert::openai_gpt::{
|
||||
OpenAIGPTLMHeadModel, OpenAiGptConfig, OpenAiGptConfigResources, OpenAiGptMergesResources,
|
||||
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::generation_utils::Cache;
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
@ -119,7 +119,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -161,7 +161,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -214,7 +214,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
@ -283,7 +283,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: Some(merges_resource),
|
||||
|
@ -1,7 +1,7 @@
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
|
||||
use rust_bert::pegasus::{PegasusConfigResources, PegasusModelResources, PegasusVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
use tch::Device;
|
||||
|
||||
@ -20,7 +20,7 @@ fn pegasus_summarization_greedy() -> anyhow::Result<()> {
|
||||
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_type: ModelType::Pegasus,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::prophetnet::{
|
||||
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
|
||||
};
|
||||
@ -22,7 +22,7 @@ fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
|
||||
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_type: ModelType::ProphetNet,
|
||||
model_resource: weights_resource,
|
||||
model_resource: ModelResource::Torch(weights_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::reformer::{
|
||||
ReformerConfig, ReformerConfigResources, ReformerForQuestionAnswering,
|
||||
@ -45,7 +45,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
|
||||
// Set-up translation model
|
||||
let generation_config = TextGenerationConfig {
|
||||
model_type: ModelType::Reformer,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::ner::NERModel;
|
||||
use rust_bert::pipelines::question_answering::{
|
||||
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
|
||||
@ -319,7 +319,9 @@ fn roberta_question_answering() -> anyhow::Result<()> {
|
||||
// Set-up question answering model
|
||||
let config = QuestionAnsweringConfig::new(
|
||||
ModelType::Roberta,
|
||||
RemoteResource::from_pretrained(RobertaModelResources::ROBERTA_QA),
|
||||
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA_QA,
|
||||
))),
|
||||
RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA_QA),
|
||||
RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA_QA),
|
||||
Some(RemoteResource::from_pretrained(
|
||||
@ -354,9 +356,9 @@ fn xlm_roberta_german_ner() -> anyhow::Result<()> {
|
||||
// Set-up question answering model
|
||||
let ner_config = TokenClassificationConfig {
|
||||
model_type: ModelType::XLMRoberta,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::XLM_ROBERTA_NER_DE,
|
||||
)),
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
||||
)),
|
||||
|
@ -1,4 +1,4 @@
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use rust_bert::resources::RemoteResource;
|
||||
@ -26,7 +26,7 @@ fn test_translation_t5() -> anyhow::Result<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(
|
||||
ModelType::T5,
|
||||
model_resource,
|
||||
ModelResource::Torch(Box::new(model_resource)),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
None,
|
||||
@ -65,7 +65,9 @@ fn test_summarization_t5() -> anyhow::Result<()> {
|
||||
// Set-up translation model
|
||||
let summarization_config = SummarizationConfig {
|
||||
model_type: ModelType::T5,
|
||||
model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
|
||||
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
|
||||
T5ModelResources::T5_SMALL,
|
||||
))),
|
||||
config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
|
||||
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
|
||||
merges_resource: None,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::common::{ModelResource, ModelType};
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
use rust_bert::resources::{RemoteResource, ResourceProvider};
|
||||
use rust_bert::xlnet::{
|
||||
@ -208,7 +208,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
|
||||
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::XLNet,
|
||||
model_resource,
|
||||
model_resource: ModelResource::Torch(model_resource),
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource: None,
|
||||
|
Loading…
Reference in New Issue
Block a user