Don't generate during training, add --quiet to generate.py

This commit is contained in:
Sergey Edunov 2017-10-11 16:11:28 -07:00
parent a8260d52d6
commit 8f058ea0fb
5 changed files with 13 additions and 33 deletions

View File

@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, src=None, dst=None):
def load_with_check(path, load_splits, src=None, dst=None):
"""Loads the train, valid, and test sets from the specified folder
and check that training files exist."""
@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None):
else:
raise ValueError('training file not found for {}-{}'.format(src, dst))
dataset = load(path, src, dst)
dataset = load(path, load_splits, src, dst)
return dataset
def load(path, src, dst):
"""Loads the train, valid, and test sets from the specified folder."""
def load(path, load_splits, src, dst):
"""Loads specified data splits (e.g. test, train or valid) from the path."""
langcode = '{}-{}'.format(src, dst)
@ -59,7 +59,7 @@ def load(path, src, dst):
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in ['train', 'valid', 'test']:
for split in load_splits:
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)

View File

@ -104,6 +104,8 @@ def add_generation_args(parser):
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unk-replace-dict', default='', type=str,
help='performs unk word replacement')
group.add_argument('--quiet', action='store_true',
help='Only print final scores')
return group

View File

@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
return epoch, batch_offset
def load_ensemble_for_inference(filenames, data_path):
def load_ensemble_for_inference(filenames, data_path, split):
# load model architectures and weights
states = []
for filename in filenames:
@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path):
# load dataset
args = states[0]['args']
dataset = data.load(data_path, args.source_lang, args.target_lang)
dataset = data.load(data_path, [split], args.source_lang, args.target_lang)
# build models
ensemble = []

View File

@ -38,7 +38,7 @@ def main():
# Load model and dataset
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, dataset = utils.load_ensemble_for_inference(args.path, args.data)
models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
@ -81,6 +81,8 @@ def main():
bpe_symbol = '@@ ' if args.remove_bpe else None
def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet:
return
id_str = '' if id is None else '-{}'.format(id)
src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
print('S{}\t{}'.format(id_str, src_str))

View File

@ -29,9 +29,6 @@ def main():
dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list ofdata subsets '
' to use for validation (train, valid, valid1,test, test1)')
dataset_args.add_argument('--test-subset', default='test', metavar='SPLIT',
help='comma separated list ofdata subset '
'to use for testing (train, valid, test)')
options.add_optimization_args(parser)
options.add_checkpoint_args(parser)
options.add_model_args(parser)
@ -48,7 +45,7 @@ def main():
torch.manual_seed(args.seed)
# Load dataset
dataset = data.load_with_check(args.data, args.source_lang, args.target_lang)
dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
@ -100,13 +97,6 @@ def main():
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
# Generate on test set and compute BLEU score
for beam in [1, 5, 10, 20]:
for subset in args.test_subset.split(','):
scorer = score_test(args, trainer.get_model(), dataset, subset, beam,
cuda_device=(0 if num_gpus > 0 else None))
print('| Test on {} with beam={}: {}'.format(subset, beam, scorer.result_string()))
# Stop multiprocessing
trainer.stop()
@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
return val_loss
def score_test(args, model, dataset, subset, beam, cuda_device):
"""Evaluate the model on the test set and return the BLEU scorer."""
translator = SequenceGenerator([model], dataset.dst_dict, beam_size=beam)
if torch.cuda.is_available():
translator.cuda()
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(subset, batch_size=4, max_positions=args.max_positions)
for _, _, ref, hypos in translator.generate_batched_itr(itr, cuda_device=cuda_device):
scorer.add(ref.int().cpu(), hypos[0]['tokens'].int().cpu())
return scorer
if __name__ == '__main__':
main()