From f840564da943f3f95bf80e93fd49be6f93d98348 Mon Sep 17 00:00:00 2001 From: Nathan Ng Date: Wed, 14 Aug 2019 10:45:52 -0700 Subject: [PATCH] initial light and dynamic convolution kernels (#547) Summary: CUDA code for light/dynamicconv kernels, including pytorch modules. Modules can be built by running setup.py in each respective folder, and can then be imported and used like any other module. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/547 Reviewed By: myleott, shubho Differential Revision: D15703660 Pulled By: nng555 fbshipit-source-id: e9c913753be3a1cd571965f7200df6678b644520 --- .gitignore | 2 + examples/pay_less_attention_paper/README.md | 16 +- fairseq/models/lightconv.py | 38 +- fairseq/modules/__init__.py | 10 +- fairseq/modules/cuda_utils.cu | 202 ++++++++++ fairseq/modules/dynamic_convolution.py | 21 +- fairseq/modules/dynamicconv_layer/__init__.py | 8 + .../dynamicconv_layer/cuda_function_gen.py | 223 +++++++++++ .../dynamicconv_layer/dynamicconv_cuda.cpp | 49 +++ .../dynamicconv_layer/dynamicconv_cuda.cuh | 49 +++ .../dynamicconv_cuda_kernel.cu | 167 ++++++++ .../dynamicconv_layer/dynamicconv_layer.py | 205 ++++++++++ .../dynamicconv_layer/dynamiconv_cpu.cpp | 35 ++ fairseq/modules/dynamicconv_layer/setup.py | 17 + fairseq/modules/lightconv_layer/__init__.py | 8 + .../lightconv_layer/cuda_function_gen.py | 289 ++++++++++++++ .../lightconv_layer/lightconv_cuda.cpp | 47 +++ .../lightconv_layer/lightconv_cuda.cuh | 82 ++++ .../lightconv_layer/lightconv_cuda_kernel.cu | 374 ++++++++++++++++++ .../lightconv_layer/lightconv_layer.py | 113 ++++++ fairseq/modules/lightconv_layer/setup.py | 14 + fairseq/modules/lightweight_convolution.py | 15 + fairseq/modules/unfold.py | 1 - 23 files changed, 1958 insertions(+), 27 deletions(-) create mode 100644 fairseq/modules/cuda_utils.cu create mode 100644 fairseq/modules/dynamicconv_layer/__init__.py create mode 100644 fairseq/modules/dynamicconv_layer/cuda_function_gen.py create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_layer.py create mode 100644 fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp create mode 100644 fairseq/modules/dynamicconv_layer/setup.py create mode 100644 fairseq/modules/lightconv_layer/__init__.py create mode 100644 fairseq/modules/lightconv_layer/cuda_function_gen.py create mode 100644 fairseq/modules/lightconv_layer/lightconv_cuda.cpp create mode 100644 fairseq/modules/lightconv_layer/lightconv_cuda.cuh create mode 100644 fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu create mode 100644 fairseq/modules/lightconv_layer/lightconv_layer.py create mode 100644 fairseq/modules/lightconv_layer/setup.py diff --git a/.gitignore b/.gitignore index ec0572d4d..7e4a2d412 100644 --- a/.gitignore +++ b/.gitignore @@ -111,6 +111,8 @@ ENV/ # Generated files fairseq/temporal_convolution_tbc +fairseq/modules/*_layer/*_forward.cu +fairseq/modules/*_layer/*_backward.cu # data data-bin/ diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md index bd66d705b..97ab847fc 100644 --- a/examples/pay_less_attention_paper/README.md +++ b/examples/pay_less_attention_paper/README.md @@ -1,5 +1,5 @@ # Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019) -This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://openreview.net/pdf?id=SkVhlh09tX) +This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://arxiv.org/abs/1901.10430) ## Citation: ```bibtex @@ -8,7 +8,7 @@ This page contains pointers to pre-trained models as well as instructions on how author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli}, booktitle = {International Conference on Learning Representations}, year = {2019}, - url = {https://openreview.net/forum?id=SkVhlh09tX}, + url = {https://arxiv.org/abs/1901.10430}, } ``` @@ -39,6 +39,18 @@ To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`. For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv. For best BLEU results, lenpen may need to be manually tuned. +To use the CUDA kernels, first install the PyTorch modules using the commands below +```sh +# to install lightconv +python fairseq/modules/lightconv_layer/cuda_function_gen.py +python fairseq/modules/lightconv_layer/setup.py install + +# to install dynamicconv +python fairseq/modules/dynamicconv_layer/cuda_function_gen.py +python fairseq/modules/dynamicconv_layer/setup.py install +``` +Once the CUDA modules are installed, they will automatically be used instead of the PyTorch modules. + ### IWSLT14 De-En Training and evaluating DynamicConv (without GLU) on a GPU: ```sh diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py index 20ff4f0e6..44d52dcd8 100644 --- a/fairseq/models/lightconv.py +++ b/fairseq/models/lightconv.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import math +import sys import torch import torch.nn as nn @@ -19,10 +20,10 @@ from fairseq.models import ( ) from fairseq.modules import ( AdaptiveSoftmax, - DynamicConv1dTBC, + DynamicConv, LayerNorm, PositionalEmbedding, - LightweightConv1dTBC, + LightweightConv, MultiheadAttention, ) @@ -173,7 +174,6 @@ class LightConvModel(FairseqEncoderDecoderModel): decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens) return LightConvModel(encoder, decoder) - class LightConvEncoder(FairseqEncoder): """ LightConv encoder consisting of *args.encoder_layers* layers. Each layer @@ -447,15 +447,15 @@ class LightConvEncoderLayer(nn.Module): self.linear1 = Linear(self.embed_dim, self.conv_dim) self.act = None if args.encoder_conv_type == 'lightweight': - self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l, - weight_softmax=args.weight_softmax, - num_heads=args.encoder_attention_heads, - weight_dropout=args.weight_dropout) + self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout) elif args.encoder_conv_type == 'dynamic': - self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=padding_l, - weight_softmax=args.weight_softmax, - num_heads=args.encoder_attention_heads, - weight_dropout=args.weight_dropout) + self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout) else: raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) @@ -535,15 +535,15 @@ class LightConvDecoderLayer(nn.Module): self.linear1 = Linear(self.embed_dim, self.conv_dim) self.act = None if args.decoder_conv_type == 'lightweight': - self.conv = LightweightConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1, - weight_softmax=args.weight_softmax, - num_heads=args.decoder_attention_heads, - weight_dropout=args.weight_dropout) + self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=kernel_size-1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout) elif args.decoder_conv_type == 'dynamic': - self.conv = DynamicConv1dTBC(self.conv_dim, kernel_size, padding_l=kernel_size-1, - weight_softmax=args.weight_softmax, - num_heads=args.decoder_attention_heads, - weight_dropout=args.weight_dropout) + self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=kernel_size-1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout) else: raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 6458f7d02..ecfdc3d69 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -9,13 +9,15 @@ from .beamable_mm import BeamableMM from .character_token_embedder import CharacterTokenEmbedder from .conv_tbc import ConvTBC from .downsampled_multihead_attention import DownsampledMultiHeadAttention -from .dynamic_convolution import DynamicConv1dTBC +from .dynamic_convolution import DynamicConv, DynamicConv1dTBC +#from .dynamicconv_layer import DynamicconvLayer from .gelu import gelu, gelu_accurate from .grad_multiply import GradMultiply from .highway import Highway from .layer_norm import LayerNorm from .learned_positional_embedding import LearnedPositionalEmbedding -from .lightweight_convolution import LightweightConv1dTBC +from .lightweight_convolution import LightweightConv, LightweightConv1dTBC +#from .lightconv_layer import LightconvLayer from .linearized_convolution import LinearizedConvolution from .logsumexp_moe import LogSumExpMoE from .mean_pool_gating_network import MeanPoolGatingNetwork @@ -36,14 +38,18 @@ __all__ = [ 'CharacterTokenEmbedder', 'ConvTBC', 'DownsampledMultiHeadAttention', +# 'DyamicconvLayer', 'DynamicConv1dTBC', + 'DynamicConv', 'gelu', 'gelu_accurate', 'GradMultiply', 'Highway', 'LayerNorm', 'LearnedPositionalEmbedding', +# 'LightconvLayer', 'LightweightConv1dTBC', + 'LightweightConv', 'LinearizedConvolution', 'LogSumExpMoE', 'MeanPoolGatingNetwork', diff --git a/fairseq/modules/cuda_utils.cu b/fairseq/modules/cuda_utils.cu new file mode 100644 index 000000000..596ff125f --- /dev/null +++ b/fairseq/modules/cuda_utils.cu @@ -0,0 +1,202 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + return (a + b - 1) / b; +} + + +template +__inline__ __device__ +void zeroSharedMem(scalar_t* data) { + /* + Given an array of length FS + SB, zero out the first padding_l and last + (FS - padding_l) values in the array + */ + + int tid = threadIdx.x; + + if (FS < SB) { + + // zero all if we have enough threads in a block to do all of them + if (tid < padding_l || tid > SB - FS + padding_l - 1) { + data[tid] = scalar_t(0.0); + } + } else { + + // otherwise zero out one block at a time + const int numIterations = divUp(FS, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if (tid + offset < padding_l) { + data[tid + offset] = scalar_t(0.0); + } else if (tid + offset < FS) { + data[SB + tid + offset] = scalar_t(0.0); + } + } + } +} + +template +__inline__ __device__ +scalar_t warpReduce(scalar_t data) { + /* + Reduce an array within each warp. After processing all values in warp will + caontain the sum of all original values in that warp. + + data - pointer to data to reduce + */ + data += __shfl_xor_sync(SHFL_MASK, data, 16); + data += __shfl_xor_sync(SHFL_MASK, data, 8); + data += __shfl_xor_sync(SHFL_MASK, data, 4); + data += __shfl_xor_sync(SHFL_MASK, data, 2); + data += __shfl_xor_sync(SHFL_MASK, data, 1); + return data; +} + +template +__inline__ __device__ +scalar_t blockReduce(scalar_t data) { + /* + Reduce an entire array on the block level. After processing, the + first value in the array will contain the reduced sum. + + data - pointer to data to reduce + */ + + static __shared__ scalar_t warpSum[32]; + const int tid = threadIdx.x; + int wid = tid / 32; + int lane = tid % 32; + + __syncthreads(); + + // reduce each warp then write to shared memory + scalar_t sum = warpReduce(data); + if (lane == 0) { + warpSum[wid] = sum; + } + + __syncthreads(); + + scalar_t v; + // perform final sum of partial warp sums + if (tid < blockDim.x / 32) { + v = warpSum[lane]; + } else { + v = scalar_t(0.0); + } + + if (wid == 0) { + v = warpReduce(v); + } + __syncthreads(); + + return v; +} + +void checkCudaStatus(cudaError_t status, int lineNumber = -1) { + + if (status != cudaSuccess) { + std::cout << cudaGetErrorString(status) + << " at line " << lineNumber << std::endl; + std::cout << "Exiting" << std::endl; + exit(1); + } +} + +template +__device__ +void load_input_to_shared(const scalar_t* input, // global memory + int inputOffset, int sequenceLength, + int iteration, int numIterations, + bool no_prev, scalar_t* output /* shared memory */) { + /* + Load a block size of input into shared memory with + right and left overhang of total size FS. If previously + loaded memory, overlap will be shifted over to reduce + global memory access + + input - pointer to start of channel sequence + inputOffset - how far in the sequence to start loading + sequenceLength - total length of sequence + iteration - which block of sequence we are loading + numIterations - total number of blocks to load + no_prev - whether to load the whole block if the previous block + wasn't loaded + output - shared memory to write input to + */ + + const int tid = threadIdx.x; + + // Load the left "overhang" of input + if (iteration > 0) { + if (padding_l < SB) { + + // load all at once + if (tid < padding_l) { + output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB]; + } + } else { + + // load in chunks of size SB + int numIterations = divUp(padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < padding_l) { + output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB]; + } + } + } + } + + // Load the right "overhang" of input + if (iteration < (numIterations - 1)) { + const int elementsLeft = sequenceLength - (iteration+1) * SB; + + if ((FS - padding_l) < SB) { + + // load all at once + if (tid < (FS - padding_l)) { + output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0); + } + } else { + + // load in chunks of size SB + int numIterations = divUp(FS - padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < (FS - padding_l)) { + output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0); + } + } + } + } + + // We should also clear out the right "overhang" + if (iteration == (numIterations - 1)) { + if ((FS - padding_l) < SB) { + + // clear out all at once + if (tid < (FS - padding_l)) { + output[padding_l + SB + tid] = scalar_t(0.0); + } + } else { + + // clear in chunks of size SB + int numIterations = divUp(FS - padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < (FS - padding_l)) { + output[padding_l + SB + tid + offset] = scalar_t(0.0); + } + } + } + } + output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0); +} diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py index a8fa47225..7fbd3f37e 100644 --- a/fairseq/modules/dynamic_convolution.py +++ b/fairseq/modules/dynamic_convolution.py @@ -10,6 +10,23 @@ import torch.nn.functional as F from fairseq import utils from .unfold import unfold1d +def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, + weight_dropout=0., weight_softmax=False, + renorm_padding=False, bias=False, conv_bias=False, + query_size=None, in_proj=False): + if torch.cuda.is_available(): + try: + from fairseq.modules.dynamicconv_layer import DynamicconvLayer + return DynamicconvLayer(input_size, kernel_size=kernel_size, + padding_l=padding_l, num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, bias=bias) + except ImportError as e: + print(e) + return DynamicConv1dTBC(input_size, kernel_size=kernel_size, + padding_l=padding_l, num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, bias=bias) def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) @@ -90,7 +107,6 @@ class DynamicConv1dTBC(nn.Module): if query is None: query = x - if unfold: output = self._forward_unfolded(x, incremental_state, query) else: @@ -193,8 +209,7 @@ class DynamicConv1dTBC(nn.Module): # turn the convolution filters into band matrices weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) - weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T - + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T output = torch.bmm(weight_expanded, x) output = output.transpose(0, 1).contiguous().view(T, B, C) return output diff --git a/fairseq/modules/dynamicconv_layer/__init__.py b/fairseq/modules/dynamicconv_layer/__init__.py new file mode 100644 index 000000000..c62ffac86 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from .dynamicconv_layer import DynamicconvLayer diff --git a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py new file mode 100644 index 000000000..caf151e4a --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py @@ -0,0 +1,223 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +def gen_forward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + blocks = [32, 64, 128, 256] + + head = """ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include "dynamicconv_cuda.cuh" + +std::vector dynamicconv_cuda_forward(at::Tensor input, at::Tensor weight, int padding_l) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = weight.size(1); + const auto filterSize = weight.size(2); + + const auto numFiltersInBlock = numFeatures / numHeads; + const dim3 blocks(minibatch, numFeatures); + + auto output = at::zeros_like(input); + auto stream = at::cuda::getCurrentCUDAStream(); +""" + + switch = """ + switch(filterSize) { +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {pad}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "dynamicconv_forward", ([&] {{ + dynamicconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> + <<>>( + input.data(), + weight.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + output.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; + } + break;\n +""" + + end = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; + } + + return {output}; +} +""" + + with open("dynamicconv_cuda_forward.cu", 'w') as forward: + forward.write(head) + forward.write(switch) + for k in kernels: + b_size = 32 + for b in blocks: + if b > k: + b_size = b + break + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=b_size, pad=pad)) + forward.write(bad_padding) + forward.write(end) + + +def gen_backward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + thresh = [512, 512, 512, 512, 512, 380, 256, 256] + min_block = [64, 64, 64, 64, 64, 64, 128, 256] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + + head = """ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include "dynamicconv_cuda.cuh" + +std::vector dynamicconv_cuda_backward(at::Tensor gradOutput, int padding_l, at::Tensor input, at::Tensor weight) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = weight.size(1); + const auto filterSize = weight.size(2); + + const auto numFiltersInBlock = numFeatures / numHeads; + auto numChunks = 1; + + auto gradInput = at::zeros_like(input); + auto gradWeight = at::zeros_like(weight); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(minibatch, numHeads, numChunks); +""" + + sequence_if = """ + if (sequenceLength < {seq}) {{ + switch(filterSize) {{ +""" + + case_k = """ + case {k}: +""" + + chunks_reset = """ + numChunks = int(ceilf(sequenceLength/float({b_size}))); + blocks = dim3(minibatch, numHeads, numChunks); +""" + + main_block = """ + if (padding_l == {p}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(), "dynamicconv_backward", ([&] {{ + dynamicconv_backward_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + gradOutput.data(), + input.data(), + weight.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + gradWeight.data(), + gradInput.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; + } + break;\n +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; + } +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + last_return = """ + } + return {gradInput, gradWeight}; +} +""" + + with open("dynamicconv_cuda_backward.cu", 'w') as backward: + backward.write(head) + for seq in seqs: + backward.write(sequence_if.format(seq=seq)) + for k, t, m in zip(kernels, thresh, min_block): + backward.write(case_k.format(k=k)) + if seq <= t: + b_size = seq + else: + b_size = m + backward.write(chunks_reset.format(b_size=b_size)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=b_size, p=p)) + backward.write(bad_padding) + backward.write(bad_filter) + backward.write(con_else) + backward.write(final_else) + for k, m in zip(kernels, min_block): + backward.write(case_k.format(k=k)) + backward.write(chunks_reset.format(b_size=m)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=m, p=p)) + backward.write(bad_padding) + backward.write(bad_filter) + backward.write(last_return) + + +if __name__ == "__main__": + gen_forward() + gen_backward() diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp new file mode 100644 index 000000000..b76c9e7fe --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp @@ -0,0 +1,49 @@ +#include +#include + +std::vector dynamicconv_cuda_forward( + at::Tensor input, + at::Tensor filters, + int padding_l); + +std::vector dynamicconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector dynamicconv_forward( + at::Tensor input, + at::Tensor filters, + int padding_l) { + + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return dynamicconv_cuda_forward(input, filters, + padding_l); +} + +std::vector dynamicconv_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return dynamicconv_cuda_backward(gradOutput, padding_l, + input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); +} diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh new file mode 100644 index 000000000..5d6ed575f --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh @@ -0,0 +1,49 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define SHFL_MASK 0xffffffff + +template +__global__ +void dynamicconv_forward_kernel(const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output); + +template +__global__ +void dynamicconv_backward_kernel( + const scalar_t* gradOutput, // B * C * T + const scalar_t* input, // B * C * T + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* gradWeight, + scalar_t* gradInput); // B * H * k * T diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu new file mode 100644 index 000000000..f29e6ded0 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu @@ -0,0 +1,167 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include "dynamicconv_cuda.cuh" +#include "dynamicconv_cuda_forward.cu" +#include "dynamicconv_cuda_backward.cu" +#include "../cuda_utils.cu" + +// FS is filter size and kernels are specialized for filter sizes +template +__global__ +void dynamicconv_forward_kernel(const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output) { + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int head = featureIdx / numFiltersInBlock; + + const int IOOffset = batchIdx * numFeatures * sequenceLength + + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + + scalar_t filter[FS]; + + __shared__ scalar_t tempInput[SB + FS]; + zeroSharedMem(tempInput); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + __syncthreads(); + const int inputOffset = i * SB; + load_input_to_shared(inputFeature, inputOffset, + sequenceLength, i, + numIterations, false, tempInput); + __syncthreads(); + if (inputOffset + tid < sequenceLength) { + + #pragma unroll + for (int k = 0; k < FS; ++k) { + const int filterOffset = batchIdx * numHeads * FS * sequenceLength + + head * FS * sequenceLength + + k * sequenceLength + + i * SB + tid; + filter[k] = weight[filterOffset]; + } + + scalar_t out = scalar_t(0.0); + #pragma unroll + for (int k = 0; k < FS; ++k) { + out += filter[k] * tempInput[tid + k]; + } + + outputFeature[inputOffset + tid] = out; + + } + } +} + +template +__global__ +void dynamicconv_backward_kernel( + const scalar_t* gradOutput, // B * C * T + const scalar_t* input, // B * C * T + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* gradWeight, + scalar_t* gradInput) { // B * H * k * T + + assert(blockDim.x == SB); + + // each block operates on a single batch and filter head + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int headIdx = blockIdx.y; + const int chunkIdx = blockIdx.z; + + const int numChunks = divUp(sequenceLength, SB); + const int inputOffset = chunkIdx * SB; + + // initialize shared memory for output gradient and input + __shared__ scalar_t tempGradOutput[SB + FS]; + __shared__ scalar_t tempInput[SB + FS]; + const int padding = FS - padding_l - 1; + + zeroSharedMem(tempGradOutput); + zeroSharedMem(tempInput); + + // initialize local filter and weight gradient sum arrays + scalar_t tempGradSum[FS]; + scalar_t bfilter[FS]; + for (int k = 0; k < FS; ++k) { + tempGradSum[k] = scalar_t(0.0); + + int idxOffset = inputOffset + tid + k - padding; + if (idxOffset >= 0 && idxOffset < sequenceLength) { + int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + + headIdx * FS * sequenceLength + + (FS - k - 1) * sequenceLength + + idxOffset; + bfilter[k] = weight[bfilterOffset]; + } else { + bfilter[k] = scalar_t(0.0); + } + } + + + // iterate over filter block + for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { + __syncthreads(); + + // load input and output gradient for this channel and chunk + const int IOOffset = batchIdx * numFeatures * sequenceLength + + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; + scalar_t* gradInputFeature = &gradInput[IOOffset]; + + load_input_to_shared(gradOutputFeature, inputOffset, + sequenceLength, chunkIdx, + numChunks, true, tempGradOutput); + load_input_to_shared(inputFeature, inputOffset, + sequenceLength, chunkIdx, + numChunks, true, tempInput); + __syncthreads(); + + // sum input and weight gradients + scalar_t out = scalar_t(0.0); + #pragma unroll + for (int k = 0; k < FS; ++k) { + tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; + out += bfilter[k] * tempGradOutput[tid + k]; + } + + if (inputOffset + tid < sequenceLength) { + gradInputFeature[inputOffset + tid] = out; + } + } + + const int gradOffset = batchIdx * numHeads * FS * sequenceLength + + headIdx * FS * sequenceLength; + scalar_t *gradWeightFeature = &gradWeight[gradOffset]; + + // write weight gradient + if (inputOffset + tid < sequenceLength) { + for (int k = 0; k < FS; ++k) { + const int outputOffset = k * sequenceLength + inputOffset + tid; + gradWeightFeature[outputOffset] = tempGradSum[k]; + } + } +} diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py new file mode 100644 index 000000000..d50e13c0d --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +from torch.autograd import Function +import torch.nn.functional as F +import dynamicconv_cuda +from fairseq import utils + + +class dynamicconvFunction(Function): + + @staticmethod + def forward(ctx, x, weights, padding_l): + ctx.padding_l = padding_l + outputs = dynamicconv_cuda.forward(x, weights, padding_l) + variables = [x, weights] + ctx.save_for_backward(*variables) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + outputs = dynamicconv_cuda.backward( + grad_output.contiguous(), + ctx.padding_l, + *ctx.saved_variables) + grad_input, grad_weights = outputs + return grad_input, grad_weights, None + + +class DynamicconvLayer(nn.Module): + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0., + bias=False, + renorm_padding=False, + conv_bias=False, + query_size=None): + + super(DynamicconvLayer, self).__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_softmax = weight_softmax + self.weight_dropout = weight_dropout + self.renorm_padding = renorm_padding + self.bias = bias + + self.weight_linear = nn.Linear(input_size, num_heads * kernel_size, bias) + if conv_bias: + self.conv_bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.conv_bias = None + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight_linear.weight) + if self.conv_bias is not None: + nn.init.constant_(self.conv_bias, 0.) + nn.init.constant_(self.weight_linaer.bias, 0.) + + def forward(self, x, incremental_state=None, query=None, unfold=None): + + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + + # during inference time, incremental BMM is faster + if incremental_state is not None: + unfold = x.size(0) > 512 if unfold is None else unfold # use unfold mode as default for long sequence to save memory + unfold = unfold or (incremental_state is not None) + assert query is None + + if query is None: + query = x + if unfold: + output = self._forward_unfolded(x, incremental_state, query) + else: + output = self._forward_expanded(x, incremental_state, query) + + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + + return output + + # during training time, use CUDA kernel + else: + weight = self.weight_linear(x).view(T, B, H, K) + if self.weight_softmax: + weight = F.softmax(weight, dim=-1) + if self.weight_dropout: + weight = F.dropout(weight, self.weight_dropout, training=self.training) + + weight = weight.permute(1, 2, 3, 0).contiguous() + self.filters = weight + x = x.permute(1, 2, 0).contiguous() + output = dynamicconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, 'input_buffer') + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + + def _forward_unfolded(self, x, incremental_state, query): + '''The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.''' + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight_linear(query).view(T*B*H, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) + x_unfold = x_unfold.view(T*B*H, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K-1: + weight = weight.narrow(1, K-T, T) + K, padding_l = T, T-1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0) + x_unfold = x_unfold.view(T*B*H, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2):] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_stat, query): + '''Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + ''' + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + weight = self.weight_linear(query).view(T*B*H, -1) + + if not self.renorm_padding: + if self.weight_softmax: + weight = F.softmax(weight, dim=1) + weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) + weight = weight.narrow(1, 0, K).contiguous() + weight = weight.view(T, B*H, K).transpose(0, 1) + + x = x.view(T, B*H, R).transpose(0, 1) + if self.weight_softmax and self.renorm_padding: + # turn the convolution filters into band matrices + weight_expanded = weight.new(B*H, T, T+K-1).fill_(float('-inf')) + weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight_expanded.narrow(2, self.padding_l, T) + # normalize the weight over valid positions like self-attention + weight_expanded = F.softmax(weight_expanded, dim=2) + weight_expanded = F.dropout(weight_expanded, self.weight_dropout, training=self.training, inplace=False) + else: + P = self.padding_l + # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length + if K > T and P == K-1: + weight = weight.narrow(2, K-T, T) + K, P = T, T-1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) + weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output diff --git a/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp new file mode 100644 index 000000000..8a6af4285 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp @@ -0,0 +1,35 @@ +#include +#include + +std::vector dynamicconv_cpu_forward( + float* input, + float* filters, + int padding_l); + +std::vector dynamicconv_cpu_backward( + float* gradOutput, + int padding_l, + float* input, + float* filters); + +std::vector dynamicconv_forward( + float* input, + float* filters, + int padding_l) { + + return dynamicconv_cpu_forward(input, filters, padding_l); +} + +std::vector dynamicconv_backward( + float* gradOutput, + int padding_l, + float* input, + float* filters) { + + return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); +} diff --git a/fairseq/modules/dynamicconv_layer/setup.py b/fairseq/modules/dynamicconv_layer/setup.py new file mode 100644 index 000000000..00ce29bc7 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +setup( + name='dynamicconv_layer', + ext_modules=[ + CUDAExtension( + name='dynamicconv_cuda', + sources=[ + 'dynamicconv_cuda.cpp', + 'dynamicconv_cuda_kernel.cu', + ], + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/fairseq/modules/lightconv_layer/__init__.py b/fairseq/modules/lightconv_layer/__init__.py new file mode 100644 index 000000000..95fe76c7c --- /dev/null +++ b/fairseq/modules/lightconv_layer/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from .lightconv_layer import LightconvLayer diff --git a/fairseq/modules/lightconv_layer/cuda_function_gen.py b/fairseq/modules/lightconv_layer/cuda_function_gen.py new file mode 100644 index 000000000..1bb3a1a0d --- /dev/null +++ b/fairseq/modules/lightconv_layer/cuda_function_gen.py @@ -0,0 +1,289 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +def gen_forward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + + head = """ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include "lightconv_cuda.cuh" + +std::vector lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = filters.size(0); + const auto filterSize = filters.size(1); + + const auto numFiltersInBlock = numFeatures / numHeads; + + const dim3 blocks(minibatch, numFeatures); + + auto output = at::zeros_like(input); + auto stream = at::cuda::getCurrentCUDAStream(); +""" + + sequence_if = """ + if (sequenceLength <= {seq}) {{ + switch(filterSize) {{ +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {pad}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{ + lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> + <<>>( + input.data(), + filters.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + output.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; + } + break; +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; + } +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + final_return = """ + } + + return {output}; +} +""" + + with open("lightconv_cuda_forward.cu", 'w') as forward: + forward.write(head) + for seq in seqs: + forward.write(sequence_if.format(seq=seq)) + for k in kernels: + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=seq, pad=pad)) + forward.write(bad_padding) + forward.write(bad_filter) + forward.write(con_else) + + forward.write(final_else) + for k in kernels: + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=seq, pad=pad)) + forward.write(bad_padding) + forward.write(bad_filter) + forward.write(final_return) + + +def gen_backward(): + + head = """ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include "lightconv_cuda.cuh" + +std::vector lightconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + // gradWrtInput + const int minibatch = input.size(0); + const int numFeatures = input.size(1); + const int sequenceLength = input.size(2); + + const int numHeads = filters.size(0); + const int filterSize = filters.size(1); + + const dim3 gradBlocks(minibatch, numFeatures); + const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads); + const dim3 weightGradSecondpassBlocks(numHeads, filterSize); + + const int numFiltersInBlock = numFeatures / numHeads; + + auto gradInput = at::zeros_like(input); + auto gradFilters = at::zeros_like(filters); + + at::DeviceGuard g(input.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + switch(filterSize) { +""" + + sequence_if = """ + if (sequenceLength <= {seq}) {{ +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {p}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{ + lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + gradOutput.data(), + filters.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + gradInput.data()); + +""" + + weight_grad_short = """ + at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat)); + lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + input.data(), + gradOutput.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + tempSumGradFilters.data() + ); + + lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t> + <<>>( + tempSumGradFilters.data(), + minibatch, + numFiltersInBlock, + gradFilters.data() + ); + }})); + }} else +""" + + weight_grad = """ + at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat)); + lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + input.data(), + gradOutput.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + tempSumGradFilters.data() + ); + + lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t> + <<>>( + tempSumGradFilters.data(), + minibatch, + numFiltersInBlock, + gradFilters.data() + ); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; + } +""" + + breakout = """ + break; +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + last_return = """ + } + return {gradInput, gradFilters}; +} +""" + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + thresh = [32, 32, 64, 128, 256, -1, -1, -1] + max_mem = [-1, -1, -1, -1, -1, 192, 96, 64] + + with open("lightconv_cuda_backward.cu", 'w') as backward: + backward.write(head) + for (k, t, mem) in zip(kernels, thresh, max_mem): + backward.write(case_k.format(k=k)) + for seq in seqs: + if (t == -1 or seq <= t) and (mem == -1 or seq < mem): + backward.write(sequence_if.format(seq=seq)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=seq, p=p)) + backward.write(weight_grad_short.format(k=k, b_size=seq, p=p)) + backward.write(bad_padding) + else: + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=32, p=p)) + backward.write(weight_grad.format(k=k, b_size=32, p=p)) + backward.write(bad_padding) + backward.write(breakout) + break + backward.write(con_else) + backward.write(bad_filter) + backward.write(last_return) + + +if __name__ == "__main__": + gen_forward() + gen_backward() diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cpp b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp new file mode 100644 index 000000000..3dc1765bf --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp @@ -0,0 +1,47 @@ +#include +#include + +std::vector lightconv_cuda_forward( + at::Tensor input, + at::Tensor filters, + int padding_l); + +std::vector lightconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector lightconv_forward( + at::Tensor input, + at::Tensor filters, + int padding_l) { + + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return lightconv_cuda_forward(input, filters, padding_l); +} + +std::vector lightconv_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return lightconv_cuda_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); + m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); +} diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cuh b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh new file mode 100644 index 000000000..f4c5fec43 --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh @@ -0,0 +1,82 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#define SHFL_MASK 0xffffffff + +template +__global__ +void lightconv_forward_kernel(const scalar_t* input, + const scalar_t* filters, + int minibatch, int sequenceLength, + int numFeatures, int numFiltersInBlock, + scalar_t* output); + +template +__global__ +void lightconv_grad_wrt_input_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output); + +template +__global__ +void lightconv_grad_wrt_weights_firstpass_short_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + float* output); + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_short_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output); + +template +__global__ +void lightconv_grad_wrt_weights_firstpass_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + float* output); + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output); + diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu new file mode 100644 index 000000000..8e17e27af --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu @@ -0,0 +1,374 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * All rights reserved. + * + */ + +#include "lightconv_cuda.cuh" +#include "lightconv_cuda_forward.cu" +#include "lightconv_cuda_backward.cu" +#include "../cuda_utils.cu" + +template +__global__ +void lightconv_forward_kernel(const scalar_t* input, + const scalar_t* filters, + int minibatch, int sequenceLength, + int numFeatures, int numFiltersInBlock, + scalar_t* output) { + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + + const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + const scalar_t* inputFilter = &filters[filterIdx * FS]; + + assert(blockDim.x == SB); + + scalar_t filter[FS]; + #pragma unroll + for (int i = 0; i < FS; ++i) { + filter[i] = inputFilter[i]; + } + + __shared__ scalar_t temp[SB + FS]; + zeroSharedMem(temp); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + // Read input into shared memory + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, (numIterations == 1), temp); + + __syncthreads(); + + scalar_t out = 0; + #pragma unroll + for (int j = 0; j < FS; ++j) { + out += filter[j] * temp[tid + j]; + } + + // Write output + const int outputOffset = inputOffset; + if ((outputOffset + tid) < sequenceLength) { + outputFeature[outputOffset + tid] = out; + } + + __syncthreads(); + } +} + +template +__global__ +void lightconv_grad_wrt_input_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output) { + + // input grad kernel is similar to forward kernel + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + + const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + const scalar_t* inputFilter = &filters[filterIdx * FS]; + + assert(blockDim.x == SB); + + scalar_t filter[FS]; + + // The only change is loading the filter in reverse + #pragma unroll + for (int i = 0; i < FS; ++i) { + filter[i] = inputFilter[FS - i - 1]; + } + + __shared__ scalar_t temp[SB + FS]; + const int padding = FS - padding_l - 1; + zeroSharedMem(temp); + + __syncthreads(); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + // Read input into shared memory + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, false, temp); + + __syncthreads(); + + scalar_t out = 0; + #pragma unroll + for (int j = 0; j < FS; ++j) { + out += filter[j] * temp[tid + j]; + } + + // Write output + const int outputOffset = inputOffset; + if ((outputOffset + tid) < sequenceLength) { + outputFeature[outputOffset + tid] = out; + } + + __syncthreads(); + } +} + +// This is by far the most expensive kernel in terms of time taken. +// Can be 16x slower than the forward or grad_wrt_input when filter size is 31 +template +__global__ +void lightconv_grad_wrt_weights_firstpass_short_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + float* output) { + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int filterIdx = blockIdx.y; + + const int numIterations = divUp(sequenceLength, SB); + + float* tempOutputGradWeight = &output[filterIdx * FS * minibatch]; + + assert(blockDim.x == SB); + + __shared__ scalar_t tempInput[SB + FS]; + __shared__ scalar_t tempGradInput[SB + FS]; + + // local weight accumulation + float accumWeights[FS]; + + // Initialize memory + for (int i = 0; i < FS; ++i) { + accumWeights[i] = float(0.0); + } + + + // loop over each sequence within filterblock + for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; ++idxInFilterBlock) { + + const int featureOffset = batchIdx * numFeatures * sequenceLength + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength; + const scalar_t* inputFeature = &input[featureOffset]; + const scalar_t* gradInputFeature = &gradInput[featureOffset]; + + zeroSharedMem(tempInput); + zeroSharedMem(tempGradInput); + __syncthreads(); + + for (int i = 0; i < numIterations; ++i) { + + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempInput); + load_input_to_shared(gradInputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempGradInput); + + __syncthreads(); + + const int gradIndex = (FS/2) + tid; + scalar_t tempGrad = tempGradInput[gradIndex]; + + #pragma unroll + for (int j = 0; j < FS; j++) { + const int inputIndex = tid + j; + accumWeights[j] += tempInput[inputIndex] * tempGrad; + } + + __syncthreads(); + + } + + } + + // Row-major sum + for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { + + float temp; + if (tid < sequenceLength) { + temp = accumWeights[filterWeightIdx]; + } else { + temp = float(0.0); + } + + const int outputOffset = filterWeightIdx * minibatch + batchIdx; + + temp = blockReduce(temp); + + if (tid == 0) { + tempOutputGradWeight[outputOffset] = temp; + } + } +} + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_short_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output) { + + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + + const int filterIdx = blockIdx.x; + const int filterWeightIdx = blockIdx.y; + + const int inputOffset = filterIdx * FS * minibatch + + filterWeightIdx * minibatch; + const float* tempInput = &input[inputOffset]; + + // read into shared memory for reduction + int readIndex = tid; + + float sum = 0.0; + while (readIndex < minibatch) { + sum += tempInput[readIndex]; + readIndex += SB; + } + + float temp = blockReduce(sum); + + if (tid == 0) { + output[blockIdx.x * FS + blockIdx.y] = temp; + } +} + +// This is by far the most expensive kernel in terms of time taken. +// Can be 16x slower than the forward or grad_wrt_input when filter size is 31 +template +__global__ +void lightconv_grad_wrt_weights_firstpass_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + float* output) { + + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + const int idxInFilterBlock = featureIdx % numFiltersInBlock; + + const int numIterations = divUp(sequenceLength, SB); + + float temp; + + __shared__ scalar_t tempInput[SB + FS]; + __shared__ scalar_t tempGradInput[SB + FS]; + zeroSharedMem(tempInput); + zeroSharedMem(tempGradInput); + __syncthreads(); + + float accumWeights[FS]; + + for (int i = 0; i < FS; ++i) { + accumWeights[i] = float(0.0); + } + + const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + const scalar_t* gradInputFeature = &gradInput[IOOffset]; + float* tempOutputGradWeight = &output[filterIdx * FS * minibatch * numFiltersInBlock]; + + for (int i = 0; i < numIterations; ++i) { + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempInput); + load_input_to_shared(gradInputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempGradInput); + __syncthreads(); + + #pragma unroll + for (int j = 0; j < FS; ++j) { + accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS/2)]; + } + + __syncthreads(); + } + + // Row-major sum + for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { + + // Write to shared memory before reduction + if (tid < sequenceLength) { + temp = accumWeights[filterWeightIdx]; + } else { + temp = float(0.0); + } + + temp = blockReduce(temp); + + const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock + + batchIdx * numFiltersInBlock + + idxInFilterBlock; + + if (tid == 0) { + tempOutputGradWeight[outputOffset] = temp; + } + } +} + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output) { + + assert(blockDim.x == SB); + const int tid = threadIdx.x; + + // What is the id within a minibatch + const int filterIdx = blockIdx.x; + const int filterWeightIdx = blockIdx.y; + + const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock + + filterWeightIdx * minibatch * numFiltersInBlock; + const float* tempInput = &input[inputOffset]; + + int readIndex = tid; + + float sum = float(0.0); + while (readIndex < (minibatch * numFiltersInBlock)) { + sum += tempInput[readIndex]; + readIndex += SB; + } + + float temp = blockReduce(sum); + + if (tid == 0) { + output[blockIdx.x * FS + blockIdx.y] = temp; + } +} diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py new file mode 100644 index 000000000..872812827 --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_layer.py @@ -0,0 +1,113 @@ +import torch +from torch import nn +from torch.autograd import Function +import torch.nn.functional as F +import time + +import lightconv_cuda +from fairseq import utils + +class lightconvFunction(Function): + + @staticmethod + def forward(ctx, x, weights, padding_l): + ctx.padding_l = padding_l + outputs = lightconv_cuda.forward(x, weights, padding_l) + variables = [x, weights] + ctx.save_for_backward(*variables) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + outputs = lightconv_cuda.backward( + grad_output.contiguous(), + ctx.padding_l, + *ctx.saved_variables) + grad_input, grad_weights = outputs + return grad_input, grad_weights, None + +class LightconvLayer(nn.Module): + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0., + bias=False): + super(LightconvLayer, self).__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_softmax = weight_softmax + self.weight_dropout = weight_dropout + + self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.) + + def forward(self, x, incremental_state=None): + + # during inference time, incremental BMM is faster + if incremental_state is not None: + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) + x_unfold = x_unfold.view(T*B*H, R, -1) + + weight = self.weight + if self.weight_softmax: + weight = F.softmax(weight.float(), dim=1).type_as(weight) + + weight = weight[:, -x_unfold.size(2):] + + K = weight.size(1) + + weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) + + weight = F.dropout(weight, self.weight_dropout, training=self.training) + output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + # during training time, use CUDA kernel + else: + x = x.permute(1, 2, 0).contiguous() + weight = self.weight + if self.weight_softmax: + weight = F.softmax(self.weight, -1) + if self.weight_dropout: + weight = F.dropout(weight, self.weight_dropout, training=self.training) + return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, 'input_buffer') + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) + + def half(self): + print("HALF") + return self._apply(lambda t: t.half() if t.is_floating_point() else t) diff --git a/fairseq/modules/lightconv_layer/setup.py b/fairseq/modules/lightconv_layer/setup.py new file mode 100644 index 000000000..c2a928ed8 --- /dev/null +++ b/fairseq/modules/lightconv_layer/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +setup( + name='lightconv_layer', + ext_modules=[ + CUDAExtension('lightconv_cuda', [ + 'lightconv_cuda.cpp', + 'lightconv_cuda_kernel.cu', + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py index 6191d4950..95d0418af 100644 --- a/fairseq/modules/lightweight_convolution.py +++ b/fairseq/modules/lightweight_convolution.py @@ -10,6 +10,21 @@ import torch.nn.functional as F from fairseq import utils from fairseq.modules.unfold import unfold1d +def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1, + weight_dropout=0., weight_softmax=False, bias=False): + if torch.cuda.is_available(): + try: + from fairseq.modules.lightconv_layer import LightconvLayer + return LightconvLayer(input_size, kernel_size=kernel_size, + padding_l=padding_l, num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, bias=bias) + except ImportError as e: + print(e) + return LightweightConv1dTBC(input_size, kernel_size=kernel_size, + padding_l=padding_l, num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, bias=bias) class LightweightConv1d(nn.Module): '''Lightweight Convolution assuming the input is BxCxT diff --git a/fairseq/modules/unfold.py b/fairseq/modules/unfold.py index 3a142db69..eff6ab575 100644 --- a/fairseq/modules/unfold.py +++ b/fairseq/modules/unfold.py @@ -5,7 +5,6 @@ import torch.nn.functional as F - def unfold1d(x, kernel_size, padding_l, pad_value=0): '''unfold T x B x C to T x B x C x K''' if kernel_size > 1: