checkpoint functions added to script (#27)

* saves optimizer state together with model
* enables to resume from saved checkpoint
* removes old checkpoint up to `save_totat_limit`

Co-authored-by: arampacha <aruthart@gmail.com>
This commit is contained in:
Arun Raja 2021-07-08 19:34:16 +08:00 committed by GitHub
parent 9c4c40756f
commit 2736967001
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 18 deletions

View File

@ -29,6 +29,8 @@ import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import json
import shutil
import datasets
from datasets import Dataset, load_dataset
@ -43,6 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.serialization import to_bytes, from_bytes
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
@ -236,6 +239,50 @@ def create_learning_rate_fn(
def mb_item(x):
return x.item() if hasattr(x, "item") else x
#checkpoint functions
def save_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
state = jax_utils.unreplicate(state)
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
model.save_pretrained(
save_dir,
params=state.params,
push_to_hub=push_to_hub,
commit_message=f"Saving weights and logs at step {mb_item(state.step)}",
)
if with_opt:
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
f.write(to_bytes(state.opt_state))
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump({"step": state.step.item()}, f)
print("checkpoint saved")
def restore_checkpoint(save_dir, state):
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
params = from_bytes(state.params, f.read())
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
opt_state = from_bytes(state.opt_state, f.read())
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
training_state = json.load(f)
step = training_state["step"]
print("checkpoint restored")
return state.replace(step=step, params=params, opt_state=opt_state), step
def rotate_checkpoints(ckpt_dir:str, save_total_limit:int):
"Removes older checkpoints so that `save_total_limit` checkpoints are kept"
# TODO: what to remove is decided using step number only, we might want to improve that
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt*")]
# sort checkpoints by step
ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
ckpts_to_delete = ckpts_sorted[:-save_total_limit]
for ckpt in ckpts_to_delete:
logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
shutil.rmtree(ckpt)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@ -464,7 +511,7 @@ def main():
# enable wandb tracking
has_wandb = find_spec("wandb") is not None
if jax.process_index() == 0 and has_wandb and "wandb" in training_args.report_to:
if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
try:
import wandb
wandb.init(
@ -478,6 +525,7 @@ def main():
except ImportError as e:
print(e)
has_wandb = False
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
@ -530,9 +578,14 @@ def main():
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
if training_args.resume_from_checkpoint:
state, resume_step = restore_checkpoint(training_args.resume_from_checkpoint, state)
else:
resume_step = 0
def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :]
@ -605,19 +658,22 @@ def main():
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
# skip to the step from which we are resuming
if cur_step < resume_step:
continue
batch = next(train_loader)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
if has_wandb and jax.process_index() == 0:
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
# TODO: add accumulation of metrics
_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
wandb.log({"training_step":cur_step, **_metrics}, commit=True)
@ -657,20 +713,17 @@ def main():
if has_tensorboard and jax.process_index() == 0:
# cur_step = epoch * (len(train_dataset) // train_batch_size)
write_eval_metric(summary_writer, eval_metrics, cur_step)
if has_wandb and jax.process_index() == 0:
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
wandb.log({"eval_step":cur_step, **_metrics})
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
save_checkpoint(model, training_args.output_dir, state, push_to_hub=training_args.push_to_hub)
rotate_checkpoints(training_args.outpu_dir, training_args.save_total_limit)
if __name__ == "__main__":

View File

@ -19,10 +19,12 @@
--overwrite_output_dir \
--num_train_epochs="1" \
--logging_steps="100" \
--eval_steps="100" \
--eval_steps="200" \
--push_to_hub="False" \
--report_to="all" \
--dtype="bfloat16" \
--skip_memory_metrics="False" \
--skip_memory_metrics="True" \
--save_steps="200" \
#--resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
# --max_train_samples="10000" \
# --max_eval_samples="1000"

View File

@ -22,4 +22,5 @@
--adafactor="False" \
--skip_memory_metrics="False"
# --max_train_samples="10000" \
# --max_eval_samples="1000" \
# --max_eval_samples="1000" \
--resume_from_checkpoint="False"

View File

@ -18,4 +18,5 @@
--weight_decay="0.01" \
--overwrite_output_dir \
--num_train_epochs="1" \
--push_to_hub="False"
--push_to_hub="False" \
--resume_from_checkpoint="False"