fix restoring from middle of epoch; fix defaulting transformer dropout params

This commit is contained in:
alexeib 2018-05-27 23:40:14 +01:00 committed by Myle Ott
parent 386847ee51
commit 978c125aee
5 changed files with 21 additions and 47 deletions

View File

@ -10,4 +10,3 @@ from .token_block_dataset import TokenBlockDataset
from .language_dataset import LanguageDatasets
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .offset_dataset import OffsetDataset

View File

@ -1,32 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data import Dataset
class OffsetDataset(Dataset):
""" Wraps an existing dataset, but starts iterating from a particular offset """
def __init__(self, dataset, offset):
"""
Args:
dataset: Dataset to wrap
offset: An integer. offset from which to start iterating
"""
super().__init__()
assert len(dataset) >= offset
self.dataset = dataset
self.offset = offset
def __getitem__(self, i):
return self.dataset[i + self.offset]
def __len__(self):
return len(self.dataset) - self.offset

View File

@ -31,11 +31,11 @@ class TransformerModel(FairseqModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
@ -399,6 +399,9 @@ def base_architecture(args):
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.attention_dropout = getattr(args, 'relu_dropout', 0.)
args.attention_dropout = getattr(args, 'dropout', 0.1)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')

View File

@ -6,6 +6,8 @@
# can be found in the PATENTS file in the same directory.
import unittest
import itertools
from unittest.mock import MagicMock, patch
import train
@ -19,10 +21,8 @@ def mock_trainer(epoch, num_updates):
def mock_loader(length):
ds = MagicMock()
ds.__len__.return_value = length
loader = MagicMock()
loader.__next__.return_value = ds
loader.__next__.return_value = list(range(length))
return loader
@ -42,16 +42,14 @@ class TestLoadCheckpoint(unittest.TestCase):
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 50)
self.assertNotIsInstance(ds, MagicMock)
self.assertEqual(next(ds), 50)
def test_load_full_checkpoint(self):
trainer = mock_trainer(2, 150)
loader = mock_loader(150)
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 2)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
self.assertEqual(next(iter(ds)), 0)
def test_load_no_checkpoint(self):
trainer = mock_trainer(0, 0)
@ -60,8 +58,7 @@ class TestLoadCheckpoint(unittest.TestCase):
epoch, ds = train.load_checkpoint(MagicMock(), trainer, loader)
self.assertEqual(epoch, 1)
self.assertEqual(len(ds), 150)
self.assertIsInstance(ds, MagicMock)
self.assertEqual(next(iter(ds)), 0)
def tearDown(self):
patch.stopall()

View File

@ -11,8 +11,10 @@ import os
import math
import torch
from itertools import islice
from fairseq import criterions, models, options, progress_bar
from fairseq.data import data_utils, data_loaders, OffsetDataset
from fairseq.data import data_utils, data_loaders
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
@ -323,7 +325,12 @@ def load_checkpoint(args, trainer, train_dataloader):
updates += len(ds)
if ds is not None and updates > trainer_updates:
ds = OffsetDataset(ds, updates - trainer_updates)
completed_batches = len(ds) - (updates - trainer_updates)
assert completed_batches >= 0
ds = iter(ds)
# consume completed batches
next(islice(ds, completed_batches, completed_batches), None)
else:
ds = next(train_dataloader)
epoch += 1