CUDA implementation of Levenshtein distance for NAT training (#960)

Summary:
## What does this PR do?
CUDA implementation for Levenshtein distance for NAT and other potential application.
It will make training Levenshtein Transformer slightly faster and clean the functions.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/960

Test Plan: Imported from GitHub. Tested locally.

Reviewed By: cndn

Differential Revision: D19207096

Pulled By: MultiPath

fbshipit-source-id: 4890bbaa851ffd302648c0d949173158dc3167e2
This commit is contained in:
Jiatao Gu 2019-12-21 02:43:47 -08:00 committed by Facebook Github Bot
parent 9ad6b5a967
commit a316bd99b7
6 changed files with 725 additions and 268 deletions

View File

@ -0,0 +1,60 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance
*/
#include "edit_dist.h"
#include <torch/types.h>
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor LevenshteinDistance(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length) {
CHECK_INPUT(source);
CHECK_INPUT(target);
CHECK_INPUT(source_length);
CHECK_INPUT(target_length);
return LevenshteinDistanceCuda(source, target, source_length, target_length);
}
torch::Tensor GenerateDeletionLabel(
torch::Tensor source,
torch::Tensor operations) {
CHECK_INPUT(source);
CHECK_INPUT(operations);
return GenerateDeletionLabelCuda(source, operations);
}
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
torch::Tensor target,
torch::Tensor operations) {
CHECK_INPUT(target);
CHECK_INPUT(operations);
return GenerateInsertionLabelCuda(target, operations);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label");
m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label");
}

View File

