Doc improvements

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1606

Differential Revision: D19330727

Pulled By: myleott

fbshipit-source-id: dc6d100e42566efbc2ebc955689878ed8a820861
This commit is contained in:
Myle Ott 2020-01-09 13:26:21 -08:00 committed by Facebook Github Bot
parent 075a4a5263
commit 0790c0cfc3
4 changed files with 25 additions and 4 deletions

View File

@ -96,13 +96,28 @@ batch) or `--tokens-per-sample` (max sequence length). You can also adjust
number of GPUs.
### 3) Evaluate
```bash
fairseq-eval-lm data-bin/wikitext-103 \
--path checkpoints/transformer_wiki103/checkpoint_best.pt \
--sample-break-mode complete --max-tokens 3072 \
--context-window 2560 --softmax-batch 1024
--max-sentences 2 \
--tokens-per-sample 512 \
--context-window 400
# | Evaluated 245569 tokens in 56.1s (4379.02 tokens/s)
# | Loss: 3.4164, Perplexity: 30.46
```
*Note:* The `--context-window` option controls how much context is provided to
each token when computing perplexity. When the window size is 0, the dataset is
chunked into segments of length 512 and perplexity is computed over each segment
normally. However, this results in worse (higher) perplexity since tokens that
appear earlier in each segment have less conditioning. When the maximum window
size is used (511 in this case), then we compute perplexity for each token
fully conditioned on 511 tokens of context. This slows down evaluation
significantly, since we must run a separate forward pass for every token in the
dataset, but results in better (lower) perplexity.
## Convolutional language models
Please see the [convolutional LM README](conv_lm/README.md) for instructions to

View File

@ -327,6 +327,11 @@ def add_distributed_training_args(parser):
help='which GPU to use (usually configured automatically)')
group.add_argument('--distributed-no-spawn', action='store_true',
help='do not spawn multiple processes even if multiple GPUs are visible')
# "c10d" is PyTorch's DDP implementation and provides the fastest
# training. "no_c10d" is a more robust, but slightly slower DDP
# implementation. Try this if you get warning messages about
# inconsistent gradients between workers, or if some of your model
# parameters are not always used.
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'],
help='DistributedDataParallel backend')

View File

@ -402,7 +402,8 @@ class Trainer(object):
):
raise RuntimeError(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=no_c10d."
"Try --ddp-backend=no_c10d, which is a more robust but "
"slightly slower DDP implementation."
)
self.meters["oom"].update(ooms, len(samples))

View File

@ -348,7 +348,7 @@ def cli_main():
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_rank = None # set based on device id
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
print('| NOTE: you may get faster training with: --ddp-backend=no_c10d')
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),