adding gradnorm clippin to existing optimizer

This commit is contained in:
arampacha 2021-07-14 15:30:03 +00:00
parent a6b30210b5
commit 424856930c
2 changed files with 22 additions and 2 deletions

View File

@ -28,7 +28,7 @@
--save_total_limit 2 \
--gradient_accumulation_steps 4 \
--report_to="wandb" \
--run_name="gpt-code-clippy-125m-3e-4-256" \
--run_name="resume-test" \
--max_eval_samples 2000 \
--save_optimizer true \
--resume_from_checkpoint $HOME/gpt-code-clippy-125M-256 \

View File

@ -37,6 +37,7 @@ from flax import training
import numpy as np
import datasets
from datasets import Dataset, load_dataset
from optax._src.wrappers import MultiSteps
from tqdm import tqdm
import jax
@ -60,6 +61,7 @@ from transformers import (
TrainingArguments,
is_tensorboard_available,
)
from transformers import training_args
from transformers.testing_utils import CaptureLogger
from importlib.util import find_spec
@ -376,6 +378,19 @@ def _state_update(state):
ms_state_dict["inner_opt_state"] = new_inner_opt_state
return state.replace(opt_state=optax.MultiStepsState(**ms_state_dict))
def add_clipping(state, k_steps):
new_inner_opt = optax.chain(
optax.clip_by_global_norm(1.),
state.tx.inner_opt
)
new_opt = optax.MultiSteps(new_inner_opt, k_steps)
new_inner_opt_state = new_inner_opt.init(state.params)
new_inner_opt_state = [new_inner_opt_state[0], state.opt_state.inner_opt_state]
ms_state_dict = {k:getattr(state.opt_state, k) for k in state.opt_state._fields}
ms_state_dict["inner_opt_state"] = new_inner_opt_state
return state.replace(tx=new_opt, opt_state=optax.MultiStepsState(**ms_state_dict))
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@ -668,6 +683,11 @@ def main():
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
if training_args.resume_from_checkpoint is None:
optimizer = optax.chain(
optax.clip_by_global_norm(1.),
optimizer
)
if training_args.gradient_accumulation_steps > 1:
optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
grad_accum_steps = training_args.gradient_accumulation_steps
@ -679,7 +699,7 @@ def main():
state = restore_checkpoint(training_args.resume_from_checkpoint, state)
state = _state_update(state)
state = add_clipping(state, training_args.gradient_accumulation_steps)
resume_step = mb_item(state.step)
else:
resume_step = 0