mirror of
https://github.com/CodedotAl/gpt-code-clippy.git
synced 2024-10-26 09:17:45 +03:00
92539be4cb
* adds cleaner streaming script * resume from checkpoint w/ MultiSteps * adds gradient clipping * upd run_clm_flax.py
138 lines
4.9 KiB
Python
138 lines
4.9 KiB
Python
import numpy as np
|
|
import threading
|
|
import queue
|
|
import multiprocessing
|
|
from collections import defaultdict
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
|
|
def make_batch(samples):
|
|
batch = {k:jnp.array(v) for k,v in samples.items()}
|
|
batch['labels'] = batch['input_ids'].copy()
|
|
return batch
|
|
|
|
class PrefetchDataloaderTread(threading.Thread):
|
|
"Prefetch dataloader for IterableDataset"
|
|
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
|
|
super().__init__(daemon=True)
|
|
self.max_steps = max_steps
|
|
self.bs = batch_size
|
|
self.seq_len = sequence_length
|
|
self.max_length = batch_size * sequence_length
|
|
self.prefetch_buffer = prefetch_buffer
|
|
self.shuffle = shuffle
|
|
self.shuffle_buffer = shuffle_buffer
|
|
self.seed = seed
|
|
self.dataset = dataset
|
|
if shuffle:
|
|
shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
|
|
self.seed += 1
|
|
self.ds_iter = iter(shuffled_dataset)
|
|
else:
|
|
self.ds_iter = iter(dataset)
|
|
self.queue = queue.Queue(prefetch_buffer)
|
|
self.rem = defaultdict(list)
|
|
self.start()
|
|
|
|
def make_iter(self):
|
|
if self.shuffle:
|
|
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
|
|
self.seed += 1
|
|
self.ds_iter = iter(shuffled_dataset)
|
|
else:
|
|
self.ds_iter = iter(self.dataset)
|
|
|
|
def __next__(self):
|
|
batch = self.queue.get()
|
|
return batch
|
|
|
|
def run(self):
|
|
i = 0
|
|
while True and i < self.max_steps:
|
|
i += 1
|
|
# prepair next batch
|
|
sample = self.rem.copy()
|
|
l = len(sample["input_ids"])
|
|
max_length = self.max_length
|
|
while l < max_length:
|
|
try:
|
|
next_sample = next(self.ds_iter)
|
|
except StopIteration:
|
|
# reset generator if a pass through dataset is completed
|
|
self.make_iter()
|
|
next_sample = next(self.ds_iter)
|
|
|
|
l += len(next_sample["input_ids"])
|
|
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
|
|
|
|
self.rem = {k:v[max_length:] for k,v in sample.items()}
|
|
sample = {k:v[:max_length] for k,v in sample.items()}
|
|
# regroup to shape [bs x seq_len]
|
|
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
|
|
|
|
self.queue.put(make_batch(samples))
|
|
self.queue.put(None)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
|
|
class PrefetchDataloader(multiprocessing.Process):
|
|
"Prefetch dataloader for IterableDataset"
|
|
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
|
|
super().__init__(daemon=True)
|
|
self.max_steps = max_steps
|
|
self.bs = batch_size
|
|
self.seq_len = sequence_length
|
|
self.max_length = batch_size * sequence_length
|
|
self.prefetch_buffer = prefetch_buffer
|
|
self.shuffle = shuffle
|
|
self.shuffle_buffer = shuffle_buffer
|
|
self.seed = seed
|
|
self.dataset = dataset
|
|
self.make_iter()
|
|
self.queue = multiprocessing.Queue(prefetch_buffer)
|
|
self.rem = defaultdict(list)
|
|
self.start()
|
|
|
|
def make_iter(self):
|
|
if self.shuffle:
|
|
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
|
|
self.seed += 1
|
|
self.ds_iter = iter(shuffled_dataset)
|
|
else:
|
|
self.ds_iter = iter(self.dataset)
|
|
|
|
def __next__(self):
|
|
return make_batch(self.queue.get())
|
|
|
|
def run(self):
|
|
i = 0
|
|
while True and i < self.max_steps:
|
|
# prepair next batch
|
|
sample = self.rem.copy()
|
|
l = len(sample["input_ids"])
|
|
max_length = self.max_length
|
|
while l < max_length:
|
|
try:
|
|
next_sample = next(self.ds_iter)
|
|
except StopIteration:
|
|
# reset generator if a pass through dataset is completed
|
|
self.make_iter()
|
|
next_sample = next(self.ds_iter)
|
|
|
|
l += len(next_sample["input_ids"])
|
|
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
|
|
|
|
self.rem = {k:v[max_length:] for k,v in sample.items()}
|
|
sample = {k:v[:max_length] for k,v in sample.items()}
|
|
# regroup to shape [bs x seq_len]
|
|
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
|
|
|
|
self.queue.put(samples)
|
|
self.queue.put(None)
|
|
|
|
def __iter__(self):
|
|
return self |