mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-08-16 12:00:25 +03:00
lint fixes (#2834)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Applied `black` and `isort` to fix failing CI ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2834 Reviewed By: vedanuj Differential Revision: D33262876 Pulled By: dianaml0 fbshipit-source-id: 03215c276fcddda9f7c78971bf6ed7c5ac21b2ee
This commit is contained in:
parent
5cd7a21cc1
commit
7fddb9d960
@ -11,11 +11,11 @@ import os
|
||||
from typing import Any, Dict, Iterator, List
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data import encoders
|
||||
from omegaconf import open_dict
|
||||
from torch import nn
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import encoders
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -132,7 +132,9 @@ class GeneratorHubInterface(nn.Module):
|
||||
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
|
||||
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
|
||||
|
||||
def score(self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs):
|
||||
def score(
|
||||
self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs
|
||||
):
|
||||
if isinstance(sentences, str):
|
||||
return self.score(
|
||||
[sentences], replace_newline_with_eos=replace_newline_with_eos, **kwargs
|
||||
|
@ -10,10 +10,12 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
AppendTokenDataset,
|
||||
ConcatDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
LMContextWindowDataset,
|
||||
@ -33,8 +35,6 @@ from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
@ -125,16 +125,14 @@ class MultilingualLanguageModelingConfig(FairseqDataclass):
|
||||
# TODO: legacy parameter kept for compatibility
|
||||
baseline_model: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "path to the baseline model (default: none)"
|
||||
},
|
||||
metadata={"help": "path to the baseline model (default: none)"},
|
||||
)
|
||||
|
||||
|
||||
lang_to_offline_shard_ratio: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "absolute path of tsv file location to indicate lang to offline shard ratio.",
|
||||
}
|
||||
},
|
||||
)
|
||||
# TODO common vars below add to parent
|
||||
seed: int = II("common.seed")
|
||||
@ -149,7 +147,9 @@ class MultilingualLanguageModelingConfig(FairseqDataclass):
|
||||
valid_subset: str = II("common.valid_subset")
|
||||
|
||||
|
||||
@register_task("multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig)
|
||||
@register_task(
|
||||
"multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig
|
||||
)
|
||||
class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
"""
|
||||
Train a language model.
|
||||
@ -216,11 +216,11 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
if args.add_bos_token:
|
||||
languages, _ = cls._get_langs(args)
|
||||
logger.info('----------------')
|
||||
logger.info("----------------")
|
||||
for lang in languages:
|
||||
dictionary.add_symbol(lang_token(lang))
|
||||
logger.info(f'add language token: {lang_token(lang)}')
|
||||
logger.info('----------------')
|
||||
logger.info(f"add language token: {lang_token(lang)}")
|
||||
logger.info("----------------")
|
||||
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
output_dictionary = dictionary
|
||||
@ -276,9 +276,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
||||
return smoothed_prob
|
||||
|
||||
def load_dataset(
|
||||
self, split: str, epoch=1, combine=False, **kwargs
|
||||
):
|
||||
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
@ -292,21 +290,28 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
lang_to_offline_shard_ratio = {}
|
||||
assert os.path.exists(
|
||||
self.args.lang_to_offline_shard_ratio
|
||||
), "provided offline shard ratio file doesn't exist: {0}".format(self.args.lang_to_offline_shard_ratio)
|
||||
), "provided offline shard ratio file doesn't exist: {0}".format(
|
||||
self.args.lang_to_offline_shard_ratio
|
||||
)
|
||||
with open(self.args.lang_to_offline_shard_ratio) as fin:
|
||||
for line in fin:
|
||||
lang, ratio = line.strip().split('\t')
|
||||
lang, ratio = line.strip().split("\t")
|
||||
ratio = float(ratio)
|
||||
lang_to_offline_shard_ratio[lang] = ratio
|
||||
|
||||
|
||||
logger.info(
|
||||
"Found offline sharded ratio: %s", lang_to_offline_shard_ratio,
|
||||
"Found offline sharded ratio: %s",
|
||||
lang_to_offline_shard_ratio,
|
||||
)
|
||||
|
||||
if split == self.args.train_subset:
|
||||
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
|
||||
logger.info(
|
||||
"Training on {0} languages: {1}".format(len(languages), languages)
|
||||
)
|
||||
else:
|
||||
logger.info("Evaluating on {0} languages: {1}".format(len(languages), languages))
|
||||
logger.info(
|
||||
"Evaluating on {0} languages: {1}".format(len(languages), languages)
|
||||
)
|
||||
|
||||
tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token)
|
||||
|
||||
@ -388,15 +393,24 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
)
|
||||
if split == self.args.train_subset:
|
||||
dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
|
||||
if lang_to_offline_shard_ratio is not None:
|
||||
if lang_to_offline_shard_ratio is not None:
|
||||
dataset_lengths_ratio_multiplier = []
|
||||
for lang in languages:
|
||||
assert lang in lang_to_offline_shard_ratio, "Lang: {0} missing in offline shard ratio file: {1}".format(
|
||||
lang, self.args.lang_to_offline_shard_ratio,
|
||||
assert (
|
||||
lang in lang_to_offline_shard_ratio
|
||||
), "Lang: {0} missing in offline shard ratio file: {1}".format(
|
||||
lang,
|
||||
self.args.lang_to_offline_shard_ratio,
|
||||
)
|
||||
dataset_lengths_ratio_multiplier.append(lang_to_offline_shard_ratio[lang])
|
||||
dataset_lengths_ratio_multiplier = np.array(dataset_lengths_ratio_multiplier)
|
||||
true_dataset_lengths = dataset_lengths * dataset_lengths_ratio_multiplier
|
||||
dataset_lengths_ratio_multiplier.append(
|
||||
lang_to_offline_shard_ratio[lang]
|
||||
)
|
||||
dataset_lengths_ratio_multiplier = np.array(
|
||||
dataset_lengths_ratio_multiplier
|
||||
)
|
||||
true_dataset_lengths = (
|
||||
dataset_lengths * dataset_lengths_ratio_multiplier
|
||||
)
|
||||
else:
|
||||
true_dataset_lengths = dataset_lengths
|
||||
# For train subset, additionally up or down sample languages.
|
||||
@ -410,7 +424,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
},
|
||||
)
|
||||
size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths
|
||||
# TODO: add an option for shrinking all size ratios to below 1
|
||||
# TODO: add an option for shrinking all size ratios to below 1
|
||||
# if self.args.multilang_sampling_alpha != 1:
|
||||
# size_ratio /= size_ratio.max()
|
||||
|
||||
@ -418,7 +432,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
# 0.999999999999999999 -> 1
|
||||
# 1.000000000000000002 -> 1
|
||||
for i in range(len(size_ratio)):
|
||||
size_ratio[i] = round(size_ratio[i], 8)
|
||||
size_ratio[i] = round(size_ratio[i], 8)
|
||||
|
||||
logger.info(
|
||||
"Up/Down Sampling ratio by language: %s",
|
||||
@ -479,7 +493,9 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
],
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, language="en_XX", **kwargs):
|
||||
def build_dataset_for_inference(
|
||||
self, src_tokens, src_lengths, language="en_XX", **kwargs
|
||||
):
|
||||
"""
|
||||
Generate batches for inference. We prepend an eos token to src_tokens
|
||||
(or bos if `--add-bos-token` is set) and we append a <pad> to target.
|
||||
@ -518,12 +534,15 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
src_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
pad_length=max_seq_len
|
||||
pad_length=max_seq_len,
|
||||
),
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
},
|
||||
"target": PadDataset(
|
||||
tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_length=max_seq_len,
|
||||
tgt_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
pad_length=max_seq_len,
|
||||
),
|
||||
},
|
||||
sizes=[np.array(src_lengths)],
|
||||
@ -531,7 +550,13 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_step(
|
||||
self, generator, models, sample, language="en_XX", prefix_tokens=None, constraints=None
|
||||
self,
|
||||
generator,
|
||||
models,
|
||||
sample,
|
||||
language="en_XX",
|
||||
prefix_tokens=None,
|
||||
constraints=None,
|
||||
):
|
||||
# Generation will always be conditioned on bos_token
|
||||
if getattr(self.args, "add_bos_token", False):
|
||||
@ -555,7 +580,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
return generator.generate(
|
||||
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
|
||||
)
|
||||
|
||||
|
||||
def eval_lm_dataloader(
|
||||
self,
|
||||
dataset,
|
||||
@ -599,4 +624,4 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.output_dictionary
|
||||
return self.output_dictionary
|
||||
|
@ -20,13 +20,13 @@ from tests.utils import (
|
||||
generate_main,
|
||||
preprocess_lm_data,
|
||||
preprocess_translation_data,
|
||||
train_translation_model,
|
||||
train_language_model,
|
||||
train_translation_model,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
||||
class TestMultiGPU(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def parse_logs(logfile):
|
||||
logs = []
|
||||
@ -44,59 +44,85 @@ class TestMultiGPU(unittest.TestCase):
|
||||
def train_flags(self, mu):
|
||||
return [
|
||||
"--memory-efficient-fp16",
|
||||
'--update-freq', '1',
|
||||
'--seed', '1',
|
||||
"--log-format", "json",
|
||||
"--max-update", str(mu),
|
||||
"--tokens-per-sample", "20",
|
||||
"--batch-size", "2",
|
||||
'--share-decoder-input-output-embed',
|
||||
'--optimizer', 'adam',
|
||||
'--max-valid-steps', '1',
|
||||
'--pad-to-fixed-length',
|
||||
'--sample-break-mode', 'none',
|
||||
"--update-freq",
|
||||
"1",
|
||||
"--seed",
|
||||
"1",
|
||||
"--log-format",
|
||||
"json",
|
||||
"--max-update",
|
||||
str(mu),
|
||||
"--tokens-per-sample",
|
||||
"20",
|
||||
"--batch-size",
|
||||
"2",
|
||||
"--share-decoder-input-output-embed",
|
||||
"--optimizer",
|
||||
"adam",
|
||||
"--max-valid-steps",
|
||||
"1",
|
||||
"--pad-to-fixed-length",
|
||||
"--sample-break-mode",
|
||||
"none",
|
||||
]
|
||||
|
||||
def _test_resume_multilingual_training(self, extra_clargs, arch="transformer_lm_gpt2_tiny"):
|
||||
def _test_resume_multilingual_training(
|
||||
self, extra_clargs, arch="transformer_lm_gpt2_tiny"
|
||||
):
|
||||
languages = ["en_XX", "fr_XX", "zh_CN"]
|
||||
save_interval = 5
|
||||
mu = 10
|
||||
flags = self.train_flags(mu) + [
|
||||
"--save-interval-updates", str(save_interval),
|
||||
"--log-interval", "1"
|
||||
] + extra_clargs
|
||||
flags = (
|
||||
self.train_flags(mu)
|
||||
+ ["--save-interval-updates", str(save_interval), "--log-interval", "1"]
|
||||
+ extra_clargs
|
||||
)
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory("test_fp16") as data_dir:
|
||||
log = os.path.join(data_dir, "train.log")
|
||||
create_dummy_data(data_dir,
|
||||
num_examples=int(mu * 20 * self.world_size * 1.5), # make sure enough data for max updates
|
||||
languages=languages)
|
||||
create_dummy_data(
|
||||
data_dir,
|
||||
num_examples=int(
|
||||
mu * 20 * self.world_size * 1.5
|
||||
), # make sure enough data for max updates
|
||||
languages=languages,
|
||||
)
|
||||
preprocess_lm_data(data_dir, languages)
|
||||
train_language_model(
|
||||
data_dir, arch,
|
||||
flags + ["--log-file", log],
|
||||
task="multilingual_language_modeling",
|
||||
data_dir,
|
||||
arch,
|
||||
flags + ["--log-file", log],
|
||||
task="multilingual_language_modeling",
|
||||
world_size=self.world_size,
|
||||
)
|
||||
log2 = os.path.join(data_dir, "resume.log")
|
||||
ckpt_name = f"checkpoint_1_{save_interval}.pt"
|
||||
restore_file = os.path.join(data_dir, ckpt_name)
|
||||
train_language_model(
|
||||
data_dir, arch,
|
||||
flags + ["--log-file", log2, "--restore-file", restore_file, '--no-save'],
|
||||
task="multilingual_language_modeling",
|
||||
data_dir,
|
||||
arch,
|
||||
flags
|
||||
+ ["--log-file", log2, "--restore-file", restore_file, "--no-save"],
|
||||
task="multilingual_language_modeling",
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
l1 = self.parse_logs(log)
|
||||
assert int(l1[-1]['train_num_updates']) == mu, f'The first run did not complete {mu} updates. Add more data'
|
||||
assert (
|
||||
int(l1[-1]["train_num_updates"]) == mu
|
||||
), f"The first run did not complete {mu} updates. Add more data"
|
||||
l2 = self.parse_logs(log2)
|
||||
|
||||
if int(l2[0]["num_updates"]) != save_interval+1:
|
||||
all_ckpt_files = [x for x in os.listdir(data_dir) if x.endswith('.pt')]
|
||||
if int(l2[0]["num_updates"]) != save_interval + 1:
|
||||
all_ckpt_files = [
|
||||
x for x in os.listdir(data_dir) if x.endswith(".pt")
|
||||
]
|
||||
import shutil
|
||||
shutil.move(data_dir, 'last_failed_resume')
|
||||
raise AssertionError(f"Likely failed to load {ckpt_name}. {all_ckpt_files} \n LOGS: {l1} \n\n {l2}. ")
|
||||
|
||||
shutil.move(data_dir, "last_failed_resume")
|
||||
raise AssertionError(
|
||||
f"Likely failed to load {ckpt_name}. {all_ckpt_files} \n LOGS: {l1} \n\n {l2}. "
|
||||
)
|
||||
for k in [
|
||||
"train_loss",
|
||||
"train_num_updates",
|
||||
@ -105,7 +131,9 @@ class TestMultiGPU(unittest.TestCase):
|
||||
]:
|
||||
from_scratch, resumed = float(l1[-1][k]), float(l2[-1][k])
|
||||
# This fails without rounding!
|
||||
assert from_scratch == resumed, f"difference at {k} {from_scratch} != {resumed}"
|
||||
assert (
|
||||
from_scratch == resumed
|
||||
), f"difference at {k} {from_scratch} != {resumed}"
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
||||
|
@ -164,7 +164,9 @@ def sequence_generator_setup():
|
||||
return tgt_dict, w1, w2, src_tokens, src_lengths, model
|
||||
|
||||
|
||||
def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False, languages=None):
|
||||
def create_dummy_data(
|
||||
data_dir, num_examples=100, maxlen=20, alignment=False, languages=None
|
||||
):
|
||||
def _create_dummy_data(dir, filename):
|
||||
data = torch.rand(num_examples * maxlen)
|
||||
data = 97 + torch.floor(26 * data).int()
|
||||
@ -195,8 +197,15 @@ def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False, la
|
||||
)
|
||||
print(ex_str, file=h)
|
||||
|
||||
files_to_write = ['train.in', 'train.out', 'valid.in', 'valid.out', 'test.in', 'test.out']
|
||||
if languages is None: # En only dummy dataset
|
||||
files_to_write = [
|
||||
"train.in",
|
||||
"train.out",
|
||||
"valid.in",
|
||||
"valid.out",
|
||||
"test.in",
|
||||
"test.out",
|
||||
]
|
||||
if languages is None: # En only dummy dataset
|
||||
for f in files_to_write:
|
||||
_create_dummy_data(data_dir, f)
|
||||
else:
|
||||
@ -232,7 +241,7 @@ def preprocess_lm_data(data_dir, languages=None):
|
||||
else:
|
||||
for lang in languages:
|
||||
lang_dir = os.path.join(data_dir, lang)
|
||||
assert(os.path.exists(lang_dir))
|
||||
assert os.path.exists(lang_dir)
|
||||
preprocess_args = preprocess_parser.parse_args(
|
||||
[
|
||||
"--only-source",
|
||||
@ -248,8 +257,9 @@ def preprocess_lm_data(data_dir, languages=None):
|
||||
)
|
||||
preprocess.main(preprocess_args)
|
||||
shutil.copyfile(
|
||||
os.path.join(data_dir, languages[0], 'dict.txt'), os.path.join(data_dir, 'dict.txt'))
|
||||
|
||||
os.path.join(data_dir, languages[0], "dict.txt"),
|
||||
os.path.join(data_dir, "dict.txt"),
|
||||
)
|
||||
|
||||
|
||||
def preprocess_translation_data(data_dir, extra_flags=None):
|
||||
|
Loading…
Reference in New Issue
Block a user