--combine-valid-sets (#1843)

Summary:
- `--combine-valid-sets` causes valid.bin, valid1.bin, ... to be concatenated. All metrics will be reported together.
- `--valid-subsets` works the same. If you pass `--valid-subsets valid1,valid2` you get valid1_loss and valid2_loss logged separately.
- if user passes `--valid-subset valid` (the default) and we see files named valid1, valid2 we raise an error. User must pass `--ignore-unused-valid-sets` to override. This previously led to valid1, valid2 being silently ignored.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1843

Reviewed By: myleott

Differential Revision: D28323815

Pulled By: sshleifer

fbshipit-source-id: dfd46076d3f684e36f8dacfadd38fd0038ce6755
This commit is contained in:
Sam Shleifer 2021-05-10 23:42:41 -07:00 committed by Facebook GitHub Bot
parent a2314b4e8a
commit 97969ac5f5
5 changed files with 182 additions and 7 deletions

View File

@ -10,7 +10,7 @@ except ImportError:
import contextlib
import itertools
import logging
import os
import re
import warnings
from typing import Optional, Tuple
@ -18,7 +18,8 @@ import numpy as np
import torch
from fairseq.file_io import PathManager
from fairseq import utils
import os
logger = logging.getLogger(__name__)
@ -68,7 +69,6 @@ def collate_tokens(
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def load_indexed_dataset(
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
):
@ -558,3 +558,33 @@ def get_bucketed_sizes(orig_sizes, buckets):
sizes[mask] = end_val
start_val = end_val
return sizes
def _find_extra_valid_paths(dataset_path: str) -> set:
paths = utils.split_paths(dataset_path)
all_valid_paths = set()
for sub_dir in paths:
contents = PathManager.ls(sub_dir)
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
# Remove .bin, .idx etc
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
return roots
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
if (
train_cfg.dataset.ignore_unused_valid_subsets
or train_cfg.dataset.combine_valid_subsets
or train_cfg.dataset.disable_validation
):
return
other_paths = _find_extra_valid_paths(train_cfg.task.data)
specified_subsets = train_cfg.dataset.valid_subset.split(",")
ignored_paths = [p for p in other_paths if p not in specified_subsets]
if ignored_paths:
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
raise ValueError(msg)

View File

@ -426,6 +426,19 @@ class DatasetConfig(FairseqDataclass):
" (e.g. train, valid, test)"
},
)
combine_valid_subsets: Optional[bool] = field(
default=None,
metadata={
"help": "comma separated list of data subsets to use for validation"
" (e.g. train, valid, test)",
"argparse_alias": "--combine-val",
},
)
ignore_unused_valid_subsets: Optional[bool] = field(
default=False,
metadata={"help": "do not raise error if valid subsets are ignored"},
)
validate_interval: int = field(
default=1, metadata={"help": "validate every N epochs"}
)

View File

@ -201,7 +201,7 @@ class LanguageModelingTask(LegacyFairseqTask):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
split (str): name of the split (e.g., train, valid, valid1, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0

View File

@ -23,7 +23,7 @@ from fairseq import (
tasks,
utils,
)
from fairseq.data import iterators
from fairseq.data import iterators, data_utils
from fairseq.data.plasma_utils import PlasmaStore
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
@ -114,8 +114,12 @@ def main(cfg: FairseqConfig) -> None:
# Load valid dataset (we load training data below, based on the latest checkpoint)
# We load the valid dataset AFTER building the model
for valid_sub_split in cfg.dataset.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
if cfg.dataset.combine_valid_subsets:
task.load_dataset("valid", combine=True, epoch=1)
else:
for valid_sub_split in cfg.dataset.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
# (optionally) Configure quantization
if cfg.common.quantization_config_path is not None:

View File

@ -0,0 +1,128 @@
import os
import shutil
import tempfile
import unittest
from fairseq import options
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
from .utils import create_dummy_data, preprocess_lm_data, train_language_model
def make_lm_config(
data_dir,
extra_flags=None,
task="language_modeling",
arch="transformer_lm_gpt2_tiny",
):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
"--task",
task,
data_dir,
"--arch",
arch,
"--optimizer",
"adam",
"--lr",
"0.0001",
"--max-tokens",
"500",
"--tokens-per-sample",
"500",
"--save-dir",
data_dir,
"--max-epoch",
"1",
]
+ (extra_flags or []),
)
cfg = convert_namespace_to_omegaconf(train_args)
return cfg
def write_empty_file(path):
with open(path, "w"):
pass
assert os.path.exists(path)
class TestValidSubsetsErrors(unittest.TestCase):
"""Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
def _test_case(self, paths, extra_flags):
with tempfile.TemporaryDirectory() as data_dir:
[
write_empty_file(os.path.join(data_dir, f"{p}.bin"))
for p in paths + ["train"]
]
cfg = make_lm_config(data_dir, extra_flags=extra_flags)
raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_default_raises(self):
with self.assertRaises(ValueError):
self._test_case(["valid", "valid1"], [])
with self.assertRaises(ValueError):
self._test_case(
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
)
def partially_specified_valid_subsets(self):
with self.assertRaises(ValueError):
self._test_case(
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
)
# Fix with ignore unused
self._test_case(
["valid", "valid1", "valid2"],
["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
)
def test_legal_configs(self):
self._test_case(["valid"], [])
self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
self._test_case(["valid", "valid1"], ["--combine-val"])
self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
self._test_case(
["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
)
self._test_case(
["valid1"], ["--valid-subset", "valid1"]
) # valid.bin doesn't need to be ignored.
def test_disable_validation(self):
self._test_case([], ["--disable-validation"])
self._test_case(["valid", "valid1"], ["--disable-validation"])
class TestCombineValidSubsets(unittest.TestCase):
def _train(self, extra_flags):
with self.assertLogs() as logs:
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_lm_data(data_dir)
shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
train_language_model(
data_dir,
"transformer_lm",
["--max-update", "0", "--log-format", "json"] + extra_flags,
run_validation=False,
)
return [x.message for x in logs.records]
def test_combined(self):
flags = ["--combine-valid-subsets"]
logs = self._train(flags)
assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
def test_subsets(self):
flags = ["--valid-subset", "valid,valid1"]
logs = self._train(flags)
assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
assert any(["valid1_ppl" in x for x in logs]) # metrics are combined