fairseq/tests/test_dataclass_utils.py
Pierre Andrews bc1504d4d7 Hierarchical Configs
Summary:
This is a precursor to D29232595

The current behaviour to convert a dataclass to a namespace is that all the fields from all DCs in the field hierarchy are flattened at the top. This is also the legacy behaviour with `add_args`.

This is kind of cumbersome to build reusable Dataclasses as we need to make sure that each field has a unique  name. In the case of Transformer for instance, we have a Decoder and Encoder config that share a large part of their fields (embed_dim, layers, etc.). We can build a single dataclass for this that can be reused and extended in other implementations. To be then able to have  a flat namespace, instead of adding all subfields as is to the root namespace, we introduce the name of the field as prefix to the arg in the namespace.

So:
`model.decoder.embed_dim` becomes `decoder_embed_dim` and `model.encoder.embed_dim` becomes `encoder_embed_dim`.

Reviewed By: myleott, dianaml0

Differential Revision: D29521386

fbshipit-source-id: f4bef036f0eeb620c6d8709ce97f96ae288848ef
2021-07-16 04:56:12 -07:00

88 lines
2.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from argparse import ArgumentParser
from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
@dataclass
class A(FairseqDataclass):
data: str = field(default="test", metadata={"help": "the data input"})
num_layers: int = field(default=200, metadata={"help": "more layers is better?"})
@dataclass
class B(FairseqDataclass):
bar: A = field(default=A())
foo: int = field(default=0, metadata={"help": "not a bar"})
@dataclass
class D(FairseqDataclass):
arch: A = field(default=A())
foo: int = field(default=0, metadata={"help": "not a bar"})
@dataclass
class C(FairseqDataclass):
data: str = field(default="test", metadata={"help": "root level data input"})
encoder: D = field(default=D())
decoder: A = field(default=A())
lr: int = field(default=0, metadata={"help": "learning rate"})
class TestDataclassUtils(unittest.TestCase):
def test_argparse_convert_basic(self):
parser = ArgumentParser()
gen_parser_from_dataclass(parser, A(), True)
args = parser.parse_args(["--num-layers", '10', "the/data/path"])
self.assertEqual(args.num_layers, 10)
self.assertEqual(args.data, "the/data/path")
def test_argparse_recursive(self):
parser = ArgumentParser()
gen_parser_from_dataclass(parser, B(), True)
args = parser.parse_args(["--num-layers", "10", "--foo", "10", "the/data/path"])
self.assertEqual(args.num_layers, 10)
self.assertEqual(args.foo, 10)
self.assertEqual(args.data, "the/data/path")
def test_argparse_recursive_prefixing(self):
self.maxDiff = None
parser = ArgumentParser()
gen_parser_from_dataclass(parser, C(), True, "")
args = parser.parse_args(
[
"--encoder-arch-data",
"ENCODER_ARCH_DATA",
"--encoder-arch-num-layers",
"10",
"--encoder-foo",
"10",
"--decoder-data",
"DECODER_DATA",
"--decoder-num-layers",
"10",
"--lr",
"10",
"the/data/path",
]
)
self.assertEqual(args.encoder_arch_data, "ENCODER_ARCH_DATA")
self.assertEqual(args.encoder_arch_num_layers, 10)
self.assertEqual(args.encoder_foo, 10)
self.assertEqual(args.decoder_data, "DECODER_DATA")
self.assertEqual(args.decoder_num_layers, 10)
self.assertEqual(args.lr, 10)
self.assertEqual(args.data, "the/data/path")
if __name__ == "__main__":
unittest.main()