Alignment train optimization (#2200)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2200

The expected alignment for p-choose is the performance bottleneck that needs to be optimized. The solution is to implement a custom operator to reduce the kernel launch overhead, and optimize the implementations of some operations.

Some key optimizations:

* Use a contiguous alpha array to avoid array concatenation. The original version create an array for each slice of alpha and concat them in the end.
* Implement cumprod using prod operation directly. It used log-cumsum-exp operations before.
* Implement cumprod using cuda CUB library which is more efficient than scan operation in pytorch.

Reviewed By: cndn

Differential Revision: D30033767

fbshipit-source-id: 853c1c2d366838d6bcfa0863999f217a394e46a7
This commit is contained in:
Rengan Xu 2021-10-06 16:48:08 -07:00 committed by Facebook GitHub Bot
parent dd3bd3c049
commit ecea95c063
8 changed files with 695 additions and 26 deletions

View File

@ -0,0 +1,166 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h> // @manual=//caffe2:torch_extension
#include <algorithm>
namespace {
template <typename T>
void exclusiveCumprod(
const T* p_choose,
T* cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
// cumprod_1mp = 1 - p_choose
for (uint32_t b = 0; b < bsz; b++) {
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
for (uint32_t src = 0; src < src_len; src++) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
cumprod_1mp[idx] = 1 - p_choose[idx];
}
}
}
// Implementing exclusive cumprod in the innermost dimension
// cumprod_1mp = cumprod(1 - p_choose)
// There is cumprod in pytorch, however there is no exclusive mode.
// cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
// exclusive means
// cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
for (uint32_t b = 0; b < bsz; b++) {
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
uint32_t idx_offset = b * tgt_len * src_len + tgt * src_len;
T prev = cumprod_1mp[idx_offset];
// index [b][tgt][0]
cumprod_1mp[idx_offset] = (T)1.0;
T curr;
for (uint32_t src = 1; src < src_len; src++) {
uint32_t idx = idx_offset + src;
curr = cumprod_1mp[idx];
cumprod_1mp[idx] = cumprod_1mp[idx - 1] * prev;
prev = curr;
}
}
}
}
template <typename T>
void clamp(
const T* cumprod_1mp,
T* cumprod_1mp_clamp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
T min_val,
T max_val) {
for (uint32_t b = 0; b < bsz; b++) {
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
for (uint32_t src = 0; src < src_len; src++) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
if (cumprod_1mp[idx] < min_val) {
cumprod_1mp_clamp[idx] = min_val;
} else if (cumprod_1mp[idx] > max_val) {
cumprod_1mp_clamp[idx] = max_val;
} else {
cumprod_1mp_clamp[idx] = cumprod_1mp[idx];
}
}
}
}
}
template <typename T>
void alignmentTrainCPUImpl(
const T* p_choose,
T* alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
float eps) {
// p_choose: bsz , tgt_len, src_len
// cumprod_1mp: bsz , tgt_len, src_len
// cumprod_1mp_clamp : bsz, tgt_len, src_len
// alpha: bsz + 1, tgt_len, src_len
uint32_t elements = bsz * tgt_len * src_len;
T* cumprod_1mp = new T[elements];
T* cumprod_1mp_clamp = new T[elements];
exclusiveCumprod<T>(p_choose, cumprod_1mp, bsz, tgt_len, src_len);
clamp<T>(
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0);
// ai = p_i * cumprod(1 pi) * cumsum(a_i / cumprod(1 pi))
// Initialize alpha [:, 0, 0]
for (uint32_t b = 0; b < bsz; b++) {
alpha[b * tgt_len * src_len] = 1.0;
}
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
for (uint32_t b = 0; b < bsz; b++) {
uint32_t alpha_idx, inout_idx;
T prev_scan = 0, curr_scan, out;
for (uint32_t src = 0; src < src_len; src++) {
// Apply scan/cumsum
if (tgt == 0) {
// alpha index is [b][tgt][src]
alpha_idx = b * tgt_len * src_len + src;
} else {
// alpha index is [b][tgt-1][src]
alpha_idx = b * tgt_len * src_len + (tgt - 1) * src_len + src;
}
// input index is [b][tgt][src]
inout_idx = b * tgt_len * src_len + tgt * src_len + src;
curr_scan = prev_scan + alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx];
out = curr_scan * p_choose[inout_idx] * cumprod_1mp[inout_idx];
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), 1.0);
prev_scan = curr_scan;
}
}
}
free(cumprod_1mp);
free(cumprod_1mp_clamp);
}
void alignmentTrainCPU(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps) {
uint32_t bsz = p_choose.size(0);
uint32_t tgt_len = p_choose.size(1);
uint32_t src_len = p_choose.size(2);
AT_DISPATCH_FLOATING_TYPES_AND2(
torch::ScalarType::Half,
torch::ScalarType::BFloat16,
p_choose.scalar_type(),
"alignmentCPUImpl",
[&]() {
alignmentTrainCPUImpl<scalar_t>(
p_choose.data_ptr<scalar_t>(),
alpha.data_ptr<scalar_t>(),
bsz,
tgt_len,
src_len,
eps);
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"alignment_train_cpu",
&alignmentTrainCPU,
"expected_alignment_from_p_choose (CPU)");
}
} // namespace

