Updated GPT-Neo, working half precision greedy generation

This commit is contained in:
Guillaume Becquin 2021-09-26 11:20:05 +02:00
parent 1fdd7757e8
commit 72fabcdbd1
20 changed files with 141 additions and 72 deletions

View File

@ -58,7 +58,7 @@ features = ["doc-only"]
[dependencies]
rust_tokenizers = "~6.2.4"
tch = "~0.5.0"
tch = { version = "0.5.0", path = "E:/Coding/tch-rs" }
serde_json = "1.0.66"
serde = { version = "1.0.129", features = ["derive"] }
dirs = "3.0.2"
@ -73,5 +73,5 @@ half = "1.7.1"
anyhow = "1.0.43"
csv = "1.1.6"
criterion = "0.3.5"
torch-sys = "0.5.0"
torch-sys = { version = "0.5.0", path = "E:/Coding/tch-rs/torch-sys" }
tempfile = "3.2.0"

View File

@ -25,16 +25,16 @@ use tch::Device;
fn main() -> anyhow::Result<()> {
// Set-up model resources
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_1_3B,
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_1_3B,
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_1_3B,
GptNeoMergesResources::GPT_NEO_125M,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_1_3B,
GptNeoModelResources::GPT_NEO_125M,
));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
@ -52,7 +52,9 @@ fn main() -> anyhow::Result<()> {
..Default::default()
};
let model = TextGenerationModel::new(generate_config)?;
let mut model = TextGenerationModel::new(generate_config)?;
// model.half();
model.set_device(Device::cuda_if_available());
let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";

View File

@ -128,7 +128,8 @@ impl AlbertSelfAttention {
self.hidden_size,
));
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
if !self.output_attentions {

View File

@ -1128,6 +1128,9 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -2,26 +2,6 @@ use crate::RustBertError;
use half;
use tch::{Kind, Scalar};
pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MAX.into()),
Kind::Int8 => Scalar::int(i8::MAX.into()),
Kind::Int16 => Scalar::int(i16::MAX.into()),
Kind::Int => Scalar::int(i32::MAX.into()),
Kind::Int64 => Scalar::int(i64::MAX),
Kind::Half => Scalar::float(half::f16::MAX.into()),
Kind::Float => Scalar::float(f32::MAX.into()),
Kind::BFloat16 => Scalar::float(half::bf16::MAX.into()),
Kind::Double => Scalar::float(f64::MAX),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get positive infinity for {:?}",
kind
)))
}
})
}
pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MIN.into()),

View File

