reinstantiate opt_state recursively when there are nested lists (#47)

This commit is contained in:
arampacha 2021-07-15 03:09:40 +03:00 committed by GitHub
parent 6d4c9bfa23
commit 30f2ee4127
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 11 deletions

View File

@ -285,8 +285,11 @@ def fake_update(state):
def reinstantiate_states(opt_state):
new_state = []
for state in opt_state:
cls = getattr(optax, type(state).__name__)
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
if isinstance(state, list):
new_state.append(reinstantiate_states(state))
else:
cls = getattr(optax, type(state).__name__)
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
return new_state
def restore_model_checkpoint(save_dir, state):
@ -622,7 +625,7 @@ def main():
mask=decay_mask_fn,
)
optimizer = optax.chain(
optax.clip_grad_by_global_norm(1.),
optax.clip_by_global_norm(1.),
optimizer
)
if training_args.gradient_accumulation_steps > 1:
@ -699,7 +702,8 @@ def main():
train_metrics = []
resume_epoch = resume_step // (steps_per_epoch * grad_accum_steps)
epochs = tqdm(range(num_epochs), desc=f"Epoch ... ({resume_epoch+1}/{num_epochs})", position=0)
logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}")
if resume_step != 0:
logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}")
for epoch in epochs:
# ======================== Training ================================
if epoch < resume_epoch:

View File

@ -290,8 +290,11 @@ def fake_update(state):
def reinstantiate_states(opt_state):
new_state = []
for state in opt_state:
cls = getattr(optax, type(state).__name__)
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
if isinstance(state, list):
new_state.append(reinstantiate_states(state))
else:
cls = getattr(optax, type(state).__name__)
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
return new_state
def restore_model_checkpoint(save_dir, state):

View File

@ -1,6 +1,6 @@
#! /bin/bash
./run_clm_flax.py \
--output_dir $HOME/tmp/gpt-neo-125M-test-2 \
--output_dir $HOME/tmp/gpt-neo-125M-test-3 \
--model_name_or_path="EleutherAI/gpt-neo-125M" \
--dataset_name="wikitext" \
--dataset_config_name="wikitext-2-raw-v1" \
@ -11,7 +11,6 @@
--per_device_eval_batch_size="16" \
--preprocessing_num_workers="8" \
--learning_rate="2e-5" \
--adafactor \
--warmup_steps="100" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
@ -25,11 +24,12 @@
--run_name="test-non-streaming" \
--dtype="bfloat16" \
--skip_memory_metrics="False" \
--save_steps="200" \
--save_strategy epoch \
--save_steps="20" \
--save_strategy steps \
--save_total_limit 2 \
--gradient_accumulation_steps 8 \
--save_optimizer true \
--resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-2/ckpt-2591 \
--resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-3/ckpt-640 \
# --adafactor \
# --max_train_samples="10000" \
# --max_eval_samples="1000"