View File

@ -0,0 +1,31 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "alignment_train_cuda.h"
#include "utils.h"
namespace {
void alignmentTrainCUDA(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps) {
CHECK_INPUT(p_choose);
CHECK_INPUT(alpha);
alignmentTrainCUDAWrapper(p_choose, alpha, eps);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"alignment_train_cuda",
&alignmentTrainCUDA,
"expected_alignment_from_p_choose (CUDA)");
}
} // namespace

View File

@ -0,0 +1,16 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h> // @manual=//caffe2:torch_extension
void alignmentTrainCUDAWrapper(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps);

View File

@ -0,0 +1,354 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> // @manual=//caffe2/aten:ATen-cu
#include <cuda_runtime.h>
#include <algorithm> // std::min/max
#include <cub/cub.cuh>
#include "alignment_train_cuda.h"
#include "utils.h"
namespace {
// The thread block length in threads along the X dimension
constexpr int BLOCK_DIM_X = 128;
// The thread block length in threads along the Y dimension
constexpr int BLOCK_DIM_Y = 8;
// The thread block length in threads for scan operation
constexpr int SCAN_BLOCK = 512;
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void
gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) {
if (code != cudaSuccess) {
fprintf(
stderr,
"\nGPUassert: %s %s %d\n",
cudaGetErrorString(code),
file,
line);
if (abort)
exit(code);
}
}
template <typename T>
struct Prod {
/// prod operator, returns <tt>a * b</tt>
__host__ __device__ __forceinline__ T
operator()(const T& a, const T& b) const {
return a * b;
}
};
template <typename T>
struct BlockPrefixProdCallbackOp {
// Running prefix
T running_total;
// Constructor
__device__ BlockPrefixProdCallbackOp(T running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ T operator()(const T block_aggregate) {
T old_prefix = running_total;
running_total *= block_aggregate;
return old_prefix;
}
};
template <typename T>
struct BlockPrefixSumCallbackOp {
// Running prefix
T running_total;
// Constructor
__device__ BlockPrefixSumCallbackOp(T running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ T operator()(const T block_aggregate) {
T old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template <typename T>
__global__ void oneMinusPKernel(
const T* __restrict__ p_choose,
T* __restrict__ cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) {
for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
cumprod_1mp[idx] = 1 - p_choose[idx];
}
}
}
}
template <typename T, int TPB>
__global__ void innermostScanKernel(
T* __restrict__ cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
for (uint32_t b = blockIdx.y; b < bsz; b += gridDim.y) {
for (uint32_t tgt = blockIdx.x; tgt < tgt_len; tgt += gridDim.x) {
// Specialize BlockScan for a 1D block of TPB threads on type T
typedef cub::BlockScan<T, TPB> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage temp_storage;
// Initialize running total
BlockPrefixProdCallbackOp<T> prefix_op(1);
const uint32_t tid = threadIdx.x;
for (uint32_t block_src = 0; block_src < src_len;
block_src += blockDim.x) {
uint32_t src = block_src + tid;
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
T thread_data = (src < src_len) ? cumprod_1mp[idx] : (T)0;
// Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage)
.ExclusiveScan(thread_data, thread_data, Prod<T>(), prefix_op);
__syncthreads();
// write the scanned value to output
if (src < src_len) {
cumprod_1mp[idx] = thread_data;
}
}
}
}
}
template <typename T>
__global__ void clampKernel(
const T* __restrict__ cumprod_1mp,
T* __restrict__ cumprod_1mp_clamp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
T min_val,
T max_val) {
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) {
for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
if (cumprod_1mp[idx] < min_val) {
cumprod_1mp_clamp[idx] = min_val;
} else if (cumprod_1mp[idx] > max_val) {
cumprod_1mp_clamp[idx] = max_val;
} else {
cumprod_1mp_clamp[idx] = cumprod_1mp[idx];
}
}
}
}
}
template <typename T>
__global__ void initAlphaCUDAKernel(
T* alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
// alpha[:, 0, 0] = 1.0
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
alpha[b * tgt_len * src_len] = (T)1.0;
}
}
template <typename T, int TPB>
__global__ void alignmentTrainCUDAKernel(
const T* __restrict__ p_choose,
const T* __restrict__ cumprod_1mp,
const T* __restrict__ cumprod_1mp_clamp,
T* __restrict__ alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
uint32_t tgt) {
for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) {
// Specialize BlockScan for a 1D block of TPB threads on type T
typedef cub::BlockScan<T, TPB> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage temp_storage;
// Initialize running total
BlockPrefixSumCallbackOp<T> prefix_op(0);
uint32_t b_offset = b * tgt_len * src_len;
const uint32_t tid = threadIdx.x;
for (uint32_t block_src = 0; block_src < src_len; block_src += blockDim.x) {
uint32_t src = block_src + tid;
// Obtain a segment of consecutive items that are blocked across threads
uint32_t inout_idx, alpha_idx;
if (tgt == 0) {
// both alpha and other input index is [b][0][src]
alpha_idx = b_offset + src;
} else {
// alpha index is [b][tgt-1][src]
alpha_idx = b_offset + (tgt - 1) * src_len + src;
}
inout_idx = b_offset + tgt * src_len + src;
T thread_data = (T)0;
if (src < src_len) {
thread_data = alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx];
}
// Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op);
__syncthreads();
if (src < src_len) {
T out = thread_data * p_choose[inout_idx] * cumprod_1mp[inout_idx];
// Clamps all elements into the range [ 0, 1.0 ]
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), (T)1.0);
}
}
}
}
template <typename T>
void exclusiveCumprod(
const T* p_choose,
T* cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
uint32_t max_grid_x,
uint32_t max_grid_y,
cudaStream_t& stream) {
// cumprod_1mp = 1 - p_choose
dim3 grid(std::min<T>(max_grid_x, bsz), 1, 1);
dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y, 1);
oneMinusPKernel<T><<<grid, block, 0, stream>>>(
p_choose, cumprod_1mp, bsz, tgt_len, src_len);
gpuErrchk(cudaGetLastError());
// scan on the innermost dimension of cumprod_1mp
// cumprod_1mp = cumprod(cumprod_1mp)
dim3 grid_scan(
std::min<T>(max_grid_x, tgt_len), std::min<T>(max_grid_y, bsz), 1);
innermostScanKernel<T, SCAN_BLOCK><<<grid_scan, SCAN_BLOCK, 0, stream>>>(
cumprod_1mp, bsz, tgt_len, src_len);
gpuErrchk(cudaGetLastError());
}
template <typename T>
void alignmentTrainCUDAImpl(
const T* p_choose,
T* alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
float eps) {
// p_choose: bsz , tgt_len, src_len
// cumprod_1mp: bsz , tgt_len, src_len
// cumprod_1mp_clamp : bsz, tgt_len, src_len
// alpha: bsz, tgt_len, src_len
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
uint32_t max_grid_x = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
uint32_t max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
// Implementing exclusive cumprod.
// cumprod_1mp = cumprod(1 - p_choose)
// There is cumprod in pytorch, however there is no exclusive mode.
// cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
// exclusive means
// cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
uint32_t elements = bsz * tgt_len * src_len;
T* cumprod_1mp;
gpuErrchk(cudaMalloc(&cumprod_1mp, elements * sizeof(T)));
exclusiveCumprod<T>(
p_choose,
cumprod_1mp,
bsz,
tgt_len,
src_len,
max_grid_x,
max_grid_y,
stream);
// clamp cumprod_1mp to the range [eps, 1.0]
T* cumprod_1mp_clamp;
gpuErrchk(cudaMalloc(&cumprod_1mp_clamp, elements * sizeof(T)));
dim3 grid_clamp(std::min<T>(max_grid_x, bsz), 1, 1);
dim3 block_clamp(BLOCK_DIM_X, BLOCK_DIM_Y, 1);
clampKernel<T><<<grid_clamp, block_clamp, 0, stream>>>(
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0);
gpuErrchk(cudaGetLastError());
// ai = p_i * cumprod(1 pi) * cumsum(a_i / cumprod(1 pi))
dim3 grid_init(std::min<int>(max_grid_x, bsz), 1, 1);
initAlphaCUDAKernel<T>
<<<grid_init, 1, 0, stream>>>(alpha, bsz, tgt_len, src_len);
gpuErrchk(cudaGetLastError());
const int grid = std::min(bsz, max_grid_x);
for (uint32_t i = 0; i < tgt_len; i++) {
alignmentTrainCUDAKernel<T, SCAN_BLOCK><<<grid, SCAN_BLOCK, 0, stream>>>(
p_choose,
cumprod_1mp,
cumprod_1mp_clamp,
alpha,
bsz,
tgt_len,
src_len,
i);
gpuErrchk(cudaGetLastError());
}
gpuErrchk(cudaFree(cumprod_1mp));
gpuErrchk(cudaFree(cumprod_1mp_clamp));
}
} // namespace
void alignmentTrainCUDAWrapper(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps) {
// p_choose dimension: bsz, tgt_len, src_len
uint32_t bsz = p_choose.size(0);
uint32_t tgt_len = p_choose.size(1);
uint32_t src_len = p_choose.size(2);
cudaSetDevice(p_choose.get_device());
AT_DISPATCH_FLOATING_TYPES_AND2(
torch::ScalarType::Half,
torch::ScalarType::BFloat16,
p_choose.scalar_type(),
"alignmentTrainCUDAImpl",
[&]() {
alignmentTrainCUDAImpl<scalar_t>(
p_choose.data_ptr<scalar_t>(),
alpha.data_ptr<scalar_t>(),
bsz,
tgt_len,
src_len,
eps);
});
}

