dl hotfix

This commit is contained in:
arampacha 2021-07-10 10:50:57 +00:00
parent 5df86d34b3
commit 9af5a7faaf
3 changed files with 6 additions and 7 deletions

View File

@ -27,9 +27,9 @@
--skip_memory_metrics="False" \
--save_steps="500" \
--save_total_limit 2 \
--gradient_accumulation_steps 4 \
--gradient_accumulation_steps 2 \
--report_to="wandb" \
--max_eval_samples="5000" \
--max_eval_samples="1000" \
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ckpt_201 \
# --max_train_samples="10000" \

View File

@ -276,7 +276,7 @@ def make_batch(samples):
# l += len(next_sample["input_ids"])
# sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
# self.rem = {k:v[-max_length:] for k,v in sample.items()}
# self.rem = {k:v[max_length:] for k,v in sample.items()}
# sample = {k:v[:max_length] for k,v in sample.items()}
# # regroup to shape [bs x seq_len]
# samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
@ -323,7 +323,7 @@ class PrefetchDataloader(Process):
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[-max_length:] for k,v in sample.items()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
@ -366,7 +366,6 @@ def write_train_metric(summary_writer, train_metrics, train_time, step):
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)

View File

@ -276,7 +276,7 @@ def make_batch(samples):
# l += len(next_sample["input_ids"])
# sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
# self.rem = {k:v[-max_length:] for k,v in sample.items()}
# self.rem = {k:v[max_length:] for k,v in sample.items()}
# sample = {k:v[:max_length] for k,v in sample.items()}
# # regroup to shape [bs x seq_len]
# samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
@ -323,7 +323,7 @@ class PrefetchDataloader(Process):
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[-max_length:] for k,v in sample.items()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}