From 72fabcdbd1f5b32688efa38d2a4ec53c2b1cbd0d Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 26 Sep 2021 11:20:05 +0200 Subject: [PATCH] Updated GPT-Neo, working half precision greedy generation --- Cargo.toml | 4 +-- examples/generation_gpt_neo.rs | 12 +++++---- src/albert/attention.rs | 3 ++- src/bart/bart_model.rs | 3 +++ src/common/kind.rs | 20 -------------- src/common/summary.rs | 4 +-- src/gpt2/gpt2_model.rs | 3 +++ src/gpt_neo/attention.rs | 31 +++++++++------------ src/gpt_neo/gpt_neo_model.rs | 21 ++++++++------- src/m2m_100/m2m_100_model.rs | 3 +++ src/marian/marian_model.rs | 3 +++ src/mbart/mbart_model.rs | 3 +++ src/openai_gpt/openai_gpt_model.rs | 3 +++ src/pegasus/pegasus_model.rs | 3 +++ src/pipelines/generation_utils.rs | 43 ++++++++++++++++++++---------- src/pipelines/text_generation.rs | 42 +++++++++++++++++++++++++++++ src/prophetnet/prophetnet_model.rs | 3 +++ src/reformer/reformer_model.rs | 3 +++ src/t5/t5_model.rs | 3 +++ src/xlnet/xlnet_model.rs | 3 +++ 20 files changed, 141 insertions(+), 72 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 73731d5..b6322a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "~6.2.4" -tch = "~0.5.0" +tch = { version = "0.5.0", path = "E:/Coding/tch-rs" } serde_json = "1.0.66" serde = { version = "1.0.129", features = ["derive"] } dirs = "3.0.2" @@ -73,5 +73,5 @@ half = "1.7.1" anyhow = "1.0.43" csv = "1.1.6" criterion = "0.3.5" -torch-sys = "0.5.0" +torch-sys = { version = "0.5.0", path = "E:/Coding/tch-rs/torch-sys" } tempfile = "3.2.0" diff --git a/examples/generation_gpt_neo.rs b/examples/generation_gpt_neo.rs index 1501289..97d9a51 100644 --- a/examples/generation_gpt_neo.rs +++ b/examples/generation_gpt_neo.rs @@ -25,16 +25,16 @@ use tch::Device; fn main() -> anyhow::Result<()> { // Set-up model resources let config_resource = Resource::Remote(RemoteResource::from_pretrained( - GptNeoConfigResources::GPT_NEO_1_3B, + GptNeoConfigResources::GPT_NEO_125M, )); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - GptNeoVocabResources::GPT_NEO_1_3B, + GptNeoVocabResources::GPT_NEO_125M, )); let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - GptNeoMergesResources::GPT_NEO_1_3B, + GptNeoMergesResources::GPT_NEO_125M, )); let model_resource = Resource::Remote(RemoteResource::from_pretrained( - GptNeoModelResources::GPT_NEO_1_3B, + GptNeoModelResources::GPT_NEO_125M, )); let generate_config = TextGenerationConfig { model_type: ModelType::GPTNeo, @@ -52,7 +52,9 @@ fn main() -> anyhow::Result<()> { ..Default::default() }; - let model = TextGenerationModel::new(generate_config)?; + let mut model = TextGenerationModel::new(generate_config)?; + // model.half(); + model.set_device(Device::cuda_if_available()); let input_context_1 = "It was a very nice and sunny"; let input_context_2 = "It was a gloom winter night, and"; diff --git a/src/albert/attention.rs b/src/albert/attention.rs index 8a99ebb..43b65bf 100644 --- a/src/albert/attention.rs +++ b/src/albert/attention.rs @@ -128,7 +128,8 @@ impl AlbertSelfAttention { self.hidden_size, )); - let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs; + let context: Tensor = + Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap(); let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm); if !self.output_attentions { diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index af94184..38dc8ab 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -1128,6 +1128,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/common/kind.rs b/src/common/kind.rs index 1893735..c6a4fc7 100644 --- a/src/common/kind.rs +++ b/src/common/kind.rs @@ -2,26 +2,6 @@ use crate::RustBertError; use half; use tch::{Kind, Scalar}; -pub(crate) fn get_positive_infinity(kind: Kind) -> Result { - Ok(match kind { - Kind::Uint8 => Scalar::int(u8::MAX.into()), - Kind::Int8 => Scalar::int(i8::MAX.into()), - Kind::Int16 => Scalar::int(i16::MAX.into()), - Kind::Int => Scalar::int(i32::MAX.into()), - Kind::Int64 => Scalar::int(i64::MAX), - Kind::Half => Scalar::float(half::f16::MAX.into()), - Kind::Float => Scalar::float(f32::MAX.into()), - Kind::BFloat16 => Scalar::float(half::bf16::MAX.into()), - Kind::Double => Scalar::float(f64::MAX), - _ => { - return Err(RustBertError::ValueError(format!( - "Type not supported: attempted to get positive infinity for {:?}", - kind - ))) - } - }) -} - pub(crate) fn get_negative_infinity(kind: Kind) -> Result { Ok(match kind { Kind::Uint8 => Scalar::int(u8::MIN.into()), diff --git a/src/common/summary.rs b/src/common/summary.rs index 6a61210..a937bf7 100644 --- a/src/common/summary.rs +++ b/src/common/summary.rs @@ -16,7 +16,7 @@ use crate::xlnet::XLNetConfig; use crate::RustBertError; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; -use tch::{nn, Kind, Tensor}; +use tch::{nn, Tensor}; #[allow(non_camel_case_types)] #[derive(Clone, Debug, Serialize, Deserialize, Copy)] @@ -132,7 +132,7 @@ impl SequenceSummary { let mut output = match self.summary_type { SummaryType::last => hidden_states.select(1, -1), SummaryType::first => hidden_states.select(1, 0), - SummaryType::mean => hidden_states.mean_dim(&[1], false, Kind::Float), + SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()), SummaryType::cls_index => { let cls_index = if let Some(cls_index_value) = cls_index { let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1]; diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index a774a81..2747b0d 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -735,6 +735,9 @@ impl PrivateLanguageGenerator for GPT fn get_var_store(&self) -> &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/gpt_neo/attention.rs b/src/gpt_neo/attention.rs index f5e20a9..26cb64b 100644 --- a/src/gpt_neo/attention.rs +++ b/src/gpt_neo/attention.rs @@ -15,7 +15,6 @@ use crate::gpt_neo::gpt_neo_model::AttentionLayerType; use crate::gpt_neo::GptNeoConfig; use crate::RustBertError; use std::borrow::Borrow; -use tch::nn::Init; use tch::{nn, Device, Kind, Tensor}; #[derive(Debug)] @@ -207,23 +206,28 @@ pub(crate) trait GptNeoAttentionUtils { key: &Tensor, value: &Tensor, causal_mask: &Tensor, - masked_bias: &Tensor, attention_dropout: &Dropout, attention_mask: Option<&Tensor>, train: bool, ) -> (Tensor, Tensor) { - let mut attention_weights = query - .matmul(&key.transpose(-1, -2)) - .where_self(causal_mask, &masked_bias.to_kind(query.kind())); + let query = query.to_kind(Kind::Float); + let key = key.to_kind(Kind::Float); + + let attention_weights = query.matmul(&key.transpose(-1, -2)); + let mut attention_weights = attention_weights.where_self( + causal_mask, + &Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()), + ); if let Some(attention_mask_value) = attention_mask { attention_weights = attention_weights + attention_mask_value; }; - attention_weights = attention_weights - .softmax(-1, Kind::Float) + let attention_weights2 = attention_weights + .softmax(-1, attention_weights.kind()) + .to_kind(value.kind()) .apply_t(attention_dropout, train); - let attention_output = attention_weights.matmul(value); + let attention_output = attention_weights2.matmul(value); (attention_output, attention_weights) } } @@ -236,7 +240,6 @@ pub struct GptNeoSelfAttention { attention_dropout: Dropout, resid_dropout: Dropout, bias: Tensor, - masked_bias: Tensor, num_heads: i64, head_dim: i64, output_attentions: bool, @@ -259,8 +262,6 @@ impl GptNeoSelfAttention { let bias = p.var_copy("bias", &bias_value); - let masked_bias = p.var("masked_bias", &[1], Init::Const(-1e9)); - let attention_dropout = Dropout::new(config.attention_dropout); let resid_dropout = Dropout::new(config.resid_dropout); @@ -306,7 +307,6 @@ impl GptNeoSelfAttention { attention_dropout, resid_dropout, bias, - masked_bias, num_heads, head_dim, output_attentions, @@ -357,7 +357,6 @@ impl GptNeoSelfAttention { &key, &value, &causal_mask, - &self.masked_bias, &self.attention_dropout, attention_mask, train, @@ -384,7 +383,6 @@ pub struct GptNeoLocalSelfAttention { out_proj: nn::Linear, attention_dropout: Dropout, resid_dropout: Dropout, - masked_bias: Tensor, num_heads: i64, head_dim: i64, window_size: i64, @@ -401,8 +399,6 @@ impl GptNeoLocalSelfAttention { { let p = p.borrow(); - let masked_bias = p.var("masked_bias", &[1], Init::Const(-1e9)); - let attention_dropout = Dropout::new(config.attention_dropout); let resid_dropout = Dropout::new(config.resid_dropout); @@ -449,7 +445,6 @@ impl GptNeoLocalSelfAttention { out_proj, attention_dropout, resid_dropout, - masked_bias, num_heads, head_dim, window_size, @@ -523,7 +518,6 @@ impl GptNeoLocalSelfAttention { &key, &value, attention_mask, - &self.masked_bias, &self.attention_dropout, None, train, @@ -539,7 +533,6 @@ impl GptNeoLocalSelfAttention { } else { None }; - Ok((attention_output, attention_weights)) } } diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index 47941d9..34456f9 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -339,14 +339,6 @@ impl GptNeoModel { let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap()); - let global_attention_mask = attention_mask.map(|attention_mask_value| { - let global_attention_mask = attention_mask_value - .view([batch_size, -1]) - .unsqueeze(1) - .unsqueeze(1); - (1 - global_attention_mask) * -1e4 - }); - let local_attention_mask = GptNeoModel::create_local_attention_mask( batch_size, full_sequence_length, @@ -358,12 +350,20 @@ impl GptNeoModel { let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap()); let position_embeds = position_ids.apply(&self.position_embeddings); + let global_attention_mask = attention_mask.map(|attention_mask_value| { + let global_attention_mask = attention_mask_value + .view([batch_size, -1]) + .unsqueeze(1) + .unsqueeze(1); + let global_attention_mask = global_attention_mask.to_kind(position_embeds.kind()); + (1 - global_attention_mask) * -1e4 + }); + let mut hidden_state = input_embeds + position_embeds; if let Some(token_type_ids) = token_type_ids { hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings); }; hidden_state = hidden_state.apply_t(&self.dropout, train); - let mut output_shape = input_shape; output_shape.push(*hidden_state.size().last().unwrap()); @@ -711,6 +711,9 @@ impl PrivateLanguageGenerator for G fn get_var_store(&self) -> &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index b1369ee..7d876d0 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -726,6 +726,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 2b38f51..d506b9f 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -900,6 +900,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index 99bc85d..a9e40ac 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -936,6 +936,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index ac069bf..aede485 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -566,6 +566,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index ee07d14..6328eff 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -697,6 +697,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 02c7e51..bdd681d 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -280,6 +280,7 @@ pub(crate) mod private_generation_utils { fn get_model(&self) -> &T; fn _get_tokenizer(&self) -> &TokenizerOption; fn get_var_store(&self) -> &nn::VarStore; + fn get_var_store_mut(&mut self) -> &mut nn::VarStore; fn get_config(&self) -> &GenerateConfig; fn get_bos_id(&self) -> &Option; fn get_eos_ids(&self) -> &Option>; @@ -488,7 +489,9 @@ pub(crate) mod private_generation_utils { } if top_p < 1f64 { let (sorted_logits, sorted_indices) = logits.sort(-1, true); - let cumulative_probabilities = sorted_logits.softmax(-1, Float).cumsum(-1, Float); + let cumulative_probabilities = sorted_logits + .softmax(-1, sorted_logits.kind()) + .cumsum(-1, sorted_logits.kind()); let mut sorted_indices_to_remove = cumulative_probabilities.ge(top_p).to_kind(Int64); if min_tokens_to_keep > 1 { @@ -563,7 +566,7 @@ pub(crate) mod private_generation_utils { let mask = scores.new_full( scores.size().as_slice(), f64::INFINITY, - (Kind::Float, scores.device()), + (scores.kind(), scores.device()), ); for idx in 0..scores.size()[0] { let batch_id = idx / num_beams; @@ -750,14 +753,7 @@ pub(crate) mod private_generation_utils { let mut past: Cache = Cache::None; let mut outputs: Tensor; let mut current_length = cur_len; - let mut scores_output = if output_scores { - Some(Tensor::zeros( - &[batch_size], - (Float, self.get_var_store().device()), - )) - } else { - None - }; + let mut scores_output: Option = None; while current_length < gen_opt.max_length { let prepared_input = self.prepare_inputs_for_generation( @@ -783,6 +779,13 @@ pub(crate) mod private_generation_utils { outputs = temp.lm_logits; past = temp.cache; + if scores_output.is_none() & output_scores { + scores_output = Some(Tensor::zeros( + &[batch_size], + (outputs.kind(), self.get_var_store().device()), + )) + } + let mut next_token_logits = outputs.select(1, -1); // Reduce probability for repeated inputs if gen_opt.repetition_penalty > 1f64 { @@ -871,7 +874,7 @@ pub(crate) mod private_generation_utils { gen_opt.top_p, 1, ); - let probabilities = next_token_logits.softmax(-1, Float); + let probabilities = next_token_logits.softmax(-1, next_token_logits.kind()); probabilities.multinomial(1, false).squeeze_dim(1) } else { next_token_logits.argmax(-1, false) @@ -882,7 +885,7 @@ pub(crate) mod private_generation_utils { scores_output = Some( prev_scores + (&next_token_logits - .log_softmax(-1, Float) + .log_softmax(-1, next_token_logits.kind()) .gather(1, &next_token.reshape(&[-1, 1]), true) .squeeze() .masked_fill(&finished_mask, 0)), @@ -1077,7 +1080,7 @@ pub(crate) mod private_generation_utils { gen_opt.forced_bos_token_id, ); - let mut scores = next_token_logits.log_softmax(-1, Float); + let mut scores = next_token_logits.log_softmax(-1, next_token_logits.kind()); // Do not allow eos token if min length is not reached if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) { @@ -1170,7 +1173,7 @@ pub(crate) mod private_generation_utils { .contiguous() .view((batch_size, group_size * vocab_size)); - let probabilities = _scores.softmax(-1, Float); + let probabilities = _scores.softmax(-1, _scores.kind()); let next_tokens = probabilities.multinomial(2 * group_size, false); let _scores = _scores.gather(-1, &next_tokens, false); let (_scores, next_scores_indices) = _scores.sort(1, true); @@ -2004,6 +2007,18 @@ pub trait LanguageGenerator>: fn get_tokenizer(&self) -> &TokenizerOption { self._get_tokenizer() } + + fn half(&mut self) { + self.get_var_store_mut().half(); + } + + fn float(&mut self) { + self.get_var_store_mut().float(); + } + + fn set_device(&mut self, device: Device) { + self.get_var_store_mut().set_device(device); + } } #[derive(Debug)] diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index b32a7f4..83d6338 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -326,6 +326,36 @@ impl TextGenerationOption { .collect(), } } + + pub fn half(&mut self) { + match self { + Self::GPT(model_ref) => model_ref.half(), + Self::GPT2(model_ref) => model_ref.half(), + Self::GPTNeo(model_ref) => model_ref.half(), + Self::XLNet(model_ref) => model_ref.half(), + Self::Reformer(model_ref) => model_ref.half(), + } + } + + pub fn float(&mut self) { + match self { + Self::GPT(model_ref) => model_ref.float(), + Self::GPT2(model_ref) => model_ref.float(), + Self::GPTNeo(model_ref) => model_ref.float(), + Self::XLNet(model_ref) => model_ref.float(), + Self::Reformer(model_ref) => model_ref.float(), + } + } + + pub fn set_device(&mut self, device: Device) { + match self { + Self::GPT(model_ref) => model_ref.set_device(device), + Self::GPT2(model_ref) => model_ref.set_device(device), + Self::GPTNeo(model_ref) => model_ref.set_device(device), + Self::XLNet(model_ref) => model_ref.set_device(device), + Self::Reformer(model_ref) => model_ref.set_device(device), + } + } } /// # TextGenerationModel to generate texts from a prompt @@ -392,6 +422,18 @@ with people, even a bishop, begging for his blessing. " }) } + pub fn half(&mut self) { + self.model.half(); + } + + pub fn float(&mut self) { + self.model.float(); + } + + pub fn set_device(&mut self, device: Device) { + self.model.set_device(device); + } + /// Generate texts from provided prompts /// /// # Arguments diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index 065c60d..02a26d2 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -999,6 +999,9 @@ impl fn get_var_store(&self) -> &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index 0155918..a765594 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -1105,6 +1105,9 @@ impl PrivateLanguageGenerator &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 6813759..068e477 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -798,6 +798,9 @@ impl PrivateLanguageGenerator fn get_var_store(&self) -> &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config } diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index e1bb422..6cf411c 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -1620,6 +1620,9 @@ impl PrivateLanguageGenerator for fn get_var_store(&self) -> &nn::VarStore { &self.var_store } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } fn get_config(&self) -> &GenerateConfig { &self.generate_config }