View File

@ -0,0 +1,19 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h> // @manual=//caffe2:torch_extension
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

View File

@ -0,0 +1,88 @@
import unittest
import numpy as np
import torch
import hypothesis.strategies as st
from hypothesis import assume, given, settings
from torch.testing._internal.common_utils import TestCase
from examples.simultaneous_translation.utils.functions import exclusive_cumprod
TEST_CUDA = torch.cuda.is_available()
class AlignmentTrainTest(TestCase):
def _test_custom_alignment_train_ref(self, p_choose, eps):
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)
bsz = p_choose.size(0)
tgt_len = p_choose.size(1)
src_len = p_choose.size(2)
alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
alpha_0[:, :, 0] = 1.0
previous_alpha = [alpha_0]
for i in range(tgt_len):
# p_choose: bsz , tgt_len, src_len
# cumprod_1mp_clamp : bsz, tgt_len, src_len
# previous_alpha[i]: bsz, 1, src_len
# alpha_i: bsz, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1
)
).clamp(0, 1.0)
previous_alpha.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_alpha[1:], dim=1)
return alpha
def _test_custom_alignment_train_impl(self, p_choose, alpha, eps):
if p_choose.is_cuda:
from alignment_train_cuda_binding import alignment_train_cuda # @manual=//deeplearning/projects/fairseq-py:alignment_train_cuda_binding
alignment_train_cuda(p_choose, alpha, eps)
else:
from alignment_train_cpu_binding import alignment_train_cpu # @manual=//deeplearning/projects/fairseq-py:alignment_train_cpu_binding
alignment_train_cpu(p_choose, alpha, eps)
@settings(deadline=None)
@given(
bsz=st.integers(1, 100),
tgt_len=st.integers(1, 100),
src_len=st.integers(1, 550),
device=st.sampled_from(["cpu", "cuda"]),
)
def test_alignment_train(self, bsz, tgt_len, src_len, device):
eps = 1e-6
assume(device == "cpu" or TEST_CUDA)
p_choose = torch.rand(bsz, tgt_len, src_len, device=device)
# run the alignment with the custom operator
alpha_act = p_choose.new_zeros([bsz, tgt_len, src_len])
self._test_custom_alignment_train_impl(p_choose, alpha_act, eps)
# runu the alignment with the ref implementation
alpha_ref = self._test_custom_alignment_train_ref(p_choose, eps)
# verify the results
alpha_act = alpha_act.cpu().detach().numpy()
alpha_ref = alpha_ref.cpu().detach().numpy()
np.testing.assert_allclose(
alpha_act,
alpha_ref,
atol=1e-3,
rtol=1e-3,
)
if __name__ == "__main__":
unittest.main()

