Modularize generate.py (#351)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/351

This makes it easier for tasks to plugin to generate.py/interactive.py
Pull Request resolved: https://github.com/pytorch/fairseq/pull/520

Differential Revision: D14183881

Pulled By: myleott

fbshipit-source-id: ede5e53ddc1215ed3b12b8f1eba048c946913c33
This commit is contained in:
Myle Ott 2019-02-22 10:06:22 -08:00 committed by Facebook Github Bot
parent 08e866f977
commit b65c579bed
11 changed files with 371 additions and 398 deletions

View File

@ -76,6 +76,8 @@ def main(parsed_args):
model.make_generation_fast_() model.make_generation_fast_()
if args.fp16: if args.fp16:
model.half() model.half()
if use_cuda:
model.cuda()
assert len(models) > 0 assert len(models) > 0
@ -95,9 +97,7 @@ def main(parsed_args):
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary) scorer = SequenceScorer(task.target_dictionary)
if use_cuda:
scorer.cuda()
score_sum = 0. score_sum = 0.
count = 0 count = 0
@ -113,10 +113,18 @@ def main(parsed_args):
word_stats = dict() word_stats = dict()
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results: for sample in t:
for hypo in hypos: sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue
gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])
for hypos_i in hypos:
hypo = hypos_i[0]
pos_scores = hypo['positional_scores'] pos_scores = hypo['positional_scores']
skipped_toks = 0 skipped_toks = 0
@ -162,7 +170,7 @@ def main(parsed_args):
if args.output_word_probs: if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0)) wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
avg_nll_loss = -score_sum / count avg_nll_loss = -score_sum / count

View File

@ -69,9 +69,6 @@ class BacktranslationDataset(FairseqDataset):
backtranslation_fn (callable): function to call to generate backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object. :class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``). (default: ``tgt_dataset.collater``).
@ -82,16 +79,12 @@ class BacktranslationDataset(FairseqDataset):
self, self,
tgt_dataset, tgt_dataset,
backtranslation_fn, backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None, output_collater=None,
cuda=True, cuda=True,
**kwargs **kwargs
): ):
self.tgt_dataset = tgt_dataset self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.output_collater = output_collater if output_collater is not None \ self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater else tgt_dataset.collater
self.cuda = cuda if torch.cuda.is_available() else False self.cuda = cuda if torch.cuda.is_available() else False
@ -130,12 +123,7 @@ class BacktranslationDataset(FairseqDataset):
samples=samples, samples=samples,
collate_fn=self.tgt_dataset.collater, collate_fn=self.tgt_dataset.collater,
generate_fn=( generate_fn=(
lambda net_input: self.backtranslation_fn( lambda net_input: self.backtranslation_fn(net_input)
net_input,
maxlen=int(
self.max_len_a * net_input['src_tokens'].size(1) + self.max_len_b
),
)
), ),
cuda=self.cuda, cuda=self.cuda,
) )

View File

