GPTNeo hypers (#35)

* store full train state

* close dl processes when done

* gpt3 scheduler
This commit is contained in:
arampacha 2021-07-10 17:52:50 +03:00 committed by GitHub
parent 9af5a7faaf
commit 240816c55e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 185 additions and 143 deletions

5
.gitignore vendored
View File

@ -131,4 +131,7 @@ dmypy.json
# Pyre type checker
.pyre/
wandb/
wandb/
_tmp/
_*.py
_*.ipynb

View File

@ -6,30 +6,32 @@
--data_dir /home/shared/code-clippy-dataset/merged-data \
--text_column_name="text" \
--do_train --do_eval \
--block_size="1024" \
--per_device_train_batch_size="16" \
--block_size="2048" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="16" \
--preprocessing_num_workers="8" \
--learning_rate="3e-4" \
--learning_rate="6e-4" \
--adafactor \
--warmup_steps="250" \
--max_steps 10000 \
--warmup_steps 3000 \
--decay_steps 5000 \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--adam_beta2="0.95" \
--weight_decay="0.01" \
--overwrite_output_dir \
--max_steps 50000 \
--num_train_epochs="1" \
--logging_steps="50" \
--eval_steps="500" \
--eval_steps="50" \
--push_to_hub="False" \
--report_to="all" \
--dtype="bfloat16" \
--skip_memory_metrics="False" \
--save_steps="500" \
--save_steps="50" \
--save_total_limit 2 \
--gradient_accumulation_steps 2 \
--gradient_accumulation_steps 8 \
--report_to="wandb" \
--max_eval_samples="1000" \
--save_optimizer true \
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
# --max_train_samples="10000" \

View File

@ -34,9 +34,6 @@ import json
import shutil
from collections import defaultdict
import numpy as np
# from queue import Queue
# import threading
from multiprocessing import Process, Queue
import datasets
from datasets import Dataset, load_dataset
from tqdm import tqdm
@ -50,6 +47,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
from flax.serialization import to_bytes, from_bytes
from transformers import (
CONFIG_MAPPING,
@ -64,6 +62,7 @@ from transformers import (
from transformers.testing_utils import CaptureLogger
from importlib.util import find_spec
from utils import PrefetchDataloader, make_batch
logger = logging.getLogger(__name__)
@ -107,7 +106,19 @@ class ModelArguments:
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
save_optimizer: Optional[bool] = field(
default=True,
metadata={"help": "Whether to store full train state including optimizer."},
)
repo_path_or_name: Optional[str] = field(
default=None,
metadata={"help": "Path to the modelhub repo directory"},
)
repo_url: Optional[str] = field(
default=None,
metadata={"help": "URL of the modelhub repo"},
)
decay_steps: int = field(default=None, metadata={"help":"Number of steps from peak to final learning rate"})
@dataclass
class DataTrainingArguments:
@ -194,7 +205,7 @@ class TrainState(train_state.TrainState):
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
# the below functions are not used now, probably to be removed
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
@ -234,105 +245,6 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
grouped_samples = group_texts(samples)
return grouped_samples
def make_batch(samples):
batch = {k:jnp.array(v) for k,v in samples.items()}
batch['labels'] = batch['input_ids'].copy()
return batch
# class PrefetchDataloader(threading.Thread):
# "Prefetch dataloader for IterableDataset"
# def __init__(self, dataset, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
# super().__init__(daemon=True)
# self.bs = batch_size
# self.seq_len = sequence_length
# self.max_length = batch_size * sequence_length
# self.prefetch_buffer = prefetch_buffer
# self.shuffle = shuffle
# self.shuffle_buffer = shuffle_buffer
# self.seed = seed
# self.dataset = dataset
# if shuffle:
# shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
# self.seed += 1
# self.ds_iter = iter(shuffled_dataset)
# else:
# self.ds_iter = iter(dataset)
# self.queue = Queue(prefetch_buffer)
# self.rem = defaultdict(list)
# self.start()
# def __next__(self):
# batch = self.queue.get()
# return batch
# def run(self):
# while True:
# # prepair next batch
# sample = self.rem.copy()
# l = len(sample["input_ids"])
# max_length = self.max_length
# while l < max_length:
# next_sample = next(self.ds_iter)
# l += len(next_sample["input_ids"])
# sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
# self.rem = {k:v[max_length:] for k,v in sample.items()}
# sample = {k:v[:max_length] for k,v in sample.items()}
# # regroup to shape [bs x seq_len]
# samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
# self.queue.put(make_batch(samples))
# def __iter__(self):
# return self
class PrefetchDataloader(Process):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
if shuffle:
shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(dataset)
self.queue = Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def __next__(self):
return make_batch(self.queue.get())
def run(self):
while True:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
next_sample = next(self.ds_iter)
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(samples)
def __iter__(self):
return self
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
"""
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
@ -382,13 +294,24 @@ def create_learning_rate_fn(
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def gpt3_schedule(warmup_steps,
total_steps,
peak_lr,
end_lr):
def sch(step):
warmup_pct = jnp.clip(step, 0, warmup_steps) / warmup_steps
anneal_pct = jnp.clip(step - warmup_steps, 0, total_steps) / total_steps
return warmup_pct * peak_lr - (peak_lr - end_lr) * (1 - jnp.cos(jnp.pi * anneal_pct)) / 2
return sch
# utils
def mb_item(x):
return x.item() if hasattr(x, "item") else x
#checkpoint functions
def save_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
def save_model_checkpoint(model, save_dir, state, with_opt:bool=False, push_to_hub:bool=False, **kwargs):
state = jax_utils.unreplicate(state)
logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
@ -397,6 +320,7 @@ def save_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool
params=state.params,
push_to_hub=push_to_hub,
commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
**kwargs
)
if with_opt:
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
@ -405,7 +329,7 @@ def save_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool
json.dump({"step": state.step.item()}, f)
logger.info("checkpoint saved")
def restore_checkpoint(save_dir, state):
def restore_model_checkpoint(save_dir, state):
logger.info(f"RESTORING CHECKPOINT FROM {save_dir}...")
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
params = from_bytes(state.params, f.read())
@ -618,14 +542,7 @@ def main():
# train_iter = iter(train_dataset)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
# steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = training_args.max_steps
train_dl = PrefetchDataloader(
train_loader = PrefetchDataloader(
tokenized_dataset,
int(training_args.per_device_train_batch_size) * jax.device_count(),
block_size,
@ -687,15 +604,14 @@ def main():
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
# steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = training_args.max_steps
total_train_steps = training_args.max_steps * training_args.gradient_accumulation_steps
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
total_train_steps,
train_batch_size,
training_args.warmup_steps,
gpt3_schedule_fn = gpt3_schedule(
training_args.warmup_steps * training_args.gradient_accumulation_steps,
model_args.decay_steps * training_args.gradient_accumulation_steps,
training_args.learning_rate,
training_args.learning_rate / 10.
)
# We use Optax's "masking" functionality to not apply weight decay
@ -718,11 +634,11 @@ def main():
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
learning_rate=gpt3_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
learning_rate=gpt3_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
@ -737,7 +653,8 @@ def main():
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
if training_args.resume_from_checkpoint:
state, resume_step = restore_checkpoint(training_args.resume_from_checkpoint, state)
state = restore_checkpoint(training_args.resume_from_checkpoint, state)
resume_step = state.step
else:
resume_step = 0
@ -763,7 +680,7 @@ def main():
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}
metrics = {"loss": loss, "learning_rate": gpt3_schedule_fn(state.step // grad_accum_steps)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
@ -787,11 +704,9 @@ def main():
state = state.replicate()
logger.info("***** Running training *****")
# logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
logger.info(f" Total optimization steps = {training_args.max_steps}")
if not training_args.skip_memory_metrics:
server = jax.profiler.start_server(9999)
@ -799,7 +714,7 @@ def main():
train_time = 0
train_metrics = []
# TODO: figure out training duration
steps = tqdm(range(total_train_steps // grad_accum_steps), position=0, initial=resume_step)
steps = tqdm(range(training_args.max_steps), position=0, initial=resume_step)
for step in range(total_train_steps):
# ======================== Training ================================
train_start = time.time()
@ -810,9 +725,10 @@ def main():
if cur_step < resume_step:
continue
# using advance_iter_and_group_samples seem to make training slower
# samples = advance_iter_and_group_samples(iter(tokenized_dataset), int(training_args.per_device_train_batch_size) * jax.device_count(), block_size)
# batch = shard(make_batch(samples))
batch = shard(next(train_dl))
batch = shard(next(train_loader))
# logger.info(f"{batch['input_ids'].shape}")
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
@ -842,7 +758,7 @@ def main():
# eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_loader = PrefetchDataloader(
tokenized_eval_dataset,
int(training_args.per_device_eval_batch_size) * jax.device_count(),
eval_batch_size,
block_size,
prefetch_buffer=data_args.prefetch_buffer,
shuffle=False,
@ -862,7 +778,7 @@ def main():
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
eval_loader.close()
# Print metrics and update progress bar
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
steps.write(desc)
@ -879,12 +795,18 @@ def main():
if cur_step % (training_args.save_steps * grad_accum_steps) == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
save_checkpoint(model, training_args.output_dir, state, push_to_hub=training_args.push_to_hub)
save_model_checkpoint(model, training_args.output_dir, state, with_opt=False,
push_to_hub=training_args.push_to_hub, repo_name_or_path=training_args.output_dir)
if model_args.save_optimizer:
# this saves full state including optimizer
save_checkpoint(training_args.output_dir, state, state.step, keep=training_args.save_total_limit, overwrite=False)
if training_args.save_total_limit is not None:
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
train_loader.close()
# save model after training is over
save_checkpoint(model, training_args.output_dir, state, with_opt=False, push_to_hub=training_args.push_to_hub)
save_checkpoint(model, training_args.output_dir, state, with_opt=False,
push_to_hub=training_args.push_to_hub, repo_name_or_path=training_args.output_dir)

115
utils.py Normal file
View File

@ -0,0 +1,115 @@
import numpy as np
import threading
import queue
from multiprocessing import Process, Queue
from collections import defaultdict
import jax
import jax.numpy as jnp
def make_batch(samples):
batch = {k:jnp.array(v) for k,v in samples.items()}
batch['labels'] = batch['input_ids'].copy()
return batch
class PrefetchDataloaderTread(threading.Thread):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
if shuffle:
shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(dataset)
self.queue = queue.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def __next__(self):
batch = self.queue.get()
return batch
def run(self):
while True:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
next_sample = next(self.ds_iter)
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(make_batch(samples))
def __iter__(self):
return self
class PrefetchDataloader(Process):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
self.make_iter()
self.queue = Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def make_iter(self):
if self.shuffle:
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(self.dataset)
def __next__(self):
return make_batch(self.queue.get())
def run(self):
while True:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
try:
next_sample = next(self.ds_iter)
except StopIteration:
# reset generator if a pass through dataset is completed
self.make_iter()
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(samples)
def __iter__(self):
return self