rust-bert/examples/sentiment_analysis_fnet.rs
guillaume-be 540c9268e7
ONNX Support (#346)
* Fixed Clippy warnings

* Revert "Shallow clone optimization (#243)"

This reverts commit ba584653bc.

* updated dependencies

* tryouts

* GPT2 tryouts

* WIP GPT2

* input mapping

* Cache storage

* Initial GPT2 prototype

* Initial ONNX Config and decoder implementation

* ONNXDecoder first draft

* Use Decoders in example

* Automated tch-ort conversion, decoder implementation

* ONNXCausalDecoder implementation

* Refactored _get_var_store to be optional, added get_device to gen trait

* updated example

* Added decoder_start_token_id to ConfigOption

* Addition of ONNXModelConfig, make max_position_embeddigs optional

* Addition of forward pass function for ONNXModel

* Working ONNX causal decoder

* Simplify tensor conversion

* refactor translation to facilitate ONNX integration

* Implementation of ONNXEncoder

* Implementation of ONNXConditionalGenerator

* working ONNXCausalGenerator

* - Reworked model resources type for pipelines and generators

* Aligned ONNXConditionalGenerator with other generators to use GenerateConfig for creation

* Moved force_token_id_generation to common utils function, fixed tests, Translation implementation

* generalized forced_bos and forced_eos tokens generation

* Aligned the `encode_prompt_text` method across language models

* Fix prompt encoding for causal generation

* Fix prompt encoding for causal generation

* Support for ONNX models for SequenceClassification

* Support for ONNX models for TokenClassification

* Support for ONNX models for POS and NER pipelines

* Support for ONNX models for ZeroShotClassification pipeline

* Support for ONNX models for QuestionAnswering pipeline

* Support for ONNX models for MaskedLM pipeline

* Added token_type_ids , updated layer cache i/o parsing for ONNX pipelines

* Support for ONNX models for TextGenerationPipeline, updated examples for remote resources

* Remove ONNX zero-shot classification example (lack of correct pretrained model)

* Addition of tests for ONNX pipelines support

* Made onnx feature optional

* Fix device lookup with onnx feature enabled

* Updates from main branch

* Flexible tokenizer creation for M2M100 (NLLB support), make NLLB test optional du to their size

* Fixed Clippy warnings

* Addition of documentation for ONNX

* Added documentation for ONNX support

* upcoming tch 1.12 fixes

* Fix merge conflicts

* Fix merge conflicts (2)

* Add download libtorch feature to ONNX tests

* Add download-onnx feature

* attempt to enable onnx download

* add remote resources feature

* onnx download

* pin ort version

* Update ort version
2023-05-30 07:20:25 +01:00

57 lines
2.3 KiB
Rust

// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::fnet::{FNetConfigResources, FNetModelResources, FNetVocabResources};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::sentiment::{SentimentConfig, SentimentModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up classifier
let config_resource = Box::new(RemoteResource::from_pretrained(
FNetConfigResources::BASE_SST2,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
FNetVocabResources::BASE_SST2,
));
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
FNetModelResources::BASE_SST2,
)));
let sentiment_config = SentimentConfig {
model_type: ModelType::FNet,
model_resource,
config_resource,
vocab_resource,
..Default::default()
};
let sentiment_classifier = SentimentModel::new(sentiment_config)?;
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
let output = sentiment_classifier.predict(input);
for sentiment in output {
println!("{sentiment:?}");
}
Ok(())
}