mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-11-09 16:55:32 +03:00
fix: concat
This commit is contained in:
parent
1b14b1f723
commit
985da51fbc
19
inference.py
19
inference.py
@ -11,6 +11,7 @@ import torch.distributed as dist
|
|||||||
from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify
|
from transformers.trainer_pt_utils import ShardSampler, distributed_concat, nested_numpify
|
||||||
from transformers import DefaultDataCollator
|
from transformers import DefaultDataCollator
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
def calc_cross_entropy_no_reduction(lm_logits, labels):
|
||||||
@ -99,16 +100,16 @@ def inference(config):
|
|||||||
sequence_lengths = torch.tensor(sequence_lengths)
|
sequence_lengths = torch.tensor(sequence_lengths)
|
||||||
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
|
||||||
|
|
||||||
train_outputs["embeddings"].extend(pooled_logits)
|
train_outputs["embeddings"].append(pooled_logits)
|
||||||
train_outputs["index"].extend(batch["index"].to(model.device))
|
train_outputs["index"].extend(batch["index"].to(model.device))
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
gathered_train = nested_numpify(distributed_concat(train_outputs))
|
gathered_train = nested_numpify(distributed_concat(train_outputs))
|
||||||
|
gathered_train["index"] = np.concatenate(gathered_train["index"])
|
||||||
gathered_train["index"] = [t.item() for t in gathered_train["index"]]
|
gathered_train["loss"] = np.concatenate(gathered_train["loss"])
|
||||||
gathered_train["loss"] = [t.item() for t in gathered_train["loss"]]
|
gathered_train["embeddings"] = np.concatenate(gathered_train["embeddings"])
|
||||||
|
|
||||||
df_train = Dataset.from_dict(gathered_train)
|
df_train = Dataset.from_dict(gathered_train)
|
||||||
df_train = df_train.sort("index")
|
df_train = df_train.sort("index")
|
||||||
@ -146,7 +147,7 @@ def inference(config):
|
|||||||
sequence_lengths = torch.tensor(sequence_lengths)
|
sequence_lengths = torch.tensor(sequence_lengths)
|
||||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||||
|
|
||||||
val_outputs["embeddings"].extend(pooled_logits)
|
val_outputs["embeddings"].append(pooled_logits)
|
||||||
val_outputs["index"].extend(batch["index"].to(model.device))
|
val_outputs["index"].extend(batch["index"].to(model.device))
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -154,8 +155,9 @@ def inference(config):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
gathered_val = nested_numpify(distributed_concat(val_outputs))
|
gathered_val = nested_numpify(distributed_concat(val_outputs))
|
||||||
|
|
||||||
gathered_val["index"] = [t.item() for t in gathered_val["index"]]
|
gathered_val["index"] = np.concatenate(gathered_val["index"])
|
||||||
gathered_val["loss"] = [t.item() for t in gathered_val["loss"]]
|
gathered_val["loss"] = np.concatenate(gathered_val["loss"])
|
||||||
|
gathered_val["embeddings"] = np.concatenate(gathered_val["embeddings"])
|
||||||
|
|
||||||
df_val = Dataset.from_dict(gathered_val)
|
df_val = Dataset.from_dict(gathered_val)
|
||||||
df_val = df_val.sort("index")
|
df_val = df_val.sort("index")
|
||||||
@ -165,7 +167,8 @@ def inference(config):
|
|||||||
val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset))
|
val_dataset = val_dataset.add_column("is_train", [False] * len(val_dataset))
|
||||||
|
|
||||||
df = concatenate_datasets([train_dataset, val_dataset])
|
df = concatenate_datasets([train_dataset, val_dataset])
|
||||||
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
|
if local_rank == 0:
|
||||||
|
df.to_json("epoch_1_checkpoint.jsonl", lines=True, orient="records", num_proc=64)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
Reference in New Issue
Block a user