From c448862185b1836ad7c861b15ea20f3b4fdbc0d9 Mon Sep 17 00:00:00 2001 From: Romain Leroux Date: Wed, 15 Feb 2023 20:10:47 +0100 Subject: [PATCH] Add GPT-J support (#285) (#288) * Add GPT-J support (#285) * Improve GPT-J implementation * Improve GPT-J tests * Adapt GPT-J to latest master branch * Specify how to convert GPT-J weights instead of providing them --- .github/workflows/continuous-integration.yml | 3 +- examples/generation_gptj.rs | 100 +++ src/gpt_j/attention.rs | 323 +++++++ src/gpt_j/gpt_j_model.rs | 836 +++++++++++++++++++ src/gpt_j/mod.rs | 58 ++ src/gpt_j/transformer.rs | 131 +++ src/lib.rs | 5 +- src/pipelines/common.rs | 19 +- src/pipelines/generation_utils.rs | 2 + src/pipelines/text_generation.rs | 16 + tests/gpt_j.rs | 172 ++++ utils/convert_model.py | 66 +- 12 files changed, 1704 insertions(+), 27 deletions(-) create mode 100644 examples/generation_gptj.rs create mode 100644 src/gpt_j/attention.rs create mode 100644 src/gpt_j/gpt_j_model.rs create mode 100644 src/gpt_j/mod.rs create mode 100644 src/gpt_j/transformer.rs create mode 100644 tests/gpt_j.rs diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index e9ba0e2..1cc7e1d 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -131,6 +131,7 @@ jobs: args: --package rust-bert --test sentence_embeddings --test longt5 + --test gpt_j convert-model: name: Model conversion test @@ -179,4 +180,4 @@ jobs: - uses: actions-rs/cargo@v1 with: command: clippy - args: --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern -A clippy::upper-case-acronyms \ No newline at end of file + args: --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern -A clippy::upper-case-acronyms diff --git a/examples/generation_gptj.rs b/examples/generation_gptj.rs new file mode 100644 index 0000000..391d53d --- /dev/null +++ b/examples/generation_gptj.rs @@ -0,0 +1,100 @@ +use std::path::PathBuf; + +use rust_bert::gpt_j::{GptJConfigResources, GptJMergesResources, GptJVocabResources}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; +use rust_bert::resources::{LocalResource, RemoteResource}; +use tch::Device; + +/// Equivalent Python code: +/// +/// ```python +/// import torch +/// from transformers import AutoTokenizer, GPTJForCausalLM +/// +/// device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +/// +/// model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16).to(device) +/// +/// tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", padding_side="left") +/// tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token}) +/// +/// prompts = ["It was a very nice and sunny", "It was a gloom winter night, and"] +/// inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device) +/// +/// with torch.no_grad(): +/// gen_tokens = model.generate( +/// **inputs, +/// min_length=0, +/// max_length=32, +/// do_sample=False, +/// early_stopping=True, +/// num_beams=1, +/// num_return_sequences=1 +/// ) +/// +/// gen_texts = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) +/// ```` +/// +/// To run this test you need to download `pytorch_model.bin` from [EleutherAI GPT-J 6B +/// (float16)][gpt-j-6B-float16] and then convert its weights with: +/// +/// ``` +/// python utils/convert_model.py resources/gpt-j-6B-float16/pytorch_model.bin +/// ``` +/// +/// [gpt-j-6B-float16]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16 +/// +fn main() -> anyhow::Result<()> { + // Resources paths + + let config_resource = Box::new(RemoteResource::from_pretrained( + GptJConfigResources::GPT_J_6B_FLOAT16, + )); + + let vocab_resource = Box::new(RemoteResource::from_pretrained( + GptJVocabResources::GPT_J_6B_FLOAT16, + )); + + let merges_resource = Box::new(RemoteResource::from_pretrained( + GptJMergesResources::GPT_J_6B_FLOAT16, + )); + + let model_resource = Box::new(LocalResource::from(PathBuf::from( + "resources/gpt-j-6B-float16/rust_model.ot", + ))); + + // Set-up model + + let generation_config = TextGenerationConfig { + model_type: ModelType::GPTJ, + model_resource, + config_resource, + vocab_resource, + merges_resource: Some(merges_resource), + min_length: 10, + max_length: Some(32), + do_sample: false, + early_stopping: true, + num_beams: 1, + num_return_sequences: 1, + device: Device::cuda_if_available(), + ..Default::default() + }; + + let model = TextGenerationModel::new(generation_config)?; + + // Generate text + + let prompts = [ + "It was a very nice and sunny", + "It was a gloom winter night, and", + ]; + let output = model.generate(&prompts, None); + + assert_eq!(output.len(), 2); + assert_eq!(output[0], "It was a very nice and sunny day, and I was sitting in the garden of my house, enjoying the sun and the fresh air. I was thinking"); + assert_eq!(output[1], "It was a gloom winter night, and the wind was howling. The snow was falling, and the temperature was dropping. The snow was coming down so hard"); + + Ok(()) +} diff --git a/src/gpt_j/attention.rs b/src/gpt_j/attention.rs new file mode 100644 index 0000000..cbfb87d --- /dev/null +++ b/src/gpt_j/attention.rs @@ -0,0 +1,323 @@ +// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::common::dropout::Dropout; +use crate::common::kind::get_min; +use crate::gpt_j::gpt_j_model::GptJConfig; +use std::borrow::Borrow; +use tch::nn::Linear; +use tch::{nn, IndexOp, Kind, NewAxis, Tensor}; + +#[derive(Debug)] +/// # Cache for GPT-J attention layers +/// Stores the cached value of key and value +pub struct LayerState { + /// Cached keys + pub prev_key: Tensor, + /// Cached values + pub prev_value: Tensor, +} + +impl Clone for LayerState { + fn clone(&self) -> Self { + LayerState { + prev_key: self.prev_key.copy(), + prev_value: self.prev_value.copy(), + } + } +} + +impl LayerState { + pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) { + self.prev_key = self.prev_key.index_select(0, new_indices); + self.prev_value = self.prev_value.index_select(0, new_indices); + } +} + +pub struct GptJAttention { + bias: Tensor, + attn_dropout: Dropout, + resid_dropout: Dropout, + scale_attn: f32, + k_proj: Linear, + v_proj: Linear, + q_proj: Linear, + out_proj: Linear, + output_attentions: bool, + dim_per_head: i64, + n_head: i64, + rotary_dim: Option, + scale: bool, + use_cache: bool, +} + +impl GptJAttention { + pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let max_positions = config.n_positions; + let bias = Tensor::ones(&[max_positions, max_positions], (Kind::Uint8, p.device())) + .tril(0) + .view([1, 1, max_positions, max_positions]) + .requires_grad_(false); + let bias = p.var_copy("bias", &bias); + + let attn_pdrop = config.attn_pdrop.unwrap_or(0.1); + let resid_pdrop = config.resid_pdrop.unwrap_or(0.1); + let output_attentions = config.output_attentions.unwrap_or(false); + + let attn_dropout = Dropout::new(attn_pdrop); + let resid_dropout = Dropout::new(resid_pdrop); + + assert_eq!( + config.n_embd % config.n_head, + 0, + "Attention hidden states not a multiple of the number of heads" + ); + let dim_per_head = config.n_embd / config.n_head; + + let scale_attn = (dim_per_head as f32).sqrt(); + + let linear_config = nn::LinearConfig { + bias: false, + ..Default::default() + }; + let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config); + if config.use_float16 { + (p / "k_proj").half(); + } + let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config); + if config.use_float16 { + (p / "v_proj").half(); + } + let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config); + if config.use_float16 { + (p / "q_proj").half(); + } + let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config); + if config.use_float16 { + (p / "out_proj").half(); + } + + GptJAttention { + bias, + attn_dropout, + resid_dropout, + output_attentions, + scale_attn, + k_proj, + v_proj, + q_proj, + out_proj, + dim_per_head, + n_head: config.n_head, + rotary_dim: config.rotary_dim, + scale: config.scale_attn_weights.unwrap_or(true), + use_cache: config.use_cache.unwrap_or(true), + } + } + + fn split_heads( + tensor: &Tensor, + num_heads: i64, + attention_head_size: i64, + rotary: bool, + ) -> Tensor { + let mut new_shape = tensor.size(); + let _ = new_shape.pop(); + new_shape.extend_from_slice(&[num_heads, attention_head_size]); + let tensor = tensor.view(new_shape.as_slice()); + if rotary { + tensor + } else if tensor.size().len() == 5 { + tensor.permute(&[0, 1, 3, 2, 4]) // (batch, blocks, head, block_length, head_features) + } else if tensor.size().len() == 4 { + tensor.permute(&[0, 2, 1, 3]) // (batch, head, seq_length, head_features) + } else { + panic!( + "Input tensor should either be a rotary head, or its rank be one of [4, 5] but is: {}", + tensor.size().len() + ) + } + } + + fn merge_heads(tensor: &Tensor, num_heads: i64, attention_head_size: i64) -> Tensor { + let tensor = if tensor.size().len() == 5 { + tensor.permute(&[0, 1, 3, 2, 4]).contiguous() + } else if tensor.size().len() == 4 { + tensor.permute(&[0, 2, 1, 3]).contiguous() + } else { + panic!( + "Input tensor rank should be one of [4, 5], but is: {}", + tensor.size().len() + ) + }; + let mut new_shape = tensor.size(); + new_shape.truncate(new_shape.len() - 2); + new_shape.push(num_heads * attention_head_size); + tensor.view(new_shape.as_slice()) + } + + fn attention( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + attention_mask: Option<&Tensor>, + train: bool, + ) -> (Tensor, Tensor) { + let query = query.to_kind(Kind::Float); + let key = key.to_kind(Kind::Float); + + let attention_weights = query.matmul(&key.transpose(-1, -2)); + + let query_dims = query.size(); + let key_dims = key.size(); + let query_length = query_dims[query_dims.len() - 2]; + let key_length = key_dims[key_dims.len() - 2]; + + let causal_mask = &self + .bias + .slice(2, key_length - query_length, key_length, 1) + .slice(3, 0, key_length, 1) + .to_kind(Kind::Bool) + .to_device(attention_weights.device()); + + let mask_value = get_min(attention_weights.kind()).unwrap(); + let mask_value = Tensor::full( + &attention_weights.size(), + mask_value, + (attention_weights.kind(), attention_weights.device()), + ); + + let mut attention_weights = attention_weights.where_self(causal_mask, &mask_value); + if self.scale { + attention_weights /= self.scale_attn; + } + if let Some(attention_mask_value) = attention_mask { + attention_weights += attention_mask_value; + }; + let attention_weights = attention_weights.softmax(-1, attention_weights.kind()); + let attention_weights = attention_weights + .to_kind(value.kind()) + .apply_t(&self.attn_dropout, train); + + let attention_output = attention_weights.matmul(value); + + (attention_output, attention_weights) + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + attention_mask: Option<&Tensor>, + layer_past: Option<&LayerState>, + train: bool, + ) -> (Tensor, Option, Option) { + let query = hidden_states.apply(&self.q_proj); + let key = hidden_states.apply(&self.k_proj); + let value = hidden_states.apply(&self.v_proj); + + let mut query = Self::split_heads(&query, self.n_head, self.dim_per_head, true); + let mut key = Self::split_heads(&key, self.n_head, self.dim_per_head, true); + let mut value = Self::split_heads(&value, self.n_head, self.dim_per_head, false); + + let mut seq_len = key.size()[1]; + let mut offset = 0; + + if let Some(layer_past) = layer_past { + offset = layer_past.prev_key.size()[layer_past.prev_key.size().len() - 2]; + seq_len += offset + }; + + if let Some(rotary_dim) = self.rotary_dim { + let k_rot = key.slice(3, 0, rotary_dim, 1); + let k_pass = key.slice(3, rotary_dim, key.size()[3], 1); + + let q_rot = query.slice(3, 0, rotary_dim, 1); + let q_pass = query.slice(3, rotary_dim, query.size()[3], 1); + + let sincos = fixed_pos_embedding(&k_rot, seq_len); + let k_rot = apply_rotary_pos_emb(&k_rot, &sincos, offset); + let q_rot = apply_rotary_pos_emb(&q_rot, &sincos, offset); + + key = Tensor::cat(&[k_rot, k_pass], -1); + query = Tensor::cat(&[q_rot, q_pass], -1); + } else { + let sincos = fixed_pos_embedding(&key, seq_len); + key = apply_rotary_pos_emb(&key, &sincos, offset); + query = apply_rotary_pos_emb(&query, &sincos, offset); + } + + key = key.permute(&[0, 2, 1, 3]); + query = query.permute(&[0, 2, 1, 3]); + + if let Some(layer_past) = layer_past { + key = Tensor::cat(&[&layer_past.prev_key, &key], -2); + value = Tensor::cat(&[&layer_past.prev_value, &value], -2); + } + + let present = self.use_cache.then(|| LayerState { + prev_key: key.copy(), + prev_value: value.copy(), + }); + + let (attn_output, attn_weights) = + self.attention(&query, &key, &value, attention_mask, train); + + let attn_output = Self::merge_heads(&attn_output, self.n_head, self.dim_per_head) + .apply(&self.out_proj) + .apply_t(&self.resid_dropout, train); + + let attn_weights = self.output_attentions.then_some(attn_weights); + + (attn_output, present, attn_weights) + } +} + +fn fixed_pos_embedding(x: &Tensor, seq_len: i64) -> (Tensor, Tensor) { + let dim = x.size()[x.size().len() - 1]; + let inv_freq = 1.0 + / Tensor::pow_scalar( + 10_000, + &(Tensor::arange_start_step(0, dim, 2, (x.kind(), x.device())) / dim), + ); + let sinusoid_inp = Tensor::einsum( + "i , j -> i j", + &[Tensor::arange(seq_len, (x.kind(), x.device())), inv_freq], + None, + ); + (sinusoid_inp.sin(), sinusoid_inp.cos()) +} + +fn apply_rotary_pos_emb(x: &Tensor, (sin, cos): &(Tensor, Tensor), offset: i64) -> Tensor { + let sin = duplicate_interleave(sin).i((NewAxis, offset..x.size()[1] + offset, NewAxis, ..)); + let cos = duplicate_interleave(cos).i((NewAxis, offset..x.size()[1] + offset, NewAxis, ..)); + (x * cos) + (rotate_every_two(x) * sin) +} + +/// A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. +fn duplicate_interleave(m: &Tensor) -> Tensor { + let dim0 = m.size()[0]; + m.view([-1, 1]) // flatten the matrix + .repeat(&[1, 2]) // repeat all elements into the 2nd dimension + .view([dim0, -1]) // reshape into a matrix, interleaving the copy +} + +fn rotate_every_two(x: &Tensor) -> Tensor { + let x1 = x.slice(3, 0, x.size()[3], 2); + let x2 = x.slice(3, 1, x.size()[3], 2); + Tensor::stack(&[-x2, x1], -1).flatten(-2, -1) +} diff --git a/src/gpt_j/gpt_j_model.rs b/src/gpt_j/gpt_j_model.rs new file mode 100644 index 0000000..073eb75 --- /dev/null +++ b/src/gpt_j/gpt_j_model.rs @@ -0,0 +1,836 @@ +// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::common::activations::Activation; +use crate::common::dropout::Dropout; +use crate::common::embeddings::process_ids_embeddings_pair; +use crate::common::kind::get_min; +use crate::gpt_j::attention::LayerState; +use crate::gpt_j::transformer::GptJBlock; +use crate::pipelines::common::{ModelType, TokenizerOption}; +use crate::pipelines::generation_utils::private_generation_utils::{ + PreparedInput, PrivateLanguageGenerator, +}; +use crate::pipelines::generation_utils::{ + Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, +}; +use crate::{Config, RustBertError}; +use rust_tokenizers::tokenizer::Gpt2Tokenizer; +use rust_tokenizers::vocab::Gpt2Vocab; +use serde::{Deserialize, Serialize}; +use std::borrow::{Borrow, BorrowMut}; +use tch::nn::{embedding, Linear}; +use tch::{nn, Device, Tensor}; + +/// # GPT-J Pretrained model weight files +pub struct GptJModelResources; + +/// # GPT-J Pretrained model config files +pub struct GptJConfigResources; + +/// # GPT-J Pretrained model vocab files +pub struct GptJVocabResources; + +/// # GPT-J Pretrained model merges files +pub struct GptJMergesResources; + +/// Model weights for Rust are not available out of the box for GPT-J but can be created +/// simply with the following command: +/// +/// ``` +/// python utils/convert_model.py path/to/gpt_j/pytorch_model.bin +/// ``` +/// +/// Where `pytorch_model.bin` was downloaded from [EleutherAI GPT-J 6B][gpt-j-6B] or +/// [EleutherAI GPT-J 6B (float16)][gpt-j-6B-float16]. Note that to convert GPT-J 6B you +/// will need about 32 Gb of RAM, and converting GPT-J 6B float16 requires about 12 Gb +/// of RAM. +/// +/// [gpt-j-6B]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/main +/// [gpt-j-6B-float16]:https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16 +/// +impl GptJModelResources { + pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = ( + "gpt-j-tiny-random/model", + "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/rust_model.ot", + ); +} + +impl GptJConfigResources { + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. + pub const GPT_J_6B: (&'static str, &'static str) = ( + "gpt-j-6B/config", + "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json", + ); + pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = ( + "gpt-j-6B/config", + "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/config.json", + ); + pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = ( + "gpt-j-tiny-random/config", + "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/config.json", + ); +} + +impl GptJVocabResources { + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. + pub const GPT_J_6B: (&'static str, &'static str) = ( + "gpt-j-6B/vocab", + "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/vocab.json", + ); + pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = ( + "gpt-j-6B/vocab", + "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/vocab.json", + ); + pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = ( + "gpt-j-tiny-random/vocab", + "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/vocab.json", + ); +} + +impl GptJMergesResources { + /// Shared under Apache 2.0 license by the EleutherAI contributors at . Modified with conversion to C-array format. + pub const GPT_J_6B: (&'static str, &'static str) = ( + "gpt-j-6B/merges", + "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/merges.txt", + ); + pub const GPT_J_6B_FLOAT16: (&'static str, &'static str) = ( + "gpt-j-6B/merges", + "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/float16/merges.txt", + ); + pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = ( + "gpt-j-tiny-random/merges", + "https://huggingface.co/anton-l/gpt-j-tiny-random/resolve/main/merges.txt", + ); +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +/// # GPT-J model configuration +/// Defines the GPT-J model architecture (e.g. number of layers, hidden layer size, vocab size...). +pub struct GptJConfig { + pub attn_pdrop: Option, + pub embd_pdrop: Option, + pub hidden_dropout_prob: Option, + pub afn: Option, + pub initializer_range: f64, + pub layer_norm_epsilon: f64, + pub n_embd: i64, + pub n_head: i64, + pub n_layer: i64, + pub n_positions: i64, + pub n_inner: Option, + pub num_labels: Option, + pub use_cache: Option, + pub output_attentions: Option, + pub output_hidden_states: Option, + pub resid_pdrop: Option, + pub rotary_dim: Option, + pub vocab_size: i64, + pub scale_attn_weights: Option, + #[serde(default = "default_use_float16")] + pub use_float16: bool, + #[serde(default = "default_preload_on_cpu")] + pub preload_on_cpu: bool, +} + +impl Config for GptJConfig {} + +impl Default for GptJConfig { + fn default() -> Self { + GptJConfig { + attn_pdrop: Some(0.1), + embd_pdrop: Some(0.1), + hidden_dropout_prob: None, + afn: Some(Activation::gelu_new), + initializer_range: 0.02, + layer_norm_epsilon: 1e-5, + n_embd: 4096, + n_head: 16, + n_layer: 28, + n_positions: 2048, + n_inner: None, + num_labels: None, + use_cache: None, + output_attentions: None, + output_hidden_states: None, + resid_pdrop: Some(0.1), + rotary_dim: Some(64), + vocab_size: 50400, + scale_attn_weights: Some(true), + use_float16: default_use_float16(), + preload_on_cpu: default_preload_on_cpu(), + } + } +} + +fn default_use_float16() -> bool { + true +} + +fn default_preload_on_cpu() -> bool { + true +} + +/// # GPT-J Base model +/// Base architecture for GPT-J model. Usually complemented with a task-specific head, such as a language model head. +/// It is made of the following blocks: +/// - `wte`: `token` embeddings +/// - `h`: Encoder (transformer) made of a vector of layers. Each layer is made of a multi-head attention layer, a layer-normalization layer, and a MLP made of linear layers. +/// - `output_past`: flag indicating if the model should return a past state. This can be fed back to the model to improve the quality of text generated. +/// - `output_hidden_states`: flag indicating if the model should return all hidden states (as opposed to only the last layer) +/// - `output_attentions`: flag indicating if the model should return activation weights +pub struct GptJModel { + wte: nn::Embedding, + drop: Dropout, + ln_f: nn::LayerNorm, + h: Vec, + use_cache: bool, + output_hidden_states: bool, + output_attentions: bool, +} + +impl GptJModel { + /// Build a new `GptJModel` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the GPT-J model + /// * `config` - `GptJConfig` object defining the model architecture + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::gpt_j::{GptJConfig, GptJModel}; + /// use rust_bert::Config; + /// use std::path::Path; + /// use tch::{nn, Device}; + /// + /// let config_path = Path::new("path/to/config.json"); + /// let device = Device::Cpu; + /// let p = nn::VarStore::new(device); + /// let config = GptJConfig::from_file(config_path); + /// let gpt_j: GptJModel = GptJModel::new(&p.root() / "gpt_j", &config); + /// ``` + pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJModel + where + P: Borrow>, + { + let p = p.borrow() / "transformer"; + + let wte = embedding( + &p / "wte", + config.vocab_size, + config.n_embd, + Default::default(), + ); + if config.use_float16 { + (&(&p / "wte") / "weight").half() + }; + + let embd_pdrop = config.embd_pdrop.unwrap_or(0.1); + let drop = Dropout::new(embd_pdrop); + + let layer_norm_config = nn::LayerNormConfig { + eps: config.layer_norm_epsilon, + ..Default::default() + }; + let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config); + if config.use_float16 { + (&p / "ln_f").half() + }; + + let mut h: Vec = vec![]; + let h_path = &p / "h"; + for layer_index in 0..config.n_layer { + h.push(GptJBlock::new(&h_path / layer_index, config)); + } + + let use_cache = config.use_cache.unwrap_or(true); + let output_attentions = config.output_attentions.unwrap_or(false); + let output_hidden_states = config.output_hidden_states.unwrap_or(false); + + GptJModel { + wte, + drop, + ln_f, + h, + use_cache, + output_hidden_states, + output_attentions, + } + } + + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) + /// * `layer_past` - Optional vector of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values. + /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 + /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`) + /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings. + /// * `_position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input. + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `GptJModelOutput` containing: + /// - `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the activations of the last hidden state + /// - `cache` - `Option>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*) + /// - `all_hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// - `all_attentions` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::gpt_j::{GptJConfig, GptJModel, LayerState}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = GptJConfig::from_file(config_path); + /// # let gpt_j_model: GptJModel = GptJModel::new(&vs.root(), &config); + /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); + /// let mut past: Vec> = Vec::with_capacity(config.n_layer as usize); + /// for _ in 0..config.n_layer as usize { + /// past.push(Some(LayerState { + /// prev_key: Tensor::rand( + /// &[ + /// batch_size, + /// config.n_head, + /// past_sequence_length, + /// config.n_embd / config.n_head, + /// ], + /// (Double, device), + /// ), + /// prev_value: Tensor::rand( + /// &[ + /// batch_size, + /// config.n_head, + /// past_sequence_length, + /// config.n_embd / config.n_head, + /// ], + /// (Double, device), + /// ), + /// })) + /// } + /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); + /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// gpt_j_model + /// .forward_t( + /// Some(&input_tensor), + /// Some(&past), + /// Some(&attention_mask), + /// Some(&token_type_ids), + /// None, + /// None, + /// false, + /// ) + /// .unwrap() + /// }); + /// ``` + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Option>>, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + train: bool, + ) -> Result { + let (calc_input_embeddings, _input_size, _device) = + process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?; + + let input_embeddings = + input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap()); + + let (layer_past, _layer_past_length) = match layer_past { + Some(value) => { + if value.len() != self.h.len() { + return Err(RustBertError::ValueError(format!( + "Past activations vector length ({}) must be equal to the number of layers ({})", + value.len(), + self.h.len() + ))); + } else { + let length = value.len(); + (value, length) + } + } + None => { + let mut out = Vec::with_capacity(self.h.len()); + out.resize_with(self.h.len(), || None); + (out, 0) + } + }; + + let kind_min = get_min(input_embeddings.kind())?; + let attention_mask: Option = attention_mask.map(|value| { + let attention_mask = value + .view((input_embeddings.size()[0], -1)) + .unsqueeze(1) + .unsqueeze(2) + .to_kind(input_embeddings.kind()); + + (attention_mask.ones_like() - attention_mask.to_kind(input_embeddings.kind())) + * kind_min + }); + + let mut hidden_state: Tensor = input_embeddings.copy(); + if let Some(token_type_ids) = token_type_ids { + let token_type_embeds = token_type_ids.apply(&self.wte); + hidden_state = hidden_state + token_type_embeds; + } + hidden_state = hidden_state.apply_t(&self.drop, train); + + let mut all_presents: Option>> = self.use_cache.then(Vec::new); + let mut all_hidden_states: Option> = self.output_hidden_states.then(Vec::new); + let mut all_attentions: Option> = self.output_attentions.then(Vec::new); + + for (layer, past) in self.h.iter().zip(layer_past) { + let temp = + layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train); + hidden_state = temp.0; + if let Some(presents) = all_presents.borrow_mut() { + presents.push(temp.1); + }; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(std::mem::take(&mut temp.2.unwrap())); + }; + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(std::mem::take(&mut hidden_state)); + }; + } + + let output = hidden_state.apply(&self.ln_f); + + Ok(GptJModelOutput { + output, + cache: all_presents, + all_hidden_states, + all_attentions, + }) + } +} + +/// # GPT-J Language Modeling head +/// GPT-J model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings +/// It is made of the following blocks: +/// - `transformer`: Base GptJModel +pub struct GptJLMHeadModel { + transformer: GptJModel, + lm_head: Linear, +} + +impl GptJLMHeadModel { + /// Build a new `GptJLMHeadModel` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the GPT-J model + /// * `config` - `GptJConfig` object defining the model architecture + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig}; + /// use rust_bert::Config; + /// use std::path::Path; + /// use tch::{nn, Device}; + /// + /// let config_path = Path::new("path/to/config.json"); + /// let device = Device::Cpu; + /// let p = nn::VarStore::new(device); + /// let config = GptJConfig::from_file(config_path); + /// let gpt_j: GptJLMHeadModel = GptJLMHeadModel::new(&p.root() / "gpt_j", &config); + /// ``` + pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJLMHeadModel + where + P: Borrow>, + { + let p = p.borrow(); + + let transformer = GptJModel::new(p, config); + let lm_head = nn::linear( + p / "lm_head", + config.n_embd, + config.vocab_size, + Default::default(), + ); + if config.use_float16 { + (p / "lm_head").half(); + } + + GptJLMHeadModel { + transformer, + lm_head, + } + } +} + +impl LMHeadModel for GptJLMHeadModel { + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) + /// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values. + /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 + /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`) + /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings. + /// * `_position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input. + /// * `_encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*). Unused for GPT-J + /// * `_decoder_input_ids` - Optional tensor of shape (*batch size*, *target_sequence_length*). Unused for GPT_J + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// + /// # Returns + /// + /// * `LMModelOutput` containing: + /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position + /// - `cache` - `GptJCache` made of `Option>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*) + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig}; + /// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = GptJConfig::from_file(config_path); + /// # let mut gpt_j_model: GptJLMHeadModel = GptJLMHeadModel::new(&vs.root(), &config); + /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); + /// let mut past: Vec = Vec::with_capacity(config.n_layer as usize); + /// for _ in 0..config.n_layer as usize { + /// past.push(Tensor::rand( + /// &[ + /// 2, + /// batch_size, + /// config.n_head, + /// past_sequence_length, + /// config.n_embd / config.n_head, + /// ], + /// (Double, device), + /// )) + /// } + /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); + /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); + /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) + /// .expand(&[batch_size, sequence_length], true); + /// + /// let model_output = no_grad(|| { + /// gpt_j_model + /// .forward_t( + /// Some(&input_tensor), + /// Cache::GPTJCache(Some(past)), + /// Some(&attention_mask), + /// Some(&token_type_ids), + /// None, + /// None, + /// None, + /// None, + /// false, + /// ) + /// .unwrap() + /// }); + /// ``` + fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Cache, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + _decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match layer_past { + Cache::GPTJCache(layer_past) => self.transformer.forward_t( + input_ids, + layer_past, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + ), + Cache::None => self.transformer.forward_t( + input_ids, + None, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with GPT-J Model".into(), + )); + } + }?; + + let lm_logits = base_model_output.output.apply(&self.lm_head); + + Ok(LMModelOutput { + lm_logits, + cache: Cache::GPTJCache(base_model_output.cache), + }) + } +} + +/// Container for the GPT-J model output. +pub struct GptJModelOutput { + /// Hidden state of the last layer of the decoder, or logits for a custom head + /// module after the decoder (e.g. vocabulary logits for language modeling tasks) + pub output: Tensor, + /// Cached attention layers keys and values if the model is used for generation + pub cache: Option>>, + /// Hidden states for all intermediate layers + pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers + pub all_attentions: Option>, +} + +/// # Language generation model based on the GPT-J architecture +pub struct GptJGenerator { + model: GptJLMHeadModel, + tokenizer: TokenizerOption, + var_store: nn::VarStore, + generate_config: GenerateConfig, + bos_token_id: Option, + eos_token_ids: Option>, + pad_token_id: Option, + is_encoder_decoder: bool, + vocab_size: i64, + decoder_start_id: Option, + max_position_embeddings: i64, +} + +impl GptJGenerator { + /// Build a new `GptJGenerator` + /// + /// # Arguments + /// + /// * `generate_config` - `GenerateConfig` object containing the resource references (model, vocabulary, configuration), generation options and device placement (CPU/GPU) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::gpt_j::GptJGenerator; + /// use rust_bert::pipelines::generation_utils::GenerateConfig; + /// + /// let generate_config = GenerateConfig { + /// max_length: 30, + /// do_sample: true, + /// num_beams: 5, + /// temperature: 1.1, + /// num_return_sequences: 3, + /// ..Default::default() + /// }; + /// let gpt_j_generator = GptJGenerator::new(generate_config)?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(generate_config: GenerateConfig) -> Result { + let vocab_path = generate_config.vocab_resource.get_local_path()?; + let merges_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "GPT-J expects a merges resources to be provided".to_string(), + ) + })? + .get_local_path()?; + + let tokenizer = TokenizerOption::from_file( + ModelType::GPTJ, + vocab_path.to_str().unwrap(), + Some(merges_path.to_str().unwrap()), + false, + None, + None, + )?; + + Self::new_with_tokenizer(generate_config, tokenizer) + } + + pub fn new_with_tokenizer( + generate_config: GenerateConfig, + tokenizer: TokenizerOption, + ) -> Result { + let config_path = generate_config.config_resource.get_local_path()?; + let weights_path = generate_config.model_resource.get_local_path()?; + let device = generate_config.device; + + generate_config.validate(); + let mut var_store = nn::VarStore::new(device); + + let config = GptJConfig::from_file(config_path); + let model = GptJLMHeadModel::new(var_store.root(), &config); + if config.preload_on_cpu && device != Device::Cpu { + var_store.set_device(Device::Cpu); + } + var_store.load(weights_path)?; + if device != Device::Cpu { + var_store.set_device(device); + } + + let bos_token_id = tokenizer.get_bos_id(); + let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]); + let pad_token_id = tokenizer.get_pad_id(); + let max_position_embeddings = config.n_positions; + let is_encoder_decoder = false; + let vocab_size = config.vocab_size; + let decoder_start_id = None; + + Ok(GptJGenerator { + model, + tokenizer, + var_store, + generate_config, + bos_token_id, + eos_token_ids, + pad_token_id, + is_encoder_decoder, + vocab_size, + decoder_start_id, + max_position_embeddings, + }) + } +} + +impl PrivateLanguageGenerator for GptJGenerator { + fn get_model(&self) -> &GptJLMHeadModel { + &self.model + } + fn _get_tokenizer(&self) -> &TokenizerOption { + &self.tokenizer + } + fn get_var_store(&self) -> &nn::VarStore { + &self.var_store + } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } + fn get_config(&self) -> &GenerateConfig { + &self.generate_config + } + fn get_bos_id(&self) -> Option { + self.bos_token_id + } + fn get_eos_ids(&self) -> Option<&Vec> { + self.eos_token_ids.as_ref() + } + fn get_pad_id(&self) -> Option { + self.pad_token_id + } + fn is_encoder_decoder(&self) -> bool { + self.is_encoder_decoder + } + fn get_vocab_size(&self) -> i64 { + self.vocab_size + } + fn get_decoder_start_id(&self) -> Option { + self.decoder_start_id + } + fn get_max_positions_embeddings(&self) -> i64 { + self.max_position_embeddings + } + + fn prepare_inputs_for_generation<'a>( + &self, + input_ids: Tensor, + _encoder_outputs: Option<&'a Tensor>, + past: Cache, + attention_mask: Tensor, + ) -> PreparedInput<'a> { + match past { + Cache::GPTJCache(past) => { + if past.is_some() { + PreparedInput { + prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)), + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: None, + prepared_decoder_input: None, + prepared_position_ids: None, + prepared_past: Cache::GPTJCache(past), + } + } else { + PreparedInput { + prepared_input: Some(input_ids), + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: None, + prepared_decoder_input: None, + prepared_position_ids: None, + prepared_past: Cache::GPTJCache(None), + } + } + } + Cache::None => PreparedInput { + prepared_input: Some(input_ids), + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: None, + prepared_decoder_input: None, + prepared_position_ids: None, + prepared_past: Cache::GPTJCache(None), + }, + _ => panic!("Cache type incompatible with GPT-J"), + } + } + + fn reorder_cache( + &self, + past: &mut Cache, + _encoder_outputs: Option, + beam_indices: &Tensor, + ) -> Option { + match past { + Cache::GPTJCache(cached_decoder_state) => match cached_decoder_state { + Some(old_cache) => { + for layer_state in old_cache.iter_mut() { + if layer_state.is_some() { + layer_state.as_mut().unwrap().reorder_cache(beam_indices) + }; + } + None + } + None => None, + }, + Cache::None => None, + _ => { + panic!("Invalid cache for GPT-J model"); + } + } + } +} + +impl LanguageGenerator for GptJGenerator {} diff --git a/src/gpt_j/mod.rs b/src/gpt_j/mod.rs new file mode 100644 index 0000000..ef375ed --- /dev/null +++ b/src/gpt_j/mod.rs @@ -0,0 +1,58 @@ +//! # GPT-J +//! +//! Implementation of the GPT-J language model +//! +//! # Model set-up and pre-trained weights loading +//! +//! ```no_run +//! # fn main() -> anyhow::Result<()> { +//! # +//! use tch::{nn, Device}; +//! # use std::path::PathBuf; +//! use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig}; +//! use rust_bert::resources::{LocalResource, ResourceProvider}; +//! use rust_bert::Config; +//! use rust_tokenizers::tokenizer::Gpt2Tokenizer; +//! +//! let config_resource = LocalResource { +//! local_path: PathBuf::from("path/to/config.json"), +//! }; +//! let vocab_resource = LocalResource { +//! local_path: PathBuf::from("path/to/vocab.txt"), +//! }; +//! let merges_resource = LocalResource { +//! local_path: PathBuf::from("path/to/vocab.txt"), +//! }; +//! let weights_resource = LocalResource { +//! local_path: PathBuf::from("path/to/model.ot"), +//! }; +//! let config_path = config_resource.get_local_path()?; +//! let vocab_path = vocab_resource.get_local_path()?; +//! let merges_path = merges_resource.get_local_path()?; +//! let weights_path = weights_resource.get_local_path()?; +//! +//! let device = Device::cuda_if_available(); +//! let mut vs = nn::VarStore::new(device); +//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file( +//! vocab_path.to_str().unwrap(), +//! merges_path.to_str().unwrap(), +//! true, +//! )?; +//! let config = GptJConfig::from_file(config_path); +//! let gpt_j_model = GptJLMHeadModel::new(&vs.root(), &config); +//! vs.load(weights_path)?; +//! +//! # Ok(()) +//! # } +//! ``` + +mod attention; +mod gpt_j_model; +mod transformer; + +pub use gpt_j_model::{ + GptJConfig, GptJConfigResources, GptJGenerator, GptJLMHeadModel, GptJMergesResources, + GptJModel, GptJModelOutput, GptJModelResources, GptJVocabResources, +}; + +pub use attention::LayerState; diff --git a/src/gpt_j/transformer.rs b/src/gpt_j/transformer.rs new file mode 100644 index 0000000..a00878c --- /dev/null +++ b/src/gpt_j/transformer.rs @@ -0,0 +1,131 @@ +// Copyright 2021 The Eleuther AI and HuggingFace Inc. team. All rights reserved. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::common::activations::{Activation, TensorFunction}; +use crate::common::dropout::Dropout; +use crate::gpt_j::attention::{GptJAttention, LayerState}; +use crate::gpt_j::gpt_j_model::GptJConfig; +use std::borrow::Borrow; +use tch::nn::Linear; +use tch::{nn, Tensor}; + +pub struct GptJMLP { + fc_in: Linear, + fc_out: Linear, + activation: TensorFunction, + dropout: Dropout, +} + +impl GptJMLP { + pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJMLP + where + P: Borrow>, + { + let p = p.borrow(); + + let intermediate_size = if let Some(n_inner) = config.n_inner { + n_inner + } else { + 4 * config.n_embd + }; + let fc_in = nn::linear( + p / "fc_in", + config.n_embd, + intermediate_size, + Default::default(), + ); + if config.use_float16 { + (p / "fc_in").half() + }; + let fc_out = nn::linear( + p / "fc_out", + intermediate_size, + config.n_embd, + Default::default(), + ); + if config.use_float16 { + (p / "fc_out").half() + }; + + let activation = match &config.afn { + Some(activation_enum) => match activation_enum { + Activation::gelu => &Activation::gelu_new, + default => default, + }, + None => &Activation::gelu_new, + } + .get_function(); + + let resid_pdrop = config.resid_pdrop.unwrap_or(0.1); + let dropout = Dropout::new(resid_pdrop); + + GptJMLP { + fc_in, + fc_out, + activation, + dropout, + } + } + + pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { + let h = (self.activation.get_fn())(&hidden_states.apply(&self.fc_in)); + h.apply(&self.fc_out).apply_t(&self.dropout, train) + } +} + +pub struct GptJBlock { + ln_1: nn::LayerNorm, + attn: GptJAttention, + mlp: GptJMLP, +} + +impl GptJBlock { + pub fn new<'p, P>(p: P, config: &GptJConfig) -> GptJBlock + where + P: Borrow>, + { + let p = p.borrow(); + + let layer_norm_config = nn::LayerNormConfig { + eps: config.layer_norm_epsilon, + ..Default::default() + }; + let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config); + if config.use_float16 { + (p / "ln_1").half() + }; + let attn = GptJAttention::new(p / "attn", config); + let mlp = GptJMLP::new(p / "mlp", config); + + GptJBlock { ln_1, attn, mlp } + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + layer_past: Option<&LayerState>, + attention_mask: Option<&Tensor>, + train: bool, + ) -> (Tensor, Option, Option) { + let residual = hidden_states; + let hidden_states = hidden_states.apply(&self.ln_1); + + let (attn_output, present, attn_weights) = + self.attn + .forward_t(&hidden_states, attention_mask, layer_past, train); + + let feed_forward_hidden_states = self.mlp.forward_t(&hidden_states, train); + let hidden_states = attn_output + feed_forward_hidden_states + residual; + + (hidden_states, present, attn_weights) + } +} diff --git a/src/lib.rs b/src/lib.rs index a077b30..4c7a435 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -584,7 +584,7 @@ //! # ; //! ``` //! -//!   +//!   //!
//! 11. Sentence embeddings //! @@ -600,7 +600,7 @@ //! "this is an example sentence", //! "each sentence is converted" //! ]; -//! +//! //! let output = model.encode(&sentences); //! # Ok(()) //! # } @@ -708,6 +708,7 @@ pub mod distilbert; pub mod electra; pub mod fnet; pub mod gpt2; +pub mod gpt_j; pub mod gpt_neo; pub mod longformer; pub mod longt5; diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index d8555dd..b4f6f10 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -26,6 +26,7 @@ use crate::distilbert::DistilBertConfig; use crate::electra::ElectraConfig; use crate::fnet::FNetConfig; use crate::gpt2::Gpt2Config; +use crate::gpt_j::GptJConfig; use crate::gpt_neo::GptNeoConfig; use crate::longformer::LongformerConfig; use crate::longt5::LongT5Config; @@ -78,6 +79,7 @@ pub enum ModelType { Albert, XLNet, GPT2, + GPTJ, OpenAiGpt, Reformer, ProphetNet, @@ -119,6 +121,8 @@ pub enum ConfigOption { XLNet(XLNetConfig), /// GPT2 configuration GPT2(Gpt2Config), + /// GPT-J configuration + GPTJ(GptJConfig), /// Reformer configuration Reformer(ReformerConfig), /// RoBERTa configuration @@ -196,6 +200,7 @@ impl ConfigOption { ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)), ModelType::XLNet => ConfigOption::XLNet(XLNetConfig::from_file(path)), ModelType::GPT2 => ConfigOption::GPT2(Gpt2Config::from_file(path)), + ModelType::GPTJ => ConfigOption::GPTJ(GptJConfig::from_file(path)), ModelType::GPTNeo => ConfigOption::GPTNeo(GptNeoConfig::from_file(path)), ModelType::OpenAiGpt => ConfigOption::OpenAiGpt(OpenAiGptConfig::from_file(path)), ModelType::Reformer => ConfigOption::Reformer(ReformerConfig::from_file(path)), @@ -285,6 +290,7 @@ impl ConfigOption { Self::LongT5(_) => panic!("LongT5 does not use a label mapping"), Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"), Self::GPT2(_) => panic!("GPT2 does not use a label mapping"), + Self::GPTJ(_) => panic!("GPT-J does not use a label mapping"), Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"), Self::Pegasus(_) => panic!("Pegasus does not use a label mapping"), } @@ -305,6 +311,7 @@ impl ConfigOption { Self::Albert(config) => Some(config.max_position_embeddings), Self::XLNet(_) => None, Self::GPT2(config) => Some(config.n_positions), + Self::GPTJ(config) => Some(config.n_positions), Self::Reformer(config) => Some(config.max_position_embeddings), Self::ProphetNet(config) => Some(config.max_position_embeddings), Self::Longformer(config) => Some(config.max_position_embeddings), @@ -555,11 +562,13 @@ impl TokenizerOption { } TokenizerOption::Reformer(ReformerTokenizer::from_file(vocab_path, lower_case)?) } - ModelType::GPT2 | ModelType::GPTNeo => TokenizerOption::GPT2(Gpt2Tokenizer::from_file( - vocab_path, - merges_path.expect("No merges specified!"), - lower_case, - )?), + ModelType::GPT2 | ModelType::GPTNeo | ModelType::GPTJ => { + TokenizerOption::GPT2(Gpt2Tokenizer::from_file( + vocab_path, + merges_path.expect("No merges specified!"), + lower_case, + )?) + } ModelType::OpenAiGpt => TokenizerOption::OpenAiGpt(OpenAiGptTokenizer::from_file( vocab_path, merges_path.expect("No merges specified!"), diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index f9787c4..71e68c7 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -74,6 +74,7 @@ use tch::{no_grad, Device, Tensor}; use crate::bart::LayerState as BartLayerState; use crate::common::error::RustBertError; use crate::common::resources::ResourceProvider; +use crate::gpt_j::LayerState as GPTJLayerState; use crate::gpt_neo::LayerState as GPTNeoLayerState; use crate::pipelines::generation_utils::private_generation_utils::{ InternalGenerateOptions, PrivateLanguageGenerator, @@ -224,6 +225,7 @@ pub enum Cache { ReformerCache(Option>>), ProphetNetCache(Option, Option)>>), GPTNeoCache(Option>>), + GPTJCache(Option>>), None, } diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index ee4d3ba..58a8a95 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -35,6 +35,7 @@ use tch::Device; use crate::common::error::RustBertError; use crate::gpt2::GPT2Generator; +use crate::gpt_j::GptJGenerator; use crate::gpt_neo::GptNeoGenerator; use crate::openai_gpt::OpenAIGenerator; use crate::pipelines::common::{ModelType, TokenizerOption}; @@ -192,6 +193,8 @@ pub enum TextGenerationOption { GPT(OpenAIGenerator), /// Text Generator based on GPT-Neo model GPTNeo(GptNeoGenerator), + /// Text Generator based on GPT-J model + GPTJ(GptJGenerator), /// Text Generator based on XLNet model XLNet(XLNetGenerator), /// Text Generator based on Reformer model @@ -216,6 +219,9 @@ impl TextGenerationOption { ModelType::GPTNeo => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new( config.into(), )?)), + ModelType::GPTJ => Ok(TextGenerationOption::GPTJ(GptJGenerator::new( + config.into(), + )?)), _ => Err(RustBertError::InvalidConfigurationError(format!( "Text generation not implemented for {:?}!", config.model_type @@ -229,6 +235,7 @@ impl TextGenerationOption { Self::GPT(_) => ModelType::OpenAiGpt, Self::GPT2(_) => ModelType::GPT2, Self::GPTNeo(_) => ModelType::GPTNeo, + Self::GPTJ(_) => ModelType::GPTJ, Self::XLNet(_) => ModelType::XLNet, Self::Reformer(_) => ModelType::Reformer, } @@ -240,6 +247,7 @@ impl TextGenerationOption { Self::GPT(model_ref) => model_ref._get_tokenizer(), Self::GPT2(model_ref) => model_ref._get_tokenizer(), Self::GPTNeo(model_ref) => model_ref._get_tokenizer(), + Self::GPTJ(model_ref) => model_ref._get_tokenizer(), Self::XLNet(model_ref) => model_ref._get_tokenizer(), Self::Reformer(model_ref) => model_ref._get_tokenizer(), } @@ -276,6 +284,11 @@ impl TextGenerationOption { .into_iter() .map(|output| output.indices) .collect(), + Self::GPTJ(ref model) => model + .generate_indices(prompt_texts, generate_options) + .into_iter() + .map(|output| output.indices) + .collect(), Self::XLNet(ref model) => model .generate_indices(prompt_texts, generate_options) .into_iter() @@ -294,6 +307,7 @@ impl TextGenerationOption { Self::GPT(model_ref) => model_ref.half(), Self::GPT2(model_ref) => model_ref.half(), Self::GPTNeo(model_ref) => model_ref.half(), + Self::GPTJ(model_ref) => model_ref.half(), Self::XLNet(model_ref) => model_ref.half(), Self::Reformer(model_ref) => model_ref.half(), } @@ -304,6 +318,7 @@ impl TextGenerationOption { Self::GPT(model_ref) => model_ref.float(), Self::GPT2(model_ref) => model_ref.float(), Self::GPTNeo(model_ref) => model_ref.float(), + Self::GPTJ(model_ref) => model_ref.float(), Self::XLNet(model_ref) => model_ref.float(), Self::Reformer(model_ref) => model_ref.float(), } @@ -314,6 +329,7 @@ impl TextGenerationOption { Self::GPT(model_ref) => model_ref.set_device(device), Self::GPT2(model_ref) => model_ref.set_device(device), Self::GPTNeo(model_ref) => model_ref.set_device(device), + Self::GPTJ(model_ref) => model_ref.set_device(device), Self::XLNet(model_ref) => model_ref.set_device(device), Self::Reformer(model_ref) => model_ref.set_device(device), } diff --git a/tests/gpt_j.rs b/tests/gpt_j.rs new file mode 100644 index 0000000..a39ebbc --- /dev/null +++ b/tests/gpt_j.rs @@ -0,0 +1,172 @@ +use rust_bert::gpt_j::{ + GptJConfig, GptJConfigResources, GptJLMHeadModel, GptJMergesResources, GptJModelResources, + GptJVocabResources, +}; +use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; +use rust_bert::resources::{RemoteResource, ResourceProvider}; +use rust_bert::Config; +use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer}; +use rust_tokenizers::vocab::Vocab; +use tch::{nn, Device, Tensor}; + +/// Equivalent Python code: +/// +/// ```python +/// import torch +/// from transformers import AutoTokenizer, GPTJForCausalLM +/// +/// device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +/// +/// model = GPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random").to(device) +/// if torch.cuda.is_available(): model = model.half() +/// +/// tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random", padding_side="left") +/// tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token}) +/// +/// prompts = ["It was a very nice and sunny", "It was a gloom winter night, and"] +/// inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device) +/// +/// with torch.no_grad(): +/// model.forward(**inputs).logits +/// ``` +#[test] +fn gpt_j_correctness() -> anyhow::Result<()> { + // Resources paths + + let config_resource = Box::new(RemoteResource::from_pretrained( + GptJConfigResources::GPT_J_TINY_RANDOM, + )); + + let vocab_resource = Box::new(RemoteResource::from_pretrained( + GptJVocabResources::GPT_J_TINY_RANDOM, + )); + + let merges_resource = Box::new(RemoteResource::from_pretrained( + GptJMergesResources::GPT_J_TINY_RANDOM, + )); + + let model_resource = Box::new(RemoteResource::from_pretrained( + GptJModelResources::GPT_J_TINY_RANDOM, + )); + + let device = Device::cuda_if_available(); + + // Set-up tokenizer + + let vocab_path = vocab_resource.get_local_path()?; + let merges_path = merges_resource.get_local_path()?; + let lower_case = false; + let tokenizer = Gpt2Tokenizer::from_file( + vocab_path.to_str().unwrap(), + merges_path.to_str().unwrap(), + lower_case, + )?; + + // Set-up model + + let mut vs = nn::VarStore::new(device); + let config_path = config_resource.get_local_path()?; + let weights_path = model_resource.get_local_path()?; + let mut config = GptJConfig::from_file(config_path); + config.use_float16 = matches!(device, Device::Cuda(_)); + let model = GptJLMHeadModel::new(vs.root(), &config); + vs.load(weights_path)?; + + // Tokenize prompts + + let prompts = [ + "It was a very nice and sunny", + "It was a gloom winter night, and", + ]; + + let pad_token = tokenizer.vocab().get_eos_value(); + let &pad_token = tokenizer + .vocab() + .special_values() + .get(pad_token) + .unwrap_or(&2); + + let tokens = Tokenizer::tokenize_list(&tokenizer, &prompts); + let max_len = tokens.iter().map(|input| input.len()).max().unwrap_or(0); + + let token_ids = tokens + .into_iter() + .map(|prompt_tokens| { + let token_ids = tokenizer.convert_tokens_to_ids(&prompt_tokens); + let mut padded = vec![pad_token; max_len - token_ids.len()]; + padded.extend(token_ids); + padded + }) + .collect::>>(); + + let token_masks = token_ids + .iter() + .map(|input| { + Tensor::of_slice( + &input + .iter() + .map(|&e| i64::from(e != pad_token)) + .collect::>(), + ) + .to(device) + }) + .collect::>(); + + let token_ids = token_ids + .into_iter() + .map(|tokens| Tensor::of_slice(&tokens).to(device)) + .collect::>(); + + let input_tensor = Tensor::stack(&token_ids, 0); + let attention_tensor = Tensor::stack(&token_masks, 0); + + // Run model inference + + let logits = tch::no_grad(|| { + model.forward_t( + Some(&input_tensor), + Cache::None, + // None, + Some(&attention_tensor), + None, + None, + None, + None, + None, + false, + ) + })? + .lm_logits; + + if matches!(device, Device::Cpu) { + assert!((logits.double_value(&[0, 0, 0]) - -0.8343).abs() < 1e-4); + assert!((logits.double_value(&[0, 0, 1]) - 0.0203).abs() < 1e-4); + assert!((logits.double_value(&[0, 0, 2]) - 0.4745).abs() < 1e-4); + assert!((logits.double_value(&[0, 0, 50397]) - 0.2641).abs() < 1e-4); + assert!((logits.double_value(&[0, 0, 50398]) - 0.1926).abs() < 1e-4); + assert!((logits.double_value(&[0, 0, 50399]) - 0.0204).abs() < 1e-4); + + assert!((logits.double_value(&[1, 0, 0]) - -0.0647).abs() < 1e-4); + assert!((logits.double_value(&[1, 0, 1]) - 0.0105).abs() < 1e-4); + assert!((logits.double_value(&[1, 0, 2]) - -0.3448).abs() < 1e-4); + assert!((logits.double_value(&[1, 0, 50397]) - -0.0445).abs() < 1e-4); + assert!((logits.double_value(&[1, 0, 50398]) - 0.0639).abs() < 1e-4); + assert!((logits.double_value(&[1, 0, 50399]) - -0.1167).abs() < 1e-4); + } else { + assert!((logits.double_value(&[0, 0, 0]) - -0.1110).abs() < 1e-2); + assert!((logits.double_value(&[0, 0, 1]) - 0.0565).abs() < 1e-2); + assert!((logits.double_value(&[0, 0, 2]) - 0.1273).abs() < 1e-2); + assert!((logits.double_value(&[0, 0, 50397]) - -0.1879).abs() < 1e-2); + assert!((logits.double_value(&[0, 0, 50398]) - -0.1114).abs() < 1e-2); + assert!((logits.double_value(&[0, 0, 50399]) - -0.3042).abs() < 1e-2); + + assert!((logits.double_value(&[1, 0, 0]) - -0.0651).abs() < 1e-2); + assert!((logits.double_value(&[1, 0, 1]) - 0.0107).abs() < 1e-2); + assert!((logits.double_value(&[1, 0, 2]) - -0.3452).abs() < 1e-2); + assert!((logits.double_value(&[1, 0, 50397]) - -0.0436).abs() < 1e-2); + assert!((logits.double_value(&[1, 0, 50398]) - 0.0645).abs() < 1e-2); + assert!((logits.double_value(&[1, 0, 50399]) - -0.1166).abs() < 1e-2); + } + + Ok(()) +} diff --git a/utils/convert_model.py b/utils/convert_model.py index 3dc7f08..c555c4d 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -1,25 +1,41 @@ -from pathlib import Path -import numpy as np -import torch -import subprocess import argparse +import numpy as np +import subprocess import sys +import torch +from pathlib import Path from torch import Tensor if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert") - parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings") - parser.add_argument("--skip_lm_head", action="store_true", help="Skip language model head") + parser.add_argument( + "source_file", help="Absolute path to the Pytorch weights file to convert" + ) + parser.add_argument( + "--skip_embeddings", + action="store_true", + help="Skip shared embeddings", + ) + parser.add_argument( + "--skip_lm_head", action="store_true", help="Skip language model head" + ) parser.add_argument("--prefix", help="Add a prefix on weight names") - parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part") + parser.add_argument( + "--suffix", + action="store_true", + help="Split weight names on '.' and keep only last part", + ) + parser.add_argument( + "--dtype", + help="Convert weights to a specific numpy DataType (float32, float16, ...)", + ) args = parser.parse_args() source_file = Path(args.source_file) target_folder = source_file.parent - weights = torch.load(str(source_file), map_location='cpu') + weights = torch.load(str(source_file), map_location="cpu") nps = {} for k, v in weights.items(): @@ -29,7 +45,7 @@ if __name__ == "__main__": "model.encoder.embed_tokens.weight", "encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", - "decoder.embed_tokens.weight" + "decoder.embed_tokens.weight", }: continue if args.skip_lm_head: @@ -40,18 +56,30 @@ if __name__ == "__main__": if args.prefix: k = args.prefix + k if args.suffix: - k = k.split('.')[-1] + k = k.split(".")[-1] if isinstance(v, Tensor): - nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32)) - print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes') + tensor = v.cpu().numpy() + if args.dtype is not None: + nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype))) + else: + nps[k] = np.ascontiguousarray(tensor) + print(f"converted {k} - {str(sys.getsizeof(nps[k]))} bytes") else: - print(f'skipped non-tensor object: {k}') - np.savez(target_folder / 'model.npz', **nps) + print(f"skipped non-tensor object: {k}") + np.savez(target_folder / "model.npz", **nps) - source = str(target_folder / 'model.npz') - target = str(target_folder / 'rust_model.ot') + source = str(target_folder / "model.npz") + target = str(target_folder / "rust_model.ot") - toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() + toml_location = (Path(__file__).resolve() / ".." / ".." / "Cargo.toml").resolve() subprocess.run( - ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target], + [ + "cargo", + "run", + "--bin=convert-tensor", + "--manifest-path=%s" % toml_location, + "--", + source, + target, + ], )