Modularize generate.py (#351)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/351

This makes it easier for tasks to plugin to generate.py/interactive.py
Pull Request resolved: https://github.com/pytorch/fairseq/pull/520

Differential Revision: D14183881

Pulled By: myleott

fbshipit-source-id: ede5e53ddc1215ed3b12b8f1eba048c946913c33
This commit is contained in:
Myle Ott 2019-02-22 10:06:22 -08:00 committed by Facebook Github Bot
parent 08e866f977
commit b65c579bed
11 changed files with 371 additions and 398 deletions

View File

@ -76,6 +76,8 @@ def main(parsed_args):
model.make_generation_fast_()
if args.fp16:
model.half()
if use_cuda:
model.cuda()
assert len(models) > 0
@ -95,9 +97,7 @@ def main(parsed_args):
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()
scorer = SequenceScorer(task.target_dictionary)
score_sum = 0.
count = 0
@ -113,10 +113,18 @@ def main(parsed_args):
word_stats = dict()
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue
gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])
for hypos_i in hypos:
hypo = hypos_i[0]
pos_scores = hypo['positional_scores']
skipped_toks = 0
@ -162,7 +170,7 @@ def main(parsed_args):
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0))
wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count

View File

@ -69,9 +69,6 @@ class BacktranslationDataset(FairseqDataset):
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
@ -82,16 +79,12 @@ class BacktranslationDataset(FairseqDataset):
self,
tgt_dataset,
backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater
self.cuda = cuda if torch.cuda.is_available() else False
@ -130,12 +123,7 @@ class BacktranslationDataset(FairseqDataset):
samples=samples,
collate_fn=self.tgt_dataset.collater,
generate_fn=(
lambda net_input: self.backtranslation_fn(
net_input,
maxlen=int(
self.max_len_a * net_input['src_tokens'].size(1) + self.max_len_b
),
)
lambda net_input: self.backtranslation_fn(net_input)
),
cuda=self.cuda,
)

View File

