fix: pyarrow filter

This commit is contained in:
Zach Nussbaum 2023-04-07 19:04:19 +00:00
parent 7a9f6d1cdc
commit 4b51e6ef37

View File

@ -12,6 +12,8 @@ from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
import pyarrow as pa
from pyarrow import compute as pc
def calc_cross_entropy_no_reduction(lm_logits, labels):
@ -116,7 +118,13 @@ def inference(config):
df_train = df_train.sort("index")
curr_idx = df_train["index"]
filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx)
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"])
@ -167,7 +175,13 @@ def inference(config):
df_val = df_val.sort("index")
curr_idx = df_val["index"]
filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx)
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"])