Improve tests

This commit is contained in:
Myle Ott 2018-02-03 12:19:22 -05:00
parent c7033ef794
commit 90c2973179
9 changed files with 440 additions and 54 deletions

View File

@ -20,6 +20,10 @@ at::Type& getDataType(const char* dtype) {
return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat);
} else if (strcmp(dtype, "torch.cuda.DoubleTensor") == 0) {
return at::getType(at::kCUDA, at::kDouble);
} else if (strcmp(dtype, "torch.DoubleTensor") == 0) {
return at::getType(at::kCPU, at::kDouble);
} else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
}

View File

@ -28,6 +28,7 @@ class CrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)

View File

@ -68,6 +68,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, None, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)

View File

@ -43,7 +43,7 @@ class FairseqModel(nn.Module):
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out, _ = self.decoder(prev_output_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
return decoder_out
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""

View File

@ -123,7 +123,8 @@ class SequenceGenerator(object):
encoder_outs.append(encoder_out)
# initialize buffers
scores = encoder_outs[0][0].data.new(bsz * beam_size).fill_(0)
scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
@ -133,7 +134,7 @@ class SequenceGenerator(object):
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
worst_finalized = [{'idx': None, 'score': float('Inf')} for i in range(bsz)]
worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
@ -150,7 +151,7 @@ class SequenceGenerator(object):
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent):
def is_finished(sent, step, unfinalized_scores=None):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
@ -158,19 +159,18 @@ class SequenceGenerator(object):
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early:
if self.stop_early or step == maxlen or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
bbsz = sent*beam_size
best_unfinalized_score = scores[bbsz:bbsz+beam_size].max()
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
best_unfinalized_score /= maxlen
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return False
def finalize_hypos(step, bbsz_idx, scores):
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
@ -183,34 +183,51 @@ class SequenceGenerator(object):
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
scores: A vector of the same size as bbsz_idx containing scores
for each hypothesis
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == scores.numel()
norm_scores = scores/math.pow(step+1, self.len_penalty) if self.normalize_scores else scores
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step+2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty
sents_seen = set()
for idx, score in zip(bbsz_idx.cpu(), norm_scores.cpu()):
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
sent = idx // beam_size
sents_seen.add(sent)
def get_hypo():
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
hypo[step] = self.eos
attention = attn[idx, :, 1:step+2].clone()
_, alignment = attention.max(dim=0)
_, alignment = attn_clone[i].max(dim=0)
return {
'tokens': hypo,
'tokens': tokens_clone[i],
'score': score,
'attention': attention,
'attention': attn_clone[i], # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif score > worst_finalized[sent]['score']:
elif not self.stop_early and score > worst_finalized[sent]['score']:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]['idx']
finalized[sent][worst_idx] = get_hypo()
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
@ -223,7 +240,7 @@ class SequenceGenerator(object):
num_finished = 0
for sent in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent):
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
num_finished += 1
return num_finished
@ -243,23 +260,38 @@ class SequenceGenerator(object):
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
else:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1))
probs.add_(scores[:, step-1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores)
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
cand_beams = buffer('cand_beams')
probs.view(bsz, -1).topk(
min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices))
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
if step < maxlen:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
else:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
probs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= finalize_hypos(
step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
@ -271,56 +303,87 @@ class SequenceGenerator(object):
if step >= self.minlen:
eos_bbsz_idx = buffer('eos_bbsz_idx')
# only consider eos when it's among the top beam_size indices
cand_bbsz_idx[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_bbsz_idx)
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx,
)
if eos_bbsz_idx.numel() > 0:
eos_scores = buffer('eos_scores', type_of=scores)
cand_scores[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_scores)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
torch.masked_select(
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
num_remaining_sent -= finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask')
torch.add(eos_mask.type_as(cand_offsets)*cand_size, cand_offsets[:eos_mask.size(1)],
out=active_mask)
torch.add(
eos_mask.type_as(cand_offsets)*cand_size,
cand_offsets[:eos_mask.size(1)],
out=active_mask,
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
active_mask.topk(beam_size, 1, largest=False, out=(_ignore, active_hypos))
active_mask.topk(
k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos),
)
active_bbsz_idx = buffer('active_bbsz_idx')
cand_bbsz_idx.gather(1, active_hypos, out=active_bbsz_idx)
active_scores = cand_scores.gather(1, active_hypos,
out=scores.view(bsz, beam_size))
cand_bbsz_idx.gather(
dim=1, index=active_hypos,
out=active_bbsz_idx,
)
active_scores = cand_scores.gather(
dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# finalize all active hypotheses once we hit maxlen
# finalize_hypos will take care of adding the EOS markers
if step == maxlen:
num_remaining_sent -= finalize_hypos(step, active_bbsz_idx, active_scores)
assert num_remaining_sent == 0
break
# copy tokens for active hypotheses
torch.index_select(tokens[:, :step+1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1])
cand_indices.gather(1, active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step+1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1],
)
if step > 0:
torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
cand_scores, dim=1, index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
# copy attention for active hypotheses
torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2])
torch.index_select(
attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2],
)
# swap buffers
old_tokens = tokens
tokens = tokens_buf
tokens_buf = old_tokens
old_scores = scores
scores = scores_buf
scores_buf = old_scores
old_attn = attn
attn = attn_buf
attn_buf = old_attn

