2019-07-30 17:45:13 +03:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
2018-02-28 01:09:42 +03:00
|
|
|
#
|
2019-07-30 17:45:13 +03:00
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
2018-02-28 01:09:42 +03:00
|
|
|
|
2018-09-25 21:02:34 +03:00
|
|
|
import argparse
|
2021-02-18 14:09:14 +03:00
|
|
|
import json
|
2020-06-04 04:49:00 +03:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import sys
|
2020-10-19 04:13:29 +03:00
|
|
|
from io import StringIO
|
|
|
|
|
2018-02-28 01:09:42 +03:00
|
|
|
import torch
|
2020-03-10 21:48:30 +03:00
|
|
|
import torch.nn.functional as F
|
2020-06-04 04:49:00 +03:00
|
|
|
from fairseq import options, utils
|
2018-06-12 20:39:41 +03:00
|
|
|
from fairseq.data import Dictionary
|
|
|
|
from fairseq.data.language_pair_dataset import collate
|
2018-02-28 01:09:42 +03:00
|
|
|
from fairseq.models import (
|
|
|
|
FairseqEncoder,
|
2019-05-15 17:09:48 +03:00
|
|
|
FairseqEncoderDecoderModel,
|
2018-02-28 01:09:42 +03:00
|
|
|
FairseqIncrementalDecoder,
|
|
|
|
)
|
2020-03-24 02:06:30 +03:00
|
|
|
from fairseq.models.fairseq_encoder import EncoderOut
|
2020-10-20 10:31:00 +03:00
|
|
|
from fairseq.tasks import LegacyFairseqTask
|
2020-10-19 04:13:29 +03:00
|
|
|
from fairseq_cli import generate, interactive, preprocess, train, validate
|
2021-04-21 16:38:03 +03:00
|
|
|
import fairseq.distributed.utils as distributed_utils
|
|
|
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
|
2020-10-19 04:13:29 +03:00
|
|
|
def dummy_dictionary(vocab_size, prefix="token_"):
|
2018-06-12 20:39:41 +03:00
|
|
|
d = Dictionary()
|
2018-02-28 01:09:42 +03:00
|
|
|
for i in range(vocab_size):
|
|
|
|
token = prefix + str(i)
|
|
|
|
d.add_symbol(token)
|
2018-04-24 21:46:54 +03:00
|
|
|
d.finalize(padding_factor=1) # don't add extra padding symbols
|
2018-02-28 01:09:42 +03:00
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
def dummy_dataloader(
|
2021-04-21 16:38:03 +03:00
|
|
|
samples, padding_idx=1, eos_idx=2, batch_size=None,
|
2018-02-28 01:09:42 +03:00
|
|
|
):
|
|
|
|
if batch_size is None:
|
|
|
|
batch_size = len(samples)
|
|
|
|
|
|
|
|
# add any missing data to samples
|
|
|
|
for i, sample in enumerate(samples):
|
2020-10-19 04:13:29 +03:00
|
|
|
if "id" not in sample:
|
|
|
|
sample["id"] = i
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
# create dataloader
|
|
|
|
dataset = TestDataset(samples)
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
|
dataset,
|
|
|
|
batch_size=batch_size,
|
2018-06-12 20:39:41 +03:00
|
|
|
collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
|
2018-02-28 01:09:42 +03:00
|
|
|
)
|
|
|
|
return iter(dataloader)
|
|
|
|
|
|
|
|
|
2018-09-25 21:02:34 +03:00
|
|
|
def sequence_generator_setup():
|
|
|
|
# construct dummy dictionary
|
|
|
|
d = dummy_dictionary(vocab_size=2)
|
|
|
|
|
|
|
|
eos = d.eos()
|
|
|
|
w1 = 4
|
|
|
|
w2 = 5
|
|
|
|
|
|
|
|
# construct source data
|
|
|
|
src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
|
|
|
|
src_lengths = torch.LongTensor([2, 2])
|
|
|
|
|
|
|
|
args = argparse.Namespace()
|
2020-10-19 04:13:29 +03:00
|
|
|
unk = 0.0
|
2018-09-25 21:02:34 +03:00
|
|
|
args.beam_probs = [
|
|
|
|
# step 0:
|
2020-10-19 04:13:29 +03:00
|
|
|
torch.FloatTensor(
|
|
|
|
[
|
|
|
|
# eos w1 w2
|
|
|
|
# sentence 1:
|
|
|
|
[0.0, unk, 0.9, 0.1], # beam 1
|
|
|
|
[0.0, unk, 0.9, 0.1], # beam 2
|
|
|
|
# sentence 2:
|
|
|
|
[0.0, unk, 0.7, 0.3],
|
|
|
|
[0.0, unk, 0.7, 0.3],
|
|
|
|
]
|
|
|
|
),
|
2018-09-25 21:02:34 +03:00
|
|
|
# step 1:
|
2020-10-19 04:13:29 +03:00
|
|
|
torch.FloatTensor(
|
|
|
|
[
|
|
|
|
# eos w1 w2 prefix
|
|
|
|
# sentence 1:
|
|
|
|
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
|
|
|
|
[0.0, unk, 0.9, 0.1], # w2: 0.1
|
|
|
|
# sentence 2:
|
|
|
|
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
|
|
|
|
[0.00, unk, 0.10, 0.9], # w2: 0.3
|
|
|
|
]
|
|
|
|
),
|
2018-09-25 21:02:34 +03:00
|
|
|
# step 2:
|
2020-10-19 04:13:29 +03:00
|
|
|
torch.FloatTensor(
|
|
|
|
[
|
|
|
|
# eos w1 w2 prefix
|
|
|
|
# sentence 1:
|
|
|
|
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
|
|
|
|
[
|
|
|
|
0.6,
|
|
|
|
unk,
|
|
|
|
0.2,
|
|
|
|
0.2,
|
|
|
|
], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
|
|
|
|
# sentence 2:
|
|
|
|
[
|
|
|
|
0.60,
|
|
|
|
unk,
|
|
|
|
0.4,
|
|
|
|
0.00,
|
|
|
|
], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
|
|
|
|
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
|
|
|
|
]
|
|
|
|
),
|
2018-09-25 21:02:34 +03:00
|
|
|
# step 3:
|
2020-10-19 04:13:29 +03:00
|
|
|
torch.FloatTensor(
|
|
|
|
[
|
|
|
|
# eos w1 w2 prefix
|
|
|
|
# sentence 1:
|
|
|
|
[
|
|
|
|
1.0,
|
|
|
|
unk,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
|
|
|
|
[
|
|
|
|
1.0,
|
|
|
|
unk,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
|
|
|
|
# sentence 2:
|
|
|
|
[
|
|
|
|
0.1,
|
|
|
|
unk,
|
|
|
|
0.5,
|
|
|
|
0.4,
|
|
|
|
], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
|
|
|
|
[
|
|
|
|
1.0,
|
|
|
|
unk,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
|
|
|
|
]
|
|
|
|
),
|
2018-09-25 21:02:34 +03:00
|
|
|
]
|
|
|
|
|
|
|
|
task = TestTranslationTask.setup_task(args, d, d)
|
|
|
|
model = task.build_model(args)
|
|
|
|
tgt_dict = task.target_dictionary
|
|
|
|
|
|
|
|
return tgt_dict, w1, w2, src_tokens, src_lengths, model
|
|
|
|
|
|
|
|
|
2020-06-04 04:49:00 +03:00
|
|
|
def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False):
|
|
|
|
def _create_dummy_data(filename):
|
|
|
|
data = torch.rand(num_examples * maxlen)
|
|
|
|
data = 97 + torch.floor(26 * data).int()
|
2020-10-19 04:13:29 +03:00
|
|
|
with open(os.path.join(data_dir, filename), "w") as h:
|
2020-06-04 04:49:00 +03:00
|
|
|
offset = 0
|
|
|
|
for _ in range(num_examples):
|
|
|
|
ex_len = random.randint(1, maxlen)
|
2020-10-19 04:13:29 +03:00
|
|
|
ex_str = " ".join(map(chr, data[offset : offset + ex_len]))
|
2020-06-04 04:49:00 +03:00
|
|
|
print(ex_str, file=h)
|
|
|
|
offset += ex_len
|
|
|
|
|
|
|
|
def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
|
2020-10-19 04:13:29 +03:00
|
|
|
with open(os.path.join(data_dir, filename_src), "r") as src_f, open(
|
|
|
|
os.path.join(data_dir, filename_tgt), "r"
|
|
|
|
) as tgt_f, open(os.path.join(data_dir, filename), "w") as h:
|
2020-06-04 04:49:00 +03:00
|
|
|
for src, tgt in zip(src_f, tgt_f):
|
|
|
|
src_len = len(src.split())
|
|
|
|
tgt_len = len(tgt.split())
|
|
|
|
avg_len = (src_len + tgt_len) // 2
|
|
|
|
num_alignments = random.randint(avg_len // 2, 2 * avg_len)
|
|
|
|
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
|
|
|
|
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
|
2020-10-19 04:13:29 +03:00
|
|
|
ex_str = " ".join(
|
|
|
|
[
|
|
|
|
"{}-{}".format(src, tgt)
|
|
|
|
for src, tgt in zip(src_indices, tgt_indices)
|
|
|
|
]
|
|
|
|
)
|
2020-06-04 04:49:00 +03:00
|
|
|
print(ex_str, file=h)
|
|
|
|
|
2020-10-19 04:13:29 +03:00
|
|
|
_create_dummy_data("train.in")
|
|
|
|
_create_dummy_data("train.out")
|
|
|
|
_create_dummy_data("valid.in")
|
|
|
|
_create_dummy_data("valid.out")
|
|
|
|
_create_dummy_data("test.in")
|
|
|
|
_create_dummy_data("test.out")
|
2020-06-04 04:49:00 +03:00
|
|
|
|
|
|
|
if alignment:
|
2020-10-19 04:13:29 +03:00
|
|
|
_create_dummy_alignment_data("train.in", "train.out", "train.align")
|
|
|
|
_create_dummy_alignment_data("valid.in", "valid.out", "valid.align")
|
|
|
|
_create_dummy_alignment_data("test.in", "test.out", "test.align")
|
2020-06-04 04:49:00 +03:00
|
|
|
|
|
|
|
|
|
|
|
def preprocess_lm_data(data_dir):
|
|
|
|
preprocess_parser = options.get_preprocessing_parser()
|
2020-10-19 04:13:29 +03:00
|
|
|
preprocess_args = preprocess_parser.parse_args(
|
|
|
|
[
|
|
|
|
"--only-source",
|
|
|
|
"--trainpref",
|
|
|
|
os.path.join(data_dir, "train.out"),
|
|
|
|
"--validpref",
|
|
|
|
os.path.join(data_dir, "valid.out"),
|
|
|
|
"--testpref",
|
|
|
|
os.path.join(data_dir, "test.out"),
|
|
|
|
"--destdir",
|
|
|
|
data_dir,
|
|
|
|
]
|
|
|
|
)
|
2020-06-04 04:49:00 +03:00
|
|
|
preprocess.main(preprocess_args)
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_translation_data(data_dir, extra_flags=None):
|
|
|
|
preprocess_parser = options.get_preprocessing_parser()
|
|
|
|
preprocess_args = preprocess_parser.parse_args(
|
|
|
|
[
|
2020-10-19 04:13:29 +03:00
|
|
|
"--source-lang",
|
|
|
|
"in",
|
|
|
|
"--target-lang",
|
|
|
|
"out",
|
|
|
|
"--trainpref",
|
|
|
|
os.path.join(data_dir, "train"),
|
|
|
|
"--validpref",
|
|
|
|
os.path.join(data_dir, "valid"),
|
|
|
|
"--testpref",
|
|
|
|
os.path.join(data_dir, "test"),
|
|
|
|
"--thresholdtgt",
|
|
|
|
"0",
|
|
|
|
"--thresholdsrc",
|
|
|
|
"0",
|
|
|
|
"--destdir",
|
|
|
|
data_dir,
|
|
|
|
]
|
|
|
|
+ (extra_flags or []),
|
2020-06-04 04:49:00 +03:00
|
|
|
)
|
|
|
|
preprocess.main(preprocess_args)
|
|
|
|
|
|
|
|
|
2020-09-25 18:28:10 +03:00
|
|
|
def preprocess_summarization_data(data_dir, extra_flags=None):
|
|
|
|
preprocess_parser = options.get_preprocessing_parser()
|
|
|
|
preprocess_args = preprocess_parser.parse_args(
|
|
|
|
[
|
2020-10-19 04:13:29 +03:00
|
|
|
"--source-lang",
|
|
|
|
"in",
|
|
|
|
"--target-lang",
|
|
|
|
"out",
|
|
|
|
"--trainpref",
|
|
|
|
os.path.join(data_dir, "train"),
|
|
|
|
"--validpref",
|
|
|
|
os.path.join(data_dir, "valid"),
|
|
|
|
"--testpref",
|
|
|
|
os.path.join(data_dir, "test"),
|
|
|
|
"--thresholdtgt",
|
|
|
|
"0",
|
|
|
|
"--thresholdsrc",
|
|
|
|
"0",
|
|
|
|
"--joined-dictionary",
|
|
|
|
"--destdir",
|
|
|
|
data_dir,
|
|
|
|
]
|
|
|
|
+ (extra_flags or []),
|
2020-09-25 18:28:10 +03:00
|
|
|
)
|
|
|
|
preprocess.main(preprocess_args)
|
|
|
|
|
|
|
|
|
2021-02-18 14:09:14 +03:00
|
|
|
def create_laser_data_and_config_json(data_dir):
|
|
|
|
src_langs = ["de", "fr", "ru", "tr", "zh"]
|
|
|
|
tgt_langs = ["en", "es"]
|
|
|
|
config_json = {}
|
|
|
|
config_train_json = []
|
|
|
|
src_vocab = None
|
|
|
|
tgt_vocab = None
|
|
|
|
|
|
|
|
for src_lang in src_langs:
|
|
|
|
for tgt_lang in tgt_langs:
|
|
|
|
langpair_folder = f"{src_lang}-{tgt_lang}"
|
|
|
|
|
|
|
|
langpair_path = os.path.join(data_dir, langpair_folder)
|
|
|
|
os.mkdir(langpair_path)
|
|
|
|
create_dummy_data(langpair_path)
|
|
|
|
preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"])
|
|
|
|
|
|
|
|
src_vocab = os.path.join(langpair_path, "dict.in.txt")
|
|
|
|
tgt_vocab = os.path.join(langpair_path, "dict.out.txt")
|
|
|
|
config_train_json.append(
|
|
|
|
{
|
|
|
|
"id": 0 if tgt_lang == "en" else 1,
|
|
|
|
"src": os.path.join(langpair_path, "train.in-out.in"),
|
|
|
|
"tgt": os.path.join(langpair_path, "train.in-out.out"),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
config_json["src_vocab"] = src_vocab
|
|
|
|
config_json["tgt_vocab"] = tgt_vocab
|
|
|
|
config_json["train"] = config_train_json
|
|
|
|
|
|
|
|
with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file:
|
|
|
|
json.dump(config_json, config_file)
|
|
|
|
|
|
|
|
return config_file
|
|
|
|
|
|
|
|
|
2020-10-19 04:13:29 +03:00
|
|
|
def train_translation_model(
|
|
|
|
data_dir,
|
|
|
|
arch,
|
|
|
|
extra_flags=None,
|
|
|
|
task="translation",
|
|
|
|
run_validation=False,
|
|
|
|
lang_flags=None,
|
|
|
|
extra_valid_flags=None,
|
2021-04-21 16:38:03 +03:00
|
|
|
world_size=1,
|
2020-10-19 04:13:29 +03:00
|
|
|
):
|
2020-06-04 04:49:00 +03:00
|
|
|
if lang_flags is None:
|
|
|
|
lang_flags = [
|
2020-10-19 04:13:29 +03:00
|
|
|
"--source-lang",
|
|
|
|
"in",
|
|
|
|
"--target-lang",
|
|
|
|
"out",
|
2020-06-04 04:49:00 +03:00
|
|
|
]
|
|
|
|
train_parser = options.get_training_parser()
|
|
|
|
train_args = options.parse_args_and_arch(
|
|
|
|
train_parser,
|
|
|
|
[
|
2020-10-19 04:13:29 +03:00
|
|
|
"--task",
|
|
|
|
task,
|
2020-06-04 04:49:00 +03:00
|
|
|
data_dir,
|
2020-10-19 04:13:29 +03:00
|
|
|
"--save-dir",
|
|
|
|
data_dir,
|
|
|
|
"--arch",
|
|
|
|
arch,
|
|
|
|
"--optimizer",
|
|
|
|
"nag",
|
|
|
|
"--lr",
|
|
|
|
"0.05",
|
|
|
|
"--max-tokens",
|
|
|
|
"500",
|
|
|
|
"--max-epoch",
|
|
|
|
"1",
|
|
|
|
"--no-progress-bar",
|
|
|
|
"--distributed-world-size",
|
2021-04-21 16:38:03 +03:00
|
|
|
str(world_size),
|
2020-10-19 04:13:29 +03:00
|
|
|
"--num-workers",
|
|
|
|
"0",
|
|
|
|
]
|
|
|
|
+ lang_flags
|
|
|
|
+ (extra_flags or []),
|
2020-06-04 04:49:00 +03:00
|
|
|
)
|
2021-04-21 16:38:03 +03:00
|
|
|
|
|
|
|
cfg = convert_namespace_to_omegaconf(train_args)
|
|
|
|
distributed_utils.call_main(cfg, train.main)
|
2020-06-04 04:49:00 +03:00
|
|
|
|
|
|
|
if run_validation:
|
|
|
|
# test validation
|
|
|
|
validate_parser = options.get_validation_parser()
|
|
|
|
validate_args = options.parse_args_and_arch(
|
|
|
|
validate_parser,
|
|
|
|
[
|
2020-10-19 04:13:29 +03:00
|
|
|
"--task",
|
|
|
|
task,
|
2020-06-04 04:49:00 +03:00
|
|
|
data_dir,
|
2020-10-19 04:13:29 +03:00
|
|
|
"--path",
|
|
|
|
os.path.join(data_dir, "checkpoint_last.pt"),
|
|
|
|
"--valid-subset",
|
|
|
|
"valid",
|
|
|
|
"--max-tokens",
|
|
|
|
"500",
|
|
|
|
"--no-progress-bar",
|
|
|
|
"--num-workers",
|
|
|
|
"0",
|
|
|
|
]
|
|
|
|
+ lang_flags
|
|
|
|
+ (extra_valid_flags or []),
|
2020-06-04 04:49:00 +03:00
|
|
|
)
|
|
|
|
validate.main(validate_args)
|
|
|
|
|
|
|
|
|
2020-11-20 23:40:49 +03:00
|
|
|
def generate_main(data_dir, extra_flags=None, path=None):
|
2020-06-04 04:49:00 +03:00
|
|
|
if extra_flags is None:
|
|
|
|
extra_flags = [
|
2020-10-19 04:13:29 +03:00
|
|
|
"--print-alignment",
|
2020-06-04 04:49:00 +03:00
|
|
|
]
|
2020-11-20 23:40:49 +03:00
|
|
|
if path is None:
|
|
|
|
path = os.path.join(data_dir, "checkpoint_last.pt")
|
2020-06-04 04:49:00 +03:00
|
|
|
generate_parser = options.get_generation_parser()
|
|
|
|
generate_args = options.parse_args_and_arch(
|
|
|
|
generate_parser,
|
|
|
|
[
|
|
|
|
data_dir,
|
2020-10-19 04:13:29 +03:00
|
|
|
"--path",
|
2020-11-20 23:40:49 +03:00
|
|
|
path,
|
2020-10-19 04:13:29 +03:00
|
|
|
"--beam",
|
|
|
|
"3",
|
|
|
|
"--batch-size",
|
|
|
|
"64",
|
|
|
|
"--max-len-b",
|
|
|
|
"5",
|
|
|
|
"--gen-subset",
|
|
|
|
"valid",
|
|
|
|
"--no-progress-bar",
|
|
|
|
"--num-workers",
|
|
|
|
"0",
|
|
|
|
]
|
|
|
|
+ (extra_flags or []),
|
2020-06-04 04:49:00 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# evaluate model in batch mode
|
|
|
|
generate.main(generate_args)
|
|
|
|
|
|
|
|
# evaluate model interactively
|
|
|
|
generate_args.buffer_size = 0
|
2020-10-19 04:13:29 +03:00
|
|
|
generate_args.input = "-"
|
2020-10-06 05:07:38 +03:00
|
|
|
generate_args.batch_size = None
|
2020-06-04 04:49:00 +03:00
|
|
|
orig_stdin = sys.stdin
|
2020-10-19 04:13:29 +03:00
|
|
|
sys.stdin = StringIO("h e l l o\n")
|
2020-06-04 04:49:00 +03:00
|
|
|
interactive.main(generate_args)
|
|
|
|
sys.stdin = orig_stdin
|
|
|
|
|
|
|
|
|
2018-02-28 01:09:42 +03:00
|
|
|
class TestDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, data):
|
|
|
|
super().__init__()
|
|
|
|
self.data = data
|
2018-10-23 08:26:02 +03:00
|
|
|
self.sizes = None
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
return self.data[index]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.data)
|
|
|
|
|
|
|
|
|
2020-09-10 03:00:56 +03:00
|
|
|
class TestTranslationTask(LegacyFairseqTask):
|
2018-06-12 20:39:41 +03:00
|
|
|
def __init__(self, args, src_dict, tgt_dict, model):
|
|
|
|
super().__init__(args)
|
|
|
|
self.src_dict = src_dict
|
|
|
|
self.tgt_dict = tgt_dict
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
|
|
|
|
return cls(args, src_dict, tgt_dict, model)
|
|
|
|
|
|
|
|
def build_model(self, args):
|
|
|
|
return TestModel.build_model(args, self)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def source_dictionary(self):
|
|
|
|
return self.src_dict
|
|
|
|
|
|
|
|
@property
|
|
|
|
def target_dictionary(self):
|
|
|
|
return self.tgt_dict
|
|
|
|
|
|
|
|
|
2019-05-15 17:09:48 +03:00
|
|
|
class TestModel(FairseqEncoderDecoderModel):
|
2018-02-28 01:09:42 +03:00
|
|
|
def __init__(self, encoder, decoder):
|
|
|
|
super().__init__(encoder, decoder)
|
|
|
|
|
|
|
|
@classmethod
|
2018-06-12 20:39:41 +03:00
|
|
|
def build_model(cls, args, task):
|
|
|
|
encoder = TestEncoder(args, task.source_dictionary)
|
|
|
|
decoder = TestIncrementalDecoder(args, task.target_dictionary)
|
2018-02-28 01:09:42 +03:00
|
|
|
return cls(encoder, decoder)
|
|
|
|
|
|
|
|
|
|
|
|
class TestEncoder(FairseqEncoder):
|
|
|
|
def __init__(self, args, dictionary):
|
|
|
|
super().__init__(dictionary)
|
|
|
|
self.args = args
|
|
|
|
|
2019-05-15 17:09:48 +03:00
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
2020-03-24 02:06:30 +03:00
|
|
|
return EncoderOut(
|
|
|
|
encoder_out=src_tokens,
|
|
|
|
encoder_padding_mask=None,
|
|
|
|
encoder_embedding=None,
|
|
|
|
encoder_states=None,
|
2020-04-02 03:51:32 +03:00
|
|
|
src_tokens=None,
|
|
|
|
src_lengths=None,
|
2020-03-24 02:06:30 +03:00
|
|
|
)
|
2018-02-28 01:09:42 +03:00
|
|
|
|
2018-06-21 21:34:29 +03:00
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
2020-03-24 02:06:30 +03:00
|
|
|
return EncoderOut(
|
|
|
|
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
|
|
|
|
encoder_padding_mask=None,
|
|
|
|
encoder_embedding=None,
|
|
|
|
encoder_states=None,
|
2020-04-02 03:51:32 +03:00
|
|
|
src_tokens=None,
|
|
|
|
src_lengths=None,
|
2020-03-24 02:06:30 +03:00
|
|
|
)
|
2018-06-21 21:34:29 +03:00
|
|
|
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
class TestIncrementalDecoder(FairseqIncrementalDecoder):
|
|
|
|
def __init__(self, args, dictionary):
|
|
|
|
super().__init__(dictionary)
|
2020-10-19 04:13:29 +03:00
|
|
|
assert hasattr(args, "beam_probs") or hasattr(args, "probs")
|
|
|
|
args.max_decoder_positions = getattr(args, "max_decoder_positions", 100)
|
2018-02-28 01:09:42 +03:00
|
|
|
self.args = args
|
|
|
|
|
2019-05-15 17:09:48 +03:00
|
|
|
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
2018-02-24 09:19:59 +03:00
|
|
|
if incremental_state is not None:
|
2018-02-28 01:09:42 +03:00
|
|
|
prev_output_tokens = prev_output_tokens[:, -1:]
|
|
|
|
bbsz = prev_output_tokens.size(0)
|
|
|
|
vocab = len(self.dictionary)
|
2020-03-24 02:06:30 +03:00
|
|
|
src_len = encoder_out.encoder_out.size(1)
|
2018-02-28 01:09:42 +03:00
|
|
|
tgt_len = prev_output_tokens.size(1)
|
|
|
|
|
|
|
|
# determine number of steps
|
2018-02-24 09:19:59 +03:00
|
|
|
if incremental_state is not None:
|
2018-02-28 01:09:42 +03:00
|
|
|
# cache step number
|
2020-10-19 04:13:29 +03:00
|
|
|
step = utils.get_incremental_state(self, incremental_state, "step")
|
2018-02-28 01:09:42 +03:00
|
|
|
if step is None:
|
|
|
|
step = 0
|
2020-10-19 04:13:29 +03:00
|
|
|
utils.set_incremental_state(self, incremental_state, "step", step + 1)
|
2018-02-28 01:09:42 +03:00
|
|
|
steps = [step]
|
|
|
|
else:
|
|
|
|
steps = list(range(tgt_len))
|
|
|
|
|
|
|
|
# define output in terms of raw probs
|
2020-10-19 04:13:29 +03:00
|
|
|
if hasattr(self.args, "probs"):
|
|
|
|
assert (
|
|
|
|
self.args.probs.dim() == 3
|
|
|
|
), "expected probs to have size bsz*steps*vocab"
|
2018-03-04 23:44:05 +03:00
|
|
|
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
|
|
|
|
else:
|
|
|
|
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
|
|
|
|
for i, step in enumerate(steps):
|
|
|
|
# args.beam_probs gives the probability for every vocab element,
|
|
|
|
# starting with eos, then unknown, and then the rest of the vocab
|
|
|
|
if step < len(self.args.beam_probs):
|
2020-10-19 04:13:29 +03:00
|
|
|
probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step]
|
2018-03-04 23:44:05 +03:00
|
|
|
else:
|
|
|
|
probs[:, i, self.dictionary.eos()] = 1.0
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
# random attention
|
2018-06-25 19:16:10 +03:00
|
|
|
attn = torch.rand(bbsz, tgt_len, src_len)
|
2018-02-28 01:09:42 +03:00
|
|
|
|
2018-11-01 11:23:43 +03:00
|
|
|
dev = prev_output_tokens.device
|
2020-02-26 22:06:29 +03:00
|
|
|
return probs.to(dev), {"attn": [attn.to(dev)]}
|
2018-02-28 01:09:42 +03:00
|
|
|
|
Conv lm implementation
This implements convolutional language model from https://arxiv.org/pdf/1612.08083.pdf
There are 3 modes for constructing batches:
- token block: fill each sample with a specified number of tokens without regard for sentence delimiters - this is what was used for training in the paper
- complete: fill each sample with a specified number of tokens but make sure it contains only complete sentences (i.e. if next sentence goes over token block limit, move it to the next sample) - this was used for evaluation in the paper
- eos: one sentence per sample (skip blank lines)
some results:
GCNN-13 - GBW - 37.46
GCNN-14B - GBW - 33.88
GCNN-8 - Wiki103 - 43.76
GCNN-14 - Wiki103 - 35.66
train:
python train.py /private/home/abaevski/data/wiki103 --save-dir /tmp --fp16 --max-epoch 35 --save-interval 1 --save-interval-updates 1000 --keep-interval-updates 25 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 --decoder-embed-dim 280 --decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion cross_entropy --max-tokens 1024 --max-target-positions 1024 --seed 1 --log-format json --log-interval 500
eval:
python eval_lm.py ~abaevski/data/wiki103 --path '/checkpoint02/abaevski/2018-04-27/lm_wiki.fp16.mxup300000.fconv.adam.lrs=reduce_lr_on_plateau.emb280.layers(850,6)*3+(850,1)+(850,5)*4+(850,1)+(850,4)*3+(1024,1)+(2048,4).lr0.0005.clp0.1.drp0.3.wd0.0.crt=cross_entropy.mxtk2048.smptk256.seed1.ngpu8/checkpoint_last.pt'
2018-05-25 16:43:37 +03:00
|
|
|
def get_normalized_probs(self, net_output, log_probs, _):
|
2018-02-28 01:09:42 +03:00
|
|
|
# the decoder returns probabilities directly
|
2018-04-02 17:13:07 +03:00
|
|
|
probs = net_output[0]
|
2018-02-28 01:09:42 +03:00
|
|
|
if log_probs:
|
2018-04-02 17:13:07 +03:00
|
|
|
return probs.log()
|
2018-02-28 01:09:42 +03:00
|
|
|
else:
|
2018-04-02 17:13:07 +03:00
|
|
|
return probs
|
2018-02-28 01:09:42 +03:00
|
|
|
|
|
|
|
def max_positions(self):
|
|
|
|
return self.args.max_decoder_positions
|
2020-03-10 21:48:30 +03:00
|
|
|
|
|
|
|
|
|
|
|
class TestReshapingEncoder(FairseqEncoder):
|
|
|
|
def __init__(self, args, dictionary):
|
|
|
|
super().__init__(dictionary)
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
|
|
|
b_sz, t_sz = src_tokens.shape
|
|
|
|
padding_needed = t_sz % 2
|
|
|
|
x = src_tokens
|
|
|
|
if padding_needed > 0:
|
|
|
|
padding_needed = 2 - padding_needed
|
|
|
|
x = F.pad(x, (0, padding_needed))
|
2020-03-24 02:06:30 +03:00
|
|
|
|
|
|
|
return EncoderOut(
|
|
|
|
encoder_out=x.view(b_sz, -1, 2),
|
|
|
|
encoder_padding_mask=None,
|
|
|
|
encoder_embedding=None,
|
|
|
|
encoder_states=None,
|
2020-04-02 03:51:32 +03:00
|
|
|
src_tokens=None,
|
|
|
|
src_lengths=None,
|
2020-03-24 02:06:30 +03:00
|
|
|
)
|
2020-03-10 21:48:30 +03:00
|
|
|
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
2020-03-24 02:06:30 +03:00
|
|
|
return EncoderOut(
|
|
|
|
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
|
|
|
|
encoder_padding_mask=None,
|
|
|
|
encoder_embedding=None,
|
|
|
|
encoder_states=None,
|
2020-04-02 03:51:32 +03:00
|
|
|
src_tokens=None,
|
|
|
|
src_lengths=None,
|
2020-03-24 02:06:30 +03:00
|
|
|
)
|
2020-03-10 21:48:30 +03:00
|
|
|
|
|
|
|
|
|
|
|
class TestReshapingModel(FairseqEncoderDecoderModel):
|
|
|
|
def __init__(self, encoder, decoder):
|
|
|
|
super().__init__(encoder, decoder)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build_model(cls, args, task):
|
|
|
|
encoder = TestReshapingEncoder(args, task.source_dictionary)
|
|
|
|
decoder = TestIncrementalDecoder(args, task.target_dictionary)
|
|
|
|
return cls(encoder, decoder)
|
2020-05-10 16:11:24 +03:00
|
|
|
|
|
|
|
|
|
|
|
class TestAdditionalInputEncoder(FairseqEncoder):
|
|
|
|
def __init__(self, args, dictionary):
|
|
|
|
super().__init__(dictionary)
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
2020-10-19 04:13:29 +03:00
|
|
|
assert "fancy_other_input" in kwargs
|
|
|
|
assert kwargs["fancy_other_input"] is not None
|
2020-05-10 16:11:24 +03:00
|
|
|
return EncoderOut(
|
|
|
|
encoder_out=src_tokens,
|
|
|
|
encoder_padding_mask=None,
|
|
|
|
encoder_embedding=None,
|
|
|
|
encoder_states=None,
|
|
|
|
src_tokens=None,
|
|
|
|
src_lengths=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
|
|
|
return EncoderOut(
|
|
|
|
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
|
|
|
|
encoder_padding_mask=None,
|
|
|
|
encoder_embedding=None,
|
|
|
|
encoder_states=None,
|
|
|
|
src_tokens=None,
|
|
|
|
src_lengths=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TestAdditionalInputModel(FairseqEncoderDecoderModel):
|
|
|
|
def __init__(self, encoder, decoder):
|
|
|
|
super().__init__(encoder, decoder)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build_model(cls, args, task):
|
|
|
|
encoder = TestAdditionalInputEncoder(args, task.source_dictionary)
|
|
|
|
decoder = TestIncrementalDecoder(args, task.target_dictionary)
|
|
|
|
return cls(encoder, decoder)
|
|
|
|
|
|
|
|
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
2020-10-19 04:13:29 +03:00
|
|
|
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
2020-05-10 16:11:24 +03:00
|
|
|
decoder_out = self.decoder(
|
2020-10-19 04:13:29 +03:00
|
|
|
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
|
|
|
)
|
2020-05-10 16:11:24 +03:00
|
|
|
return decoder_out
|
2021-04-21 16:38:03 +03:00
|
|
|
|
|
|
|
|
|
|
|
def train_language_model(
|
|
|
|
data_dir,
|
|
|
|
arch,
|
|
|
|
extra_flags=None,
|
|
|
|
run_validation=False,
|
|
|
|
extra_valid_flags=None,
|
|
|
|
task="language_modeling",
|
|
|
|
world_size=1,
|
|
|
|
):
|
|
|
|
train_parser = options.get_training_parser()
|
|
|
|
train_args = options.parse_args_and_arch(
|
|
|
|
train_parser,
|
|
|
|
[
|
|
|
|
"--task",
|
|
|
|
task,
|
|
|
|
data_dir,
|
|
|
|
"--arch",
|
|
|
|
arch,
|
|
|
|
"--optimizer",
|
|
|
|
"adam",
|
|
|
|
"--lr",
|
|
|
|
"0.0001",
|
|
|
|
"--max-tokens",
|
|
|
|
"500",
|
|
|
|
"--tokens-per-sample",
|
|
|
|
"500",
|
|
|
|
"--save-dir",
|
|
|
|
data_dir,
|
|
|
|
"--max-epoch",
|
|
|
|
"1",
|
|
|
|
"--no-progress-bar",
|
|
|
|
"--distributed-world-size",
|
|
|
|
str(world_size),
|
|
|
|
"--ddp-backend",
|
|
|
|
"no_c10d",
|
|
|
|
"--num-workers",
|
|
|
|
"0",
|
|
|
|
]
|
|
|
|
+ (extra_flags or []),
|
|
|
|
)
|
|
|
|
cfg = convert_namespace_to_omegaconf(train_args)
|
|
|
|
distributed_utils.call_main(cfg, train.main)
|
|
|
|
|
|
|
|
if run_validation:
|
|
|
|
# test validation
|
|
|
|
validate_parser = options.get_validation_parser()
|
|
|
|
validate_args = options.parse_args_and_arch(
|
|
|
|
validate_parser,
|
|
|
|
[
|
|
|
|
"--task",
|
|
|
|
task,
|
|
|
|
data_dir,
|
|
|
|
"--path",
|
|
|
|
os.path.join(data_dir, "checkpoint_last.pt"),
|
|
|
|
"--valid-subset",
|
|
|
|
"valid",
|
|
|
|
"--max-tokens",
|
|
|
|
"500",
|
|
|
|
"--no-progress-bar",
|
|
|
|
"--num-workers",
|
|
|
|
"0",
|
|
|
|
]
|
|
|
|
+ (extra_valid_flags or []),
|
|
|
|
)
|
|
|
|
validate.main(validate_args)
|