Implementation of the paper "Jointly Learning to Align and Translate with Transformer Models" (#877)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/877

This PR implements guided alignment training described in  "Jointly Learning to Align and Translate with Transformer Models (https://arxiv.org/abs/1909.02074)".

In summary, it allows for training selected heads of the Transformer Model with external alignments computed by Statistical Alignment Toolkits. During inference, attention probabilities from the trained heads can be used to extract reliable alignments. In our work, we did not see any regressions in the translation performance because of guided alignment training.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1095

Differential Revision: D17170337

Pulled By: myleott

fbshipit-source-id: daa418bef70324d7088dbb30aa2adf9f95774859
This commit is contained in:
Sarthak Garg 2019-09-30 06:56:15 -07:00 committed by Facebook Github Bot
parent acb6fba005
commit 1c66792948
20 changed files with 899 additions and 61 deletions

View File

@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
@ -100,6 +101,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
We also have more detailed READMEs to reproduce results from specific papers:
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)

View File

@ -0,0 +1,89 @@
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)
This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).
## Training a joint alignment-translation model on WMT'18 En-De
##### 1. Extract and preprocess the WMT'18 En-De data
```bash
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
```
##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
In this example, we use FastAlign.
```bash
git clone git@github.com:clab/fast_align.git
pushd fast_align
mkdir build
cd build
cmake ..
make
popd
ALIGN=fast_align/build/fast_align
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
```
##### 3. Preprocess the dataset with the above generated alignments.
```bash
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref bpe.32k/train \
--validpref bpe.32k/valid \
--testpref bpe.32k/test \
--align-suffix align \
--destdir binarized/ \
--joined-dictionary \
--workers 32
```
##### 4. Train a model
```bash
fairseq-train \
binarized \
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 3500 --label-smoothing 0.1 \
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
--keep-interval-updates -1 --save-interval-updates 0 \
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \
--fp16
```
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
- increase the learning rate; 0.0007 works well for big batches
##### 5. Evaluate and generate the alignments (BPE level)
```bash
fairseq-generate \
binarized --gen-subset test --print-alignment \
--source-lang en --target-lang de \
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1
```
##### 6. Other resources.
The code for:
1. preparing alignment test sets
2. converting BPE level alignments to token level alignments
3. symmetrizing bidirectional alignments
4. evaluating alignments using AER metric
can be found [here](https://github.com/lilt/alignment-scripts)
## Citation
```bibtex
@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
address = {Hong Kong},
month = {November},
url = {https://arxiv.org/abs/1909.02074},
year = {2019},
}
```

View File

@ -0,0 +1,118 @@
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training-parallel-nc-v13/news-commentary-v13.de-en"
"rapid2016.de-en"
)
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi
src=en
tgt=de
lang=en-de
prep=wmt18_en_de
tmp=$prep/tmp
orig=orig
dev=dev/newstest2012
codes=32000
bpe=bpe.32k
mkdir -p $orig $tmp $prep $bpe
cd $orig
for ((i=0;i<${#URLS[@]};++i)); do
url=${URLS[i]}
file=$(basename $url)
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit 1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..
echo "pre-processing train data..."
for l in $src $tgt; do
rm -rf $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
done
done
echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\/\'/g" | \
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
echo ""
done
# apply length filtering before BPE
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100
# use newstest2012 for valid
echo "pre-processing valid data..."
for l in $src $tgt; do
rm -rf $tmp/valid.$l
cat $orig/$dev.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
done
mkdir output
mv $tmp/{train,valid,test}.{$src,$tgt} output
#BPE
git clone git@github.com:glample/fastBPE.git
pushd fastBPE
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
popd
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done

View File

@ -52,6 +52,22 @@ class Binarizer:
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
@staticmethod
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
nseq = 0
with open(filename, 'r') as f:
f.seek(offset)
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = alignment_parser(line)
nseq += 1
consumer(ids)
line = f.readline()
return {'nseq': nseq}
@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:

View File

@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import utils
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from . import register_criterion
@register_criterion('label_smoothed_cross_entropy_with_alignment')
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):
def __init__(self, args, task):
super().__init__(args, task)
self.alignment_lambda = args.alignment_lambda
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
super(LabelSmoothedCrossEntropyCriterionWithAlignment,
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
help='weight for the alignment loss')
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if 'alignments' in sample and sample['alignments'] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output['alignment_loss'] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]['attn']
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample['alignments']
align_weights = sample['align_weights'].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
else:
return None
return loss
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

View File

@ -22,6 +22,28 @@ def collate(
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
print("| alignment size mismatch found, skipping alignment!")
return False
return True
def compute_alignment_weights(alignments):
"""
Given a tensor of shape [:, 2] containing the source-target indices
corresponding to the alignments, a weight vector containing the
inverse frequency of each target index is computed.
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
_, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
return 1. / align_weights.float()
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
@ -35,6 +57,7 @@ def collate(
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
if input_feeding:
@ -61,6 +84,32 @@ def collate(
}
if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens
if samples[0].get('alignment', None) is not None:
bsz, tgt_sz = batch['target'].shape
src_sz = batch['net_input']['src_tokens'].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
if left_pad_source:
offsets[:, 0] += (src_sz - src_lengths)
if left_pad_target:
offsets[:, 1] += (tgt_sz - tgt_lengths)
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
batch['alignments'] = alignments
batch['align_weights'] = align_weights
return batch
@ -91,6 +140,8 @@ class LanguagePairDataset(FairseqDataset):
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
"""
def __init__(
@ -98,7 +149,9 @@ class LanguagePairDataset(FairseqDataset):
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False,
shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
@ -118,6 +171,9 @@ class LanguagePairDataset(FairseqDataset):
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
@ -136,11 +192,14 @@ class LanguagePairDataset(FairseqDataset):
if self.src[index][-1] == eos:
src_item = self.src[index][:-1]
return {
example = {
'id': index,
'source': src_item,
'target': tgt_item,
}
if self.align_dataset is not None:
example['alignment'] = self.align_dataset[index]
return example
def __len__(self):
return len(self.src)
@ -212,3 +271,5 @@ class LanguagePairDataset(FairseqDataset):
self.src.prefetch(indices)
if self.tgt is not None:
self.tgt.prefetch(indices)
if self.align_dataset is not None:
self.align_dataset.prefetch(indices)

View File

@ -222,6 +222,9 @@ class FairseqEncoderDecoderModel(BaseFairseqModel):
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Similar to *forward* but only return features.

View File

@ -68,6 +68,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.supports_align_args = True
@staticmethod
def add_args(parser):
@ -195,6 +196,69 @@ class TransformerModel(FairseqEncoderDecoderModel):
)
@register_model('transformer_align')
class TransformerAlignModel(TransformerModel):
"""
See "Jointly Learning to Align and Translate with Transformer
Models" (Garg et al., EMNLP 2019).
"""
def __init__(self, encoder, decoder, args):
super().__init__(encoder, decoder)
self.alignment_heads = args.alignment_heads
self.alignment_layer = args.alignment_layer
self.full_context_alignment = args.full_context_alignment
@staticmethod
def add_args(parser):
# fmt: off
super(TransformerAlignModel, TransformerAlignModel).add_args(parser)
parser.add_argument('--alignment-heads', type=int, metavar='D',
help='Number of cross attention heads per layer to supervised with alignments')
parser.add_argument('--alignment-layer', type=int, metavar='D',
help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.')
parser.add_argument('--full-context-alignment', type=bool, metavar='D',
help='Whether or not alignment is supervised conditioned on the full target context.')
# fmt: on
@classmethod
def build_model(cls, args, task):
# set any default arguments
transformer_align(args)
transformer_model = TransformerModel.build_model(args, task)
return TransformerAlignModel(transformer_model.encoder, transformer_model.decoder, args)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
return self.forward_decoder(prev_output_tokens, encoder_out)
def forward_decoder(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
attn_args = {'alignment_layer': self.alignment_layer, 'alignment_heads': self.alignment_heads}
decoder_out = self.decoder(
prev_output_tokens,
encoder_out,
**attn_args,
**extra_args,
)
if self.full_context_alignment:
attn_args['full_context_alignment'] = self.full_context_alignment
_, alignment_out = self.decoder(
prev_output_tokens, encoder_out, features_only=True, **attn_args, **extra_args,
)
decoder_out[1]['attn'] = alignment_out['attn']
return decoder_out
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
@ -423,7 +487,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else:
self.layer_norm = None
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
@ -432,25 +503,53 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state)
x = self.output_layer(x)
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
def extract_features(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
full_context_alignment=False,
alignment_layer=None,
alignment_heads=None,
**unused,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = len(self.layers) - 1
# embed positions
positions = self.embed_positions(
prev_output_tokens,
@ -474,15 +573,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
inner_states = [x]
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# decoder layers
attn = None
inner_states = [x]
for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
@ -491,15 +589,32 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else:
encoder_state = encoder_out['encoder_out']
x, attn = layer(
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm:
x = self.layer_norm(x)
@ -531,7 +646,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim:
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
@ -668,3 +788,18 @@ def transformer_wmt_en_de_big_t2t(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('transformer_align', 'transformer_align')
def transformer_align(args):
args.alignment_heads = getattr(args, 'alignment_heads', 1)
args.alignment_layer = getattr(args, 'alignment_layer', 4)
args.full_context_alignment = getattr(args, 'full_context_alignment', False)
base_architecture(args)
@register_model_architecture('transformer_align', 'transformer_wmt_en_de_big_align')
def transformer_wmt_en_de_big_align(args):
args.alignment_heads = getattr(args, 'alignment_heads', 1)
args.alignment_layer = getattr(args, 'alignment_layer', 4)
transformer_wmt_en_de_big(args)

View File

@ -90,15 +90,37 @@ class MultiheadAttention(nn.Module):
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None, before_softmax=False):
def forward(
self,
query, key, value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
):
"""Input shape: Time x Batch x Channel
Timesteps can be masked by supplying a T x T mask in the
`attn_mask` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
@ -249,12 +271,11 @@ class MultiheadAttention(nn.Module):
if before_softmax:
return attn_weights, v
attn_weights = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace,
).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if (self.onnx_trace and attn.size(1) == 1):
# when ONNX tracing a single decoder step (sequence length == 1)
@ -265,9 +286,10 @@ class MultiheadAttention(nn.Module):
attn = self.out_proj(attn)
if need_weights:
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
else:
attn_weights = None

View File

@ -195,16 +195,25 @@ class TransformerDecoderLayer(nn.Module):
prev_attn_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
need_attn=False,
need_head_weights=False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
@ -259,7 +268,8 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x

View File

@ -224,6 +224,8 @@ def add_preprocess_args(parser):
help="comma separated, valid file prefixes")
group.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes")
group.add_argument("--align-suffix", metavar="FP", default=None,
help="alignment file suffix")
group.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir")
group.add_argument("--thresholdtgt", metavar="N", default=0, type=int,

View File

@ -7,7 +7,8 @@ import math
import torch
from fairseq import search
from fairseq import search, utils
from fairseq.data import data_utils
from fairseq.models import FairseqIncrementalDecoder
@ -81,7 +82,6 @@ class SequenceGenerator(object):
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
assert temperature > 0, '--temperature must be greater than 0'
@ -98,14 +98,7 @@ class SequenceGenerator(object):
self.search = search.BeamSearch(tgt_dict)
@torch.no_grad()
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
def generate(self, models, sample, **kwargs):
"""Generate a batch of translations.
Args:
@ -113,8 +106,21 @@ class SequenceGenerator(object):
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
model = EnsembleModel(models)
return self._generate(model, sample, **kwargs)
@torch.no_grad()
def _generate(
self,
model,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
if not self.retain_dropout:
model.eval()
@ -155,7 +161,6 @@ class SequenceGenerator(object):
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos if bos_token is None else bos_token
attn, attn_buf = None, None
nonpad_idxs = None
# The blacklist indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
@ -251,17 +256,15 @@ class SequenceGenerator(object):
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
hypo_attn = attn_clone[i]
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
'score': score,
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'alignment': None,
'positional_scores': pos_scores[i],
}
@ -345,7 +348,6 @@ class SequenceGenerator(object):
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(lprobs)
@ -512,7 +514,6 @@ class SequenceGenerator(object):
# sort by score descending
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
@ -577,9 +578,11 @@ class EnsembleModel(torch.nn.Module):
temperature=1.,
):
if self.incremental_states is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=self.incremental_states[model]))
decoder_out = list(model.forward_decoder(
tokens, encoder_out=encoder_out, incremental_state=self.incremental_states[model],
))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
@ -605,3 +608,104 @@ class EnsembleModel(torch.nn.Module):
return
for model in self.models:
model.decoder.reorder_incremental_state(self.incremental_states[model], new_order)
class SequenceGeneratorWithAlignment(SequenceGenerator):
def __init__(self, tgt_dict, left_pad_target=False, **kwargs):
"""Generates translations of a given source sentence.
Produces alignments following "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
left_pad_target (bool, optional): Whether or not the
hypothesis should be left padded or not when they are
teacher forced for generating alignments.
"""
super().__init__(tgt_dict, **kwargs)
self.left_pad_target = left_pad_target
@torch.no_grad()
def generate(self, models, sample, **kwargs):
model = EnsembleModelWithAlignment(models)
finalized = super()._generate(model, sample, **kwargs)
src_tokens = sample['net_input']['src_tokens']
bsz = src_tokens.shape[0]
beam_size = self.beam_size
src_tokens, src_lengths, prev_output_tokens, tgt_tokens = \
self._prepare_batch_for_alignment(sample, finalized)
if any(getattr(m, 'full_context_alignment', False) for m in model.models):
attn = model.forward_align(src_tokens, src_lengths, prev_output_tokens)
else:
attn = [
finalized[i // beam_size][i % beam_size]['attention'].transpose(1, 0)
for i in range(bsz * beam_size)
]
# Process the attn matrix to extract hard alignments.
for i in range(bsz * beam_size):
alignment = utils.extract_hard_alignment(attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos)
finalized[i // beam_size][i % beam_size]['alignment'] = alignment
return finalized
def _prepare_batch_for_alignment(self, sample, hypothesis):
src_tokens = sample['net_input']['src_tokens']
bsz = src_tokens.shape[0]
src_tokens = src_tokens[:, None, :].expand(-1, self.beam_size, -1).contiguous().view(bsz * self.beam_size, -1)
src_lengths = sample['net_input']['src_lengths']
src_lengths = src_lengths[:, None].expand(-1, self.beam_size).contiguous().view(bsz * self.beam_size)
prev_output_tokens = data_utils.collate_tokens(
[beam['tokens'] for example in hypothesis for beam in example],
self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=True,
)
tgt_tokens = data_utils.collate_tokens(
[beam['tokens'] for example in hypothesis for beam in example],
self.pad, self.eos, self.left_pad_target, move_eos_to_beginning=False,
)
return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
class EnsembleModelWithAlignment(EnsembleModel):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__(models)
def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
avg_attn = None
for model in self.models:
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
attn = decoder_out[1]['attn']
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
if len(self.models) > 1:
avg_attn.div_(len(self.models))
return avg_attn
def _decode_one(
self, tokens, model, encoder_out, incremental_states, log_probs,
temperature=1.,
):
if self.incremental_states is not None:
decoder_out = list(model.forward_decoder(
tokens,
encoder_out=encoder_out,
incremental_state=self.incremental_states[model],
))
else:
decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
attn = decoder_out[1]
if type(attn) is dict:
attn = attn.get('attn', None)
if attn is not None:
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
probs = probs[:, -1, :]
return probs, attn

View File

@ -14,6 +14,7 @@ class SequenceScorer(object):
def __init__(self, tgt_dict, softmax_batch=None):
self.pad = tgt_dict.pad()
self.eos = tgt_dict.eos()
self.softmax_batch = softmax_batch or sys.maxsize
assert self.softmax_batch > 0
@ -44,6 +45,7 @@ class SequenceScorer(object):
)
return probs
orig_target = sample['target']
# compute scores for each model in the ensemble
@ -53,6 +55,8 @@ class SequenceScorer(object):
model.eval()
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
if type(attn) is dict:
attn = attn.get('attn', None)
batched = batch_for_softmax(decoder_out, orig_target)
probs, idx = None, 0
@ -100,8 +104,9 @@ class SequenceScorer(object):
avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None:
avg_attn_i = avg_attn[i, start_idxs[i]:]
_, alignment = avg_attn_i.max(dim=0)
avg_attn_i = avg_attn[i]
alignment = utils.extract_hard_alignment(avg_attn_i, sample['net_input']['src_tokens'][i],
sample['target'][i], self.pad, self.eos)
else:
avg_attn_i = alignment = None
hypos.append([{

View File

@ -198,8 +198,12 @@ class FairseqTask(object):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(self.target_dictionary)
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment
if getattr(args, 'print_alignment', False):
seq_gen_cls = SequenceGeneratorWithAlignment
else:
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
self.target_dictionary,
beam_size=getattr(args, 'beam', 5),
max_len_a=getattr(args, 'max_len_a', 0),

View File

@ -24,7 +24,7 @@ def load_langpair_dataset(
tgt, tgt_dict,
combine, dataset_impl, upsample_primary,
left_pad_source, left_pad_target, max_source_positions,
max_target_positions, prepend_bos=False,
max_target_positions, prepend_bos=False, load_alignments=False,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
@ -74,6 +74,12 @@ def load_langpair_dataset(
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)
return LanguagePairDataset(
src_dataset, src_dataset.sizes, src_dict,
tgt_dataset, tgt_dataset.sizes, tgt_dict,
@ -81,6 +87,7 @@ def load_langpair_dataset(
left_pad_target=left_pad_target,
max_source_positions=max_source_positions,
max_target_positions=max_target_positions,
align_dataset=align_dataset,
)
@ -120,6 +127,8 @@ class TranslationTask(FairseqTask):
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
@ -193,6 +202,7 @@ class TranslationTask(FairseqTask):
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
)
def build_dataset_for_inference(self, src_tokens, src_lengths):

View File

@ -16,6 +16,7 @@ import warnings
import torch
import torch.nn.functional as F
from itertools import accumulate
from fairseq.modules import gelu, gelu_accurate
@ -367,3 +368,47 @@ def set_torch_seed(seed):
assert isinstance(seed, int)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def parse_alignment(line):
"""
Parses a single line from the alingment file.
Args:
line (str): String containing the alignment of the format:
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> ..
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed.
Returns:
torch.IntTensor: packed alignments of shape (2 * m).
"""
alignments = line.strip().split()
parsed_alignment = torch.IntTensor(2 * len(alignments))
for idx, alignment in enumerate(alignments):
src_idx, tgt_idx = alignment.split('-')
parsed_alignment[2 * idx] = int(src_idx)
parsed_alignment[2 * idx + 1] = int(tgt_idx)
return parsed_alignment
def get_token_to_word_mapping(tokens, exclude_list):
n = len(tokens)
word_start = [int(token not in exclude_list) for token in tokens]
word_idx = list(accumulate(word_start))
token_to_word = {i: word_idx[i] for i in range(n)}
return token_to_word
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1)
src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1)
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
alignment = []
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent):
attn_valid = attn[tgt_valid]
attn_valid[:, src_invalid] = float('-inf')
_, src_indices = attn_valid.max(dim=1)
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
alignment.append((src_token_to_word[src_idx.item()] - 1, tgt_token_to_word[tgt_idx.item()] - 1))
return alignment

View File

@ -137,7 +137,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
alignment=hypo['alignment'],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
@ -156,7 +156,7 @@ def main(args):
if args.print_alignment:
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
))
if args.print_step:
@ -180,6 +180,7 @@ def main(args):
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
return scorer

View File

@ -162,7 +162,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
alignment=hypo['alignment'],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
@ -174,9 +174,10 @@ def main(args):
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if args.print_alignment:
alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
print('A-{}\t{}'.format(
id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
alignment_str
))
# update running id counter

View File

@ -157,6 +157,60 @@ def main(args):
)
)
def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers):
nseq = [0]
def merge_result(worker_result):
nseq[0] += worker_result['nseq']
input_file = input_prefix
offsets = Binarizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(
binarize_alignments,
(
args,
input_file,
utils.parse_alignment,
prefix,
offsets[worker_id],
offsets[worker_id + 1]
),
callback=merge_result
)
pool.close()
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
impl=args.dataset_impl)
merge_result(
Binarizer.binarize_alignments(
input_file, utils.parse_alignment, lambda t: ds.add_item(t),
offset=0, end=offsets[1]
)
)
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, None)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
print(
"| [alignments] {}: parsed {} alignments".format(
input_file,
nseq[0]
)
)
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.dataset_impl == "raw":
# Copy original text file to destination folder
@ -180,9 +234,19 @@ def main(args):
outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)
def make_all_alignments():
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers)
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers)
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers)
make_all(args.source_lang, src_dict)
if target:
make_all(args.target_lang, tgt_dict)
if args.align_suffix:
make_all_alignments()
print("| Wrote preprocessed data to {}".format(args.destdir))
@ -242,11 +306,28 @@ def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos
return res
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
impl=args.dataset_impl, vocab_size=None)
def consumer(tensor):
ds.add_item(tensor)
res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset,
end=end)
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
return res
def dataset_dest_prefix(args, output_prefix, lang):
base = "{}/{}".format(args.destdir, output_prefix)
lang_part = (
".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else ""
)
if lang is not None:
lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang)
elif args.only_source:
lang_part = ""
else:
lang_part = ".{}-{}".format(args.source_lang, args.target_lang)
return "{}{}".format(base, lang_part)

