Make checkpoint wrapper pickleable (#1603)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1603

Test Plan: Imported from OSS

Reviewed By: sshleifer

Differential Revision: D26237760

Pulled By: myleott

fbshipit-source-id: 73c67bdea4b5b16e3159a5d4f0151e514e853357
This commit is contained in:
Myle Ott 2021-02-06 08:05:41 -08:00 committed by Facebook GitHub Bot
parent 0f93bd1a7d
commit 5a170841f2
2 changed files with 33 additions and 22 deletions

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Dict, List, Tuple, Union
import torch
@ -25,29 +26,32 @@ def checkpoint_wrapper(m, offload_to_cpu=False):
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
"""
original_forward = m.forward
def _checkpointed_forward(*args, **kwargs):
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(
original_forward, parent_ctx_dict, kwarg_keys, *flat_args
)
if isinstance(output, torch.Tensor):
return output
else:
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs)
return output
m.forward = _checkpointed_forward
m.forward = functools.partial(
_checkpointed_forward,
m.forward, # original_forward
offload_to_cpu,
)
return m
def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs):
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(
original_forward, parent_ctx_dict, kwarg_keys, *flat_args
)
if isinstance(output, torch.Tensor):
return output
else:
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs)
return output
def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]:
"""
Usage::

View File

@ -12,7 +12,9 @@ from torch.utils.checkpoint import checkpoint
class Model(nn.Module):
def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False):
def __init__(
self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs
):
super().__init__()
torch.manual_seed(0)
self.use_pytorch_checkpoint = use_pytorch_checkpoint
@ -23,7 +25,7 @@ class Model(nn.Module):
nn.Linear(128, 32),
)
if use_fairseq_checkpoint:
self.ffn = checkpoint_wrapper(self.ffn)
self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
self.out = nn.Linear(32, 1)
def forward(self, x):
@ -60,6 +62,11 @@ class TestActivationCheckpointing(unittest.TestCase):
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
fairseq_cpt_offload = get_loss_and_gnorm(model)
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
def test_checkpoint_wrapper_cpu(self):
self._test_checkpoint_wrapper(device=torch.device("cpu"))