Small changes to make tests more reliable (#1572)

Summary:
After this, `python setup.py test` should be more reliable (including when multiple GPUs are present)

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1572

Reviewed By: alexeib

Differential Revision: D25984113

Pulled By: myleott

fbshipit-source-id: 7fef27ae90c079c07f592ed9fb350ccf8b56d23d
This commit is contained in:
Myle Ott 2021-01-21 07:32:08 -08:00 committed by Facebook GitHub Bot
parent 9a1c49706b
commit cfbf0dddbc
10 changed files with 42 additions and 38 deletions

View File

@ -136,8 +136,8 @@ pip install --editable ./
# on MacOS:
# CFLAGS="-stdlib=libc++" pip install --editable ./
# to install the latest stable release (0.10.1)
# pip install fairseq==0.10.1
# to install the latest stable release (0.10.x)
# pip install fairseq
```
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:

View File

@ -24,6 +24,9 @@ cpdef list batch_by_size_vec(
int64_t max_sentences,
int32_t bsz_mult,
):
if indices.shape[0] == 0:
return []
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
f"Sentences lengths should not exceed max_tokens={max_tokens}"
)

View File

@ -301,10 +301,6 @@ def distributed_main(i, main, cfg: FairseqConfig, kwargs):
main(cfg, **kwargs)
# make sure checkpoints finish saving
if torch.distributed.is_initialized():
torch.distributed.barrier()
def call_main(cfg: FairseqConfig, main, **kwargs):
if cfg.distributed_training.distributed_init_method is None:
@ -323,6 +319,7 @@ def call_main(cfg: FairseqConfig, main, **kwargs):
torch.cuda.device_count(),
cfg.distributed_training.distributed_world_size,
),
join=True,
)
else:
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)

View File

@ -173,7 +173,7 @@ class RobertaHubInterface(nn.Module):
add_if_not_exist=False,
)
masked_index = (tokens == self.task.mask_idx).nonzero()
masked_index = (tokens == self.task.mask_idx).nonzero(as_tuple=False)
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)

View File

@ -19,7 +19,7 @@ except ImportError:
def is_cuda_extension_usable() -> bool:
"""Check whether ngram_repeat_block_cuda is built properly"""
if not EXTENSION_BUILT:
if not EXTENSION_BUILT or not torch.cuda.is_available():
return False
bsz = 2
tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda")

View File

@ -86,8 +86,10 @@ class SequenceGenerator(nn.Module):
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
if no_repeat_ngram_size > 0:
self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
else:
self.repeat_ngram_blocker = None
assert temperature > 0, "--temperature must be greater than 0"
@ -373,8 +375,10 @@ class SequenceGenerator(nn.Module):
if self.should_set_src_lengths:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
if self.repeat_ngram_blocker is not None:
lprobs = self.repeat_ngram_blocker(
tokens, lprobs, bsz, beam_size, step
)
# Shape: (batch, cand_size)
cand_scores, cand_indices, cand_beams = self.search.step(

View File

@ -242,18 +242,19 @@ def get_files(path, relative_to="fairseq"):
return all_files
try:
# symlink examples into fairseq package so package_data accepts them
fairseq_examples = os.path.join("fairseq", "examples")
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples):
os.symlink(os.path.join("..", "examples"), fairseq_examples)
if __name__ == "__main__":
try:
# symlink examples into fairseq package so package_data accepts them
fairseq_examples = os.path.join("fairseq", "examples")
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples):
os.symlink(os.path.join("..", "examples"), fairseq_examples)
package_data = {
"fairseq": (
get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config"))
)
}
do_setup(package_data)
finally:
if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples):
os.unlink(fairseq_examples)
package_data = {
"fairseq": (
get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config"))
)
}
do_setup(package_data)
finally:
if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples):
os.unlink(fairseq_examples)

View File

@ -17,6 +17,7 @@ def spawn_and_init(fn, world_size, args=None):
fn=functools.partial(init_and_run, fn, args),
args=(world_size, tmp_file.name,),
nprocs=world_size,
join=True,
)

View File

@ -1623,8 +1623,9 @@ class TestActivationCheckpointing(unittest.TestCase):
with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
with self.assertLogs():
create_dummy_data(data_dir, num_examples=20)
preprocess_translation_data(data_dir)
ckpt_logs = _train(["--checkpoint-activations"])
baseline_logs = _train([])
assert len(baseline_logs) == len(ckpt_logs)

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import random
import unittest
from multiprocessing import Manager
@ -141,16 +142,12 @@ class TestBMUF(unittest.TestCase):
def bmuf_process(self, cfg, args, iterations):
processes = []
results = Manager().dict()
ctx = torch.multiprocessing.get_context("spawn")
for rank in range(args.distributed_world_size):
p = ctx.Process(
target=single_gpu_training, args=(cfg, args, rank, iterations, results)
)
p.start()
processes.append(p)
for p in processes:
p.join()
torch.multiprocessing.spawn(
fn=functools.partial(single_gpu_training, cfg, args),
args=(iterations, results),
nprocs=args.distributed_world_size,
join=True,
)
return results
def test_bmuf_sync(self):