@ -15,18 +15,34 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__( def __init__(
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True, self,
normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False, tgt_dict,
sampling=False, sampling_topk=-1, sampling_temperature=1., beam_size=1,
diverse_beam_groups=-1, diverse_beam_strength=0.5, max_len_a=0,
match_source_len=False, no_repeat_ngram_size=0 max_len_b=200,
min_len=1,
stop_early=True,
normalize_scores=True,
len_penalty=1.,
unk_penalty=0.,
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_temperature=1.,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
match_source_len=False,
no_repeat_ngram_size=0,
): ):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
tgt_dict (~fairseq.data.Dictionary): target dictionary
beam_size (int, optional): beam width (default: 1) beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will max_len_a/b (int, optional): generate sequences of maximum length
be bounded by minlen and maxlen (not including end-of-sentence) ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
stop_early (bool, optional): stop generation immediately after we stop_early (bool, optional): stop generation immediately after we
finalize beam_size hypotheses, even though longer hypotheses finalize beam_size hypotheses, even though longer hypotheses
might have better normalized scores (default: True) might have better normalized scores (default: True)
@ -50,16 +66,16 @@ class SequenceGenerator(object):
match_source_len (bool, optional): outputs should match the source match_source_len (bool, optional): outputs should match the source
length (default: False) length (default: False)
""" """
self.models = models
self.pad = tgt_dict.pad() self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk() self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos() self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict) self.vocab_size = len(tgt_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen # the max beam size is the dictionary size - 1, since we never select pad
max_decoder_len = min(m.max_decoder_positions() for m in self.models) self.beam_size = min(beam_size, self.vocab_size - 1)
max_decoder_len -= 1 # we define maxlen not including the EOS marker self.max_len_a = max_len_a
self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len) self.max_len_b = max_len_b
self.min_len = min_len
self.stop_early = stop_early self.stop_early = stop_early
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
self.len_penalty = len_penalty self.len_penalty = len_penalty
@ -81,109 +97,51 @@ class SequenceGenerator(object):
else: else:
self.search = search.BeamSearch(tgt_dict) self.search = search.BeamSearch(tgt_dict)
def cuda(self): @torch.no_grad()
for model in self.models: def generate(self, models, sample=None, net_input=None, prefix_tokens=None, **kwargs):
model.cuda()
return self
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b (int, optional): generate sequences of maximum length
``ax + b``, where ``x`` is the source sentence length.
cuda (bool, optional): use GPU for generation
timer (StopwatchMeter, optional): time generations
prefix_size (int, optional): prefill the generation with the gold
prefix up to this length.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if 'net_input' not in s:
continue
input = s['net_input']
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in input.items()
if k != 'prev_output_tokens'
}
srclen = encoder_input['src_tokens'].size(1)
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
encoder_input,
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data):
# remove padding
src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations. """Generate a batch of translations.
Args: Args:
encoder_input (dict): dictionary containing the inputs to models (List[~fairseq.models.FairseqModel]): ensemble of models
*model.encoder.forward*. sample (dict): batch
beam_size (int, optional): overriding the beam size prefix_tokens (torch.LongTensor, optional): force decoder to begin
(default: *self.beam_size*). with these tokens
max_len (int, optional): maximum length of the generated sequence
prefix_tokens (LongTensor, optional): force decoder to begin with
these tokens
""" """
with torch.no_grad(): model = EnsembleModel(models)
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens) if not self.retain_dropout:
model.eval()
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
}
def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""See generate"""
src_tokens = encoder_input['src_tokens'] src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, srclen = src_tokens.size() bsz, src_len = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen beam_size = self.beam_size
if self.match_source_len: if self.match_source_len:
maxlen = src_lengths.max().item() max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
# the max beam size is the dictionary size - 1, since we never select pad # compute the encoder output for each beam
beam_size = beam_size if beam_size is not None else self.beam_size encoder_outs = model.forward_encoder(encoder_input)
beam_size = min(beam_size, self.vocab_size - 1) new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = [] encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
incremental_states = {}
for model in self.models:
if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder):
incremental_states[model] = {}
else:
incremental_states[model] = None
# compute the encoder output for each beam
if hasattr(model, 'encoder'):
encoder_out = model.encoder(**encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
else:
encoder_out = None
encoder_outs.append(encoder_out)
# initialize buffers # initialize buffers
scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0) scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone() scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad) tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos tokens[:, 0] = self.eos
attn, attn_buf = None, None attn, attn_buf = None, None
@ -218,13 +176,13 @@ class SequenceGenerator(object):
""" """
assert len(finalized[sent]) <= beam_size assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size: if len(finalized[sent]) == beam_size:
if self.stop_early or step == maxlen or unfinalized_scores is None: if self.stop_early or step == max_len or unfinalized_scores is None:
return True return True
# stop if the best unfinalized score is worse than the worst # stop if the best unfinalized score is worse than the worst
# finalized one # finalized one
best_unfinalized_score = unfinalized_scores[sent].max() best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores: if self.normalize_scores:
best_unfinalized_score /= maxlen ** self.len_penalty best_unfinalized_score /= max_len ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score: if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True return True
return False return False
@ -326,20 +284,17 @@ class SequenceGenerator(object):
reorder_state = None reorder_state = None
batch_idxs = None batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams # reorder decoder internal states based on the prev choice of beams
if reorder_state is not None: if reorder_state is not None:
if batch_idxs is not None: if batch_idxs is not None:
# update beam indices to take into account removed sentences # update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models): model.reorder_incremental_state(reorder_state)
if isinstance(model.decoder, FairseqIncrementalDecoder): model.reorder_encoder_out(encoder_outs, reorder_state)
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
if encoder_outs is not None and hasattr(model, 'encoder'):
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
lprobs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states) lprobs, avg_attn_scores = model.forward_decoder(tokens[:, :step + 1], encoder_outs)
lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
@ -356,7 +311,7 @@ class SequenceGenerator(object):
# Record attention scores # Record attention scores
if avg_attn_scores is not None: if avg_attn_scores is not None:
if attn is None: if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2) attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone() attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad) nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores) attn[:, :, step + 1].copy_(avg_attn_scores)
@ -365,7 +320,7 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(lprobs) scores_buf = scores_buf.type_as(lprobs)
eos_bbsz_idx = buffer('eos_bbsz_idx') eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores) eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen: if step < max_len:
self.search.set_src_lengths(src_lengths) self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0: if self.no_repeat_ngram_size > 0:
@ -387,9 +342,9 @@ class SequenceGenerator(object):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather( cand_scores = torch.gather(
probs_slice, dim=1, probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data index=prefix_tokens[:, step].view(-1, 1)
).expand(-1, cand_size) ).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size)
cand_beams = torch.zeros_like(cand_indices) cand_beams = torch.zeros_like(cand_indices)
else: else:
cand_scores, cand_indices, cand_beams = self.search.step( cand_scores, cand_indices, cand_beams = self.search.step(
@ -401,7 +356,7 @@ class SequenceGenerator(object):
# make probs contain cumulative scores for each hypothesis # make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, step - 1].unsqueeze(-1)) lprobs.add_(scores[:, step - 1].unsqueeze(-1))
# finalize all active hypotheses once we hit maxlen # finalize all active hypotheses once we hit max_len
# pick the hypothesis with the highest prob of EOS right now # pick the hypothesis with the highest prob of EOS right now
torch.sort( torch.sort(
lprobs[:, self.eos], lprobs[:, self.eos],
@ -421,7 +376,7 @@ class SequenceGenerator(object):
eos_mask = cand_indices.eq(self.eos) eos_mask = cand_indices.eq(self.eos)
finalized_sents = set() finalized_sents = set()
if step >= self.minlen: if step >= self.min_len:
# only consider eos when it's among the top beam_size indices # only consider eos when it's among the top beam_size indices
torch.masked_select( torch.masked_select(
cand_bbsz_idx[:, :beam_size], cand_bbsz_idx[:, :beam_size],
@ -440,7 +395,7 @@ class SequenceGenerator(object):
assert num_remaining_sent >= 0 assert num_remaining_sent >= 0
if num_remaining_sent == 0: if num_remaining_sent == 0:
break break
assert step < maxlen assert step < max_len
if len(finalized_sents) > 0: if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents) new_bsz = bsz - len(finalized_sents)
@ -543,14 +498,38 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs, incremental_states):
class EnsembleModel(torch.nn.Module):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__()
self.models = torch.nn.ModuleList(models)
self.incremental_states = None
if all(isinstance(m.decoder, FairseqIncrementalDecoder) for m in models):
self.incremental_states = {m: {} for m in models}
def has_encoder(self):
return hasattr(self.models[0], 'encoder')
def max_decoder_positions(self):
return min(m.max_decoder_positions() for m in self.models)
@torch.no_grad()
def forward_encoder(self, encoder_input):
if not self.has_encoder():
return None
return [model.encoder(**encoder_input) for model in self.models]
@torch.no_grad()
def forward_decoder(self, tokens, encoder_outs):
if len(self.models) == 1: if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True) return self._decode_one(tokens, self.models[0], encoder_outs[0], self.incremental_states, log_probs=True)
log_probs = [] log_probs = []
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=True) probs, attn = self._decode_one(tokens, model, encoder_out, self.incremental_states, log_probs=True)
log_probs.append(probs) log_probs.append(probs)
if attn is not None: if attn is not None:
if avg_attn is None: if avg_attn is None:
@ -563,19 +542,32 @@ class SequenceGenerator(object):
return avg_probs, avg_attn return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs): def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad(): if self.incremental_states is not None:
if incremental_states[model] is not None: decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model]))
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model])) else:
else: decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out[0] = decoder_out[0][:, -1:, :]
decoder_out[0] = decoder_out[0][:, -1:, :] attn = decoder_out[1]
attn = decoder_out[1] if type(attn) is dict:
attn = attn['attn']
if attn is not None:
if type(attn) is dict: if type(attn) is dict:
attn = attn['attn'] attn = attn['attn']
if attn is not None: attn = attn[:, -1, :]
if type(attn) is dict:
attn = attn['attn']
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :] probs = probs[:, -1, :]
return probs, attn return probs, attn
def reorder_encoder_out(self, encoder_outs, new_order):
if not self.has_encoder():
return
return [
model.encoder.reorder_encoder_out(encoder_out, new_order)
for model, encoder_out in zip(self.models, encoder_outs)
]
def reorder_incremental_state(self, new_order):
if self.incremental_states is None:
return
for model in self.models:
model.decoder.reorder_incremental_state(self.incremental_states[model], new_order)