View File

@ -61,6 +61,7 @@ setup(
install_requires=reqs.strip().split('\n'),
packages=find_packages(),
ext_modules=[bleu],
test_suite='tests',
# build and install PyTorch extensions
package_data={

0
tests/__init__.py Normal file
View File

View File

@ -0,0 +1,215 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import unittest
import torch
from torch.autograd import Variable
from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils
class TestSequenceGenerator(unittest.TestCase):
def setUp(self):
# construct dummy dictionary
d = test_utils.dummy_dictionary(vocab_size=2)
self.assertEqual(d.pad(), 1)
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5
# construct source data
self.src_tokens = Variable(torch.LongTensor([
[ self.w1, self.w2, self.eos ],
[ self.w1, self.w2, self.eos ],
]))
self.src_lengths = Variable(torch.LongTensor([2, 2]))
args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[ 0.0, unk, 0.9, 0.1 ], # beam 1
[ 0.0, unk, 0.9, 0.1 ], # beam 2
# sentence 2:
[ 0.0, unk, 0.7, 0.3 ],
[ 0.0, unk, 0.7, 0.3 ],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[ 1.0, unk, 0.0, 0.0 ], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[ 0.0, unk, 0.9, 0.1 ], # w2: 0.1
# sentence 2:
[ 0.25, unk, 0.35, 0.4 ], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[ 0.00, unk, 0.10, 0.9 ], # w2: 0.3
]),
# step 2:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[ 0.0, unk, 0.1, 0.9 ], # w2 w1: 0.1*0.9
[ 0.6, unk, 0.2, 0.2 ], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[ 0.60, unk, 0.4, 0.00 ], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[ 0.01, unk, 0.0, 0.99 ], # w2 w2: 0.3*0.9
]),
# step 3:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[ 1.0, unk, 0.0, 0.0 ], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[ 1.0, unk, 0.0, 0.0 ], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[ 0.1, unk, 0.5, 0.4 ], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[ 1.0, unk, 0.0, 0.0 ], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
self.model = test_utils.TestModel.build_model(args, d, d)
def test_with_normalization(self):
generator = SequenceGenerator([self.model])
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator([self.model], normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], normalized=False)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], normalized=False)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], normalized=False)
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator([self.model], len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], lenpen=lenpen)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator([self.model], len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w1, eos])
self.assertHypoScore(hypos[0][1], [0.9, 1.0], lenpen=lenpen)
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
generator = SequenceGenerator([self.model], maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.1, 0.6])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w2, w2, eos])
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self):
generator = SequenceGenerator([self.model], stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w2, w2, w2, w2, eos])
self.assertHypoScore(hypos[1][0], [0.3, 0.9, 0.99, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == '__main__':
unittest.main()

101
tests/utils.py Normal file
View File

@ -0,0 +1,101 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
from fairseq import dictionary
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
)
def dummy_dictionary(vocab_size, prefix='token_'):
d = dictionary.Dictionary()
for i in range(vocab_size):
token = prefix + str(i)
d.add_symbol(token)
d.finalize()
return d
class TestModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def build_model(cls, args, src_dict, dst_dict):
encoder = TestEncoder(args, src_dict)
decoder = TestIncrementalDecoder(args, dst_dict)
return cls(encoder, decoder)
class TestEncoder(FairseqEncoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
def forward(self, src_tokens, src_lengths):
return src_tokens
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
assert hasattr(args, 'beam_probs')
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, prev_output_tokens, encoder_out):
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if self._is_incremental_eval:
# cache step number
step = self.get_incremental_state('step')
if step is None:
step = 0
self.set_incremental_state('step', step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
# define output in terms of raw probs
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, src_len, tgt_len)
return Variable(probs), Variable(attn)
def get_normalized_probs(self, net_output, log_probs):
# the decoder returns probabilities directly
if log_probs:
return net_output.log()
else:
return net_output
def max_positions(self):
return self.args.max_decoder_positions