@ -0,0 +1,338 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "edit_dist.h"
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <utility> // std::pair
template <typename scalar_t>
__global__ void generate_deletion_label_kernel(
const scalar_t* __restrict__ source,
const size_t source_size,
const size_t operation_size,
int* __restrict__ operations,
int* __restrict__ labels) {
const int index = blockIdx.x;
const int offset = index * operation_size;
const int offset_label = index * source_size;
for (int i = 0; i < source_size; i++) {
labels[offset_label + i] = 0;
}
int k = 0;
for (int i = 0; i < operation_size; i++){
if (operations[offset + i] == 0){
break;
} else if (operations[offset + i] == 1){
continue;
} else {
labels[offset_label + k] = 3 - operations[offset + i];
k++;
}
}
}
template <typename scalar_t>
__global__ void generate_insertion_label_kernel(
const scalar_t* __restrict__ target,
const size_t target_size,
const size_t operation_size,
int* __restrict__ operations,
int* __restrict__ labels,
int* __restrict__ masks) {
const int index = blockIdx.x;
const int offset = index * operation_size;
const int offset_label = index * target_size;
int k = 0;
int u = 0;
int m = 0;
for (int i = 0; i < target_size; i++) {
labels[offset_label + i] = 0;
masks[offset_label + i] = 0;
}
for (int i = 0; i < operation_size-1; i++){
if (operations[offset + i] == 0){
break;
} else if (operations[offset + i] == 2){
continue;
} else if (operations[offset + i] == 1){
masks[offset_label + m] = 1;
u++; m++;
} else {
labels[offset_label + k] = u;
masks[offset_label + m] = 0;
k++; m++;
u = 0;
}
}
}
template <typename scalar_t>
__global__ void levenshtein_distance_kernel(
const scalar_t* __restrict__ source,
const scalar_t* __restrict__ target,
const int* __restrict__ source_length,
const int* __restrict__ target_length,
const size_t source_size,
const size_t target_size,
int* __restrict__ operations,
int* __restrict__ errors_curr) {
const int index = blockIdx.x;
const int offset = index * (source_size + target_size);
const int d = index * (source_size + 1) * (target_size + 1);
const int t = target_size + 1;
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
auto opt_idx = [offset](int k) { return offset + k; };
const int hyp_len = source_length[index];
const int ref_len = target_length[index];
const scalar_t* hyp_begin = source + index * source_size;
const scalar_t* ref_begin = target + index * target_size;
// dynamic programming
for (int i = 0; i <= hyp_len; i++){
errors_curr[err_idx(i, 0)] = i;
}
for (int j = 0; j <= ref_len; j++){
errors_curr[err_idx(0, j)] = j;
}
for (int i = 1; i <= hyp_len; i++){
for (int j = 1; j <= ref_len; j++){
errors_curr[err_idx(i, j)] = min(
min(
errors_curr[err_idx(i-1, j)],
errors_curr[err_idx(i, j-1)]
) + 1,
errors_curr[err_idx(i-1, j-1)] + 2 * (
*(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
)
);
}
}
// back-tracing
int i = hyp_len;
int j = ref_len;
int o = hyp_len + ref_len;
for (int k = 0; k < source_size + target_size; k++) {
operations[opt_idx(k)] = 0;
}
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
o--; operations[opt_idx(o)] = 1; j--; // insertion
} else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
o--; operations[opt_idx(o)] = 2; i--; // deletion
} else {
o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
}
}
// moving to the left
for (int k = 0; k < hyp_len + ref_len; k++) {
if (k + o < hyp_len + ref_len){
operations[opt_idx(k)] = operations[opt_idx(k+o)];
} else{
operations[opt_idx(k)] = 0; // padding
}
}
}
template <typename scalar_t>
__global__ void faster_levenshtein_distance_kernel(
const scalar_t* __restrict__ source,
const scalar_t* __restrict__ target,
const int* __restrict__ source_length,
const int* __restrict__ target_length,
const size_t source_size,
const size_t target_size,
int* __restrict__ operations) {
extern __shared__ short errors[];
auto errors_curr = errors;
const int index = blockIdx.x;
const int offset = index * (source_size + target_size);
const int t = target_size + 1;
auto err_idx = [t](int i, int j) { return i * t + j; };
auto opt_idx = [offset](int k) { return offset + k; };
const int hyp_len = source_length[index];
const int ref_len = target_length[index];
const scalar_t* hyp_begin = source + index * source_size;
const scalar_t* ref_begin = target + index * target_size;
// dynamic programming
for (int i = 0; i <= hyp_len; i++){
errors_curr[err_idx(i, 0)] = i;
}
for (int j = 0; j <= ref_len; j++){
errors_curr[err_idx(0, j)] = j;
}
for (int i = 1; i <= hyp_len; i++){
for (int j = 1; j <= ref_len; j++){
errors_curr[err_idx(i, j)] = min(
min(
errors_curr[err_idx(i-1, j)],
errors_curr[err_idx(i, j-1)]
) + 1,
errors_curr[err_idx(i-1, j-1)] + 2 * (
*(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
)
);
}
}
// back-tracing
int i = hyp_len;
int j = ref_len;
int o = hyp_len + ref_len;
for (int k = 0; k < source_size + target_size; k++) {
operations[opt_idx(k)] = 0;
}
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
o--; operations[opt_idx(o)] = 1; j--; // insertion
} else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
o--; operations[opt_idx(o)] = 2; i--; // deletion
} else {
o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
}
}
// moving to the left
for (int k = 0; k < hyp_len + ref_len; k++) {
if (k + o < hyp_len + ref_len){
operations[opt_idx(k)] = operations[opt_idx(k+o)];
} else{
operations[opt_idx(k)] = 0; // padding
}
}
}
torch::Tensor GenerateDeletionLabelCuda(
torch::Tensor source,
torch::Tensor operations) {
const auto batch_size = source.size(0);
// const auto shared_size = (source.size(1) + 1) * 2 * sizeof(short);
at::TensorOptions options(source.device());
options = options.dtype(at::ScalarType::Int);
auto labels = torch::empty({batch_size, source.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
generate_deletion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
source.data<scalar_t>(),
source.size(1),
operations.size(1),
operations.data<int>(),
labels.data<int>());
}));
return labels;
}
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
torch::Tensor target,
torch::Tensor operations) {
const auto batch_size = target.size(0);
at::TensorOptions options(target.device());
options = options.dtype(at::ScalarType::Int);
auto labels = torch::empty({batch_size, target.size(1)}, options);
auto masks = torch::empty({batch_size, target.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] {
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
target.data<scalar_t>(),
target.size(1),
operations.size(1),
operations.data<int>(),
labels.data<int>(),
masks.data<int>());
}));
return std::make_pair(labels, masks);
}
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length) {
const auto batch_size = source.size(0);
at::TensorOptions options(source.device());
options = options.dtype(at::ScalarType::Int);
auto operations = torch::empty({batch_size, source.size(1) + target.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
if (shared_size > 40000) {
auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
levenshtein_distance_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
source.data<scalar_t>(),
target.data<scalar_t>(),
source_length.data<int>(),
target_length.data<int>(),
source.size(1),
target.size(1),
operations.data<int>(),
distances.data<int>());
}));
} else {
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] {
faster_levenshtein_distance_kernel<scalar_t><<<batch_size, 1, shared_size, stream>>>(
source.data<scalar_t>(),
target.data<scalar_t>(),
source_length.data<int>(),
target_length.data<int>(),
source.size(1),
target.size(1),
operations.data<int>());
}));
}
return operations;
}
}

