gradient accumulation for APPS

This commit is contained in:
Mrinal18 2021-07-16 11:48:10 +00:00
parent feaf6b922a
commit e2240fe910

View File

@ -670,6 +670,34 @@ def main():
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
return loss.mean()
#Function to write gradient checkpointing using Jax
#https://github.com/cybertronai/gradient-checkpointing
def binomial_checkpoint(step, state):
#gradient accumulation
def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
"""Accumulate gradient over multiple steps to save on memory."""
if accum_steps and accum_steps > 1:
assert inputs.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
step_size = inputs.shape[0] // accum_steps
(l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])
def acc_grad_and_loss(i, l_and_g):
inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
(step_size,) + inputs.shape[1:])
lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
(step_size, 1)).squeeze(axis=-1)
(li, _), gi = loss_and_grad_fn(params, inps, lbls)
l, g = l_and_g
return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)
l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
return l, g
else:
return loss_and_grad_fn(params, inputs, labels)
# Define gradient update step fn
def train_step(state, batch):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
@ -682,6 +710,7 @@ def main():
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
accumulate_gradient(grad_fn, state.params, batch, batch["labels"], grad_accum_steps)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)