View File

@ -13,60 +13,23 @@ from fairseq import utils
class SequenceScorer(object): class SequenceScorer(object):
"""Scores the target for a given source sentence.""" """Scores the target for a given source sentence."""
def __init__(self, models, tgt_dict): def __init__(self, tgt_dict):
self.models = models
self.pad = tgt_dict.pad() self.pad = tgt_dict.pad()
def cuda(self): @torch.no_grad()
for model in self.models: def generate(self, models, sample, **kwargs):
model.cuda()
return self
def score_batched_itr(self, data_itr, cuda=False, timer=None):
"""Iterate over a batched dataset and yield scored translations."""
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if timer is not None:
timer.start()
pos_scores, attn = self.score(s)
for i, id in enumerate(s['id'].data):
# remove padding from ref
src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len
if attn is not None:
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{
'tokens': ref,
'score': score_i,
'attention': attn_i,
'alignment': alignment,
'positional_scores': pos_scores_i,
}]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator
yield id, src, ref, hypos
def score(self, sample):
"""Score a batch of translations.""" """Score a batch of translations."""
net_input = sample['net_input'] net_input = sample['net_input']
# compute scores for each model in the ensemble # compute scores for each model in the ensemble
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model in self.models: for model in models:
with torch.no_grad(): model.eval()
model.eval() decoder_out = model.forward(**net_input)
decoder_out = model.forward(**net_input) attn = decoder_out[1]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=len(self.models) == 1, sample=sample).data probs = model.get_normalized_probs(decoder_out, log_probs=len(models) == 1, sample=sample)
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
@ -77,13 +40,33 @@ class SequenceScorer(object):
avg_attn = attn avg_attn = attn
else: else:
avg_attn.add_(attn) avg_attn.add_(attn)
if len(self.models) > 1: if len(models) > 1:
avg_probs.div_(len(self.models)) avg_probs.div_(len(models))
avg_probs.log_() avg_probs.log_()
if avg_attn is not None: if avg_attn is not None:
avg_attn.div_(len(self.models)) avg_attn.div_(len(models))
avg_probs = avg_probs.gather( avg_probs = avg_probs.gather(
dim=2, dim=2,
index=sample['target'].data.unsqueeze(-1), index=sample['target'].unsqueeze(-1),
) ).squeeze(2)
return avg_probs.squeeze(2), avg_attn
hypos = []
for i in range(avg_probs.size(0)):
# remove padding from ref
ref = utils.strip_pad(sample['target'][i, :], self.pad) if sample['target'] is not None else None
tgt_len = ref.numel()
avg_probs_i = avg_probs[i][:tgt_len]
score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None:
avg_attn_i = avg_attn[i]
_, alignment = avg_attn_i.max(dim=0)
else:
avg_attn_i = alignment = None
hypos.append([{
'tokens': ref,
'score': score_i,
'attention': avg_attn_i,
'alignment': alignment,
'positional_scores': avg_probs_i,
}])
return hypos

