rust-bert/examples/generation_gptj.rs

100 lines
3.3 KiB
Rust
Raw Normal View History

use std::path::PathBuf;
use rust_bert::gpt_j::{GptJConfigResources, GptJMergesResources, GptJVocabResources};
ONNX Support (#346) * Fixed Clippy warnings * Revert "Shallow clone optimization (#243)" This reverts commit ba584653bc8d563b8991b3ef6aa4e25d545b0ef3. * updated dependencies * tryouts * GPT2 tryouts * WIP GPT2 * input mapping * Cache storage * Initial GPT2 prototype * Initial ONNX Config and decoder implementation * ONNXDecoder first draft * Use Decoders in example * Automated tch-ort conversion, decoder implementation * ONNXCausalDecoder implementation * Refactored _get_var_store to be optional, added get_device to gen trait * updated example * Added decoder_start_token_id to ConfigOption * Addition of ONNXModelConfig, make max_position_embeddigs optional * Addition of forward pass function for ONNXModel * Working ONNX causal decoder * Simplify tensor conversion * refactor translation to facilitate ONNX integration * Implementation of ONNXEncoder * Implementation of ONNXConditionalGenerator * working ONNXCausalGenerator * - Reworked model resources type for pipelines and generators * Aligned ONNXConditionalGenerator with other generators to use GenerateConfig for creation * Moved force_token_id_generation to common utils function, fixed tests, Translation implementation * generalized forced_bos and forced_eos tokens generation * Aligned the `encode_prompt_text` method across language models * Fix prompt encoding for causal generation * Fix prompt encoding for causal generation * Support for ONNX models for SequenceClassification * Support for ONNX models for TokenClassification * Support for ONNX models for POS and NER pipelines * Support for ONNX models for ZeroShotClassification pipeline * Support for ONNX models for QuestionAnswering pipeline * Support for ONNX models for MaskedLM pipeline * Added token_type_ids , updated layer cache i/o parsing for ONNX pipelines * Support for ONNX models for TextGenerationPipeline, updated examples for remote resources * Remove ONNX zero-shot classification example (lack of correct pretrained model) * Addition of tests for ONNX pipelines support * Made onnx feature optional * Fix device lookup with onnx feature enabled * Updates from main branch * Flexible tokenizer creation for M2M100 (NLLB support), make NLLB test optional du to their size * Fixed Clippy warnings * Addition of documentation for ONNX * Added documentation for ONNX support * upcoming tch 1.12 fixes * Fix merge conflicts * Fix merge conflicts (2) * Add download libtorch feature to ONNX tests * Add download-onnx feature * attempt to enable onnx download * add remote resources feature * onnx download * pin ort version * Update ort version
2023-05-30 09:20:25 +03:00
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
use rust_bert::resources::{LocalResource, RemoteResource};
use tch::Device;
/// Equivalent Python code:
///
/// ```python
/// import torch
/// from transformers import AutoTokenizer, GPTJForCausalLM
///
/// device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
///
/// model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16).to(device)
///
/// tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", padding_side="left")
/// tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
///
/// prompts = ["It was a very nice and sunny", "It was a gloom winter night, and"]
/// inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
///
/// with torch.no_grad():
/// gen_tokens = model.generate(
/// **inputs,
/// min_length=0,
/// max_length=32,
/// do_sample=False,
/// early_stopping=True,
/// num_beams=1,
/// num_return_sequences=1
/// )
///
/// gen_texts = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
/// ````
///
/// To run this test you need to download `pytorch_model.bin` from [EleutherAI GPT-J 6B
/// (float16)][gpt-j-6B-float16] and then convert its weights with:
///
/// ```
/// python utils/convert_model.py resources/gpt-j-6B-float16/pytorch_model.bin
/// ```
///
/// [gpt-j-6B-float16]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Box::new(RemoteResource::from_pretrained(
GptJConfigResources::GPT_J_6B_FLOAT16,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
GptJVocabResources::GPT_J_6B_FLOAT16,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
GptJMergesResources::GPT_J_6B_FLOAT16,
));
let model_resource = Box::new(LocalResource::from(PathBuf::from(
"resources/gpt-j-6B-float16/rust_model.ot",
)));
// Set-up model
let generation_config = TextGenerationConfig {
model_type: ModelType::GPTJ,
ONNX Support (#346) * Fixed Clippy warnings * Revert "Shallow clone optimization (#243)" This reverts commit ba584653bc8d563b8991b3ef6aa4e25d545b0ef3. * updated dependencies * tryouts * GPT2 tryouts * WIP GPT2 * input mapping * Cache storage * Initial GPT2 prototype * Initial ONNX Config and decoder implementation * ONNXDecoder first draft * Use Decoders in example * Automated tch-ort conversion, decoder implementation * ONNXCausalDecoder implementation * Refactored _get_var_store to be optional, added get_device to gen trait * updated example * Added decoder_start_token_id to ConfigOption * Addition of ONNXModelConfig, make max_position_embeddigs optional * Addition of forward pass function for ONNXModel * Working ONNX causal decoder * Simplify tensor conversion * refactor translation to facilitate ONNX integration * Implementation of ONNXEncoder * Implementation of ONNXConditionalGenerator * working ONNXCausalGenerator * - Reworked model resources type for pipelines and generators * Aligned ONNXConditionalGenerator with other generators to use GenerateConfig for creation * Moved force_token_id_generation to common utils function, fixed tests, Translation implementation * generalized forced_bos and forced_eos tokens generation * Aligned the `encode_prompt_text` method across language models * Fix prompt encoding for causal generation * Fix prompt encoding for causal generation * Support for ONNX models for SequenceClassification * Support for ONNX models for TokenClassification * Support for ONNX models for POS and NER pipelines * Support for ONNX models for ZeroShotClassification pipeline * Support for ONNX models for QuestionAnswering pipeline * Support for ONNX models for MaskedLM pipeline * Added token_type_ids , updated layer cache i/o parsing for ONNX pipelines * Support for ONNX models for TextGenerationPipeline, updated examples for remote resources * Remove ONNX zero-shot classification example (lack of correct pretrained model) * Addition of tests for ONNX pipelines support * Made onnx feature optional * Fix device lookup with onnx feature enabled * Updates from main branch * Flexible tokenizer creation for M2M100 (NLLB support), make NLLB test optional du to their size * Fixed Clippy warnings * Addition of documentation for ONNX * Added documentation for ONNX support * upcoming tch 1.12 fixes * Fix merge conflicts * Fix merge conflicts (2) * Add download libtorch feature to ONNX tests * Add download-onnx feature * attempt to enable onnx download * add remote resources feature * onnx download * pin ort version * Update ort version
2023-05-30 09:20:25 +03:00
model_resource: ModelResource::Torch(model_resource),
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: Some(32),
do_sample: false,
early_stopping: true,
num_beams: 1,
num_return_sequences: 1,
device: Device::cuda_if_available(),
..Default::default()
};
let model = TextGenerationModel::new(generation_config)?;
// Generate text
let prompts = [
"It was a very nice and sunny",
"It was a gloom winter night, and",
];
let output = model.generate(&prompts, None)?;
assert_eq!(output.len(), 2);
assert_eq!(output[0], "It was a very nice and sunny day, and I was sitting in the garden of my house, enjoying the sun and the fresh air. I was thinking");
assert_eq!(output[1], "It was a gloom winter night, and the wind was howling. The snow was falling, and the temperature was dropping. The snow was coming down so hard");
Ok(())
}