mirror of
https://github.com/CodedotAl/gpt-code-clippy.git
synced 2024-08-16 10:20:28 +03:00
gradient accumulation for APPS
This commit is contained in:
parent
feaf6b922a
commit
e2240fe910
@ -670,6 +670,34 @@ def main():
|
||||
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
|
||||
return loss.mean()
|
||||
|
||||
#Function to write gradient checkpointing using Jax
|
||||
#https://github.com/cybertronai/gradient-checkpointing
|
||||
|
||||
def binomial_checkpoint(step, state):
|
||||
|
||||
#gradient accumulation
|
||||
def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
|
||||
"""Accumulate gradient over multiple steps to save on memory."""
|
||||
if accum_steps and accum_steps > 1:
|
||||
assert inputs.shape[0] % accum_steps == 0, (
|
||||
f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
|
||||
step_size = inputs.shape[0] // accum_steps
|
||||
(l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])
|
||||
|
||||
def acc_grad_and_loss(i, l_and_g):
|
||||
inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
|
||||
(step_size,) + inputs.shape[1:])
|
||||
lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
|
||||
(step_size, 1)).squeeze(axis=-1)
|
||||
(li, _), gi = loss_and_grad_fn(params, inps, lbls)
|
||||
l, g = l_and_g
|
||||
return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)
|
||||
|
||||
l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
|
||||
l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
|
||||
return l, g
|
||||
else:
|
||||
return loss_and_grad_fn(params, inputs, labels)
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||
@ -682,6 +710,7 @@ def main():
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
accumulate_gradient(grad_fn, state.params, batch, batch["labels"], grad_accum_steps)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
Loading…
Reference in New Issue
Block a user