mirror of
https://github.com/CodedotAl/gpt-code-clippy.git
synced 2024-10-26 09:17:45 +03:00
reinstantiate opt_state recursively when there are nested lists (#47)
This commit is contained in:
parent
6d4c9bfa23
commit
30f2ee4127
@ -285,8 +285,11 @@ def fake_update(state):
|
||||
def reinstantiate_states(opt_state):
|
||||
new_state = []
|
||||
for state in opt_state:
|
||||
cls = getattr(optax, type(state).__name__)
|
||||
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
|
||||
if isinstance(state, list):
|
||||
new_state.append(reinstantiate_states(state))
|
||||
else:
|
||||
cls = getattr(optax, type(state).__name__)
|
||||
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
|
||||
return new_state
|
||||
|
||||
def restore_model_checkpoint(save_dir, state):
|
||||
@ -622,7 +625,7 @@ def main():
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
optimizer = optax.chain(
|
||||
optax.clip_grad_by_global_norm(1.),
|
||||
optax.clip_by_global_norm(1.),
|
||||
optimizer
|
||||
)
|
||||
if training_args.gradient_accumulation_steps > 1:
|
||||
@ -699,7 +702,8 @@ def main():
|
||||
train_metrics = []
|
||||
resume_epoch = resume_step // (steps_per_epoch * grad_accum_steps)
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... ({resume_epoch+1}/{num_epochs})", position=0)
|
||||
logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}")
|
||||
if resume_step != 0:
|
||||
logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}")
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
if epoch < resume_epoch:
|
||||
|
@ -290,8 +290,11 @@ def fake_update(state):
|
||||
def reinstantiate_states(opt_state):
|
||||
new_state = []
|
||||
for state in opt_state:
|
||||
cls = getattr(optax, type(state).__name__)
|
||||
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
|
||||
if isinstance(state, list):
|
||||
new_state.append(reinstantiate_states(state))
|
||||
else:
|
||||
cls = getattr(optax, type(state).__name__)
|
||||
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
|
||||
return new_state
|
||||
|
||||
def restore_model_checkpoint(save_dir, state):
|
||||
|
@ -1,6 +1,6 @@
|
||||
#! /bin/bash
|
||||
./run_clm_flax.py \
|
||||
--output_dir $HOME/tmp/gpt-neo-125M-test-2 \
|
||||
--output_dir $HOME/tmp/gpt-neo-125M-test-3 \
|
||||
--model_name_or_path="EleutherAI/gpt-neo-125M" \
|
||||
--dataset_name="wikitext" \
|
||||
--dataset_config_name="wikitext-2-raw-v1" \
|
||||
@ -11,7 +11,6 @@
|
||||
--per_device_eval_batch_size="16" \
|
||||
--preprocessing_num_workers="8" \
|
||||
--learning_rate="2e-5" \
|
||||
--adafactor \
|
||||
--warmup_steps="100" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
@ -25,11 +24,12 @@
|
||||
--run_name="test-non-streaming" \
|
||||
--dtype="bfloat16" \
|
||||
--skip_memory_metrics="False" \
|
||||
--save_steps="200" \
|
||||
--save_strategy epoch \
|
||||
--save_steps="20" \
|
||||
--save_strategy steps \
|
||||
--save_total_limit 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--save_optimizer true \
|
||||
--resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-2/ckpt-2591 \
|
||||
--resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-3/ckpt-640 \
|
||||
# --adafactor \
|
||||
# --max_train_samples="10000" \
|
||||
# --max_eval_samples="1000"
|
||||
|
Loading…
Reference in New Issue
Block a user