gpt-code-clippy/utils.py
arampacha 92e00c6855
Dataloaders (#55)
* adds online data filtering

* Add new lang extensions to filter by

* Switch script to use filtering

* Add C++ headers also `lang_exts` should be a set

* script for filtered run

Co-authored-by: ncoop57 <nacooper01@email.wm.edu>
Co-authored-by: Santiago Hincapie <shpotes3@gmail.com>
2021-07-17 03:07:39 +03:00

203 lines
7.2 KiB
Python

import numpy as np
import threading
import queue
import multiprocessing
from collections import defaultdict
import jax
import jax.numpy as jnp
def make_batch(samples):
batch = {k:jnp.array(v) for k,v in samples.items()}
batch['labels'] = batch['input_ids'].copy()
return batch
class PrefetchDataloaderTread(threading.Thread):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
if shuffle:
shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(dataset)
self.queue = queue.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def make_iter(self):
if self.shuffle:
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(self.dataset)
def __next__(self):
batch = self.queue.get()
return batch
def run(self):
i = 0
while True and i < self.max_steps:
i += 1
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
try:
next_sample = next(self.ds_iter)
except StopIteration:
# reset generator if a pass through dataset is completed
self.make_iter()
next_sample = next(self.ds_iter)
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(make_batch(samples))
self.queue.put(None)
def __iter__(self):
return self
class PrefetchDataloader(multiprocessing.Process):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
self.make_iter()
self.queue = multiprocessing.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def make_iter(self):
if self.shuffle:
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(self.dataset)
def __next__(self):
return make_batch(self.queue.get())
def run(self):
i = 0
while True and i < self.max_steps:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
try:
next_sample = next(self.ds_iter)
except StopIteration:
# reset generator if a pass through dataset is completed
self.make_iter()
next_sample = next(self.ds_iter)
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(samples)
self.queue.put(None)
def __iter__(self):
return self
class PrefetchDataloaderWithFilter(multiprocessing.Process):
"Prefetch dataloader for IterableDataset"
def __init__(
self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0,
):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
self.make_iter()
self.queue = multiprocessing.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def make_iter(self):
if self.shuffle:
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(self.dataset)
def __next__(self):
return make_batch(self.queue.get())
def run(self):
i = 0
while True and i < self.max_steps:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
try:
next_sample = next(self.ds_iter)
except StopIteration:
# reset generator if a pass through dataset is completed
self.make_iter()
next_sample = next(self.ds_iter)
sample_len = len(next_sample["input_ids"])
if sample_len < 200 or sample_len > 200_000:
continue
l += sample_len
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(samples)
self.queue.put(None)
def __iter__(self):
return self