mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
50158da3a7
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/pytorch/fairseq/pull/3593 Reviewed By: msbaines Differential Revision: D28992614 Pulled By: dianaml0 fbshipit-source-id: b2dfcab472a65c41536e78600a0e6b3745dc3a08
139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
from fairseq import options
|
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
|
from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
|
|
from .utils import create_dummy_data, preprocess_lm_data, train_language_model
|
|
|
|
|
|
def make_lm_config(
|
|
data_dir=None,
|
|
extra_flags=None,
|
|
task="language_modeling",
|
|
arch="transformer_lm_gpt2_tiny",
|
|
):
|
|
task_args = [task]
|
|
if data_dir is not None:
|
|
task_args += [data_dir]
|
|
train_parser = options.get_training_parser()
|
|
train_args = options.parse_args_and_arch(
|
|
train_parser,
|
|
[
|
|
"--task",
|
|
*task_args,
|
|
"--arch",
|
|
arch,
|
|
"--optimizer",
|
|
"adam",
|
|
"--lr",
|
|
"0.0001",
|
|
"--max-tokens",
|
|
"500",
|
|
"--tokens-per-sample",
|
|
"500",
|
|
"--save-dir",
|
|
data_dir,
|
|
"--max-epoch",
|
|
"1",
|
|
]
|
|
+ (extra_flags or []),
|
|
)
|
|
cfg = convert_namespace_to_omegaconf(train_args)
|
|
return cfg
|
|
|
|
|
|
def write_empty_file(path):
|
|
with open(path, "w"):
|
|
pass
|
|
assert os.path.exists(path)
|
|
|
|
|
|
class TestValidSubsetsErrors(unittest.TestCase):
|
|
"""Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
|
|
|
|
def _test_case(self, paths, extra_flags):
|
|
with tempfile.TemporaryDirectory() as data_dir:
|
|
[
|
|
write_empty_file(os.path.join(data_dir, f"{p}.bin"))
|
|
for p in paths + ["train"]
|
|
]
|
|
cfg = make_lm_config(data_dir, extra_flags=extra_flags)
|
|
raise_if_valid_subsets_unintentionally_ignored(cfg)
|
|
|
|
def test_default_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
self._test_case(["valid", "valid1"], [])
|
|
with self.assertRaises(ValueError):
|
|
self._test_case(
|
|
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
|
|
)
|
|
|
|
def partially_specified_valid_subsets(self):
|
|
with self.assertRaises(ValueError):
|
|
self._test_case(
|
|
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
|
|
)
|
|
# Fix with ignore unused
|
|
self._test_case(
|
|
["valid", "valid1", "valid2"],
|
|
["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
|
|
)
|
|
|
|
def test_legal_configs(self):
|
|
self._test_case(["valid"], [])
|
|
self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
|
|
self._test_case(["valid", "valid1"], ["--combine-val"])
|
|
self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
|
|
self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
|
|
self._test_case(
|
|
["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
|
|
)
|
|
self._test_case(
|
|
["valid1"], ["--valid-subset", "valid1"]
|
|
) # valid.bin doesn't need to be ignored.
|
|
|
|
def test_disable_validation(self):
|
|
self._test_case([], ["--disable-validation"])
|
|
self._test_case(["valid", "valid1"], ["--disable-validation"])
|
|
|
|
def test_dummy_task(self):
|
|
cfg = make_lm_config(task="dummy_lm")
|
|
raise_if_valid_subsets_unintentionally_ignored(cfg)
|
|
|
|
def test_masked_dummy_task(self):
|
|
cfg = make_lm_config(task="dummy_masked_lm")
|
|
raise_if_valid_subsets_unintentionally_ignored(cfg)
|
|
|
|
|
|
class TestCombineValidSubsets(unittest.TestCase):
|
|
def _train(self, extra_flags):
|
|
with self.assertLogs() as logs:
|
|
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
|
|
create_dummy_data(data_dir, num_examples=20)
|
|
preprocess_lm_data(data_dir)
|
|
|
|
shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
|
|
shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
|
|
train_language_model(
|
|
data_dir,
|
|
"transformer_lm",
|
|
["--max-update", "0", "--log-format", "json"] + extra_flags,
|
|
run_validation=False,
|
|
)
|
|
return [x.message for x in logs.records]
|
|
|
|
def test_combined(self):
|
|
flags = ["--combine-valid-subsets"]
|
|
logs = self._train(flags)
|
|
assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
|
|
assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
|
|
|
|
def test_subsets(self):
|
|
flags = ["--valid-subset", "valid,valid1"]
|
|
logs = self._train(flags)
|
|
assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
|
|
assert any(["valid1_ppl" in x for x in logs]) # metrics are combined
|