View File

@ -0,0 +1,25 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length);
torch::Tensor GenerateDeletionLabelCuda(
torch::Tensor source,
torch::Tensor operations);
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
torch::Tensor source,
torch::Tensor operations);

View File

@ -19,276 +19,15 @@ from fairseq.models.nat import (
FairseqNATDecoder,
ensemble_decoder
)
from fairseq.utils import new_arange
from fairseq.modules.transformer_sentence_encoder import init_bert_params
# -------------- Helper Functions --------------------------------------------------- #
def _skip(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [_skip(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: _skip(v, mask) for k, v in x.items()}
raise NotImplementedError
def _skip_encoder_out(encoder, encoder_out, mask):
if not mask.any():
return encoder_out
else:
return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze())
def _fill(x, mask, y, padding_idx):
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None:
return y
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
n_selected = mask.sum()
assert n_selected == y.size(0)
if n_selected == x.size(0):
return y
if x.size(1) < y.size(1):
dims = [x.size(0), y.size(1) - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
x[mask] = y
elif x.size(1) > y.size(1):
x[mask] = padding_idx
if x.dim() == 2:
x[mask, :y.size(1)] = y
else:
x[mask, :y.size(1), :] = y
else:
x[mask] = y
return x
def load_libnat():
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
return libnat
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
libnat = load_libnat()
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
)
mask_inputs = [
[len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
]
# generate labels
masked_tgt_masks = []
for mask_input in mask_inputs:
mask_label = []
for beam_size in mask_input[1:-1]: # HACK 1:-1
mask_label += [0] + [1 for _ in range(beam_size)]
masked_tgt_masks.append(
mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
)
mask_ins_targets = [
mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
for mask_input in mask_inputs
]
# transform to tensor
masked_tgt_masks = torch.tensor(
masked_tgt_masks, device=out_tokens.device
).bool()
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
def _get_del_targets(in_tokens, out_tokens, padding_idx):
libnat = load_libnat()
out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
)
word_del_targets = [b[-1] for b in full_labels]
word_del_targets = [
labels + [0 for _ in range(out_seq_len - len(labels))]
for labels in word_del_targets
]
# transform to tensor
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
libnat = load_libnat()
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
)
word_del_targets = [b[-1] for b in full_labels]
word_del_targets = [
labels + [0 for _ in range(out_seq_len - len(labels))]
for labels in word_del_targets
]
mask_inputs = [
[len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
]
mask_ins_targets = [
mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
for mask_input in mask_inputs
]
# transform to tensor
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets, mask_ins_targets
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
new_arange(out_lengths, out_max_len)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(
in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx
):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
new_arange(in_tokens)
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
from .levenshtein_utils import (
_skip, _skip_encoder_out, _fill,
_get_ins_targets, _get_del_targets,
_apply_ins_masks, _apply_ins_words, _apply_del_words
)
@register_model("levenshtein_transformer")

View File

@ -0,0 +1,284 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.utils import new_arange
# -------------- Helper Functions --------------------------------------------------- #
def load_libnat():
try:
from fairseq import libnat_cuda
return libnat_cuda, True
except ImportError as e:
print(e + '... fall back to CPU version')
try:
from fairseq import libnat
return libnat, False
except ImportError as e:
import sys
sys.stderr.write("ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n")
raise e
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
libnat, use_cuda = load_libnat()
def _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx):
in_masks = in_tokens.ne(padding_idx)
out_masks = out_tokens.ne(padding_idx)
mask_ins_targets, masked_tgt_masks = libnat.generate_insertion_labels(
out_tokens.int(), libnat.levenshtein_distance(
in_tokens.int(), out_tokens.int(),
in_masks.sum(1).int(), out_masks.sum(1).int()
)
)
masked_tgt_masks = masked_tgt_masks.bool() & out_masks
mask_ins_targets = mask_ins_targets.type_as(
in_tokens)[:, 1:in_masks.size(1)].masked_fill_(~in_masks[:, 1:], 0)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx):
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
)
mask_inputs = [
[len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels
]
# generate labels
masked_tgt_masks = []
for mask_input in mask_inputs:
mask_label = []
for beam_size in mask_input[1:-1]: # HACK 1:-1
mask_label += [0] + [1 for _ in range(beam_size)]
masked_tgt_masks.append(
mask_label + [0 for _ in range(out_seq_len - len(mask_label))]
)
mask_ins_targets = [
mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))]
for mask_input in mask_inputs
]
# transform to tensor
masked_tgt_masks = torch.tensor(
masked_tgt_masks, device=out_tokens.device
).bool()
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
if use_cuda:
return _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx)
return _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx)
def _get_del_targets(in_tokens, out_tokens, padding_idx):
libnat, use_cuda = load_libnat()
def _get_del_targets_cuda(in_tokens, out_tokens, padding_idx):
in_masks = in_tokens.ne(padding_idx)
out_masks = out_tokens.ne(padding_idx)
word_del_targets = libnat.generate_deletion_labels(
in_tokens.int(),
libnat.levenshtein_distance(
in_tokens.int(), out_tokens.int(),
in_masks.sum(1).int(), out_masks.sum(1).int()
)
)
word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_(~in_masks, 0)
return word_del_targets
def _get_del_targets_cpu(in_tokens, out_tokens, padding_idx):
out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
)
word_del_targets = [b[-1] for b in full_labels]
word_del_targets = [
labels + [0 for _ in range(out_seq_len - len(labels))]
for labels in word_del_targets
]
# transform to tensor
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets
if use_cuda:
return _get_del_targets_cuda(in_tokens, out_tokens, padding_idx)
return _get_del_targets_cpu(in_tokens, out_tokens, padding_idx)
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
new_arange(out_lengths, out_max_len)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(
in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx
):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
new_arange(in_tokens)
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
def _skip(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [_skip(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: _skip(v, mask) for k, v in x.items()}
raise NotImplementedError
def _skip_encoder_out(encoder, encoder_out, mask):
if not mask.any():
return encoder_out
else:
return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze())
def _fill(x, mask, y, padding_idx):
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None:
return y
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
n_selected = mask.sum()
assert n_selected == y.size(0)
if n_selected == x.size(0):
return y
if x.size(1) < y.size(1):
dims = [x.size(0), y.size(1) - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
x[mask] = y
elif x.size(1) > y.size(1):
x[mask] = padding_idx
if x.dim() == 2:
x[mask, :y.size(1)] = y
else:
x[mask, :y.size(1), :] = y
else:
x[mask] = y
return x

View File

@ -76,9 +76,20 @@ try:
sources=[
'fairseq/clib/libnat/edit_dist.cpp',
],
),
)
])
if 'CUDA_HOME' in os.environ:
extensions.extend([
cpp_extension.CppExtension(
'fairseq.libnat_cuda',
sources=[
'fairseq/clib/libnat_cuda/edit_dist.cu',
'fairseq/clib/libnat_cuda/binding.cpp'
],
)])
cmdclass['build_ext'] = cpp_extension.BuildExtension
except ImportError:
pass