View File

@ -42,32 +42,14 @@ def expected_alignment_from_p_choose(
if padding_mask is not None:
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0)
# cumprod_1mp : bsz, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0)
if p_choose.is_cuda:
p_choose = p_choose.contiguous()
from alignment_train_cuda_binding import alignment_train_cuda as alignment_train
else:
from alignment_train_cpu_binding import alignment_train_cpu as alignment_train
alpha_0 = p_choose.new_zeros([bsz, 1, src_len])
alpha_0[:, :, 0] = 1.0
previous_alpha = [alpha_0]
for i in range(tgt_len):
# p_choose: bsz , tgt_len, src_len
# cumprod_1mp_clamp : bsz, tgt_len, src_len
# previous_alpha[i]: bsz, 1, src_len
# alpha_i: bsz, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1
)
).clamp(0, 1.0)
previous_alpha.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_alpha[1:], dim=1)
alpha = p_choose.new_zeros([bsz, tgt_len, src_len])
alignment_train(p_choose, alpha, eps)
# Mix precision to prevent overflow for fp16
alpha = alpha.type(dtype)

View File

@ -117,7 +117,13 @@ try:
sources=[
"fairseq/clib/libnat/edit_dist.cpp",
],
)
),
cpp_extension.CppExtension(
"alignment_train_cpu_binding",
sources=[
"examples/operators/alignment_train_cpu.cpp",
],
),
]
)
if "CUDA_HOME" in os.environ:
@ -137,6 +143,13 @@ try:
"fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu",
],
),
cpp_extension.CppExtension(
"alignment_train_cuda_binding",
sources=[
"examples/operators/alignment_train_kernel.cu",
"examples/operators/alignment_train_cuda.cpp",
],
),
]
)
cmdclass["build_ext"] = cpp_extension.BuildExtension