@ -15,18 +15,34 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1.,
diverse_beam_groups=-1, diverse_beam_strength=0.5,
match_source_len=False, no_repeat_ngram_size=0
self,
tgt_dict,
beam_size=1,
max_len_a=0,
max_len_b=200,
min_len=1,
stop_early=True,
normalize_scores=True,
len_penalty=1.,
unk_penalty=0.,
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_temperature=1.,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
match_source_len=False,
no_repeat_ngram_size=0,
):
"""Generates translations of a given source sentence.
Args:
tgt_dict (~fairseq.data.Dictionary): target dictionary
beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will
be bounded by minlen and maxlen (not including end-of-sentence)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
stop_early (bool, optional): stop generation immediately after we
finalize beam_size hypotheses, even though longer hypotheses
might have better normalized scores (default: True)
@ -50,16 +66,16 @@ class SequenceGenerator(object):
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
self.models = models
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models)
max_decoder_len -= 1 # we define maxlen not including the EOS marker
self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
@ -81,109 +97,51 @@ class SequenceGenerator(object):
else:
self.search = search.BeamSearch(tgt_dict)
def cuda(self):
for model in self.models:
model.cuda()
return self
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b (int, optional): generate sequences of maximum length
``ax + b``, where ``x`` is the source sentence length.
cuda (bool, optional): use GPU for generation
timer (StopwatchMeter, optional): time generations
prefix_size (int, optional): prefill the generation with the gold
prefix up to this length.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if 'net_input' not in s:
continue
input = s['net_input']
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in input.items()
if k != 'prev_output_tokens'
}
srclen = encoder_input['src_tokens'].size(1)
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
encoder_input,
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data):
# remove padding
src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
@torch.no_grad()
def generate(self, models, sample=None, net_input=None, prefix_tokens=None, **kwargs):
"""Generate a batch of translations.
Args:
encoder_input (dict): dictionary containing the inputs to
*model.encoder.forward*.
beam_size (int, optional): overriding the beam size
(default: *self.beam_size*).
max_len (int, optional): maximum length of the generated sequence
prefix_tokens (LongTensor, optional): force decoder to begin with
these tokens
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
"""
with torch.no_grad():
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
model = EnsembleModel(models)
if not self.retain_dropout:
model.eval()
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
}
def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""See generate"""
src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
bsz, src_len = src_tokens.size()
beam_size = self.beam_size
if self.match_source_len:
maxlen = src_lengths.max().item()
max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = []
incremental_states = {}
for model in self.models:
if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder):
incremental_states[model] = {}
else:
incremental_states[model] = None
# compute the encoder output for each beam
if hasattr(model, 'encoder'):
encoder_out = model.encoder(**encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
else:
encoder_out = None
encoder_outs.append(encoder_out)
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
# initialize buffers
scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
attn, attn_buf = None, None
@ -218,13 +176,13 @@ class SequenceGenerator(object):
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early or step == maxlen or unfinalized_scores is None:
if self.stop_early or step == max_len or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
best_unfinalized_score /= maxlen ** self.len_penalty
best_unfinalized_score /= max_len ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return False
@ -326,20 +284,17 @@ class SequenceGenerator(object):
reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
if encoder_outs is not None and hasattr(model, 'encoder'):
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
model.reorder_incremental_state(reorder_state)
model.reorder_encoder_out(encoder_outs, reorder_state)
lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
lprobs, avg_attn_scores = model.forward_decoder(tokens[:, :step + 1], encoder_outs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
@ -356,7 +311,7 @@ class SequenceGenerator(object):
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
@ -365,7 +320,7 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(lprobs)
eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen:
if step < max_len:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
@ -387,9 +342,9 @@ class SequenceGenerator(object):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
index=prefix_tokens[:, step].view(-1, 1)
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size)
cand_beams = torch.zeros_like(cand_indices)
else:
cand_scores, cand_indices, cand_beams = self.search.step(
@ -401,7 +356,7 @@ class SequenceGenerator(object):
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, step - 1].unsqueeze(-1))
# finalize all active hypotheses once we hit maxlen
# finalize all active hypotheses once we hit max_len
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
lprobs[:, self.eos],
@ -421,7 +376,7 @@ class SequenceGenerator(object):
eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen:
if step >= self.min_len:
# only consider eos when it's among the top beam_size indices
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
@ -440,7 +395,7 @@ class SequenceGenerator(object):
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
assert step < max_len
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
@ -543,14 +498,38 @@ class SequenceGenerator(object):
return finalized
def _decode(self, tokens, encoder_outs, incremental_states):
class EnsembleModel(torch.nn.Module):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__()
self.models = torch.nn.ModuleList(models)
self.incremental_states = None
if all(isinstance(m.decoder, FairseqIncrementalDecoder) for m in models):
self.incremental_states = {m: {} for m in models}
def has_encoder(self):
return hasattr(self.models[0], 'encoder')
def max_decoder_positions(self):
return min(m.max_decoder_positions() for m in self.models)
@torch.no_grad()
def forward_encoder(self, encoder_input):
if not self.has_encoder():
return None
return [model.encoder(**encoder_input) for model in self.models]
@torch.no_grad()
def forward_decoder(self, tokens, encoder_outs):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
return self._decode_one(tokens, self.models[0], encoder_outs[0], self.incremental_states, log_probs=True)
log_probs = []
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=True)
probs, attn = self._decode_one(tokens, model, encoder_out, self.incremental_states, log_probs=True)
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
@ -563,19 +542,32 @@ class SequenceGenerator(object):
return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :]
attn = decoder_out[1]
if self.incremental_states is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :]
attn = decoder_out[1]
if type(attn) is dict:
attn = attn['attn']
if attn is not None:
if type(attn) is dict:
attn = attn['attn']
if attn is not None:
if type(attn) is dict:
attn = attn['attn']
attn = attn[:, -1, :]
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :]
return probs, attn
def reorder_encoder_out(self, encoder_outs, new_order):
if not self.has_encoder():
return
return [
model.encoder.reorder_encoder_out(encoder_out, new_order)
for model, encoder_out in zip(self.models, encoder_outs)
]
def reorder_incremental_state(self, new_order):
if self.incremental_states is None:
return
for model in self.models:
model.decoder.reorder_incremental_state(self.incremental_states[model], new_order)

View File