View File

@ -180,6 +180,31 @@ class FairseqTask(object):
from fairseq import criterions from fairseq import criterions
return criterions.build_criterion(args, self) return criterions.build_criterion(args, self)
def build_generator(self, args):
if args.score_reference:
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(self.target_dictionary)
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
self.target_dictionary,
beam_size=args.beam,
max_len_a=args.max_len_a,
max_len_b=args.max_len_b,
min_len=args.min_len,
stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen,
unk_penalty=args.unkpen,
sampling=args.sampling,
sampling_topk=args.sampling_topk,
sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False): def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
""" """
Do forward and backward, and return the loss as computed by *criterion* Do forward and backward, and return the loss as computed by *criterion*
@ -214,11 +239,9 @@ class FairseqTask(object):
loss, sample_size, logging_output = criterion(model, sample) loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output return loss, sample_size, logging_output
def init_logging_output(self, sample): def inference_step(self, generator, models, sample, prefix_tokens=None):
return { with torch.no_grad():
'ntokens': sample['ntokens'] if sample is not None else 0, return generator.generate(models, sample, prefix_tokens=prefix_tokens)
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
def grad_denom(self, sample_sizes, criterion): def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes) return criterion.__class__.grad_denom(sample_sizes)

View File

@ -13,8 +13,6 @@ import torch
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module from fairseq.utils import import_user_module
@ -59,6 +57,8 @@ def main(args):
) )
if args.fp16: if args.fp16:
model.half() model.half()
if use_cuda:
model.cuda()
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
@ -82,20 +82,7 @@ def main(args):
# Initialize generator # Initialize generator
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
if args.score_reference: generator = task.build_generator(args)
translator = SequenceScorer(models, task.target_dictionary)
else:
translator = SequenceGenerator(
models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if use_cuda:
translator.cuda()
# Generate and compute BLEU score # Generate and compute BLEU score
if args.sacrebleu: if args.sacrebleu:
@ -105,79 +92,89 @@ def main(args):
num_sentences = 0 num_sentences = 0
has_target = True has_target = True
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference:
translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
else:
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations: for sample in t:
# Process input and ground truth sample = utils.move_to_cuda(sample) if use_cuda else sample
has_target = target_tokens is not None if 'net_input' not in sample:
target_tokens = target_tokens.int().cpu() if has_target else None continue
# Either retrieve the original sentences or regenerate them from tokens. prefix_tokens = None
if align_dict is not None: if args.prefix_size > 0:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) prefix_tokens = sample['target'][:, :args.prefix_size]
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else: gen_timer.start()
if src_dict is not None: hypos = task.inference_step(generator, models, sample, prefix_tokens)
src_str = src_dict.string(src_tokens, args.remove_bpe) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample['id'].tolist()):
has_target = sample['target'] is not None
# Remove padding
src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
target_tokens = None
if has_target:
target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else: else:
src_str = "" if src_dict is not None:
if has_target: src_str = src_dict.string(src_tokens, args.remove_bpe)
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) else:
src_str = ""
if not args.quiet: if has_target:
if src_dict is not None: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if not args.quiet: if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) if src_dict is not None:
print('P-{}\t{}'.format( print('S-{}\t{}'.format(sample_id, src_str))
sample_id, if has_target:
' '.join(map( print('T-{}\t{}'.format(sample_id, target_str))
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
if args.print_alignment: # Process top predictions
print('A-{}\t{}'.format( for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
sample_id, sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment)) ' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
)) ))
# Score only the top hypothesis if args.print_alignment:
if has_target and i == 0: print('A-{}\t{}'.format(
if align_dict is not None or args.remove_bpe is not None: sample_id,
# Convert back to tokens for evaluation with unk replacement and/or without BPE ' '.join(map(lambda x: str(utils.item(x)), alignment))
target_tokens = tokenizer.Tokenizer.tokenize( ))
target_str, tgt_dict, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0)) # Score only the top hypothesis
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(num_generated_tokens)
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
num_sentences += 1 num_sentences += sample['nsentences']
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))

