average loss over task tokens only (#57)

This commit is contained in:
arampacha 2021-07-18 06:31:02 +03:00 committed by GitHub
parent 65a3b335d2
commit a6f6f5ca34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -1,6 +1,6 @@
#! /bin/bash
./run_clm_apps.py \
--output_dir /home/shared/models/gpt-code-clippy-apps-4 \
--output_dir /home/shared/models/gpt-code-clippy-1.3B-apps \
--model_name_or_path EleutherAI/gpt-neo-1.3B \
--dataset_name ./apps.py \
--dataset_config_name formatted \
@ -16,8 +16,8 @@
--weight_decay="0.1" \
--overwrite_output_dir \
--num_train_epochs="5" \
--logging_steps="20" \
--eval_steps="1000" \
--logging_steps="50" \
--eval_steps="2000" \
--push_to_hub="False" \
--report_to="wandb" \
--dtype="bfloat16" \
@ -28,6 +28,7 @@
--gradient_accumulation_steps 1 \
--adafactor true \
--all_data true \
--seed 842 \
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
# --max_train_samples="10000" \
# --max_eval_samples="1000"

View File

@ -635,7 +635,7 @@ def main():
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1])) * labels_mask[..., 1:]
return loss.mean()
return (loss.sum(axis=-1) / labels_mask.sum(axis=-1)).mean()
# Define gradient update step fn
def train_step(state, batch):