@ -13,60 +13,23 @@ from fairseq import utils
class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(self, models, tgt_dict):
self.models = models
def __init__(self, tgt_dict):
self.pad = tgt_dict.pad()
def cuda(self):
for model in self.models:
model.cuda()
return self
def score_batched_itr(self, data_itr, cuda=False, timer=None):
"""Iterate over a batched dataset and yield scored translations."""
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if timer is not None:
timer.start()
pos_scores, attn = self.score(s)
for i, id in enumerate(s['id'].data):
# remove padding from ref
src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len
if attn is not None:
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{
'tokens': ref,
'score': score_i,
'attention': attn_i,
'alignment': alignment,
'positional_scores': pos_scores_i,
}]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator
yield id, src, ref, hypos
def score(self, sample):
@torch.no_grad()
def generate(self, models, sample, **kwargs):
"""Score a batch of translations."""
net_input = sample['net_input']
# compute scores for each model in the ensemble
avg_probs = None
avg_attn = None
for model in self.models:
with torch.no_grad():
model.eval()
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
for model in models:
model.eval()
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=len(self.models) == 1, sample=sample).data
probs = model.get_normalized_probs(decoder_out, log_probs=len(models) == 1, sample=sample)
if avg_probs is None:
avg_probs = probs
else:
@ -77,13 +40,33 @@ class SequenceScorer(object):
avg_attn = attn
else:
avg_attn.add_(attn)
if len(self.models) > 1:
avg_probs.div_(len(self.models))
if len(models) > 1:
avg_probs.div_(len(models))
avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(self.models))
avg_attn.div_(len(models))
avg_probs = avg_probs.gather(
dim=2,
index=sample['target'].data.unsqueeze(-1),
)
return avg_probs.squeeze(2), avg_attn
index=sample['target'].unsqueeze(-1),
).squeeze(2)
hypos = []
for i in range(avg_probs.size(0)):
# remove padding from ref
ref = utils.strip_pad(sample['target'][i, :], self.pad) if sample['target'] is not None else None
tgt_len = ref.numel()
avg_probs_i = avg_probs[i][:tgt_len]
score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None:
avg_attn_i = avg_attn[i]
_, alignment = avg_attn_i.max(dim=0)
else:
avg_attn_i = alignment = None
hypos.append([{
'tokens': ref,
'score': score_i,
'attention': avg_attn_i,
'alignment': alignment,
'positional_scores': avg_probs_i,
}])
return hypos

View File

