mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
0dfd6b6240
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678 Reviewed By: Mortimerp9 Differential Revision: D32653381 Pulled By: dianaml0 fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
from argparse import ArgumentParser
|
|
from dataclasses import dataclass, field
|
|
|
|
from fairseq.dataclass import FairseqDataclass
|
|
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
|
|
|
|
|
@dataclass
|
|
class A(FairseqDataclass):
|
|
data: str = field(default="test", metadata={"help": "the data input"})
|
|
num_layers: int = field(default=200, metadata={"help": "more layers is better?"})
|
|
|
|
|
|
@dataclass
|
|
class B(FairseqDataclass):
|
|
bar: A = field(default=A())
|
|
foo: int = field(default=0, metadata={"help": "not a bar"})
|
|
|
|
|
|
@dataclass
|
|
class D(FairseqDataclass):
|
|
arch: A = field(default=A())
|
|
foo: int = field(default=0, metadata={"help": "not a bar"})
|
|
|
|
|
|
@dataclass
|
|
class C(FairseqDataclass):
|
|
data: str = field(default="test", metadata={"help": "root level data input"})
|
|
encoder: D = field(default=D())
|
|
decoder: A = field(default=A())
|
|
lr: int = field(default=0, metadata={"help": "learning rate"})
|
|
|
|
|
|
class TestDataclassUtils(unittest.TestCase):
|
|
def test_argparse_convert_basic(self):
|
|
parser = ArgumentParser()
|
|
gen_parser_from_dataclass(parser, A(), True)
|
|
args = parser.parse_args(["--num-layers", "10", "the/data/path"])
|
|
self.assertEqual(args.num_layers, 10)
|
|
self.assertEqual(args.data, "the/data/path")
|
|
|
|
def test_argparse_recursive(self):
|
|
parser = ArgumentParser()
|
|
gen_parser_from_dataclass(parser, B(), True)
|
|
args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"])
|
|
self.assertEqual(args.num_layers, 10)
|
|
self.assertEqual(args.foo, 10)
|
|
self.assertEqual(args.data, "the/data/path")
|
|
|
|
def test_argparse_recursive_prefixing(self):
|
|
self.maxDiff = None
|
|
parser = ArgumentParser()
|
|
gen_parser_from_dataclass(parser, C(), True, "")
|
|
args = parser.parse_args(
|
|
[
|
|
"--encoder-arch-data",
|
|
"ENCODER_ARCH_DATA",
|
|
"--encoder-arch-num-layers",
|
|
"10",
|
|
"--encoder-foo",
|
|
"10",
|
|
"--decoder-data",
|
|
"DECODER_DATA",
|
|
"--decoder-num-layers",
|
|
"10",
|
|
"--lr",
|
|
"10",
|
|
"the/data/path",
|
|
]
|
|
)
|
|
self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA")
|
|
self.assertEqual(args.encoder_arch_num_layers, 10)
|
|
self.assertEqual(args.encoder_foo, 10)
|
|
self.assertEqual(args.decoder_data, "DECODER_DATA")
|
|
self.assertEqual(args.decoder_num_layers, 10)
|
|
self.assertEqual(args.lr, 10)
|
|
self.assertEqual(args.data, "the/data/path")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|