fix: concat

This commit is contained in:
Zach 2023-04-07 04:33:34 +00:00
parent 1b14b1f723
commit 985da51fbc

View File

@ -11,6 +11,7 @@ import torch.distributed as dist
from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader
import numpy as np
def calc_cross_entropy_no_reduction(lm_logits, labels):
@ -99,16 +100,16 @@ def inference(config):
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
train_outputs["embeddings"].extend(pooled_logits)
train_outputs["embeddings"].append(pooled_logits)
train_outputs["index"].extend(batch["index"].to(model.device))
torch.cuda.empty_cache()
dist.barrier()
gathered_train = nested_numpify(distributed_concat(train_outputs))
gathered_train["index"] = [t.item() for t in gathered_train["index"]]
gathered_train["loss"] = [t.item() for t in gathered_train["loss"]]
gathered_train["index"] = np.concatenate(gathered_train["index"])
gathered_train["loss"] = np.concatenate(gathered_train["loss"])
gathered_train["embeddings"] = np.concatenate(gathered_train["embeddings"])
df_train = Dataset.from_dict(gathered_train)
df_train = df_train.sort("index")
@ -146,7 +147,7 @@ def inference(config):
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
val_outputs["embeddings"].extend(pooled_logits)
val_outputs["embeddings"].append(pooled_logits)
val_outputs["index"].extend(batch["index"].to(model.device))
torch.cuda.empty_cache()
@ -154,8 +155,9 @@ def inference(config):
dist.barrier()
gathered_val = nested_numpify(distributed_concat(val_outputs))
gathered_val["index"] = [t.item() for t in gathered_val["index"]]
gathered_val["loss"] = [t.item() for t in gathered_val["loss"]]
gathered_val["index"] = np.concatenate(gathered_val["index"])
gathered_val["loss"] = np.concatenate(gathered_val["loss"])
gathered_val["embeddings"] = np.concatenate(gathered_val["embeddings"])
df_val = Dataset.from_dict(gathered_val)
df_val = df_val.sort("index")
@ -165,7 +167,8 @@ def inference(config):
val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset))
df = concatenate_datasets([train_dataset, val_dataset])
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
if local_rank == 0:
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
def main():