mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
Make checkpoint wrapper pickleable (#1603)
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1603 Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26237760 Pulled By: myleott fbshipit-source-id: 73c67bdea4b5b16e3159a5d4f0151e514e853357
This commit is contained in:
parent
0f93bd1a7d
commit
5a170841f2
@ -3,6 +3,7 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -25,29 +26,32 @@ def checkpoint_wrapper(m, offload_to_cpu=False):
|
||||
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
|
||||
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
|
||||
"""
|
||||
original_forward = m.forward
|
||||
|
||||
def _checkpointed_forward(*args, **kwargs):
|
||||
# Autograd Functions in PyTorch work best with positional args, since
|
||||
# the backward must return gradients (or None) for every input argument.
|
||||
# We can flatten keyword arguments to make this easier.
|
||||
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
|
||||
parent_ctx_dict = {"offload": offload_to_cpu}
|
||||
output = CheckpointFunction.apply(
|
||||
original_forward, parent_ctx_dict, kwarg_keys, *flat_args
|
||||
)
|
||||
if isinstance(output, torch.Tensor):
|
||||
return output
|
||||
else:
|
||||
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
|
||||
if packed_non_tensor_outputs:
|
||||
output = unpack_non_tensors(output, packed_non_tensor_outputs)
|
||||
return output
|
||||
|
||||
m.forward = _checkpointed_forward
|
||||
m.forward = functools.partial(
|
||||
_checkpointed_forward,
|
||||
m.forward, # original_forward
|
||||
offload_to_cpu,
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs):
|
||||
# Autograd Functions in PyTorch work best with positional args, since
|
||||
# the backward must return gradients (or None) for every input argument.
|
||||
# We can flatten keyword arguments to make this easier.
|
||||
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
|
||||
parent_ctx_dict = {"offload": offload_to_cpu}
|
||||
output = CheckpointFunction.apply(
|
||||
original_forward, parent_ctx_dict, kwarg_keys, *flat_args
|
||||
)
|
||||
if isinstance(output, torch.Tensor):
|
||||
return output
|
||||
else:
|
||||
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
|
||||
if packed_non_tensor_outputs:
|
||||
output = unpack_non_tensors(output, packed_non_tensor_outputs)
|
||||
return output
|
||||
|
||||
|
||||
def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]:
|
||||
"""
|
||||
Usage::
|
||||
|
@ -12,7 +12,9 @@ from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False):
|
||||
def __init__(
|
||||
self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.use_pytorch_checkpoint = use_pytorch_checkpoint
|
||||
@ -23,7 +25,7 @@ class Model(nn.Module):
|
||||
nn.Linear(128, 32),
|
||||
)
|
||||
if use_fairseq_checkpoint:
|
||||
self.ffn = checkpoint_wrapper(self.ffn)
|
||||
self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
|
||||
self.out = nn.Linear(32, 1)
|
||||
|
||||
def forward(self, x):
|
||||
@ -60,6 +62,11 @@ class TestActivationCheckpointing(unittest.TestCase):
|
||||
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
|
||||
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
|
||||
|
||||
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
|
||||
fairseq_cpt_offload = get_loss_and_gnorm(model)
|
||||
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
|
||||
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
|
||||
|
||||
def test_checkpoint_wrapper_cpu(self):
|
||||
self._test_checkpoint_wrapper(device=torch.device("cpu"))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user