@ -16,7 +16,7 @@ use crate::xlnet::XLNetConfig;
use crate::RustBertError;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use tch::{nn, Kind, Tensor};
use tch::{nn, Tensor};
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
@ -132,7 +132,7 @@ impl SequenceSummary {
let mut output = match self.summary_type {
SummaryType::last => hidden_states.select(1, -1),
SummaryType::first => hidden_states.select(1, 0),
SummaryType::mean => hidden_states.mean_dim(&[1], false, Kind::Float),
SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()),
SummaryType::cls_index => {
let cls_index = if let Some(cls_index_value) = cls_index {
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];

View File

@ -735,6 +735,9 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -15,7 +15,6 @@ use crate::gpt_neo::gpt_neo_model::AttentionLayerType;
use crate::gpt_neo::GptNeoConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::Init;
use tch::{nn, Device, Kind, Tensor};
#[derive(Debug)]
@ -207,23 +206,28 @@ pub(crate) trait GptNeoAttentionUtils {
key: &Tensor,
value: &Tensor,
causal_mask: &Tensor,
masked_bias: &Tensor,
attention_dropout: &Dropout,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor) {
let mut attention_weights = query
.matmul(&key.transpose(-1, -2))
.where_self(causal_mask, &masked_bias.to_kind(query.kind()));
let query = query.to_kind(Kind::Float);
let key = key.to_kind(Kind::Float);
let attention_weights = query.matmul(&key.transpose(-1, -2));
let mut attention_weights = attention_weights.where_self(
causal_mask,
&Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()),
);
if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value;
};
attention_weights = attention_weights
.softmax(-1, Kind::Float)
let attention_weights2 = attention_weights
.softmax(-1, attention_weights.kind())
.to_kind(value.kind())
.apply_t(attention_dropout, train);
let attention_output = attention_weights.matmul(value);
let attention_output = attention_weights2.matmul(value);
(attention_output, attention_weights)
}
}
@ -236,7 +240,6 @@ pub struct GptNeoSelfAttention {
attention_dropout: Dropout,
resid_dropout: Dropout,
bias: Tensor,
masked_bias: Tensor,
num_heads: i64,
head_dim: i64,
output_attentions: bool,
@ -259,8 +262,6 @@ impl GptNeoSelfAttention {
let bias = p.var_copy("bias", &bias_value);
let masked_bias = p.var("masked_bias", &[1], Init::Const(-1e9));
let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);
@ -306,7 +307,6 @@ impl GptNeoSelfAttention {
attention_dropout,
resid_dropout,
bias,
masked_bias,
num_heads,
head_dim,
output_attentions,
@ -357,7 +357,6 @@ impl GptNeoSelfAttention {
&key,
&value,
&causal_mask,
&self.masked_bias,
&self.attention_dropout,
attention_mask,
train,
@ -384,7 +383,6 @@ pub struct GptNeoLocalSelfAttention {
out_proj: nn::Linear,
attention_dropout: Dropout,
resid_dropout: Dropout,
masked_bias: Tensor,
num_heads: i64,
head_dim: i64,
window_size: i64,
@ -401,8 +399,6 @@ impl GptNeoLocalSelfAttention {
{
let p = p.borrow();
let masked_bias = p.var("masked_bias", &[1], Init::Const(-1e9));
let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);
@ -449,7 +445,6 @@ impl GptNeoLocalSelfAttention {
out_proj,
attention_dropout,
resid_dropout,
masked_bias,
num_heads,
head_dim,
window_size,
@ -523,7 +518,6 @@ impl GptNeoLocalSelfAttention {
&key,
&value,
attention_mask,
&self.masked_bias,
&self.attention_dropout,
None,
train,
@ -539,7 +533,6 @@ impl GptNeoLocalSelfAttention {
} else {
None
};
Ok((attention_output, attention_weights))
}
}

View File

@ -339,14 +339,6 @@ impl GptNeoModel {
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let global_attention_mask = attention_mask.map(|attention_mask_value| {
let global_attention_mask = attention_mask_value
.view([batch_size, -1])
.unsqueeze(1)
.unsqueeze(1);
(1 - global_attention_mask) * -1e4
});
let local_attention_mask = GptNeoModel::create_local_attention_mask(
batch_size,
full_sequence_length,
@ -358,12 +350,20 @@ impl GptNeoModel {
let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let position_embeds = position_ids.apply(&self.position_embeddings);
let global_attention_mask = attention_mask.map(|attention_mask_value| {
let global_attention_mask = attention_mask_value
.view([batch_size, -1])
.unsqueeze(1)
.unsqueeze(1);
let global_attention_mask = global_attention_mask.to_kind(position_embeds.kind());
(1 - global_attention_mask) * -1e4
});
let mut hidden_state = input_embeds + position_embeds;
if let Some(token_type_ids) = token_type_ids {
hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings);
};
hidden_state = hidden_state.apply_t(&self.dropout, train);
let mut output_shape = input_shape;
output_shape.push(*hidden_state.size().last().unwrap());
@ -711,6 +711,9 @@ impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for G
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -726,6 +726,9 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -900,6 +900,9 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -936,6 +936,9 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -566,6 +566,9 @@ impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTok
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -697,6 +697,9 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -280,6 +280,7 @@ pub(crate) mod private_generation_utils {
fn get_model(&self) -> &T;
fn _get_tokenizer(&self) -> &TokenizerOption;
fn get_var_store(&self) -> &nn::VarStore;
fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
fn get_config(&self) -> &GenerateConfig;
fn get_bos_id(&self) -> &Option<i64>;
fn get_eos_ids(&self) -> &Option<Vec<i64>>;
@ -488,7 +489,9 @@ pub(crate) mod private_generation_utils {
}
if top_p < 1f64 {
let (sorted_logits, sorted_indices) = logits.sort(-1, true);
let cumulative_probabilities = sorted_logits.softmax(-1, Float).cumsum(-1, Float);
let cumulative_probabilities = sorted_logits
.softmax(-1, sorted_logits.kind())
.cumsum(-1, sorted_logits.kind());
let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Int64);
if min_tokens_to_keep > 1 {
@ -563,7 +566,7 @@ pub(crate) mod private_generation_utils {
let mask = scores.new_full(
scores.size().as_slice(),
f64::INFINITY,
(Kind::Float, scores.device()),
(scores.kind(), scores.device()),
);
for idx in 0..scores.size()[0] {
let batch_id = idx / num_beams;
@ -750,14 +753,7 @@ pub(crate) mod private_generation_utils {
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
let mut scores_output = if output_scores {
Some(Tensor::zeros(
&[batch_size],
(Float, self.get_var_store().device()),
))
} else {
None
};
let mut scores_output: Option<Tensor> = None;
while current_length < gen_opt.max_length {
let prepared_input = self.prepare_inputs_for_generation(
@ -783,6 +779,13 @@ pub(crate) mod private_generation_utils {
outputs = temp.lm_logits;
past = temp.cache;
if scores_output.is_none() & output_scores {
scores_output = Some(Tensor::zeros(
&[batch_size],
(outputs.kind(), self.get_var_store().device()),
))
}
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if gen_opt.repetition_penalty > 1f64 {
@ -871,7 +874,7 @@ pub(crate) mod private_generation_utils {
gen_opt.top_p,
1,
);
let probabilities = next_token_logits.softmax(-1, Float);
let probabilities = next_token_logits.softmax(-1, next_token_logits.kind());
probabilities.multinomial(1, false).squeeze_dim(1)
} else {
next_token_logits.argmax(-1, false)
@ -882,7 +885,7 @@ pub(crate) mod private_generation_utils {
scores_output = Some(
prev_scores
+ (&next_token_logits
.log_softmax(-1, Float)
.log_softmax(-1, next_token_logits.kind())
.gather(1, &next_token.reshape(&[-1, 1]), true)
.squeeze()
.masked_fill(&finished_mask, 0)),
@ -1077,7 +1080,7 @@ pub(crate) mod private_generation_utils {
gen_opt.forced_bos_token_id,
);
let mut scores = next_token_logits.log_softmax(-1, Float);
let mut scores = next_token_logits.log_softmax(-1, next_token_logits.kind());
// Do not allow eos token if min length is not reached
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
@ -1170,7 +1173,7 @@ pub(crate) mod private_generation_utils {
.contiguous()
.view((batch_size, group_size * vocab_size));
let probabilities = _scores.softmax(-1, Float);
let probabilities = _scores.softmax(-1, _scores.kind());
let next_tokens = probabilities.multinomial(2 * group_size, false);
let _scores = _scores.gather(-1, &next_tokens, false);
let (_scores, next_scores_indices) = _scores.sort(1, true);
@ -2004,6 +2007,18 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
fn get_tokenizer(&self) -> &TokenizerOption {
self._get_tokenizer()
}
fn half(&mut self) {
self.get_var_store_mut().half();
}
fn float(&mut self) {
self.get_var_store_mut().float();
}
fn set_device(&mut self, device: Device) {
self.get_var_store_mut().set_device(device);
}
}
#[derive(Debug)]

View File

@ -326,6 +326,36 @@ impl TextGenerationOption {
.collect(),
}
}
pub fn half(&mut self) {
match self {
Self::GPT(model_ref) => model_ref.half(),
Self::GPT2(model_ref) => model_ref.half(),
Self::GPTNeo(model_ref) => model_ref.half(),
Self::XLNet(model_ref) => model_ref.half(),
Self::Reformer(model_ref) => model_ref.half(),
}
}
pub fn float(&mut self) {
match self {
Self::GPT(model_ref) => model_ref.float(),
Self::GPT2(model_ref) => model_ref.float(),
Self::GPTNeo(model_ref) => model_ref.float(),
Self::XLNet(model_ref) => model_ref.float(),
Self::Reformer(model_ref) => model_ref.float(),
}
}
pub fn set_device(&mut self, device: Device) {
match self {
Self::GPT(model_ref) => model_ref.set_device(device),
Self::GPT2(model_ref) => model_ref.set_device(device),
Self::GPTNeo(model_ref) => model_ref.set_device(device),
Self::XLNet(model_ref) => model_ref.set_device(device),
Self::Reformer(model_ref) => model_ref.set_device(device),
}
}
}
/// # TextGenerationModel to generate texts from a prompt
@ -392,6 +422,18 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
})
}
pub fn half(&mut self) {
self.model.half();
}
pub fn float(&mut self) {
self.model.float();
}
pub fn set_device(&mut self, device: Device) {
self.model.set_device(device);
}
/// Generate texts from provided prompts
///
/// # Arguments

View File

@ -999,6 +999,9 @@ impl
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -1105,6 +1105,9 @@ impl PrivateLanguageGenerator<ReformerModelWithLMHead, ReformerVocab, ReformerTo
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -798,6 +798,9 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}

View File

@ -1620,6 +1620,9 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}