lint fixes (#2834)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Applied `black` and `isort` to fix failing CI

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2834

Reviewed By: vedanuj

Differential Revision: D33262876

Pulled By: dianaml0

fbshipit-source-id: 03215c276fcddda9f7c78971bf6ed7c5ac21b2ee
This commit is contained in:
Diana Liskovich 2021-12-29 11:49:57 -08:00 committed by Facebook GitHub Bot
parent 5cd7a21cc1
commit 7fddb9d960
4 changed files with 142 additions and 77 deletions

View File

@ -11,11 +11,11 @@ import os
from typing import Any, Dict, Iterator, List
import torch
from fairseq import utils
from fairseq.data import encoders
from omegaconf import open_dict
from torch import nn
from fairseq import utils
from fairseq.data import encoders
logger = logging.getLogger(__name__)
@ -132,7 +132,9 @@ class GeneratorHubInterface(nn.Module):
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
def score(self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs):
def score(
self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs
):
if isinstance(sentences, str):
return self.score(
[sentences], replace_newline_with_eos=replace_newline_with_eos, **kwargs

View File

@ -10,10 +10,12 @@ from typing import Optional
import numpy as np
import torch
from omegaconf import II
from fairseq import utils
from fairseq.data import (
ConcatDataset,
AppendTokenDataset,
ConcatDataset,
Dictionary,
IdDataset,
LMContextWindowDataset,
@ -33,8 +35,6 @@ from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import LegacyFairseqTask, register_task
from omegaconf import II
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@ -125,16 +125,14 @@ class MultilingualLanguageModelingConfig(FairseqDataclass):
# TODO: legacy parameter kept for compatibility
baseline_model: str = field(
default="",
metadata={
"help": "path to the baseline model (default: none)"
},
metadata={"help": "path to the baseline model (default: none)"},
)
lang_to_offline_shard_ratio: str = field(
default="",
metadata={
"help": "absolute path of tsv file location to indicate lang to offline shard ratio.",
}
},
)
# TODO common vars below add to parent
seed: int = II("common.seed")
@ -149,7 +147,9 @@ class MultilingualLanguageModelingConfig(FairseqDataclass):
valid_subset: str = II("common.valid_subset")
@register_task("multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig)
@register_task(
"multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig
)
class MultilingualLanguageModelingTask(LegacyFairseqTask):
"""
Train a language model.
@ -216,11 +216,11 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
if args.add_bos_token:
languages, _ = cls._get_langs(args)
logger.info('----------------')
logger.info("----------------")
for lang in languages:
dictionary.add_symbol(lang_token(lang))
logger.info(f'add language token: {lang_token(lang)}')
logger.info('----------------')
logger.info(f"add language token: {lang_token(lang)}")
logger.info("----------------")
logger.info("dictionary: {} types".format(len(dictionary)))
output_dictionary = dictionary
@ -276,9 +276,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(
self, split: str, epoch=1, combine=False, **kwargs
):
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -292,21 +290,28 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
lang_to_offline_shard_ratio = {}
assert os.path.exists(
self.args.lang_to_offline_shard_ratio
), "provided offline shard ratio file doesn't exist: {0}".format(self.args.lang_to_offline_shard_ratio)
), "provided offline shard ratio file doesn't exist: {0}".format(
self.args.lang_to_offline_shard_ratio
)
with open(self.args.lang_to_offline_shard_ratio) as fin:
for line in fin:
lang, ratio = line.strip().split('\t')
lang, ratio = line.strip().split("\t")
ratio = float(ratio)
lang_to_offline_shard_ratio[lang] = ratio
logger.info(
"Found offline sharded ratio: %s", lang_to_offline_shard_ratio,
"Found offline sharded ratio: %s",
lang_to_offline_shard_ratio,
)
if split == self.args.train_subset:
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
logger.info(
"Training on {0} languages: {1}".format(len(languages), languages)
)
else:
logger.info("Evaluating on {0} languages: {1}".format(len(languages), languages))
logger.info(
"Evaluating on {0} languages: {1}".format(len(languages), languages)
)
tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token)
@ -388,15 +393,24 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
)
if split == self.args.train_subset:
dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
if lang_to_offline_shard_ratio is not None:
if lang_to_offline_shard_ratio is not None:
dataset_lengths_ratio_multiplier = []
for lang in languages:
assert lang in lang_to_offline_shard_ratio, "Lang: {0} missing in offline shard ratio file: {1}".format(
lang, self.args.lang_to_offline_shard_ratio,
assert (
lang in lang_to_offline_shard_ratio
), "Lang: {0} missing in offline shard ratio file: {1}".format(
lang,
self.args.lang_to_offline_shard_ratio,
)
dataset_lengths_ratio_multiplier.append(lang_to_offline_shard_ratio[lang])
dataset_lengths_ratio_multiplier = np.array(dataset_lengths_ratio_multiplier)
true_dataset_lengths = dataset_lengths * dataset_lengths_ratio_multiplier
dataset_lengths_ratio_multiplier.append(
lang_to_offline_shard_ratio[lang]
)
dataset_lengths_ratio_multiplier = np.array(
dataset_lengths_ratio_multiplier
)
true_dataset_lengths = (
dataset_lengths * dataset_lengths_ratio_multiplier
)
else:
true_dataset_lengths = dataset_lengths
# For train subset, additionally up or down sample languages.
@ -410,7 +424,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
},
)
size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths
# TODO: add an option for shrinking all size ratios to below 1
# TODO: add an option for shrinking all size ratios to below 1
# if self.args.multilang_sampling_alpha != 1:
# size_ratio /= size_ratio.max()
@ -418,7 +432,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
# 0.999999999999999999 -> 1
# 1.000000000000000002 -> 1
for i in range(len(size_ratio)):
size_ratio[i] = round(size_ratio[i], 8)
size_ratio[i] = round(size_ratio[i], 8)
logger.info(
"Up/Down Sampling ratio by language: %s",
@ -479,7 +493,9 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
],
)
def build_dataset_for_inference(self, src_tokens, src_lengths, language="en_XX", **kwargs):
def build_dataset_for_inference(
self, src_tokens, src_lengths, language="en_XX", **kwargs
):
"""
Generate batches for inference. We prepend an eos token to src_tokens
(or bos if `--add-bos-token` is set) and we append a <pad> to target.
@ -518,12 +534,15 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
pad_length=max_seq_len
pad_length=max_seq_len,
),
"src_lengths": NumelDataset(src_dataset, reduce=False),
},
"target": PadDataset(
tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_length=max_seq_len,
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
pad_length=max_seq_len,
),
},
sizes=[np.array(src_lengths)],
@ -531,7 +550,13 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
@torch.no_grad()
def inference_step(
self, generator, models, sample, language="en_XX", prefix_tokens=None, constraints=None
self,
generator,
models,
sample,
language="en_XX",
prefix_tokens=None,
constraints=None,
):
# Generation will always be conditioned on bos_token
if getattr(self.args, "add_bos_token", False):
@ -555,7 +580,7 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
)
def eval_lm_dataloader(
self,
dataset,
@ -599,4 +624,4 @@ class MultilingualLanguageModelingTask(LegacyFairseqTask):
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.output_dictionary
return self.output_dictionary

View File

@ -20,13 +20,13 @@ from tests.utils import (
generate_main,
preprocess_lm_data,
preprocess_translation_data,
train_translation_model,
train_language_model,
train_translation_model,
)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestMultiGPU(unittest.TestCase):
@staticmethod
def parse_logs(logfile):
logs = []
@ -44,59 +44,85 @@ class TestMultiGPU(unittest.TestCase):
def train_flags(self, mu):
return [
"--memory-efficient-fp16",
'--update-freq', '1',
'--seed', '1',
"--log-format", "json",
"--max-update", str(mu),
"--tokens-per-sample", "20",
"--batch-size", "2",
'--share-decoder-input-output-embed',
'--optimizer', 'adam',
'--max-valid-steps', '1',
'--pad-to-fixed-length',
'--sample-break-mode', 'none',
"--update-freq",
"1",
"--seed",
"1",
"--log-format",
"json",
"--max-update",
str(mu),
"--tokens-per-sample",
"20",
"--batch-size",
"2",
"--share-decoder-input-output-embed",
"--optimizer",
"adam",
"--max-valid-steps",
"1",
"--pad-to-fixed-length",
"--sample-break-mode",
"none",
]
def _test_resume_multilingual_training(self, extra_clargs, arch="transformer_lm_gpt2_tiny"):
def _test_resume_multilingual_training(
self, extra_clargs, arch="transformer_lm_gpt2_tiny"
):
languages = ["en_XX", "fr_XX", "zh_CN"]
save_interval = 5
mu = 10
flags = self.train_flags(mu) + [
"--save-interval-updates", str(save_interval),
"--log-interval", "1"
] + extra_clargs
flags = (
self.train_flags(mu)
+ ["--save-interval-updates", str(save_interval), "--log-interval", "1"]
+ extra_clargs
)
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_fp16") as data_dir:
log = os.path.join(data_dir, "train.log")
create_dummy_data(data_dir,
num_examples=int(mu * 20 * self.world_size * 1.5), # make sure enough data for max updates
languages=languages)
create_dummy_data(
data_dir,
num_examples=int(
mu * 20 * self.world_size * 1.5
), # make sure enough data for max updates
languages=languages,
)
preprocess_lm_data(data_dir, languages)
train_language_model(
data_dir, arch,
flags + ["--log-file", log],
task="multilingual_language_modeling",
data_dir,
arch,
flags + ["--log-file", log],
task="multilingual_language_modeling",
world_size=self.world_size,
)
log2 = os.path.join(data_dir, "resume.log")
ckpt_name = f"checkpoint_1_{save_interval}.pt"
restore_file = os.path.join(data_dir, ckpt_name)
train_language_model(
data_dir, arch,
flags + ["--log-file", log2, "--restore-file", restore_file, '--no-save'],
task="multilingual_language_modeling",
data_dir,
arch,
flags
+ ["--log-file", log2, "--restore-file", restore_file, "--no-save"],
task="multilingual_language_modeling",
world_size=self.world_size,
)
l1 = self.parse_logs(log)
assert int(l1[-1]['train_num_updates']) == mu, f'The first run did not complete {mu} updates. Add more data'
assert (
int(l1[-1]["train_num_updates"]) == mu
), f"The first run did not complete {mu} updates. Add more data"
l2 = self.parse_logs(log2)
if int(l2[0]["num_updates"]) != save_interval+1:
all_ckpt_files = [x for x in os.listdir(data_dir) if x.endswith('.pt')]
if int(l2[0]["num_updates"]) != save_interval + 1:
all_ckpt_files = [
x for x in os.listdir(data_dir) if x.endswith(".pt")
]
import shutil
shutil.move(data_dir, 'last_failed_resume')
raise AssertionError(f"Likely failed to load {ckpt_name}. {all_ckpt_files} \n LOGS: {l1} \n\n {l2}. ")
shutil.move(data_dir, "last_failed_resume")
raise AssertionError(
f"Likely failed to load {ckpt_name}. {all_ckpt_files} \n LOGS: {l1} \n\n {l2}. "
)
for k in [
"train_loss",
"train_num_updates",
@ -105,7 +131,9 @@ class TestMultiGPU(unittest.TestCase):
]:
from_scratch, resumed = float(l1[-1][k]), float(l2[-1][k])
# This fails without rounding!
assert from_scratch == resumed, f"difference at {k} {from_scratch} != {resumed}"
assert (
from_scratch == resumed
), f"difference at {k} {from_scratch} != {resumed}"
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")

View File

@ -164,7 +164,9 @@ def sequence_generator_setup():
return tgt_dict, w1, w2, src_tokens, src_lengths, model
def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False, languages=None):
def create_dummy_data(
data_dir, num_examples=100, maxlen=20, alignment=False, languages=None
):
def _create_dummy_data(dir, filename):
data = torch.rand(num_examples * maxlen)
data = 97 + torch.floor(26 * data).int()
@ -195,8 +197,15 @@ def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False, la
)
print(ex_str, file=h)
files_to_write = ['train.in', 'train.out', 'valid.in', 'valid.out', 'test.in', 'test.out']
if languages is None: # En only dummy dataset
files_to_write = [
"train.in",
"train.out",
"valid.in",
"valid.out",
"test.in",
"test.out",
]
if languages is None: # En only dummy dataset
for f in files_to_write:
_create_dummy_data(data_dir, f)
else:
@ -232,7 +241,7 @@ def preprocess_lm_data(data_dir, languages=None):
else:
for lang in languages:
lang_dir = os.path.join(data_dir, lang)
assert(os.path.exists(lang_dir))
assert os.path.exists(lang_dir)
preprocess_args = preprocess_parser.parse_args(
[
"--only-source",
@ -248,8 +257,9 @@ def preprocess_lm_data(data_dir, languages=None):
)
preprocess.main(preprocess_args)
shutil.copyfile(
os.path.join(data_dir, languages[0], 'dict.txt'), os.path.join(data_dir, 'dict.txt'))
os.path.join(data_dir, languages[0], "dict.txt"),
os.path.join(data_dir, "dict.txt"),
)
def preprocess_translation_data(data_dir, extra_flags=None):