Add test for activation checkpointing (#1453)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1453

Test Plan: Imported from OSS

Reviewed By: sshleifer

Differential Revision: D25108463

Pulled By: myleott

fbshipit-source-id: 3cebce9be7fe503401eabba3f483c26847e7a3c0
This commit is contained in:
Myle Ott 2020-11-20 12:40:49 -08:00 committed by Facebook GitHub Bot
parent 94f59bb67b
commit fa113ff1de
2 changed files with 26 additions and 2 deletions

View File

@ -6,6 +6,7 @@
from typing import Any, Dict, List, Tuple, Union
import torch
import torch.utils.checkpoint as checkpoint
from fairseq import utils
@ -133,7 +134,7 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args):
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation
torch.utils.checkpoint.check_backward_validity(args)
checkpoint.check_backward_validity(args)
ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
@ -165,7 +166,7 @@ class CheckpointFunction(torch.autograd.Function):
)
tensor_inputs = ctx.saved_tensors
tensor_inputs = torch.utils.checkpoint.detach_variable(tensor_inputs)
tensor_inputs = checkpoint.detach_variable(tensor_inputs)
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)
# Store the current states.

View File

@ -249,6 +249,29 @@ class TestTranslation(unittest.TestCase):
)
generate_main(data_dir)
def test_transformer_with_activation_checkpointing(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
[
"--encoder-layers",
"2",
"--decoder-layers",
"2",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--checkpoint-activations",
],
run_validation=True,
)
generate_main(data_dir)
def test_multilingual_transformer(self):
# test with all combinations of encoder/decoder lang tokens
encoder_langtok_flags = [