Added constrained decoding (#1536) (#2402)

Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?

This PR implements constrained decoding ([Hokamp & Liu, 2017](https://www.aclweb.org/anthology/P17-1141/); [Post & Vilar, 2018](https://www.aclweb.org/anthology/N18-1119/)) with vectorization for batching ([Hu et al., 2019](https://www.aclweb.org/anthology/N19-1090/)). In addition, it add *ordered constraints*, where the constraints are generated on the target side in order, with zero or more unconstrained tokens in between. This variant allows for optimizations that increase speed and BLEU scores (when testing with random scraps from the references).

### Usage and quick start

It works with `fairseq-interactive` via a new command-line option: `fairseq-interactive --constraints [ordered,unordered]`, defaulting to `ordered` if nothing is provided. When active, it will split lines from STDIN on `\t`, with separate constraints each separated by a tab. For example (after downloading the [Fairseq WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md)):

```bash
echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\tinfluence" \
  | [normalize.py](https://gist.github.com/mjpost/4c54446b7030d7c64b57461d27090650) \
  | [tok.py](https://gist.github.com/mjpost/ed7456f6a987c533102fc121678ed302) \
  | PYTHONPATH=$HOME/code/fairseq-constraints fairseq-interactive $modeldir \
  --bpe fastbpe \
  --bpe-codes $modeldir/bpecodes \
  --constraints \
  --constraints-both
  -s de -t en \
  --path $modeldir/model1.pt \
  --max-tokens 1000 \
  --beam 5 \
```

Adding the `--constraints-both` option causes it to batch-decode the input sentence both with and without the constraints. When run with the Fairseq WMT19 German--English model, the following results are produced (here run on a CPU, don't be alarmed by the times!)

```text
S-0     Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
W-0     1.844   seconds
C-0     hard
C-0     influence
H-0     -1.5333266258239746     Mach@@ ine trans@@ lation is hard to influence .
D-0     -1.5333266258239746     Machine translation is hard to influence .
P-0     -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511
S-0     Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
W-0     1.844   seconds
H-0     -0.3731671869754791     Mach@@ ine trans@@ lation is difficult to control .
D-0     -0.3731671869754791     Machine translation is difficult to control .
P-0     -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.1430 -0.1665 -0.8482 -0.1678 -0.1514
2020-07-31 12:17:55 | INFO | fairseq_cli.interactive | Total time: 12.803 seconds; translation time: 3.688
```

Note the new tags present in the output:

* `C-#` records active constraints (after applying preprocessing) for a sentence
* `W-#` reports the sentence-level translation time (a useful unrelated feature I hope you'll accept)

Some unit tests are written (`fairseq/test_constraints.py`) but not yet integrated. Advice here on where to place this is welcome. I also have not run this through lint; if someone can tell me the command to run, I'd appreciate it.

### Implementation notes

This is largely self-contained, implemented in a new `LexicallyConstrainedBeamSearch` class in `search.py`. It does require a few minimal hooks from `_generate()` in `sequence_generator.py`, to ensure that constraints are updated at each timestep. (Edit: most changes in that file are documentation clarifications, corrections, and updates). Unconstrained sentences that are intermingled with constrained ones will not incur any time penalty, so long as they do not occur in the same batch.

Addresses https://github.com/pytorch/fairseq/issues/1536.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/2402

Reviewed By: alexeib

Differential Revision: D23188945

Pulled By: myleott

fbshipit-source-id: 9f5ed855f7a1dcf535b091c0ccf98b07fb9cbdd6
This commit is contained in:
Matt Post 2020-08-20 11:58:27 -07:00 committed by Facebook GitHub Bot
parent adbd89fd4b
commit bd1b35d9b7
25 changed files with 1598 additions and 44 deletions

View File

@ -32,6 +32,7 @@ We provide reference implementations of various sequence modeling papers:
- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/transformer_lm/README.md)
- [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
@ -51,6 +52,7 @@ We provide reference implementations of various sequence modeling papers:
### What's New:
- August 2020: [Lexically constrained decoding(examples/constrained_decoding/README.md)
- August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
- July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
- May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)

View File

@ -0,0 +1,124 @@
# (Vectorized) Lexically constrained decoding with dynamic beam allocation
This page provides instructions for how to use lexically constrained decoding in Fairseq.
Fairseq implements the code described in the following papers:
* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/)
* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/)
## Quick start
Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`.
Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens)
is a separate field.
The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md),
translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints
"hard" and "to influence".
echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \
| normalize.py | tok.py \
| fairseq-interactive /path/to/model \
--path /path/to/model/model1.pt \
--bpe fastbpe \
--bpe-codes /path/to/model/bpecodes \
--constraints \
-s de -t en \
--beam 10
(tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing).
This will generate the following output:
[snip]
S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
W-0 1.844 seconds
C-0 hard
C-0 influence
H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence .
D-0 -1.5333266258239746 Machine translation is hard to influence .
P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511
By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated
between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`.
Note that you may want to use a larger beam.
## Implementation details
The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance.
This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints
provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`:
* OrderedConstraintState: assumes the C input constraints will be generated in the provided order
* UnorderedConstraintState: tries to apply C (phrasal) constraints in all C! orders
## Differences from Sockeye
There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints).
* Generating constraints in the order supplied (the default option here) is not available in Sockeye.
* Due to an improved beam allocation method, there is no need to prune the beam.
* Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient.
* [The extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged
into the main branch.
* Sockeye 2, released in July 2020, no longer supports constrained decoding.
## Citation
The paper first describing lexical constraints for seq2seq decoding is:
```bibtex
@inproceedings{hokamp-liu-2017-lexically,
title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search",
author = "Hokamp, Chris and
Liu, Qun",
booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = jul,
year = "2017",
address = "Vancouver, Canada",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/P17-1141",
doi = "10.18653/v1/P17-1141",
pages = "1535--1546",
}
```
The fairseq implementation uses the extensions described in
```bibtex
@inproceedings{post-vilar-2018-fast,
title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation",
author = "Post, Matt and
Vilar, David",
booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)",
month = jun,
year = "2018",
address = "New Orleans, Louisiana",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/N18-1119",
doi = "10.18653/v1/N18-1119",
pages = "1314--1324",
}
```
and
```bibtex
@inproceedings{hu-etal-2019-improved,
title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting",
author = "Hu, J. Edward and
Khayrallah, Huda and
Culkin, Ryan and
Xia, Patrick and
Chen, Tongfei and
Post, Matt and
Van Durme, Benjamin",
booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
month = jun,
year = "2019",
address = "Minneapolis, Minnesota",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/N19-1090",
doi = "10.18653/v1/N19-1090",
pages = "839--850",
}
```

View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
from sacremoses.normalize import MosesPunctNormalizer
def main(args):
normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn)
for line in sys.stdin:
print(normalizer.normalize(line.rstrip()), flush=True)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en')
parser.add_argument('--penn', '-p', action='store_true')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import sacremoses
def main(args):
"""Tokenizes, preserving tabs"""
mt = sacremoses.MosesTokenizer(lang=args.lang)
def tok(s):
return mt.tokenize(s, return_str=True)
for line in sys.stdin:
parts = list(map(tok, line.split("\t")))
print(*parts, sep="\t", flush=True)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en')
parser.add_argument('--penn', '-p', action='store_true')
parser.add_argument('--fields', '-f', help="fields to tokenize")
args = parser.parse_args()
main(args)

View File

@ -201,13 +201,14 @@ class TranslationMoETask(TranslationTask):
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None):
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None):
expert = expert or self.args.gen_expert
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=self.expert_index(expert),
)

View File

@ -22,6 +22,7 @@ import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.scoring # noqa
import fairseq.tasks # noqa
import fairseq.token_generation_constraints # noqa
import fairseq.benchmark # noqa
import fairseq.model_parallel # noqa

View File

@ -21,10 +21,10 @@ class Dictionary(object):
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
bos="<s>",
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos

View File

@ -133,6 +133,16 @@ def collate(
batch['alignments'] = alignments
batch['align_weights'] = align_weights
if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
# the length of the longest sample.
lens = [sample.get("constraints").size(0) for sample in samples]
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
constraints[i, 0:lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints
return batch
@ -161,6 +171,8 @@ class LanguagePairDataset(FairseqDataset):
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
@ -180,6 +192,7 @@ class LanguagePairDataset(FairseqDataset):
shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False, eos=None,
num_buckets=0,
src_lang_id=None,
@ -206,6 +219,7 @@ class LanguagePairDataset(FairseqDataset):
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
self.eos = (eos if eos is not None else src_dict.eos())
self.src_lang_id = src_lang_id
@ -279,6 +293,8 @@ class LanguagePairDataset(FairseqDataset):
}
if self.align_dataset is not None:
example['alignment'] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example
def __len__(self):

View File

@ -105,7 +105,9 @@ class IterativeRefinementGenerator(object):
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None):
def generate(self, models, sample, prefix_tokens=None, constraints=None):
if constraints is not None:
raise NotImplementedError("Constrained decoding with the IterativeRefinementGenerator is not supported")
# TODO: iterative refinement generator does not support ensemble for now.
if not self.retain_dropout:

View File

@ -609,6 +609,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"],
help='enables lexically constrained decoding')
group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',

View File

@ -10,6 +10,8 @@ import torch
import torch.nn as nn
from torch import Tensor
from fairseq.token_generation_constraints import ConstraintState, UnorderedConstraintState, OrderedConstraintState
class Search(nn.Module):
def __init__(self, tgt_dict):
@ -19,6 +21,7 @@ class Search(nn.Module):
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.src_lengths = torch.tensor(-1)
self.supports_constraints = False
def step(self, step, lprobs, scores):
"""Take a single search step.
@ -46,10 +49,49 @@ class Search(nn.Module):
def set_src_lengths(self, src_lengths):
self.src_lengths = src_lengths
@torch.jit.export
def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
"""Initialize constraint states for constrained decoding (if supported).
Args:
batch_constraints: (torch.Tensor, optional)
the list of constraints, in packed form
beam_size: (int)
the beam size
Returns:
*encoder_out* rearranged according to *new_order*
"""
pass
def prune_sentences(self, batch_idxs: Tensor):
"""
Removes constraint states for completed sentences (if supported).
This is called from sequence_generator._generate() when sentences are
deleted from the batch.
Args:
batch_idxs: Indices of *sentences* whose constraint state should be *kept*.
"""
pass
def update_constraints(self, active_hypos: Tensor):
"""
Updates the constraint states by selecting the beam items that are retained.
This is called at each time step of sequence_generator._generate() when
the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size.
Args:
active_hypos: (batch size, beam size)
list of integers denoting, for each sentence, which beam candidate items
should be kept.
"""
pass
class BeamSearch(Search):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
self.constraint_states = None
@torch.jit.export
def step(self, step: int, lprobs, scores: Optional[Tensor]):
@ -75,11 +117,306 @@ class BeamSearch(Search):
)
scores_buf = top_prediction[0]
indices_buf = top_prediction[1]
# Project back into relative indices and beams
beams_buf = indices_buf // vocab_size
indices_buf = indices_buf.fmod(vocab_size)
# At this point, beams_buf and indices_buf are single-dim and contain relative indices
return scores_buf, indices_buf, beams_buf
class LexicallyConstrainedBeamSearch(Search):
"""Implements lexically constrained beam search as described in
Fast Lexically Constrained Decoding with Dynamic Beam
Allocation for Neural Machine Translation. Post & Vilar,
NAACL 2018. https://www.aclweb.org/anthology/N18-1119/
and
Improved Lexically Constrained Decoding for Translation and
Monolingual Rewriting. Hu et al, NAACL
2019. https://www.aclweb.org/anthology/N19-1090/
This is accomplished by maintaining, for each beam hypothesis, a
ConstraintState object (see constraints.py) that tracks which
constraints have been generated and using this information to
shape the beam for each input sentence.
"""
def __init__(self, tgt_dict, representation):
super().__init__(tgt_dict)
self.representation = representation
self.vocab_size = len(tgt_dict)
self.num_cands = 0
self.supports_constraints = True
@torch.jit.export
def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
self.constraint_states = []
for constraint_tensor in batch_constraints:
if self.representation == "ordered":
constraint_state = OrderedConstraintState.create(constraint_tensor)
elif self.representation == "unordered":
constraint_state = UnorderedConstraintState.create(constraint_tensor)
self.constraint_states.append([constraint_state for i in range(beam_size)])
@torch.jit.export
def prune_sentences(self, batch_idxs: Tensor):
self.constraint_states = [self.constraint_states[i] for i in batch_idxs.tolist()]
@torch.jit.export
def update_constraints(self, active_hypos: Tensor):
if self.constraint_states:
batch_size = active_hypos.size(0)
for sentid in range(batch_size):
self.constraint_states[sentid] = [self.constraint_states[sentid][i] for i in active_hypos[sentid]]
@torch.jit.export
def step(self, step: int, lprobs: Tensor, scores: Optional[Tensor]):
"""
A constrained step builds a large candidates list from the following:
- the top 2 * {beam_size} items over the whole beam
- for each item in the beam
- the top {each_k} (default 1)
- all next constraints
We then compute the constrained state of each beam item, and assign
stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so
on. We then sort by (stripe, score), and truncate the list at
2 * beam size.
Args:
step: the decoder step
lprobs: (batch size, beam size, target vocab)
the target-vocab distributions for each item in the beam.
Retrun: A tuple of (scores, indices, beams, constraints) where:
scores: (batch, output beam size)
the scores of the chosen elements
indices: (batch, output beam size)
the target vocab indices of the chosen elements
beams: (batch, output beam size)
the 0-indexed hypothesis ids of the chosen elements
constraints: (batch, output beam size)
the new constraint states
"""
each_k = 1
device = lprobs.device
batch_size, beam_size, vocab_size = lprobs.size()
self.num_cands = min(
# Just take the k-best. We'll get another k from the 1-best from each
# row, plus more from the constraints
beam_size * 2,
lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad
)
# STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items
constraint_states = self.constraint_states
if constraint_states and step > 0:
not_finished_indices = []
for sentno, sent_constraints in enumerate(constraint_states):
for beamno, state in enumerate(sent_constraints):
index = sentno * beam_size + beamno
if not state.finished:
not_finished_indices.append(index)
not_finished_indices = torch.tensor(not_finished_indices)
if not_finished_indices.numel() > 0:
lprobs.view(batch_size * beam_size, -1)[not_finished_indices, self.eos] = -math.inf
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam entry for each batch item
lprobs = lprobs[:, ::beam_size, :].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
assert scores is not None
lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
top_prediction = torch.topk(
lprobs.view(batch_size, -1),
self.num_cands,
)
scores_buf, indices_buf = top_prediction
# Project back into relative indices and beams
beams_buf = indices_buf // vocab_size
indices_buf = indices_buf.fmod(vocab_size)
# Short circuit if there are no constraints in this batch
if not constraint_states:
return scores_buf, indices_buf, beams_buf
# STEP 1: get top-1 from each hypothesis across all sentences in the batch
if step > 0:
top_scores, top_indices = torch.topk(
lprobs.view(batch_size * beam_size, -1),
k=each_k,
dim=1,
)
top_scores = top_scores.view(batch_size, -1)
top_indices = top_indices.view(batch_size, -1)
scores_buf = torch.cat((scores_buf, top_scores), dim=1)
indices_buf = torch.cat((indices_buf, top_indices), dim=1)
new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1)
beams_buf = torch.cat((beams_buf, new_beams), dim=1)
# Now, process sentences in the batch one by one.
new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device)
new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
for sentno, states in enumerate(constraint_states):
scores, indices, beams, new_states = self.step_sentence(step,
sentno,
lprobs[sentno],
constraint_states[sentno],
beams_buf[sentno].clone(),
indices_buf[sentno].clone(),
scores_buf[sentno].clone())
new_scores_buf[sentno] = scores
new_indices_buf[sentno] = indices
new_beams_buf[sentno] = beams
self.constraint_states[sentno] = new_states
return new_scores_buf, new_indices_buf, new_beams_buf
@torch.jit.export
def step_sentence(self,
step: int,
sentno: int,
lprobs: Tensor,
constraint_states: List[List[ConstraintState]],
beams_buf: Tensor,
indices_buf: Tensor,
scores_buf: Tensor):
"""Does per-sentence processing. Adds all constraints for each
hypothesis to the list of candidates; then removes duplicates,
sorts, and dynamically stripes across the banks. All tensor inputs
are collapsed to those pertaining to a single input sentence.
"""
device = lprobs.device
# STEP 2: Add all constraints for each beam item
for beamno, state in enumerate(constraint_states):
next_tokens = torch.tensor(list(state.next_tokens()), device=device).long()
if next_tokens.numel() != 0:
indices_buf = torch.cat((indices_buf, next_tokens))
next_beams = torch.tensor(beamno, device=device).repeat(next_tokens.size(0)).long()
beams_buf = torch.cat((beams_buf, next_beams))
next_values = lprobs[beamno].take(next_tokens.view(-1))
scores_buf = torch.cat((scores_buf, next_values))
# At the 0th time step, there is just one beam item
if step == 0:
break
# STEP 3: Compute the "bank" for each candidate. This is the
# number of constraints it's generated. We need this so that
# we can do round-robin allocation of the beam across these
# banks. If C is the number of constraints, we select the best
# item in bank C, then the best in bank C-1, etc, followed by
# the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so
# on, until the maximum beam size. We accomplish this by
# creating a sort key and striping across the banks.
# Compute the new states for all candidates
cands_size = indices_buf.size(0)
constraint_states = [constraint_states[beams_buf[i]].advance(indices_buf[i])
for i in range(cands_size)]
banks = torch.tensor([state.bank for state in constraint_states], device=device)
# STEP 4: Sort
num_constraint_tokens = len(state.tokens)
# Sort by keys (bank, score) (i.e., sort banks together, and scores
# within banks). AFAIK pytorch doesn't support either stable sort or
# multi-key sorting, so we have to hack this.
MAX_SCORE = -100
sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf
sort_values, sort_indices = sort_key.sort(dim=0, descending=True)
scores_buf = scores_buf[sort_indices]
indices_buf = indices_buf[sort_indices]
beams_buf = beams_buf[sort_indices]
banks = banks[sort_indices]
# Sort the constraints to follow suit
constraint_states = [constraint_states[i] for i in sort_indices]
# STEP 5: Remove duplicates. The topk calls (overall and
# per-row) plus the per-row generation of constraints will
# produce duplicates. Here we remove them.
def roll(t):
"""Rolls a 1d tensor left by 1.
[0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3]
"""
return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0)
# We map candidates (beam, token_id) to a single dimension.
# This is then shifted by 1. We can then easily identify
# duplicates and create a mask that identifies unique
# extensions.
uniques_mask = (beams_buf * (self.vocab_size + 1) + indices_buf)
uniques_mask = roll(uniques_mask) != uniques_mask
# Use the mask to pare down the data structures
scores_buf = torch.masked_select(scores_buf, uniques_mask)
indices_buf = torch.masked_select(indices_buf, uniques_mask)
beams_buf = torch.masked_select(beams_buf, uniques_mask)
banks = torch.masked_select(banks, uniques_mask)
i = 1
for mask in uniques_mask[1:]:
if not mask:
constraint_states.pop(i)
i += mask
# STEP 6: Assign IDs round-robin across banks, sort, and
# truncate. Now that the candidates are sorted by (bank,
# score) and uniqed, we dynamically allocate the {beam_size}
# beam by striping across the candidates. These stripes will
# be used as sort keys to do round-robin selection. This is
# accomplished in a single pass with offsets. Sorting by
# highest-banks (furthest-along hypotheses) first ensures
# progress through the constraints.
#
# e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0
# OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1
# NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7
# = 0 5 10 1 6 11 13 2 7 12 3 8
#
# Sorting by this then gives the following banks:
#
# 3 2 1 0 3 2 1 0 3 2 1 2
#
# We'll take the top {beam_size} of these.
stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)]
stripes = torch.zeros_like(banks)
cur_bank_count = -1
cur_bank = banks[0]
for i, bank in enumerate(banks):
if bank != cur_bank:
cur_bank_count = 0
cur_bank = bank
else:
cur_bank_count += 1
stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count]
# STEP 7: Sort by the stripes values
sort_values, sort_indices = stripes.sort(dim=0)
scores_buf = scores_buf[sort_indices]
indices_buf = indices_buf[sort_indices]
beams_buf = beams_buf[sort_indices]
constraint_states = [constraint_states[i] for i in sort_indices]
# STEP 8: Truncate to the candidates size!
scores_buf = scores_buf[:self.num_cands]
indices_buf = indices_buf[:self.num_cands]
beams_buf = beams_buf[:self.num_cands]
return scores_buf, indices_buf, beams_buf, constraint_states
class LengthConstrainedBeamSearch(Search):
def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
super().__init__(tgt_dict)

View File

@ -61,6 +61,7 @@ class SequenceGenerator(nn.Module):
self.model = models
else:
self.model = EnsembleModel(models)
self.tgt_dict = tgt_dict
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos() if eos is None else eos
@ -113,7 +114,7 @@ class SequenceGenerator(nn.Module):
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
return self._generate(sample, prefix_tokens, bos_token)
return self._generate(sample, prefix_tokens, bos_token=bos_token)
# TODO(myleott): unused, deprecate after pytorch-translate migration
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
@ -157,6 +158,8 @@ class SequenceGenerator(nn.Module):
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
constraints (torch.LongTensor, optional): force decoder to include
the list of constraints
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
@ -166,6 +169,7 @@ class SequenceGenerator(nn.Module):
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
incremental_states = torch.jit.annotate(
@ -192,10 +196,15 @@ class SequenceGenerator(nn.Module):
raise Exception('expected src_tokens or source in net input')
# bsz: total number of sentences in beam
input_size = src_tokens.size()
bsz, src_len = input_size[0], input_size[1]
bsz, src_len = src_tokens.size()
beam_size = self.beam_size
if constraints is not None and not self.search.supports_constraints:
raise NotImplementedError("Target-side constraints were provided, but search method doesn't support them")
# Initialize constraints, when active
self.search.init_constraints(constraints, beam_size)
max_len: int = -1
if self.match_source_len:
max_len = src_lengths.max().item()
@ -221,7 +230,7 @@ class SequenceGenerator(nn.Module):
# initialize buffers
scores = (
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
) # +1 for eos; pad is never choosed for scoring
) # +1 for eos; pad is never chosen for scoring
tokens = (
torch.zeros(bsz * beam_size, max_len + 2)
.to(src_tokens)
@ -327,6 +336,7 @@ class SequenceGenerator(nn.Module):
if self.no_repeat_ngram_size > 0:
lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step)
# Shape: (batch, cand_size)
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
@ -339,10 +349,13 @@ class SequenceGenerator(nn.Module):
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
# Shape of eos_mask: (batch size, beam size)
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
# only consider eos when it's among the top beam_size indices
# Now we know what beam item(s) to finish
# Shape: 1d list of absolute-numbered
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
@ -352,6 +365,7 @@ class SequenceGenerator(nn.Module):
eos_scores = torch.masked_select(
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents = self.finalize_hypos(
step,
eos_bbsz_idx,
@ -372,6 +386,8 @@ class SequenceGenerator(nn.Module):
break
assert step < max_len
# Remove finalized sentences (ones for which {beam_size}
# finished hypotheses have been generated) from the batch.
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
@ -381,6 +397,9 @@ class SequenceGenerator(nn.Module):
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask)
# Choose the subset of the hypothesized constraints that will continue
self.search.prune_sentences(batch_idxs)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
@ -402,7 +421,8 @@ class SequenceGenerator(nn.Module):
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# Set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
@ -414,16 +434,24 @@ class SequenceGenerator(nn.Module):
cand_offsets[: eos_mask.size(1)],
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
# get the top beam_size active hypotheses, which are just
# the hypos with the smallest values in active_mask.
# {active_hypos} indicates which {beam_size} hypotheses
# from the list of {2 * beam_size} candidates were
# selected. Shapes: (batch size, beam size)
new_cands_to_ignore, active_hypos = torch.topk(
active_mask, k=beam_size, dim=1, largest=False
)
# update cands_to_ignore to ignore any finalized hypos
# update cands_to_ignore to ignore any finalized hypos.
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
# Make sure there is at least one active item for each sentence in the batch.
assert (~cands_to_ignore).any(dim=1).all()
# update cands_to_ignore to ignore any finalized hypos
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
# can be selected more than once).
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
@ -431,9 +459,12 @@ class SequenceGenerator(nn.Module):
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
# Set the tokens for each beam (can select the same row more than once)
tokens[:, : step + 1] = torch.index_select(
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
)
# Select the next token for each of them
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
cand_indices, dim=1, index=active_hypos
)
@ -445,6 +476,9 @@ class SequenceGenerator(nn.Module):
cand_scores, dim=1, index=active_hypos
)
# Update constraints based on which candidates were selected for the next beam
self.search.update_constraints(active_hypos)
# copy attention for active hypotheses
if attn is not None:
attn[:, :, : step + 2] = torch.index_select(
@ -517,13 +551,18 @@ class SequenceGenerator(nn.Module):
max_len: int,
):
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
Returns number of sentences being finalized.
A sentence is finalized when {beam_size} finished items have been collected for it.
Returns number of sentences (not beam items) being finalized.
These will be removed from the batch and not processed further.
Args:
bbsz_idx (Tensor):
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
# clone relevant token and attention tensors.
# tokens is (batch * beam, max_len). So the index_select
# gets the newly EOS rows, then selects cols 1..{step + 2}
tokens_clone = tokens.index_select(0, bbsz_idx)[
:, 1 : step + 2
] # skip the first index, which is EOS
@ -545,6 +584,10 @@ class SequenceGenerator(nn.Module):
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
# cum_unfin records which sentences in the batch are finished.
# It helps match indexing between (a) the original sentences
# in the batch and (b) the current, possibly-reduced set of
# sentences.
cum_unfin: List[int] = []
prev = 0
for f in finished:
@ -554,12 +597,22 @@ class SequenceGenerator(nn.Module):
cum_unfin.append(prev)
# set() is not supported in script export
# The keys here are of the form "{sent}_{unfin_idx}", where
# "unfin_idx" is the index in the current (possibly reduced)
# list of sentences, and "sent" is the index in the original,
# unreduced batch
sents_seen: Dict[str, Optional[Tensor]] = {}
# For every finished beam item
for i in range(bbsz_idx.size()[0]):
idx = bbsz_idx[i]
score = eos_scores[i]
# sentence index in the current (possibly reduced) batch
unfin_idx = idx // beam_size
# sentence index in the original (unreduced) batch
sent = unfin_idx + cum_unfin[unfin_idx]
# print(f"{step} FINISHED {idx} {score} {sent}={unfin_idx} {cum_unfin}")
# Cannot create dict for key type '(int, int)' in torchscript.
# The workaround is to cast int to string
seen = str(sent.item()) + "_" + str(unfin_idx.item())
@ -569,12 +622,15 @@ class SequenceGenerator(nn.Module):
if self.match_source_len and step > src_lengths[unfin_idx]:
score = torch.tensor(-math.inf).to(score)
# An input sentence (among those in a batch) is finished when
# beam_size hypotheses have been collected for it
if len(finalized[sent]) < beam_size:
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i]
else:
hypo_attn = torch.empty(0)
finalized[sent].append(
{
"tokens": tokens_clone[i],
@ -586,15 +642,18 @@ class SequenceGenerator(nn.Module):
)
newly_finished: List[int] = []
for seen in sents_seen.keys():
# check termination conditions for this sentence
sent: int = int(float(seen.split("_")[0]))
unfin_idx: int = int(float(seen.split("_")[1]))
if not finished[sent] and self.is_finished(
step, unfin_idx, max_len, len(finalized[sent]), beam_size
):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
def is_finished(
@ -606,9 +665,9 @@ class SequenceGenerator(nn.Module):
beam_size: int,
):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
Check whether decoding for a sentence is finished, which
occurs when the list of finalized sentences has reached the
beam size, or when we reach the maximum length.
"""
assert finalized_sent_len <= beam_size
if finalized_sent_len == beam_size or step == max_len:

View File

@ -291,6 +291,7 @@ class FairseqTask(object):
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
match_source_len = getattr(args, "match_source_len", False)
diversity_rate = getattr(args, "diversity_rate", -1)
constrained = getattr(args, "constraints", False)
if (
sum(
int(cond)
@ -330,6 +331,8 @@ class FairseqTask(object):
search_strategy = search.DiverseSiblingsSearch(
self.target_dictionary, diversity_rate
)
elif constrained:
search_strategy = search.LexicallyConstrainedBeamSearch(self.target_dictionary, args.constraints)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
@ -395,9 +398,9 @@ class FairseqTask(object):
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
return generator.generate(models, sample, prefix_tokens=prefix_tokens, constraints=constraints)
def begin_epoch(self, epoch, model):
"""Hook function called before the start of each epoch."""

View File

@ -258,7 +258,7 @@ class LanguageModelingTask(FairseqTask):
sizes=[np.array(src_lengths)],
)
def inference_step(self, generator, models, sample, prefix_tokens=None):
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
with torch.no_grad():
# Generation will always be conditioned on bos_token
if getattr(self.args, "add_bos_token", False):
@ -266,6 +266,9 @@ class LanguageModelingTask(FairseqTask):
else:
bos_token = self.source_dictionary.eos()
if constraints is not None:
raise NotImplementedError("Constrained decoding with the language_modeling task is not supported")
# SequenceGenerator doesn't use src_tokens directly, we need to
# pass the `prefix_tokens` argument instead
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():

View File

@ -221,7 +221,10 @@ class MultilingualTranslationTask(FairseqTask):
eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported")
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets(
OrderedDict([(
@ -312,7 +315,7 @@ class MultilingualTranslationTask(FairseqTask):
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
with torch.no_grad():
if self.args.decoder_langtok:
bos_token = _lang_token_index(self.target_dictionary, self.args.target_lang)
@ -322,6 +325,7 @@ class MultilingualTranslationTask(FairseqTask):
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=bos_token,
)

View File

@ -270,8 +270,10 @@ class TranslationTask(FairseqTask):
shuffle=(split != 'test'),
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints)
def build_model(self, args):
model = super().build_model(args)
@ -377,7 +379,7 @@ class TranslationTask(FairseqTask):
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample, None)
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]['tokens']))

View File

@ -109,11 +109,13 @@ class TranslationFromPretrainedBARTTask(TranslationTask):
eos=self.tgt_dict.index('[{}]'.format(self.args.target_lang))
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
src_lang_id = self.source_dictionary.index('[{}]'.format(self.args.source_lang))
source_tokens = []
for s_t in src_tokens:
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
source_tokens.append(s_t)
dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary)
dataset = LanguagePairDataset(source_tokens, src_lengths, self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints)
return dataset

View File

@ -141,7 +141,11 @@ class TranslationLevenshteinTask(TranslationTask):
adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
retain_history=getattr(args, 'retain_iter_history', False))
def build_dataset_for_inference(self, src_tokens, src_lengths):
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/
raise NotImplementedError("Constrained decoding with the translation_lev task is not supported")
return LanguagePairDataset(
src_tokens, src_lengths, self.source_dictionary, append_bos=True
)

View File

@ -126,7 +126,10 @@ class TranslationMultiSimpleEpochTask(FairseqTask):
epoch=epoch, combine=combine, shard_epoch=shard_epoch, **kwargs
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError("Constrained decoding with the multilingual_translation task is not supported")
src_data = ListDataset(src_tokens, src_lengths)
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main']
@ -173,7 +176,7 @@ class TranslationMultiSimpleEpochTask(FairseqTask):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
with torch.no_grad():
_, tgt_langtok_spec = self.args.langtoks['main']
if not self.args.lang_tok_replacing_bos_eos:
@ -188,6 +191,7 @@ class TranslationMultiSimpleEpochTask(FairseqTask):
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)
else:
return generator.generate(

View File

@ -0,0 +1,500 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Implements tracking of constraints for a beam item.
A list of constraints is given as a list of one or more token
sequences, each of length at least one token. For example, for an input sentence
> Die maschinelle Übersetzung ist schwer zu kontrollieren.
We could have the constraints:
* to influence
* hard
There are two implementations:
* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints.
* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints.
The difference is that in the first, the constraints are assumed to be
in order; the algorithm will permit zero or more tokens between them.
In the second, the constraints are not ordered, so many orderings will
be explored.
The same sequence can be present any number of times, and will appear
that many times in the output.
"""
import torch
from collections import Counter
from typing import Tuple, List, Optional, Set
class ConstraintState:
def __init__(self):
pass
def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor:
"""Takes a list of list of constraints in tensor form (a list of
tensor constraints for each sentence) and transforms it into a
packed Tensor. For example, here is a batch of size 3 with 3, 0,
and 1 constraints:
[ [ [3 1 2], [3], [4 5 6 7], ]
[],
[ [1 8 9 10 1 4 11 12], ]
]
Its corresponding packed structure is:
[ [ 3 3 1 2 0 3 0 4 5 6 7 0],
[ 0 0 0 0 0 0 0 0 0 0 0 0],
[ 1 1 8 9 10 1 4 11 12 0 0 0] ]
The packed tensor has shape (batch size, maxlen), where
maxlen is defined below. Each row contains concatenated
constraint tokens for that sentence, with 0 appended after
each constraint. The first item in each row is the number
of constraints for that sentence. So maxlen is the maximum
of
(number of constraints) + (sum length of constraints) + 1.
across all sentences in the batch.
"""
# The maximum word length of concatenated constraints for any sentence
max_constraints_len = 1
for sentence_constraints in batch_constraints:
if len(sentence_constraints):
# number of constraints, plus sum of constrain lens, plus a zero after each
constraints_len = 1 + sum([c.size(0) for c in sentence_constraints]) + len(sentence_constraints)
max_constraints_len = max(max_constraints_len, constraints_len)
batch_size = len(batch_constraints)
constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long()
for i, sentence_constraints in enumerate(batch_constraints):
constraints_tensor[i, 0] = len(sentence_constraints)
offset = 1
for j, constraint in enumerate(sentence_constraints):
this_len = constraint.size(0)
constraints_tensor[i, offset:offset+this_len] = constraint
offset += this_len + 1
return constraints_tensor.long()
def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Transforms *one row* of a packed constraint tensor (e.g., for one
sentence in the batch) into a list of constraint tensors.
"""
constraint_list = []
num_constraints = constraint_tensor[0]
constraints = constraint_tensor.tolist()
offset = 1
for i in range(num_constraints):
where = constraints.index(0, offset)
constraint_list.append(constraint_tensor[offset:where])
offset = where + 1
return constraint_list
class ConstraintNode:
"""
Represents a node in a trie managing unordered constraints.
"""
def __init__(self, token: int = None, parent=None):
# The token associate with this node (None for the root)
self.token = int(token) if token is not None else None
# The parent (None at the root)
self.parent = parent
# Whether this node is a completed constraint
self.terminal = 0
# List of child nodes
self.children = {}
# The cumulative number of constraints from this point in the
# trie forward
self.num_constraints = 0
@property
def id(self):
return self.token
def __str__(self):
term = self.terminal != 0
return f"[{self.token}].{term}#{self.num_constraints}"
def __getitem__(self, key: int):
return self.children.get(key, None)
def next_tokens(self) -> Set[int]:
"""The set of child labels."""
return set(self.children.keys())
@staticmethod
def create(constraints: List[List[int]]):
root = ConstraintNode()
for sequence in constraints:
root.add_sequence(sequence)
return root
@staticmethod
def print_graph(node: "ConstraintNode"):
if len(node.children) == 0:
return str(node)
else:
s = f"({node}"
for child in node.children.values():
s += " " + ConstraintNode.print_graph(child)
s += ")"
return s
def token_counts(self) -> Counter:
"""Returns a counter of the number of times each token is used
in a constraint.
"""
token_counts = Counter()
kids = list(self.children.values())
while len(kids) > 0:
kid = kids.pop()
token_counts[kid.id] += kid.num_constraints
kids += list(kid.children.values())
return token_counts
def tokens(self) -> Set[int]:
"""Returns the set of tokens in constraints."""
return set(self.token_counts().keys())
def add_sequence(self, sequence: List[int]):
"""Adds a constraint, represented as a list of integers, to
the trie."""
assert len(sequence) > 0
token = int(sequence[0])
if token not in self.children:
self.children[token] = ConstraintNode(token, parent=self)
node = self.children[token]
if len(sequence) == 1:
node.terminal += 1
node.num_constraints += 1
parent = node.parent
while parent is not None:
parent.num_constraints += 1
parent = parent.parent
else:
node.add_sequence(sequence[1:])
class UnorderedConstraintState(ConstraintState):
"""
Records progress through the set of constraints for each item in the beam
using a trie.
"""
def __init__(self,
node: ConstraintNode,
copy_from: "ConstraintState" = None):
self.node = node
if copy_from is None:
# The root node
self.root = node
# The set of states in the graph that have been completed
self.completed = Counter()
# The...
self.generated = Counter()
# The list of tokens we need to generate
self.needed_tokens = self.root.tokens()
else:
self.completed = Counter(copy_from.completed)
self.generated = Counter(copy_from.generated)
self.root = copy_from.root
# Mark the node as generated
if self.node != self.root:
self.generated[node] += 1
@staticmethod
def create(constraint_tensor: torch.Tensor):
constraint_list = unpack_constraints(constraint_tensor)
constraint_trie_root = ConstraintNode.create(constraint_list)
return UnorderedConstraintState(constraint_trie_root)
def __str__(self):
gen_str = ",".join([str(node) for node in self.generated])
return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}"
def __copy__(self):
copied_state = UnorderedConstraintState(self.node, copy_from=self)
return copied_state
def copy(self):
return self.__copy__()
@property
def name(self):
if self.node.id is None:
return "ROOT"
else:
return str(self.node.id)
@property
def is_root(self):
return self.node == self.root
@property
def bank(self):
return sum(self.generated.values())
@property
def num_completed(self):
"""The number of constraints (not constraint tokens) that are completed.
In addition to the already-completed states, we need to account for the
current state, which might get marked as completed when another token
is generated.
"""
in_final = self.node.terminal and self.completed[self.node] < self.node.terminal
return sum(self.completed.values()) + in_final
@property
def finished(self):
return self.root.num_constraints - self.num_completed == 0
@property
def token_counts(self):
return self.root.token_counts()
@property
def tokens(self):
return self.root.tokens()
@property
def num_constraint_tokens(self):
return sum(self.token_counts.values())
def next_tokens(self) -> Set[int]:
"""Returns the list of tokens that could come next.
These are (a) all tokens extending the root state and, for
non-root states, additionally all tokens extending the current
state."""
if self.node != self.root:
return self.root.next_tokens().union(self.node.next_tokens())
else:
return self.root.next_tokens()
def advance(self, token: int):
"""Reads in a token and advances the state. Here's how it works.
We can advance to the next state if:
- there is a matching child
- its path isn't blocked
A path is blocked when all constraints that are descendants of
that node have already been generated, in the current state.
If we are not able to advance from the current state, we "fall
off the graph" and return to the root state. There, we again
try to advance, checking the same criteria.
In any case, when falling off the graph, we need to do some
bookkeeping. We:
- check whether any constraints were met (all prefixes of
current state)
- if one is found, mark it as completed
- adjust visited nodes accordingly
"""
token = int(token)
next_state = None
child = self.node[token]
if child is not None and self.generated[child] < child.num_constraints:
next_state = UnorderedConstraintState(child, copy_from=self)
def rewind():
"""If we're mid-trie and an "illegal" token is chosen next, we need
to reset our state to the root state. However, along the way, we need
to check whether a prefix of the current trie state represents a state
we could mark as completed.
"""
node = self.node
while node != self.root:
if node.terminal and self.completed[node] < node.terminal:
next_state.completed[node] += 1
return
next_state.generated[node] -= 1
node = node.parent
# Fall off the graph, check the root
if next_state is None and token in self.root.next_tokens():
child = self.root[token]
# We can only traverse this edge if it's not saturated
if self.generated[child] < child.num_constraints:
next_state = UnorderedConstraintState(child, copy_from=self)
else:
next_state = UnorderedConstraintState(self.root, copy_from=self)
# Rewind
rewind()
elif next_state is None:
next_state = UnorderedConstraintState(self.root, copy_from=self)
# Rewind
rewind()
return next_state
class ConstraintSequence:
def __init__(self, sequences: List[List[int]]):
"""Represents a set of possibly multitoken constraints by
concatenating them and internally recording the end points.
"""
self.sequences = []
self.endpoints = []
self.num_tokens = 0
self.tokens = set()
for sequence in sequences:
for token in sequence:
self.tokens.add(token)
self.num_tokens += len(sequence)
self.endpoints += [False for x in range(len(sequence) - 1)] + [True]
self.sequences += sequence
def __getitem__(self, key: int):
return self.sequences[key]
def __len__(self):
return len(self.sequences)
def __str__(self):
return str(self.sequences)
class OrderedConstraintState(ConstraintState):
"""
Records progress through the set of linear nonbranching constraints with gaps.
"""
def __init__(self,
sequence: ConstraintSequence,
state: int = -1):
self.sequence = sequence
self.state = state
@staticmethod
def create(constraint_tensor: torch.Tensor):
constraint_list = unpack_constraints(constraint_tensor)
return OrderedConstraintState(ConstraintSequence(constraint_list), -1)
def __str__(self):
return f"{self.state}/{self.bank}x{self.num_completed}"
def __copy__(self):
return OrderedConstraintState(self.sequence, self.state)
def copy(self):
return self.__copy__()
@property
def num_completed(self):
if self.state == -1:
return 0
count = len(list(filter(lambda x: x, self.sequence.endpoints[0:self.state+1])))
return count
@property
def is_root(self):
return self.state == -1
@property
def name(self):
if self.state == -1:
return "ROOT"
else:
return str(self.sequence[self.state])
@property
def bank(self) -> int:
return self.state + 1
@property
def finished(self):
return self.state + 1 == len(self.sequence)
@property
def token_counts(self):
return self.sequence.token_counts()
@property
def tokens(self):
return self.sequence.tokens
@property
def num_constraint_tokens(self):
return sum(self.token_counts.values())
def next_tokens(self) -> Set[int]:
"""Returns the list of tokens that could come next.
These are (a) all tokens extending the root state and, for
non-root states, additionally all tokens extending the current
state."""
tokens = set()
if self.state > 0:
tokens.add(self.sequence[0])
if not self.finished:
tokens.add(self.sequence[self.state + 1])
return tokens
def advance(self, token: int):
"""Reads in a token and advances the state. Here's how it works.
We can advance to the next state if:
- there is a matching child
- its path isn't blocked
A path is blocked when all constraints that are descendants of
that node have already been generated, in the current state.
If we are not able to advance from the current state, we "fall
off the graph" and return to the root state. There, we again
try to advance, checking the same criteria.
In any case, when falling off the graph, we need to do some
bookkeeping. We:
- check whether any constraints were met (all prefixes of
current state)
- if one is found, mark it as completed
- adjust visited nodes accordingly
"""
token = int(token)
# print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="")
if self.finished:
# Accept anything
next_state = self.copy()
elif self.sequence[self.state + 1] == token:
# Advance to the next token
next_state = OrderedConstraintState(self.sequence, self.state + 1)
elif self.sequence.endpoints[self.state]:
# Accept anything between constraints (*)
next_state = self.copy()
elif token == self.sequence[0]:
# Start over having generated the first token
next_state = OrderedConstraintState(self.sequence, 0)
else:
# Start over from the root
next_state = OrderedConstraintState(self.sequence, -1)
return next_state

View File

@ -150,8 +150,12 @@ def _main(args, output_file):
if args.prefix_size > 0:
prefix_tokens = sample['target'][:, :args.prefix_size]
constraints = None
if "constraints" in sample:
constraints = sample["constraints"]
gen_timer.start()
hypos = task.inference_step(generator, models, sample, prefix_tokens)
hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints)
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
gen_timer.stop(num_generated_tokens)

View File

@ -12,6 +12,7 @@ import fileinput
import logging
import math
import sys
import time
import os
import numpy as np
@ -20,9 +21,9 @@ import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from .generate import get_symbols_to_strip_from_output
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
@ -32,7 +33,7 @@ logging.basicConfig(
logger = logging.getLogger('fairseq_cli.interactive')
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Batch = namedtuple('Batch', 'ids src_tokens src_lengths constraints')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
@ -50,28 +51,63 @@ def buffered_read(input, buffer_size):
def make_batches(lines, args, task, max_positions, encode_fn):
def encode_fn_target(x):
return encode_fn(x)
if args.constraints:
# Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints
batch_constraints = [list() for _ in lines]
for i, line in enumerate(lines):
if "\t" in line:
lines[i], *batch_constraints[i] = line.split("\t")
# Convert each List[str] to List[Tensor]
for i, constraint_list in enumerate(batch_constraints):
batch_constraints[i] = [task.target_dictionary.encode_line(
encode_fn_target(constraint),
append_eos=False,
add_if_not_exist=False,
) for constraint in constraint_list]
tokens = [
task.source_dictionary.encode_line(
encode_fn(src_str), add_if_not_exist=False
).long()
for src_str in lines
]
if args.constraints:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
lengths = [t.numel() for t in tokens]
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths),
dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch['id']
src_tokens = batch['net_input']['src_tokens']
src_lengths = batch['net_input']['src_lengths']
constraints = batch.get("constraints", None)
yield Batch(
ids=batch['id'],
src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
ids=ids,
src_tokens=src_tokens,
src_lengths=src_lengths,
constraints=constraints,
)
def main(args):
start_time = time.time()
total_translate_time = 0
utils.import_user_module(args)
if args.buffer_size < 1:
@ -147,6 +183,9 @@ def main(args):
*[model.max_positions() for model in models]
)
if args.constraints:
logger.warning("NOTE: Constrained decoding currently assumes a shared subword vocabulary.")
if args.buffer_size > 1:
logger.info('Sentence buffer size: %s', args.buffer_size)
logger.info('NOTE: hypothesis and token scores are output in base 2')
@ -155,11 +194,15 @@ def main(args):
for inputs in buffered_read(args.input, args.buffer_size):
results = []
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
'net_input': {
@ -167,16 +210,29 @@ def main(args):
'src_lengths': src_lengths,
},
}
translations = task.inference_step(generator, models, sample)
translate_start_time = time.time()
translations = task.inference_step(generator, models, sample, constraints=constraints)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if args.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
results.append((start_id + id, src_tokens_i, hypos))
constraints = list_constraints[i]
results.append((start_id + id, src_tokens_i, hypos,
{ "constraints": constraints,
"time": translate_time / len(translations) }
))
# sort output to match input order
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
print('S-{}\t{}'.format(id, src_str))
print('S-{}\t{}'.format(id_, src_str))
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
for constraint in info["constraints"]:
print("C-{}\t{}".format(id_, tgt_dict.string(constraint, args.remove_bpe)))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
@ -192,11 +248,11 @@ def main(args):
detok_hypo_str = decode_fn(hypo_str)
score = hypo['score'] / math.log(2) # convert to base 2
# original hypothesis (after tokenization and BPE)
print('H-{}\t{}\t{}'.format(id, score, hypo_str))
print('H-{}\t{}\t{}'.format(id_, score, hypo_str))
# detokenized hypothesis
print('D-{}\t{}\t{}'.format(id, score, detok_hypo_str))
print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str))
print('P-{}\t{}'.format(
id,
id_,
' '.join(map(
lambda x: '{:.4f}'.format(x),
# convert from base e to base 2
@ -206,13 +262,14 @@ def main(args):
if args.print_alignment:
alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
print('A-{}\t{}'.format(
id,
id_,
alignment_str
))
# update running id counter
# update running id_ counter
start_id += len(inputs)
logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(time.time() - start_time, total_translate_time))
def cli_main():
parser = options.get_interactive_generation_parser()

83
scripts/constraints/extract.py Executable file
View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Extracts random constraints from reference files."""
import argparse
import random
import sys
from sacrebleu import extract_ngrams
def get_phrase(words, index, length):
assert(index < len(words) - length + 1)
phr = ' '.join(words[index:index+length])
for i in range(index, index + length):
words.pop(index)
return phr
def main(args):
if args.seed:
random.seed(args.seed)
for line in sys.stdin:
constraints = []
def add_constraint(constraint):
constraints.append(constraint)
source = line.rstrip()
if '\t' in line:
source, target = line.split('\t')
if args.add_sos:
target = f"<s> {target}"
if args.add_eos:
target = f"{target} </s>"
if len(target.split()) >= args.len:
words = [target]
num = args.number
choices = {}
for i in range(num):
if len(words) == 0:
break
segmentno = random.choice(range(len(words)))
segment = words.pop(segmentno)
tokens = segment.split()
phrase_index = random.choice(range(len(tokens)))
choice = " ".join(tokens[phrase_index:min(len(tokens), phrase_index + args.len)])
for j in range(phrase_index, min(len(tokens), phrase_index + args.len)):
tokens.pop(phrase_index)
if phrase_index > 0:
words.append(" ".join(tokens[0:phrase_index]))
if phrase_index + 1 < len(tokens):
words.append(" ".join(tokens[phrase_index:]))
choices[target.find(choice)] = choice
# mask out with spaces
target = target.replace(choice, " " * len(choice), 1)
for key in sorted(choices.keys()):
add_constraint(choices[key])
print(source, *constraints, sep="\t")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--number', '-n', type=int, default=1, help="number of phrases")
parser.add_argument('--len', '-l', type=int, default=1, help="phrase length")
parser.add_argument('--add-sos', default=False, action='store_true', help='add <s> token')
parser.add_argument('--add-eos', default=False, action='store_true', help='add </s> token')
parser.add_argument('--seed', "-s", default=0, type=int)
args = parser.parse_args()
main(args)

33
scripts/constraints/validate.py Executable file
View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
"""Reads in a fairseq output file, and verifies that the constraints
(C- lines) are present in the output (the first H- line). Assumes that
constraints are listed prior to the first hypothesis.
"""
constraints = []
found = 0
total = 0
for line in sys.stdin:
if line.startswith("C-"):
constraints.append(line.rstrip().split("\t")[1])
elif line.startswith("H-"):
text = line.split("\t")[2]
for constraint in constraints:
total += 1
if constraint in text:
found += 1
else:
print(f"No {constraint} in {text}", file=sys.stderr)
constraints = []
print(f"Found {found} / {total} = {100 * found / total:.1f}%")

254
tests/test_constraints.py Executable file
View File

@ -0,0 +1,254 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import torch
import unittest
from fairseq.token_generation_constraints import *
def tensorize(constraints: List[List[int]]) -> torch.Tensor:
return [torch.tensor(x) for x in constraints]
class TestHelperRoutines(unittest.TestCase):
def setUp(self):
self.examples = [
(
[[]],
torch.tensor([[0]])
),
(
[[], []],
torch.tensor([[0], [0]])
),
(
[[torch.tensor([1, 2])], []],
torch.tensor([[1, 1, 2, 0], [0, 0, 0, 0]])
),
(
[[torch.tensor([3, 1, 2]), torch.tensor([3]), torch.tensor([4, 5, 6, 7])],
[],
[ torch.tensor([1, 8, 9, 10, 1, 4, 11, 12]) ]],
torch.tensor([[3, 3, 1, 2, 0, 3, 0, 4, 5, 6, 7, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 8, 9, 10, 1, 4, 11, 12, 0, 0, 0]])
)
]
def test_packing(self):
"""Ensures the list of lists of tensors gets packed correctly."""
for batch_constraints, expected_tensor in self.examples:
packed = pack_constraints(batch_constraints)
assert torch.equal(packed, expected_tensor)
class TestUnorderedConstraintState(unittest.TestCase):
def setUp(self):
# Tuples of (contraint set, expected printed graph, token counts per node)
self.examples = [
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
"([None].False#6 ([1].True#4 ([2].False#1 [3].True#1) [3].True#1 [4].True#1) ([4].False#2 ([5].True#2 ([6].False#1 [7].True#1))))",
{ 1: 4, 2: 1, 3: 2, 4: 3, 5: 2, 6: 1, 7: 1 }
),
( [], "[None].False#0", {} ),
( tensorize([[0]]), "([None].False#1 [0].True#1)", { 0: 1 } ),
( tensorize([[100000, 1, 2, 3, 4, 5]]), "([None].False#1 ([100000].False#1 ([1].False#1 ([2].False#1 ([3].False#1 ([4].False#1 [5].True#1))))))", { 100000: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1 } ),
(
tensorize([[1, 2], [1, 2]]),
"([None].False#2 ([1].False#2 [2].True#2))",
{ 1: 2, 2: 2 },
),
(
tensorize([[1, 2], [3, 4]]),
"([None].False#2 ([1].False#1 [2].True#1) ([3].False#1 [4].True#1))",
{ 1: 1, 2: 1, 3: 1, 4: 1},
),
]
self.sequences = [
(
self.examples[0][0],
[],
{ "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
),
(
self.examples[0][0],
[1, 2],
{ "bank": 2, "num_completed": 0, "finished": False, "is_root": False },
),
(
self.examples[0][0],
[1, 2, 94],
{ "bank": 1, "num_completed": 1, "finished": False, "is_root": True },
),
(
self.examples[0][0],
[1, 3, 999, 1, 4],
{ "bank": 4, "num_completed": 2, "finished": False, "is_root": False },
),
(
self.examples[0][0],
[1, 3, 999, 1, 4, 999],
{ "bank": 4, "num_completed": 2, "finished": False, "is_root": True },
),
(
self.examples[0][0],
[4, 5, 6, 8],
{ "bank": 2, "num_completed": 1, "finished": False, "is_root": True },
),
(
self.examples[0][0],
# Tricky, because in last three, goes down [1->4] branch, could miss [1] and [4->5]
# [[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]],
[1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
{ "bank": 14, "num_completed": 6, "finished": True, "is_root": False },
),
(
self.examples[0][0],
[1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
{ "bank": 14, "num_completed": 6, "finished": True, "is_root": True },
),
(
tensorize([[1], [2, 3]]),
# Should not be able to get credit for entering 1 a second time
[1, 1],
{ "bank": 1, "num_completed": 1, "finished": False, "is_root": True },
),
(
self.examples[4][0],
[1, 2, 1, 2],
{ "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
),
(
self.examples[4][0],
[1, 2, 1, 2, 1],
{ "bank": 4, "num_completed": 2, "finished": True, "is_root": True },
),
(
self.examples[5][0],
[1, 2, 3, 4, 5],
{ "bank": 4, "num_completed": 2, "finished": True, "is_root": True },
),
]
def test_graphs(self):
"""
Test whether unordered graph systems are created correctly.
"""
for example in self.examples:
constraints, expected, gold_counts = example
c = ConstraintNode.create(constraints)
assert ConstraintNode.print_graph(c) == expected, f"got {ConstraintNode.print_graph(c)}, expected {expected}"
assert c.token_counts() == gold_counts, f"{c} got {c.token_counts()} wanted {gold_counts}"
def test_next_tokens(self):
"""
Tests that the set of next tokens is correct.
"""
for example in self.examples:
constraints, expected, gold_counts = example
root = ConstraintNode.create(constraints)
root_tokens = set(root.children.keys())
for sequence in constraints:
state = UnorderedConstraintState(root)
for token in sequence:
all_tokens = root_tokens.union(state.node.children.keys())
assert all_tokens == state.next_tokens(), f"ALL {all_tokens} NEXT {state.next_tokens()}"
state = state.advance(token)
def test_sequences(self):
for constraints, tokens, expected in self.sequences:
state = UnorderedConstraintState.create(pack_constraints([constraints])[0])
for token in tokens:
state = state.advance(token)
result = {}
for attr in expected.keys():
result[attr] = getattr(state, attr)
assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}"
class TestOrderedConstraintState(unittest.TestCase):
def setUp(self):
self.sequences = [
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[],
{ "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2],
{ "bank": 2, "num_completed": 0, "finished": False, "is_root": False },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 94],
{ "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 3, 999, 1, 4],
{ "bank": 0, "num_completed": 0, "finished": False, "is_root": True },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 3, 999, 999],
{ "bank": 3, "num_completed": 1, "finished": False, "is_root": False },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 3, 77, 1, 3, 1],
{ "bank": 6, "num_completed": 2, "finished": False, "is_root": False },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 3, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5],
{ "bank": 14, "num_completed": 6, "finished": True, "is_root": False },
),
(
tensorize([[1, 2, 3], [1, 3], [1, 4], [4, 5, 6, 7], [1], [4, 5]]),
[1, 2, 999, 1, 2, 3, 999, 1, 3, 1, 4, 4, 5, 6, 7, 1, 4, 5, 117],
{ "bank": 14, "num_completed": 6, "finished": True, "is_root": False },
),
(
tensorize([[1], [2, 3]]),
[1, 1],
{ "bank": 1, "num_completed": 1, "finished": False, "is_root": False },
),
(
tensorize([[1, 2], [1, 2]]),
[1, 2, 1, 2],
{ "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
),
(
tensorize([[1, 2], [1, 2]]),
[1, 2, 1, 2, 1],
{ "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
),
(
tensorize([[1, 2], [3, 4]]),
[1, 2, 3, 4, 5],
{ "bank": 4, "num_completed": 2, "finished": True, "is_root": False },
),
]
def test_sequences(self):
for i, (constraints, tokens, expected) in enumerate(self.sequences):
state = OrderedConstraintState.create(pack_constraints([constraints])[0])
for token in tokens:
state = state.advance(token)
result = {}
for attr in expected.keys():
result[attr] = getattr(state, attr)
assert result == expected, f"TEST({tokens}) GOT: {result} WANTED: {expected}"
if __name__ == "__main__":
unittest.main()