mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-22 06:39:29 +03:00
Don't generate during training, add --quiet to generate.py
This commit is contained in:
parent
a8260d52d6
commit
8f058ea0fb
@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary
|
||||
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
|
||||
|
||||
|
||||
def load_with_check(path, src=None, dst=None):
|
||||
def load_with_check(path, load_splits, src=None, dst=None):
|
||||
"""Loads the train, valid, and test sets from the specified folder
|
||||
and check that training files exist."""
|
||||
|
||||
@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None):
|
||||
else:
|
||||
raise ValueError('training file not found for {}-{}'.format(src, dst))
|
||||
|
||||
dataset = load(path, src, dst)
|
||||
dataset = load(path, load_splits, src, dst)
|
||||
return dataset
|
||||
|
||||
|
||||
def load(path, src, dst):
|
||||
"""Loads the train, valid, and test sets from the specified folder."""
|
||||
def load(path, load_splits, src, dst):
|
||||
"""Loads specified data splits (e.g. test, train or valid) from the path."""
|
||||
|
||||
langcode = '{}-{}'.format(src, dst)
|
||||
|
||||
@ -59,7 +59,7 @@ def load(path, src, dst):
|
||||
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
|
||||
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
|
||||
|
||||
for split in ['train', 'valid', 'test']:
|
||||
for split in load_splits:
|
||||
for k in itertools.count():
|
||||
prefix = "{}{}".format(split, k if k > 0 else '')
|
||||
src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
|
||||
|
@ -104,6 +104,8 @@ def add_generation_args(parser):
|
||||
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
|
||||
group.add_argument('--unk-replace-dict', default='', type=str,
|
||||
help='performs unk word replacement')
|
||||
group.add_argument('--quiet', action='store_true',
|
||||
help='Only print final scores')
|
||||
|
||||
return group
|
||||
|
||||
|
@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
|
||||
return epoch, batch_offset
|
||||
|
||||
|
||||
def load_ensemble_for_inference(filenames, data_path):
|
||||
def load_ensemble_for_inference(filenames, data_path, split):
|
||||
# load model architectures and weights
|
||||
states = []
|
||||
for filename in filenames:
|
||||
@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path):
|
||||
|
||||
# load dataset
|
||||
args = states[0]['args']
|
||||
dataset = data.load(data_path, args.source_lang, args.target_lang)
|
||||
dataset = data.load(data_path, [split], args.source_lang, args.target_lang)
|
||||
|
||||
# build models
|
||||
ensemble = []
|
||||
|
@ -38,7 +38,7 @@ def main():
|
||||
|
||||
# Load model and dataset
|
||||
print('| loading model(s) from {}'.format(', '.join(args.path)))
|
||||
models, dataset = utils.load_ensemble_for_inference(args.path, args.data)
|
||||
models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset)
|
||||
|
||||
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
|
||||
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
|
||||
@ -81,6 +81,8 @@ def main():
|
||||
|
||||
bpe_symbol = '@@ ' if args.remove_bpe else None
|
||||
def display_hypotheses(id, src, orig, ref, hypos):
|
||||
if args.quiet:
|
||||
return
|
||||
id_str = '' if id is None else '-{}'.format(id)
|
||||
src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
|
||||
print('S{}\t{}'.format(id_str, src_str))
|
||||
|
26
train.py
26
train.py
@ -29,9 +29,6 @@ def main():
|
||||
dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
|
||||
help='comma separated list ofdata subsets '
|
||||
' to use for validation (train, valid, valid1,test, test1)')
|
||||
dataset_args.add_argument('--test-subset', default='test', metavar='SPLIT',
|
||||
help='comma separated list ofdata subset '
|
||||
'to use for testing (train, valid, test)')
|
||||
options.add_optimization_args(parser)
|
||||
options.add_checkpoint_args(parser)
|
||||
options.add_model_args(parser)
|
||||
@ -48,7 +45,7 @@ def main():
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Load dataset
|
||||
dataset = data.load_with_check(args.data, args.source_lang, args.target_lang)
|
||||
dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
|
||||
if args.source_lang is None or args.target_lang is None:
|
||||
# record inferred languages in args, so that it's saved in checkpoints
|
||||
args.source_lang, args.target_lang = dataset.src, dataset.dst
|
||||
@ -100,13 +97,6 @@ def main():
|
||||
train_meter.stop()
|
||||
print('| done training in {:.1f} seconds'.format(train_meter.sum))
|
||||
|
||||
# Generate on test set and compute BLEU score
|
||||
for beam in [1, 5, 10, 20]:
|
||||
for subset in args.test_subset.split(','):
|
||||
scorer = score_test(args, trainer.get_model(), dataset, subset, beam,
|
||||
cuda_device=(0 if num_gpus > 0 else None))
|
||||
print('| Test on {} with beam={}: {}'.format(subset, beam, scorer.result_string()))
|
||||
|
||||
# Stop multiprocessing
|
||||
trainer.stop()
|
||||
|
||||
@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
|
||||
return val_loss
|
||||
|
||||
|
||||
def score_test(args, model, dataset, subset, beam, cuda_device):
|
||||
"""Evaluate the model on the test set and return the BLEU scorer."""
|
||||
|
||||
translator = SequenceGenerator([model], dataset.dst_dict, beam_size=beam)
|
||||
if torch.cuda.is_available():
|
||||
translator.cuda()
|
||||
|
||||
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
|
||||
itr = dataset.dataloader(subset, batch_size=4, max_positions=args.max_positions)
|
||||
for _, _, ref, hypos in translator.generate_batched_itr(itr, cuda_device=cuda_device):
|
||||
scorer.add(ref.int().cpu(), hypos[0]['tokens'].int().cpu())
|
||||
return scorer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user