diff --git a/run_clm_flax.py b/run_clm_flax.py index 64a044d..5402288 100755 --- a/run_clm_flax.py +++ b/run_clm_flax.py @@ -285,8 +285,11 @@ def fake_update(state): def reinstantiate_states(opt_state): new_state = [] for state in opt_state: - cls = getattr(optax, type(state).__name__) - new_state.append(cls(**{k:getattr(state, k) for k in state._fields})) + if isinstance(state, list): + new_state.append(reinstantiate_states(state)) + else: + cls = getattr(optax, type(state).__name__) + new_state.append(cls(**{k:getattr(state, k) for k in state._fields})) return new_state def restore_model_checkpoint(save_dir, state): @@ -622,7 +625,7 @@ def main(): mask=decay_mask_fn, ) optimizer = optax.chain( - optax.clip_grad_by_global_norm(1.), + optax.clip_by_global_norm(1.), optimizer ) if training_args.gradient_accumulation_steps > 1: @@ -699,7 +702,8 @@ def main(): train_metrics = [] resume_epoch = resume_step // (steps_per_epoch * grad_accum_steps) epochs = tqdm(range(num_epochs), desc=f"Epoch ... ({resume_epoch+1}/{num_epochs})", position=0) - logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}") + if resume_step != 0: + logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}") for epoch in epochs: # ======================== Training ================================ if epoch < resume_epoch: diff --git a/run_clm_streaming_flax_clean.py b/run_clm_streaming_flax_clean.py index 35cc646..0ef3bb1 100755 --- a/run_clm_streaming_flax_clean.py +++ b/run_clm_streaming_flax_clean.py @@ -290,8 +290,11 @@ def fake_update(state): def reinstantiate_states(opt_state): new_state = [] for state in opt_state: - cls = getattr(optax, type(state).__name__) - new_state.append(cls(**{k:getattr(state, k) for k in state._fields})) + if isinstance(state, list): + new_state.append(reinstantiate_states(state)) + else: + cls = getattr(optax, type(state).__name__) + new_state.append(cls(**{k:getattr(state, k) for k in state._fields})) return new_state def restore_model_checkpoint(save_dir, state): diff --git a/run_clm_wikitext.sh b/run_clm_wikitext.sh index e73253d..7f8a11f 100644 --- a/run_clm_wikitext.sh +++ b/run_clm_wikitext.sh @@ -1,6 +1,6 @@ #! /bin/bash ./run_clm_flax.py \ - --output_dir $HOME/tmp/gpt-neo-125M-test-2 \ + --output_dir $HOME/tmp/gpt-neo-125M-test-3 \ --model_name_or_path="EleutherAI/gpt-neo-125M" \ --dataset_name="wikitext" \ --dataset_config_name="wikitext-2-raw-v1" \ @@ -11,7 +11,6 @@ --per_device_eval_batch_size="16" \ --preprocessing_num_workers="8" \ --learning_rate="2e-5" \ - --adafactor \ --warmup_steps="100" \ --adam_beta1="0.9" \ --adam_beta2="0.98" \ @@ -25,11 +24,12 @@ --run_name="test-non-streaming" \ --dtype="bfloat16" \ --skip_memory_metrics="False" \ - --save_steps="200" \ - --save_strategy epoch \ + --save_steps="20" \ + --save_strategy steps \ --save_total_limit 2 \ --gradient_accumulation_steps 8 \ --save_optimizer true \ - --resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-2/ckpt-2591 \ + --resume_from_checkpoint $HOME/tmp/gpt-neo-125M-test-3/ckpt-640 \ + # --adafactor \ # --max_train_samples="10000" \ # --max_eval_samples="1000"