Grad accum (#30)

* adds gradient accumulation

* adds saving when training is finished
This commit is contained in:
arampacha 2021-07-08 22:16:49 +03:00 committed by GitHub
parent 2736967001
commit 19462b5239
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 26 deletions

View File

@ -242,23 +242,23 @@ def mb_item(x):
#checkpoint functions
def save_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
state = jax_utils.unreplicate(state)
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
logger.info(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
model.save_pretrained(
save_dir,
params=state.params,
push_to_hub=push_to_hub,
commit_message=f"Saving weights and logs at step {mb_item(state.step)}",
commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
)
if with_opt:
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
f.write(to_bytes(state.opt_state))
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump({"step": state.step.item()}, f)
print("checkpoint saved")
logger.info("checkpoint saved")
def restore_checkpoint(save_dir, state):
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
logger.info(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
params = from_bytes(state.params, f.read())
@ -269,13 +269,13 @@ def restore_checkpoint(save_dir, state):
training_state = json.load(f)
step = training_state["step"]
print("checkpoint restored")
logger.info("checkpoint restored")
return state.replace(step=step, params=params, opt_state=opt_state), step
def rotate_checkpoints(ckpt_dir:str, save_total_limit:int):
"Removes older checkpoints so that `save_total_limit` checkpoints are kept"
# TODO: what to remove is decided using step number only, we might want to improve that
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt*")]
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
# sort checkpoints by step
ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
ckpts_to_delete = ckpts_sorted[:-save_total_limit]
@ -533,7 +533,7 @@ def main():
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
@ -578,7 +578,10 @@ def main():
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
if training_args.gradient_accumulation_steps > 1:
optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
grad_accum_steps = training_args.gradient_accumulation_steps
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
@ -609,7 +612,7 @@ def main():
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
@ -636,7 +639,7 @@ def main():
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
if not training_args.skip_memory_metrics:
@ -654,10 +657,12 @@ def main():
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
train_loader = data_loader(input_rng, train_dataset, train_batch_size // grad_accum_steps, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
steps_trained_progress_bar = tqdm(range(steps_per_epoch), desc="Training...", position=1,
leave=False, initial=(resume_step // grad_accum_steps))
for step in range(steps_per_epoch * grad_accum_steps):
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
# skip to the step from which we are resuming
if cur_step < resume_step:
@ -666,8 +671,10 @@ def main():
batch = next(train_loader)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
if step % grad_accum_steps == 0:
steps_trained_progress_bar.update(1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
if cur_step % (training_args.logging_steps * grad_accum_steps)== 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
@ -684,7 +691,7 @@ def main():
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
if cur_step % (training_args.eval_steps * grad_accum_steps) == 0 and cur_step > 0:
# ======================== Evaluating ==============================
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
@ -717,11 +724,15 @@ def main():
_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
wandb.log({"eval_step":cur_step, **_metrics})
if cur_step % training_args.save_steps == 0 and cur_step > 0:
if cur_step % (training_args.save_steps * grad_accum_steps) == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
save_checkpoint(model, training_args.output_dir, state, push_to_hub=training_args.push_to_hub)
rotate_checkpoints(training_args.outpu_dir, training_args.save_total_limit)
if training_args.save_total_limit is not None:
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
# save model after training is over
save_checkpoint(model, training_args.output_dir, state, with_opt=False, push_to_hub=training_args.push_to_hub)

View File

@ -12,7 +12,7 @@
--preprocessing_num_workers="8" \
--learning_rate="3e-4" \
--adafactor \
--warmup_steps="200" \
--warmup_steps="50" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--weight_decay="0.01" \
@ -25,6 +25,8 @@
--dtype="bfloat16" \
--skip_memory_metrics="True" \
--save_steps="200" \
#--resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
--save_total_limit 2 \
--gradient_accumulation_steps 1 \
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
# --max_train_samples="10000" \
# --max_eval_samples="1000"

View File

@ -7,10 +7,10 @@
--text_column_name="func_code_string" \
--do_train --do_eval \
--block_size="1024" \
--per_device_train_batch_size="2" \
--per_device_eval_batch_size="2" \
--per_device_train_batch_size="3" \
--per_device_eval_batch_size="3" \
--preprocessing_num_workers="8" \
--learning_rate="3e-4" \
--learning_rate="1e-4" \
--warmup_steps="1000" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
@ -19,8 +19,13 @@
--num_train_epochs="1" \
--push_to_hub="False" \
--dtype="bfloat16" \
--adafactor="False" \
--skip_memory_metrics="False"
# --max_train_samples="10000" \
# --max_eval_samples="1000" \
--resume_from_checkpoint="False"
--adafactor \
--skip_memory_metrics="False" \
--gradient_accumulation_steps 1 \
--report_to="none" \
--run_name="test_13b" \
--max_train_samples="10000" \
--max_eval_samples="1000" \
--save_total_limit 1 \
# --resume_from_checkpoint="None" \