bs=1 and corrected save checkpoint function

This commit is contained in:
arunraja-hub 2021-07-12 17:28:31 +00:00
parent 41a7423ca4
commit 4aabbacbe9
2 changed files with 6 additions and 7 deletions

View File

@ -732,14 +732,13 @@ def main():
if jax.process_index() == 0:
save_model_checkpoint(model, training_args.output_dir, state, with_opt=False,
push_to_hub=training_args.push_to_hub, repo_name_or_path=training_args.output_dir)
if model_args.save_optimizer:
# this saves full state including optimizer
save_checkpoint(training_args.output_dir, state, state.step, keep=training_args.save_total_limit, overwrite=False)
# this saves full state including optimizer
# save_checkpoint(training_args.output_dir, state, state.step, keep=training_args.save_total_limit, overwrite=False)
if training_args.save_total_limit is not None:
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
# save model after training is over
save_checkpoint(model, training_args.output_dir, state, with_opt=False, push_to_hub=training_args.push_to_hub)
save_model_checkpoint(model, training_args.output_dir, state, with_opt=False, push_to_hub=training_args.push_to_hub)

View File

@ -6,9 +6,9 @@
--dataset_config_name="python" \
--text_column_name="func_code_string" \
--do_train --do_eval \
--block_size="1024" \
--per_device_train_batch_size="3" \
--per_device_eval_batch_size="3" \
--block_size="2048" \
--per_device_train_batch_size="1" \
--per_device_eval_batch_size="1" \
--preprocessing_num_workers="8" \
--learning_rate="1e-4" \
--warmup_steps="1000" \