@ -180,6 +180,31 @@ class FairseqTask(object):
from fairseq import criterions
return criterions.build_criterion(args, self)
def build_generator(self, args):
if args.score_reference:
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(self.target_dictionary)
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
self.target_dictionary,
beam_size=args.beam,
max_len_a=args.max_len_a,
max_len_b=args.max_len_b,
min_len=args.min_len,
stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen,
unk_penalty=args.unkpen,
sampling=args.sampling,
sampling_topk=args.sampling_topk,
sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
"""
Do forward and backward, and return the loss as computed by *criterion*
@ -214,11 +239,9 @@ class FairseqTask(object):
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def init_logging_output(self, sample):
return {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes)

View File

@ -13,8 +13,6 @@ import torch
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
@ -59,6 +57,8 @@ def main(args):
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
@ -82,20 +82,7 @@ def main(args):
# Initialize generator
gen_timer = StopwatchMeter()
if args.score_reference:
translator = SequenceScorer(models, task.target_dictionary)
else:
translator = SequenceGenerator(
models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if use_cuda:
translator.cuda()
generator = task.build_generator(args)
# Generate and compute BLEU score
if args.sacrebleu:
@ -105,79 +92,89 @@ def main(args):
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference:
translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
else:
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample['target'][:, :args.prefix_size]
gen_timer.start()
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample['id'].tolist()):
has_target = sample['target'] is not None
# Remove padding
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
target_tokens = None
if has_target:
target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else:
src_str = ""
if has_target:
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
if src_dict is not None:
print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
else:
src_str = ""
if has_target:
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
sample_id,
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
if src_dict is not None:
print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
if args.print_alignment:
print('A-{}\t{}'.format(
# Process top predictions
for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
# Score only the top hypothesis
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
if args.print_alignment:
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
wps_meter.update(src_tokens.size(0))
# Score only the top hypothesis
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(num_generated_tokens)
t.log({'wps': round(wps_meter.avg)})
num_sentences += 1
num_sentences += sample['nsentences']
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))

View File

@ -13,24 +13,24 @@ from collections import namedtuple
import fileinput
import sys
import numpy as np
import torch
from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'srcs tokens lengths')
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(input, buffer_size):
buffer = []
for src_str in fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")):
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
for src_str in h:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
@ -41,7 +41,7 @@ def make_batches(lines, args, task, max_positions):
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
for src_str in lines
]
lengths = np.array([t.numel() for t in tokens])
lengths = torch.LongTensor([t.numel() for t in tokens])
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens,
@ -50,10 +50,9 @@ def make_batches(lines, args, task, max_positions):
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
srcs=[lines[i] for i in batch['id']],
tokens=batch['net_input']['src_tokens'],
lengths=batch['net_input']['src_lengths'],
), batch['id']
ids=batch['id'],
src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
)
def main(args):
@ -83,6 +82,7 @@ def main(args):
)
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
@ -93,71 +93,16 @@ def main(args):
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Initialize generator
translator = SequenceGenerator(
models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if use_cuda:
translator.cuda()
generator = task.build_generator(args)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
def make_result(src_str, hypos):
result = Translation(
src_str='O\t{}'.format(src_str),
hypos=[],
pos_scores=[],
alignments=[],
)
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.pos_scores.append('P\t{}'.format(
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result
def process_batch(batch):
tokens = batch.tokens
lengths = batch.lengths
if use_cuda:
tokens = tokens.cuda()
lengths = lengths.cuda()
encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
translations = translator.generate(
encoder_input,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
max_positions = utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
@ -166,21 +111,55 @@ def main(args):
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
start_id = 0
for inputs in buffered_read(args.input, args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices)
results.extend(process_batch(batch))
for batch in make_batches(inputs, args, task, max_positions):
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
print(hypo)
print(pos_scores)
if align is not None:
print(align)
sample = {
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
},
}
translations = task.inference_step(generator, models, sample)
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
results.append((start_id + id, src_tokens_i, hypos))
# sort output to match input order
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
src_str = src_dict.string(src_tokens, args.remove_bpe)
print('S-{}\t{}'.format(id, src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
id,
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if args.print_alignment:
print('A-{}\t{}'.format(
id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
# update running id counter
start_id += len(results)
def cli_main():

View File

@ -42,7 +42,8 @@ setup(
install_requires=[
'cffi',
'numpy',
'torch',
# don't include torch, to support both release and nightly builds
#'torch',
'tqdm',
],
packages=find_packages(exclude=['scripts', 'tests']),

View File

@ -44,14 +44,13 @@ class TestBacktranslationDataset(unittest.TestCase):
)
generator = SequenceGenerator(
models=[self.model],
tgt_dict=self.tgt_dict,
max_len_a=0,
max_len_b=200,
beam_size=2,
unk_penalty=0,
sampling=False,
)
if self.cuda:
generator.cuda()
backtranslation_dataset = BacktranslationDataset(
tgt_dataset=TransformEosDataset(
@ -60,9 +59,9 @@ class TestBacktranslationDataset(unittest.TestCase):
# remove eos from the input src
remove_eos_from_src=remove_eos_from_input_src,
),
backtranslation_fn=generator.generate,
max_len_a=0,
max_len_b=200,
backtranslation_fn=(
lambda net_input: generator.generate([self.model], {'net_input': net_input})
),
output_collater=TransformEosDataset(
dataset=tgt_dataset,
eos=self.tgt_dict.eos(),

View File

@ -21,13 +21,15 @@ class TestSequenceGenerator(unittest.TestCase):
self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
test_utils.sequence_generator_setup()
)
self.encoder_input = {
'src_tokens': src_tokens, 'src_lengths': src_lengths,
self.sample = {
'net_input': {
'src_tokens': src_tokens, 'src_lengths': src_lengths,
},
}
def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.encoder_input, beam_size=2)
generator = SequenceGenerator(self.tgt_dict, beam_size=2)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -45,8 +47,8 @@ class TestSequenceGenerator(unittest.TestCase):
def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.encoder_input, beam_size=2)
generator = SequenceGenerator(self.tgt_dict, beam_size=2, normalize_scores=False)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -63,8 +65,8 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2)
generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -81,8 +83,8 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2)
generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
@ -98,8 +100,8 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.encoder_input, beam_size=2)
generator = SequenceGenerator(self.tgt_dict, beam_size=2, max_len_b=2)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -115,8 +117,8 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.encoder_input, beam_size=2)
generator = SequenceGenerator(self.tgt_dict, stop_early=False, beam_size=2)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -212,11 +214,10 @@ class TestDiverseBeamSearch(unittest.TestCase):
def test_diverse_beam_search(self):
generator = SequenceGenerator(
[self.model], self.tgt_dict,
beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
self.tgt_dict, beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
)
encoder_input = {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}
hypos = generator.generate(encoder_input)
sample = {'net_input': {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}}
hypos = generator.generate([self.model], sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])

View File

@ -85,10 +85,12 @@ class TestSequenceScorer(unittest.TestCase):
task = test_utils.TestTranslationTask.setup_task(args, d, d)
model = task.build_model(args)
scorer = SequenceScorer([model], task.target_dictionary)
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id])
scorer = SequenceScorer(task.target_dictionary)
for sample in data_itr:
hypos = task.inference_step(scorer, [model], sample)
for id, hypos_id in zip(sample['id'].tolist(), hypos):
self.assertHypoTokens(hypos_id[0], data[id]['target'])
self.assertHypoScore(hypos_id[0], expected_scores[id])
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))