View File

@ -13,24 +13,24 @@ from collections import namedtuple
import fileinput import fileinput
import sys import sys
import numpy as np
import torch import torch
from fairseq import data, options, tasks, tokenizer, utils from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'srcs tokens lengths') Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(input, buffer_size): def buffered_read(input, buffer_size):
buffer = [] buffer = []
for src_str in fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")): with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
buffer.append(src_str.strip()) for src_str in h:
if len(buffer) >= buffer_size: buffer.append(src_str.strip())
yield buffer if len(buffer) >= buffer_size:
buffer = [] yield buffer
buffer = []
if len(buffer) > 0: if len(buffer) > 0:
yield buffer yield buffer
@ -41,7 +41,7 @@ def make_batches(lines, args, task, max_positions):
tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long() tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
for src_str in lines for src_str in lines
] ]
lengths = np.array([t.numel() for t in tokens]) lengths = torch.LongTensor([t.numel() for t in tokens])
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths), dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
@ -50,10 +50,9 @@ def make_batches(lines, args, task, max_positions):
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
for batch in itr: for batch in itr:
yield Batch( yield Batch(
srcs=[lines[i] for i in batch['id']], ids=batch['id'],
tokens=batch['net_input']['src_tokens'], src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
lengths=batch['net_input']['src_lengths'], )
), batch['id']
def main(args): def main(args):
@ -83,6 +82,7 @@ def main(args):
) )
# Set dictionaries # Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
# Optimize ensemble for generation # Optimize ensemble for generation
@ -93,71 +93,16 @@ def main(args):
) )
if args.fp16: if args.fp16:
model.half() model.half()
if use_cuda:
model.cuda()
# Initialize generator # Initialize generator
translator = SequenceGenerator( generator = task.build_generator(args)
models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
def make_result(src_str, hypos):
result = Translation(
src_str='O\t{}'.format(src_str),
hypos=[],
pos_scores=[],
alignments=[],
)
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.pos_scores.append('P\t{}'.format(
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result
def process_batch(batch):
tokens = batch.tokens
lengths = batch.lengths
if use_cuda:
tokens = tokens.cuda()
lengths = lengths.cuda()
encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
translations = translator.generate(
encoder_input,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
max_positions = utils.resolve_max_positions( max_positions = utils.resolve_max_positions(
task.max_positions(), task.max_positions(),
*[model.max_positions() for model in models] *[model.max_positions() for model in models]
@ -166,21 +111,55 @@ def main(args):
if args.buffer_size > 1: if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size) print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:') print('| Type the input sentence and press return:')
start_id = 0
for inputs in buffered_read(args.input, args.buffer_size): for inputs in buffered_read(args.input, args.buffer_size):
indices = []
results = [] results = []
for batch, batch_indices in make_batches(inputs, args, task, max_positions): for batch in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices) src_tokens = batch.src_tokens
results.extend(process_batch(batch)) src_lengths = batch.src_lengths
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
for i in np.argsort(indices): sample = {
result = results[i] 'net_input': {
print(result.src_str) 'src_tokens': src_tokens,
for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments): 'src_lengths': src_lengths,
print(hypo) },
print(pos_scores) }
if align is not None: translations = task.inference_step(generator, models, sample)
print(align) for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
results.append((start_id + id, src_tokens_i, hypos))
# sort output to match input order
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
src_str = src_dict.string(src_tokens, args.remove_bpe)
print('S-{}\t{}'.format(id, src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
id,
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if args.print_alignment:
print('A-{}\t{}'.format(
id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
# update running id counter
start_id += len(results)
def cli_main(): def cli_main():

View File

@ -42,7 +42,8 @@ setup(
install_requires=[ install_requires=[
'cffi', 'cffi',
'numpy', 'numpy',
'torch', # don't include torch, to support both release and nightly builds
#'torch',
'tqdm', 'tqdm',
], ],
packages=find_packages(exclude=['scripts', 'tests']), packages=find_packages(exclude=['scripts', 'tests']),

View File

@ -44,14 +44,13 @@ class TestBacktranslationDataset(unittest.TestCase):
) )
generator = SequenceGenerator( generator = SequenceGenerator(
models=[self.model],
tgt_dict=self.tgt_dict, tgt_dict=self.tgt_dict,
max_len_a=0,
max_len_b=200,
beam_size=2, beam_size=2,
unk_penalty=0, unk_penalty=0,
sampling=False, sampling=False,
) )
if self.cuda:
generator.cuda()
backtranslation_dataset = BacktranslationDataset( backtranslation_dataset = BacktranslationDataset(
tgt_dataset=TransformEosDataset( tgt_dataset=TransformEosDataset(
@ -60,9 +59,9 @@ class TestBacktranslationDataset(unittest.TestCase):
# remove eos from the input src # remove eos from the input src
remove_eos_from_src=remove_eos_from_input_src, remove_eos_from_src=remove_eos_from_input_src,
), ),
backtranslation_fn=generator.generate, backtranslation_fn=(
max_len_a=0, lambda net_input: generator.generate([self.model], {'net_input': net_input})
max_len_b=200, ),
output_collater=TransformEosDataset( output_collater=TransformEosDataset(
dataset=tgt_dataset, dataset=tgt_dataset,
eos=self.tgt_dict.eos(), eos=self.tgt_dict.eos(),

View File

@ -21,13 +21,15 @@ class TestSequenceGenerator(unittest.TestCase):
self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = ( self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
test_utils.sequence_generator_setup() test_utils.sequence_generator_setup()
) )
self.encoder_input = { self.sample = {
'src_tokens': src_tokens, 'src_lengths': src_lengths, 'net_input': {
'src_tokens': src_tokens, 'src_lengths': src_lengths,
},
} }
def test_with_normalization(self): def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict) generator = SequenceGenerator(self.tgt_dict, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -45,8 +47,8 @@ class TestSequenceGenerator(unittest.TestCase):
def test_without_normalization(self): def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case # Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order # Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False) generator = SequenceGenerator(self.tgt_dict, beam_size=2, normalize_scores=False)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -63,8 +65,8 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self): def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6 lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -81,8 +83,8 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self): def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0 lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos]) self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
@ -98,8 +100,8 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen) self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self): def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2) generator = SequenceGenerator(self.tgt_dict, beam_size=2, max_len_b=2)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -115,8 +117,8 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01]) self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self): def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False) generator = SequenceGenerator(self.tgt_dict, stop_early=False, beam_size=2)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -212,11 +214,10 @@ class TestDiverseBeamSearch(unittest.TestCase):
def test_diverse_beam_search(self): def test_diverse_beam_search(self):
generator = SequenceGenerator( generator = SequenceGenerator(
[self.model], self.tgt_dict, self.tgt_dict, beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
) )
encoder_input = {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths} sample = {'net_input': {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}}
hypos = generator.generate(encoder_input) hypos = generator.generate([self.model], sample)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, w1, eos])

View File

@ -85,10 +85,12 @@ class TestSequenceScorer(unittest.TestCase):
task = test_utils.TestTranslationTask.setup_task(args, d, d) task = test_utils.TestTranslationTask.setup_task(args, d, d)
model = task.build_model(args) model = task.build_model(args)
scorer = SequenceScorer([model], task.target_dictionary) scorer = SequenceScorer(task.target_dictionary)
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr): for sample in data_itr:
self.assertHypoTokens(hypos[0], data[id]['target']) hypos = task.inference_step(scorer, [model], sample)
self.assertHypoScore(hypos[0], expected_scores[id]) for id, hypos_id in zip(sample['id'].tolist(), hypos):
self.assertHypoTokens(hypos_id[0], data[id]['target'])
self.assertHypoScore(hypos_id[0], expected_scores[id])
def assertHypoTokens(self, hypo, tokens): def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens)) self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))