Stronger --checkpoint-activations test (#1505)

Summary:
- captures and inspects train and valid logs using unittest's `assert_logs_equal`
- asserts that `--checkpoint-activations` does not change `train_loss` or `valid_loss`.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1505

Reviewed By: myleott

Differential Revision: D25544991

Pulled By: sshleifer

fbshipit-source-id: 2762095ab4e7c819a803b3556f5774db8c6b6f39
This commit is contained in:
Sam Shleifer 2020-12-16 19:06:45 -08:00 committed by Facebook GitHub Bot
parent 409032596b
commit c8a0659be5

View File

@ -5,13 +5,14 @@
import contextlib
import logging
import json
import os
import random
import sys
import tempfile
import unittest
from io import StringIO
from typing import List, Dict
import torch
from fairseq import options
from fairseq_cli import eval_lm, train, validate
@ -249,29 +250,6 @@ class TestTranslation(unittest.TestCase):
)
generate_main(data_dir)
def test_transformer_with_activation_checkpointing(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
[
"--encoder-layers",
"2",
"--decoder-layers",
"2",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--checkpoint-activations",
],
run_validation=True,
)
generate_main(data_dir)
def test_multilingual_transformer(self):
# test with all combinations of encoder/decoder lang tokens
encoder_langtok_flags = [
@ -326,7 +304,9 @@ class TestTranslation(unittest.TestCase):
+ dec_ltok_flag,
)
@unittest.skipIf(sys.platform.lower() == "darwin", "skip latent depth test on MacOS")
@unittest.skipIf(
sys.platform.lower() == "darwin", "skip latent depth test on MacOS"
)
def test_multilingual_translation_latent_depth(self):
# test with latent depth in encoder, decoder, or both
encoder_latent_layer = [[], ["--encoder-latent-layer"]]
@ -465,9 +445,7 @@ class TestTranslation(unittest.TestCase):
"test_translation_multi_simple_epoch_dict"
) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(
data_dir, extra_flags=[]
)
preprocess_translation_data(data_dir, extra_flags=[])
train_translation_model(
data_dir,
arch="transformer",
@ -517,9 +495,7 @@ class TestTranslation(unittest.TestCase):
"test_translation_multi_simple_epoch_dict"
) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(
data_dir, extra_flags=[]
)
preprocess_translation_data(data_dir, extra_flags=[])
train_translation_model(
data_dir,
arch="transformer",
@ -619,11 +595,17 @@ class TestTranslation(unittest.TestCase):
"0",
],
run_validation=True,
extra_valid_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"],
extra_valid_flags=[
"--user-dir",
"examples/pointer_generator/pointer_generator_src",
],
)
generate_main(
data_dir,
extra_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"],
extra_flags=[
"--user-dir",
"examples/pointer_generator/pointer_generator_src",
],
)
def test_lightconv(self):
@ -953,7 +935,7 @@ class TestTranslation(unittest.TestCase):
data_dir,
[
"--model-overrides",
"{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}"
"{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}",
],
)
@ -1080,7 +1062,9 @@ class TestLanguageModeling(unittest.TestCase):
def test_transformer_lm_with_adaptive_softmax(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_transformer_lm_with_adaptive_softmax") as data_dir:
with tempfile.TemporaryDirectory(
"test_transformer_lm_with_adaptive_softmax"
) as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
@ -1199,7 +1183,8 @@ class TestLanguageModeling(unittest.TestCase):
train_language_model(
data_dir=data_dir,
arch="transformer_xl",
extra_flags=task_flags + [
extra_flags=task_flags
+ [
"--n-layer",
"2",
],
@ -1537,6 +1522,65 @@ class TestOptimizers(unittest.TestCase):
generate_main(data_dir)
def read_last_log_entry(
logs: List[logging.LogRecord], logger_name: str
) -> Dict[str, float]:
for x in reversed(logs):
if x.name == logger_name:
return json.loads(x.message)
raise ValueError(f"No entries from {logger_name} found in captured logs")
class TestActivationCheckpointing(unittest.TestCase):
def test_activation_checkpointing_does_not_change_metrics(self):
"""--checkpoint-activations should not change loss"""
base_flags = [
"--encoder-layers",
"2",
"--decoder-layers",
"2",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--restore-file",
"x.pt",
"--log-format",
"json",
"--log-interval",
"1",
"--max-update",
"2",
]
def _train(extra_flags):
with self.assertLogs() as logs:
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
base_flags + extra_flags,
run_validation=True,
extra_valid_flags=["--log-format", "json"],
)
return logs.records
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
ckpt_logs = _train(["--checkpoint-activations"])
baseline_logs = _train([])
assert len(baseline_logs) == len(ckpt_logs)
baseline_train_stats = read_last_log_entry(baseline_logs, "train")
ckpt_train_stats = read_last_log_entry(ckpt_logs, "train")
assert baseline_train_stats["train_loss"] == ckpt_train_stats["train_loss"]
baseline_valid_stats = read_last_log_entry(baseline_logs, "valid")
ckpt_valid_stats = read_last_log_entry(ckpt_logs, "valid")
assert baseline_valid_stats["valid_loss"] == ckpt_valid_stats["valid_loss"]
def create_dummy_roberta_head_data(
data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False
):
@ -1653,7 +1697,12 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None):
def train_language_model(
data_dir, arch, extra_flags=None, run_validation=False, extra_valid_flags=None, task="language_modeling"
data_dir,
arch,
extra_flags=None,
run_validation=False,
extra_valid_flags=None,
task="language_modeling",
):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
@ -1723,7 +1772,8 @@ def eval_lm_main(data_dir, extra_flags=None):
"--no-progress-bar",
"--num-workers",
"0",
] + (extra_flags or []),
]
+ (extra_flags or []),
)
eval_lm.main(eval_lm_args)