View File

@ -266,6 +266,27 @@ class TestTranslation(unittest.TestCase):
'--gen-expert', '0'
])
def test_alignment(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_alignment') as data_dir:
create_dummy_data(data_dir, alignment=True)
preprocess_translation_data(data_dir, ['--align-suffix', 'align'])
train_translation_model(
data_dir,
'transformer_align',
[
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--load-alignments',
'--alignment-layer', '1',
'--criterion', 'label_smoothed_cross_entropy_with_alignment'
],
run_validation=True,
)
generate_main(data_dir)
class TestStories(unittest.TestCase):
@ -484,7 +505,7 @@ class TestCommonOptions(unittest.TestCase):
generate_main(data_dir)
def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def create_dummy_data(data_dir, num_examples=1000, maxlen=20, alignment=False):
def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen)
@ -497,6 +518,20 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
print(ex_str, file=h)
offset += ex_len
def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
with open(os.path.join(data_dir, filename_src), 'r') as src_f, \
open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \
open(os.path.join(data_dir, filename), 'w') as h:
for src, tgt in zip(src_f, tgt_f):
src_len = len(src.split())
tgt_len = len(tgt.split())
avg_len = (src_len + tgt_len) // 2
num_alignments = random.randint(avg_len // 2, 2 * avg_len)
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)])
print(ex_str, file=h)
_create_dummy_data('train.in')
_create_dummy_data('train.out')
_create_dummy_data('valid.in')
@ -504,6 +539,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
_create_dummy_data('test.in')
_create_dummy_data('test.out')
if alignment:
_create_dummy_alignment_data('train.in', 'train.out', 'train.align')
_create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align')
_create_dummy_alignment_data('test.in', 'test.out', 'test.align')
def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = options.get_preprocessing_parser()