From 985da51fbc1b885c7a3153d2557283b353a56f05 Mon Sep 17 00:00:00 2001 From: Zach Date: Fri, 7 Apr 2023 04:33:34 +0000 Subject: [PATCH] fix: concat --- inference.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/inference.py b/inference.py index e8585d3e..3facd91c 100644 --- a/inference.py +++ b/inference.py @@ -11,6 +11,7 @@ import torch.distributed as dist from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify from transformers import DefaultDataCollator from torch.utils.data import DataLoader +import numpy as np def calc_cross_entropy_no_reduction(lm_logits, labels): @@ -99,16 +100,16 @@ def inference(config): sequence_lengths = torch.tensor(sequence_lengths) pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths] - train_outputs["embeddings"].extend(pooled_logits) + train_outputs["embeddings"].append(pooled_logits) train_outputs["index"].extend(batch["index"].to(model.device)) torch.cuda.empty_cache() dist.barrier() gathered_train = nested_numpify(distributed_concat(train_outputs)) - - gathered_train["index"] = [t.item() for t in gathered_train["index"]] - gathered_train["loss"] = [t.item() for t in gathered_train["loss"]] + gathered_train["index"] = np.concatenate(gathered_train["index"]) + gathered_train["loss"] = np.concatenate(gathered_train["loss"]) + gathered_train["embeddings"] = np.concatenate(gathered_train["embeddings"]) df_train = Dataset.from_dict(gathered_train) df_train = df_train.sort("index") @@ -146,7 +147,7 @@ def inference(config): sequence_lengths = torch.tensor(sequence_lengths) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - val_outputs["embeddings"].extend(pooled_logits) + val_outputs["embeddings"].append(pooled_logits) val_outputs["index"].extend(batch["index"].to(model.device)) torch.cuda.empty_cache() @@ -154,8 +155,9 @@ def inference(config): dist.barrier() gathered_val = nested_numpify(distributed_concat(val_outputs)) - gathered_val["index"] = [t.item() for t in gathered_val["index"]] - gathered_val["loss"] = [t.item() for t in gathered_val["loss"]] + gathered_val["index"] = np.concatenate(gathered_val["index"]) + gathered_val["loss"] = np.concatenate(gathered_val["loss"]) + gathered_val["embeddings"] = np.concatenate(gathered_val["embeddings"]) df_val = Dataset.from_dict(gathered_val) df_val = df_val.sort("index") @@ -165,7 +167,8 @@ def inference(config): val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset)) df = concatenate_datasets([train_dataset, val_dataset]) - df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64) + if local_rank == 0: + df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64) def main():