Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2670

Reviewed By: ngoyal2707

Differential Revision: D23982491

Pulled By: myleott

fbshipit-source-id: 629b791d6c05dd67b63dcc2da0313c6799f777f8
This commit is contained in:
Myle Ott 2020-09-29 07:26:18 -07:00 committed by Facebook GitHub Bot
parent a524832d1d
commit caea771afa
2 changed files with 26 additions and 13 deletions

View File

@ -403,9 +403,16 @@ def import_user_module(args):
module_path = fairseq_rel_path
module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
if module_name in sys.modules:
module_bak = sys.modules[module_name]
del sys.modules[module_name]
else:
module_bak = None
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.modules['fairseq_user_dir'] = sys.modules[module_name]
if module_bak is not None and module_name != 'fairseq_user_dir':
sys.modules[module_name] = module_bak
def softmax(x, dim: int, onnx_trace: bool = False):

View File

@ -270,16 +270,22 @@ class TestTranslation(unittest.TestCase):
with tempfile.TemporaryDirectory('test_transformer_pointer_generator') as data_dir:
create_dummy_data(data_dir)
preprocess_summarization_data(data_dir)
train_translation_model(data_dir, 'transformer_pointer_generator', [
'--user-dir', 'examples/pointer_generator/src',
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--alignment-layer', '-1',
'--alignment-heads', '1',
'--source-position-markers', '0',
], run_validation=True)
train_translation_model(
data_dir,
'transformer_pointer_generator',
extra_flags=[
'--user-dir', 'examples/pointer_generator/src',
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--alignment-layer', '-1',
'--alignment-heads', '1',
'--source-position-markers', '0',
],
run_validation=True,
extra_valid_flags=['--user-dir', 'examples/pointer_generator/src'],
)
generate_main(
data_dir,
extra_flags=['--user-dir', 'examples/pointer_generator/src'],