Merged PR 10553: Fix multiple problems in reduce kernels that occurred during back-prop

This fixes a number of bugs in our GPU reduce-kernels that would manifest mainly for larger matrices and during back-prop. We also drop support for CUDA 8.0 to be able to take advantage of new GPU primitives introduced by NVidia in CUDA 9.0.
This commit is contained in:
Martin Junczys-Dowmunt 2019-11-26 01:48:07 +00:00
parent b19820c8ba
commit 93b7ed80fe
12 changed files with 581 additions and 482 deletions

View File

@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Gradient-checkpointing
### Fixed
- Fixed multiple reduction kernels on GPU
- Replace IntrusivePtr with std::uniq_ptr in FastOpt, fixes random segfaults
due to thread-non-safty of reference counting.
- Make sure that items are 256-byte aligned during saving
@ -38,6 +39,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Compilation with CUDA 10.1
### Changed
- Dropped support for CUDA 8.0, CUDA 9.0 is now minimal requirement
- Removed autotuner for now, will be switched back on later
- Boost depdendency is now optional and only required for marian_server
- Dropped support for g++-4.9

View File

@ -191,7 +191,7 @@ if(USE_STATIC_LIBS)
endif()
endif()
find_package(CUDA "8.0") # TODO: only enable FP16-related options for compute_70 and higher.
find_package(CUDA "9.0") # TODO: only enable FP16-related options for compute_70 and higher.
if(CUDA_FOUND)
# CUDA >= 10.0 requires CMake >= 3.12.2
if((CUDA_VERSION VERSION_EQUAL "10.0" OR CUDA_VERSION VERSION_GREATER "10.0") AND (CMAKE_VERSION VERSION_LESS "3.12.2"))

View File

@ -1 +1 @@
v1.8.21
v1.8.22

View File

@ -1,347 +1,248 @@
// This software contains source code provided by NVIDIA Corporation.
/*
* Copyright 1993-2015 NVIDIA Corporation. All rights reserved.
/* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECINDIRECFunctor, T, AccTyf, pe, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACSf, TRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
MJD: Relevant text from the NVIDIA EULA:
2.1 Sample Source Code Modification, Ownership and Distribution
Subject to the terms of the SLA and this Supplement, NVIDIA hereby grants you a non-
exclusive, non-transferable license, without the right to sublicense, during the applicable
license term unless earlier terminated pursuant to the SLA, to have Authorized Users
modify and create derivative works of CUDA Licensed Software that constitutes sample
source code, when provided to you by NVIDIA in source code form. You hold all rights,
title and interest in and to your modifications and derivative works of the sample source
code software that you create as permitted hereunder (collective, Derivatives), subject
to NVIDIAs underlying Intellectual Property Rights in and to the CUDA Licensed
Software; provided, however that you grant NVIDIA and its Affiliates an irrevocable,
perpetual, nonexclusive, worldwide, royalty-free paid-up license to make, have made,
use, have used, reproduce, license, distribute, sublicense, transfer and otherwise
commercialize Derivatives including (without limitation) with the CUDA Licensed
Software or other NVIDIA products, technologies or materials. You may distribute the
CUDA Supplement to Software License Agreement End User License Agreements (EULA)
DR-06739-001_v01_v9.0 | 14 sample source code as delivered by NVIDIA and/or your Derivatives,
provided that all NVIDIA copyright notices and trademarks are maintained and used properly
and the sample source code includes the following notice: This software contains source code
provided by NVIDIA Corporation.
*/
#pragma once
#include "tensors/tensor.h"
#include <cuda_runtime.h>
#include "functional/tmp.h"
#include <cooperative_groups.h>
namespace marian {
template <unsigned int blockSize, typename AccType>
__device__ void
reduceBlock(volatile AccType *sdata, AccType mySum, const unsigned int tid)
{
sdata[tid] = mySum;
__syncthreads();
namespace cg = cooperative_groups;
// do reduction in shared mem
if (blockSize >= 512)
{
if (tid < 256)
{
sdata[tid] = mySum = mySum + sdata[tid + 256];
}
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
template <class T>
struct SharedMemory {
__device__ inline operator T *() {
extern __shared__ int __smem[];
return (T *)__smem;
}
__syncthreads();
__device__ inline operator const T *() const {
extern __shared__ int __smem[];
return (T *)__smem;
}
};
// specialize for double to avoid unaligned memory
// access compile errors
template <>
struct SharedMemory<double> {
__device__ inline operator double *() {
extern __shared__ double __smem_d[];
return (double *)__smem_d;
}
__device__ inline operator const double *() const {
extern __shared__ double __smem_d[];
return (double *)__smem_d;
}
};
/*
This version adds multiple elements per thread sequentially. This reduces
the overall cost of the algorithm while keeping the work complexity O(n) and
the step complexity O(log n). (Brent's Theorem optimization)
Note, this kernel needs a minimum of 64*sizeof(T) bytes of shared memory.
In other words if blockSize <= 32, allocate 64*sizeof(T) bytes.
If blockSize > 32, allocate blockSize*sizeof(T) bytes.
*/
template <typename T, typename AccType, unsigned int blockSize, bool nIsPow2Greater1, size_t K, class Functor, class AggFunctor>
__global__ void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
const functional::Shape full,
functional::Tensor<AccType> out,
functional::Array<functional::Tensor<T>, K> ins) {
int n = full.elements();
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
AccType *sdata = SharedMemory<AccType>();
// perform first level of reduction,
// reading from global memory, writing to shared memory
unsigned int tid = threadIdx.x;
unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x;
unsigned int gridSize = blockSize * 2 * gridDim.x;
AccType mySum = aggInit;
// we reduceSinglePass multiple elements per thread. The number is determined by the
// number of active thread blocks (via gridDim). More blocks will result
// in a larger gridSize and therefore fewer elements per thread
while (i < n) {
mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i));
// ensure we don't read out of bounds -- this is optimized away for powerOf2
// sized arrays
if (nIsPow2Greater1 || i + blockSize < n)
mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i + blockSize));
i += gridSize;
}
// each thread puts its local sum into shared memory
sdata[tid] = mySum;
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 256]);
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 128]);
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 64]);
}
cg::sync(cta);
cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
if (cta.thread_rank() < 32) {
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64)
mySum = aggFunctor(mySum, sdata[tid + 32]);
// reduce final warp using shuffle
for (int offset = tile32.size() / 2; offset > 0; offset /= 2) {
mySum = aggFunctor(mySum, tile32.shfl_down(mySum, offset));
}
}
if (blockSize >= 256)
{
if (tid < 128)
{
sdata[tid] = mySum = mySum + sdata[tid + 128];
}
__syncthreads();
}
if (blockSize >= 128)
{
if (tid < 64)
{
sdata[tid] = mySum = mySum + sdata[tid + 64];
}
__syncthreads();
}
if (tid < 32)
{
if (blockSize >= 64)
{
sdata[tid] = mySum = mySum + sdata[tid + 32];
}
if (blockSize >= 32)
{
sdata[tid] = mySum = mySum + sdata[tid + 16];
}
if (blockSize >= 16)
{
sdata[tid] = mySum = mySum + sdata[tid + 8];
}
if (blockSize >= 8)
{
sdata[tid] = mySum = mySum + sdata[tid + 4];
}
if (blockSize >= 4)
{
sdata[tid] = mySum = mySum + sdata[tid + 2];
}
if (blockSize >= 2)
{
sdata[tid] = mySum = mySum + sdata[tid + 1];
}
}
// write result for this block to global mem
if (cta.thread_rank() == 0)
out[blockIdx.x] = aggFunctor(out[blockIdx.x], mySum * scale); // aggFunctor?
}
template <unsigned int blockSize, bool nIsPow2, typename T, typename AccType, class Functor>
__device__ void
reduceBlocks(Functor f, T *g_idata, AccType *g_odata, unsigned int n)
{
extern __shared__ AccType sdata[];
static inline bool isPow2Greater1(unsigned int x) { // is power of two but also larger than 1, otherwise an out-of-bounds read occurs
return x > 1 && ((x & (x - 1)) == 0);
}
// perform first level of reduction,
// reading from global memory, writing to shared memory
unsigned int tid = threadIdx.x;
unsigned int i = blockIdx.x*(blockSize*2) + threadIdx.x;
unsigned int gridSize = blockSize*2*gridDim.x;
AccType mySum = 0;
static inline unsigned int nextPow2(unsigned int x) {
--x;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
return ++x;
}
// we reduce multiple elements per thread. The number is determined by the
// number of active thread blocks (via gridDim). More blocks will result
// in a larger gridSize and therefore fewer elements per thread
while (i < n)
{
mySum += f((AccType)g_idata[i]);
////////////////////////////////////////////////////////////////////////////////
// Wrapper function for kernel launch
////////////////////////////////////////////////////////////////////////////////
template <typename T, typename AccType, size_t K, class Functor, class AggFunctor>
void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
const functional::Shape full,
functional::Tensor<AccType> out,
functional::Array<functional::Tensor<T>, K> ins,
int threads, int blocks) {
int size = full.elements();
// when there is only one warp per block, we need to allocate two warps
// worth of shared memory so that we don't index shared memory out of bounds
int smemSize = (threads <= 32) ? 2 * threads * sizeof(AccType) : threads * sizeof(AccType);
dim3 dimBlock(threads, 1, 1);
dim3 dimGrid(blocks, 1, 1);
// ensure we don't read out of bounds -- this is optimized away for powerOf2 sized arrays
if (nIsPow2 || i + blockSize < n)
mySum += f((AccType)g_idata[i+blockSize]);
i += gridSize;
if (isPow2Greater1(size)) {
switch (threads) {
case 512:
reduceSinglePass<T, AccType, 512, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 256:
reduceSinglePass<T, AccType, 256, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 128:
reduceSinglePass<T, AccType, 128, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 64:
reduceSinglePass<T, AccType, 64, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 32:
reduceSinglePass<T, AccType, 32, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 16:
reduceSinglePass<T, AccType, 16, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 8:
reduceSinglePass<T, AccType, 8, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 4:
reduceSinglePass<T, AccType, 4, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 2:
reduceSinglePass<T, AccType, 2, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 1:
reduceSinglePass<T, AccType, 1, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
}
// do reduction in shared mem
reduceBlock<blockSize>(sdata, mySum, tid);
// write result for this block to global mem
if (tid == 0) g_odata[blockIdx.x] = sdata[0];
}
// Global variable used by reduceSinglePass to count how many blocks have finished
__device__ unsigned int retirementCount = 0;
cudaError_t setRetirementCount(int retCnt)
{
return cudaMemcpyToSymbol(retirementCount, &retCnt, sizeof(unsigned int), 0, cudaMemcpyHostToDevice);
}
// This reduction kernel reduces an arbitrary size array in a single kernel invocation
// It does so by keeping track of how many blocks have finished. After each thread
// block completes the reduction of its own block of data, it "takes a ticket" by
// atomically incrementing a global counter. If the ticket value is equal to the number
// of thread blocks, then the block holding the ticket knows that it is the last block
// to finish. This last block is responsible for summing the results of all the other
// blocks.
//
// In order for this to work, we must be sure that before a block takes a ticket, all
// of its memory transactions have completed. This is what __threadfence() does -- it
// blocks until the results of all outstanding memory transactions within the
// calling thread are visible to all other threads.
//
// For more details on the reduction algorithm (notably the multi-pass approach), see
// the "reduction" sample in the CUDA SDK.
template <unsigned int blockSize, bool nIsPow2, typename T, typename AccType, class Functor>
__global__ void reduceSinglePass(Functor f, T *g_idata, AccType *g_odata, unsigned int n)
{
//
// PHASE 1: Process all inputs assigned to this block
//
reduceBlocks<blockSize, nIsPow2, T, AccType>(f, g_idata, g_odata, n);
//
// PHASE 2: Last block finished will process all partial sums
//
if (gridDim.x > 1)
{
const unsigned int tid = threadIdx.x;
__shared__ bool amLast;
extern AccType __shared__ smem[];
// wait until all outstanding memory instructions in this thread are finished
__threadfence();
// Thread 0 takes a ticket
if (tid==0)
{
unsigned int ticket = atomicInc(&retirementCount, gridDim.x);
// If the ticket ID is equal to the number of blocks, we are the last block!
amLast = (ticket == gridDim.x-1);
}
__syncthreads();
// The last block sums the results of all other blocks
if (amLast)
{
int i = tid;
AccType mySum = 0;
while (i < gridDim.x)
{
mySum += g_odata[i];
i += blockSize;
}
reduceBlock<blockSize>(smem, mySum, tid);
if (tid==0)
{
g_odata[0] = smem[0];
// reset retirement count so that next run succeeds
retirementCount = 0;
}
}
} else {
switch (threads) {
case 512:
reduceSinglePass<T, AccType, 512, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 256:
reduceSinglePass<T, AccType, 256, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 128:
reduceSinglePass<T, AccType, 128, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 64:
reduceSinglePass<T, AccType, 64, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 32:
reduceSinglePass<T, AccType, 32, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 16:
reduceSinglePass<T, AccType, 16, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 8:
reduceSinglePass<T, AccType, 8, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 4:
reduceSinglePass<T, AccType, 4, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 2:
reduceSinglePass<T, AccType, 2, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
case 1:
reduceSinglePass<T, AccType, 1, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
break;
}
}
}
bool isPow2(unsigned int x)
{
return ((x&(x-1))==0);
}
template <typename T, typename AccType, class Functor>
void ReduceAll(Functor f, Tensor blockMem, Tensor in)
{
cudaSetDevice(in->getDeviceId().no);
int size = in->shape().elements();
int threads = std::min(MAX_THREADS, size);
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
dim3 dimBlock(threads, 1, 1);
dim3 dimGrid(blocks, 1, 1);
int smemSize = threads * sizeof(AccType);
T* d_idata = in->data<T>();
AccType* d_odata = blockMem->data<AccType>();
// choose which of the optimized versions of reduction to launch
if (isPow2(size))
{
switch (threads)
{
case 512:
reduceSinglePass<512, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 256:
reduceSinglePass<256, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 128:
reduceSinglePass<128, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 64:
reduceSinglePass< 64, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 32:
reduceSinglePass< 32, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 16:
reduceSinglePass< 16, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 8:
reduceSinglePass< 8, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 4:
reduceSinglePass< 4, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 2:
reduceSinglePass< 2, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 1:
reduceSinglePass< 1, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
}
}
else
{
switch (threads)
{
case 512:
reduceSinglePass<512, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 256:
reduceSinglePass<256, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 128:
reduceSinglePass<128, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 64:
reduceSinglePass< 64, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 32:
reduceSinglePass< 32, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 16:
reduceSinglePass< 16, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 8:
reduceSinglePass< 8, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 4:
reduceSinglePass< 4, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 2:
reduceSinglePass< 2, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
case 1:
reduceSinglePass< 1, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
break;
}
}
}
}
}

View File

@ -9,136 +9,155 @@
namespace marian {
namespace functional {
template <size_t K, class Functor>
// This struct and its specializations are never used directly, only through apply and applyWithCast below.
template <size_t K, class Functor, typename AccType> // K-ary application of Functor, elements are cast to AccType before application of Functor
struct FApply {};
template <class Functor>
struct FApply<1, Functor> {
template <class Functor, typename AccType>
struct FApply<1, Functor, AccType> {
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 1>& in,
const functional::Array<int, 1>& indices) {
return functor(in[0].data()[indices[0]]);
return functor((AccType)in[0].data()[indices[0]]); // indices is an array of offsets into multiple tensors, index[i] corresponds in[i] based on up to arity K
}
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 1>& in,
int index) {
return functor(in[0].data()[index]);
return functor((AccType)in[0].data()[index]);
}
};
template <class Functor>
struct FApply<2, Functor> {
template <class Functor, typename AccType>
struct FApply<2, Functor, AccType> {
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 2>& in,
const functional::Array<int, 2>& indices) {
return functor(in[0].data()[indices[0]],
in[1].data()[indices[1]]);
return functor((AccType)in[0].data()[indices[0]],
(AccType)in[1].data()[indices[1]]);
}
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 2>& in,
int index) {
return functor(in[0].data()[index],
in[1].data()[index]);
return functor((AccType)in[0].data()[index],
(AccType)in[1].data()[index]);
}
};
template <class Functor>
struct FApply<3, Functor> {
template <class Functor, typename AccType>
struct FApply<3, Functor, AccType> {
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 3>& in,
const functional::Array<int, 3>& indices) {
return functor(in[0].data()[indices[0]],
in[1].data()[indices[1]],
in[2].data()[indices[2]]);
return functor((AccType)in[0].data()[indices[0]],
(AccType)in[1].data()[indices[1]],
(AccType)in[2].data()[indices[2]]);
}
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 3>& in,
int index) {
return functor(in[0].data()[index],
in[1].data()[index],
in[2].data()[index]);
return functor((AccType)in[0].data()[index],
(AccType)in[1].data()[index],
(AccType)in[2].data()[index]);
}
};
template <class Functor>
struct FApply<4, Functor> {
template <class Functor, typename AccType>
struct FApply<4, Functor, AccType> {
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 4>& in,
const functional::Array<int, 4>& indices) {
return functor(in[0].data()[indices[0]],
in[1].data()[indices[1]],
in[2].data()[indices[2]],
in[3].data()[indices[3]]);
return functor((AccType)in[0].data()[indices[0]],
(AccType)in[1].data()[indices[1]],
(AccType)in[2].data()[indices[2]],
(AccType)in[3].data()[indices[3]]);
}
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 4>& in,
int index) {
return functor(in[0].data()[index],
in[1].data()[index],
in[2].data()[index],
in[3].data()[index]);
return functor((AccType)in[0].data()[index],
(AccType)in[1].data()[index],
(AccType)in[2].data()[index],
(AccType)in[3].data()[index]);
}
};
template <class Functor>
struct FApply<5, Functor> {
template <class Functor, typename AccType>
struct FApply<5, Functor, AccType> {
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 5>& in,
const functional::Array<int, 5>& indices) {
return functor(in[0].data()[indices[0]],
in[1].data()[indices[1]],
in[2].data()[indices[2]],
in[3].data()[indices[3]],
in[4].data()[indices[4]]);
return functor((AccType)in[0].data()[indices[0]],
(AccType)in[1].data()[indices[1]],
(AccType)in[2].data()[indices[2]],
(AccType)in[3].data()[indices[3]],
(AccType)in[4].data()[indices[4]]);
}
template <typename ElementType>
HOST_DEVICE_INLINE static ElementType apply(
HOST_DEVICE_INLINE static AccType apply(
Functor functor,
functional::Array<functional::Tensor<ElementType>, 5>& in,
int index) {
return functor(in[0].data()[index],
in[1].data()[index],
in[2].data()[index],
in[3].data()[index],
in[4].data()[index]);
return functor((AccType)in[0].data()[index],
(AccType)in[1].data()[index],
(AccType)in[2].data()[index],
(AccType)in[3].data()[index],
(AccType)in[4].data()[index]);
}
};
template <size_t K, class Functor, typename ElementType>
/******************************************************************************/
// Applying functor to sets of K tensors
template <typename ElementType, size_t K, class Functor>
HOST_DEVICE_INLINE ElementType apply(Functor functor,
functional::Array<functional::Tensor<ElementType>, K>& in,
const functional::Array<int, K>& indices) {
return FApply<K, Functor>::apply(functor, in, indices);
return FApply<K, Functor, ElementType>::apply(functor, in, indices); // functor is applied to same type as input ElementType, no casting required
}
template <size_t K, class Functor, typename ElementType>
template <typename ElementType, size_t K, class Functor>
HOST_DEVICE_INLINE ElementType apply(Functor functor,
functional::Array<functional::Tensor<ElementType>, K>& in,
int index) {
return FApply<K, Functor>::apply(functor, in, index);
return FApply<K, Functor, ElementType>::apply(functor, in, index); // functor is applied to same type as input ElementType, no casting required
}
template <typename AccType, typename ElementType, size_t K, class Functor>
HOST_DEVICE_INLINE AccType applyWithCast(Functor functor,
functional::Array<functional::Tensor<ElementType>, K>& in,
const functional::Array<int, K>& indices) {
return FApply<K, Functor, AccType>::apply(functor, in, indices); // ElementType and AccType are potentially different, cast to AccType before applying functor.
// This is useful when accumulating e.g. 16-bit into 32-bit and we want to case to 32-bit before
// the functor is applied. L2-Norm is a good use-case since the square can be large.
}
template <typename AccType, typename ElementType, size_t K, class Functor>
HOST_DEVICE_INLINE AccType applyWithCast(Functor functor,
functional::Array<functional::Tensor<ElementType>, K>& in,
int index) {
return FApply<K, Functor, AccType>::apply(functor, in, index); // ElementType and AccType are potentially different, cast to AccType before applying functor
}
/******************************************************************************/
@ -180,7 +199,7 @@ struct Loop<1, N, K> {
for(size_t j = 0; j < K; ++j) {
acc[j] = pAcc[j] + (dim[N - 1] + i) * in[j].shape().bstride(N - 1);
}
agg = aggFunctor(agg, (AccType)apply<K>(functor, in, acc));
agg = aggFunctor(agg, applyWithCast<AccType>(functor, in, acc));
}
return agg;
}

View File

@ -355,35 +355,51 @@ Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but
}
Expr sum(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, sum of itself is a
return a;
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::sum);
}
Expr mean(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, mean of itself is a
return a;
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::mean);
}
Expr std(Expr a, int ax) {
return Expression<ReduceNodeOp>(a - mean(a,ax), ax, ReduceNodeOpCode::rms);
if(a->shape()[ax] == 1) // nothing to reduce, std(a) = 0
return a - a;
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
}
Expr var(Expr a, int ax) {
Expr var(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
return a - a;
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
}
Expr max(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, max of itself is a
return a;
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::max);
}
Expr min(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, min of itself is a
return a;
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::min);
}
Expr prod(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, prod of itself is a
return a;
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::prod);
}
// log(sum(exp(a)))
Expr logsumexp(Expr a, int ax) {
if(a->shape()[ax] == 1) // nothing to reduce, log(sum(exp(a))) = log(exp(a)) = a
return a;
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::logSumExp);
}

View File

@ -1,4 +1,5 @@
#include "tensors/gpu/add.h"
#include "tensors/gpu/add_all.h"
#include "tensors/gpu/cuda_helpers.h"
@ -12,11 +13,13 @@ namespace marian {
namespace gpu {
template <size_t K, class Functor, class AggFunctor, typename T, typename AccType>
__global__ void gAggregateGeneric(Functor functor, AccType aggInit, AggFunctor aggFunctor,
const functional::Shape full,
functional::Tensor<T> out,
functional::Array<functional::Tensor<T>, K> ins,
AccType scale = 1.0) {
__global__ void gAggregateGeneric(Functor functor, // functor applied to single corresponding elements in tensors (via broadcasting),
AccType aggInit, // aggInit is starting value of accumulation (e.g. 0 for sum),
AggFunctor aggFunctor, // aggFunctor is used to accumulate values (e.g. sum),
const functional::Shape full, // maximal combined shape of all tensors via broadcasting
functional::Tensor<T> out, // output tensor
functional::Array<functional::Tensor<T>, K> ins, // input tensors
AccType scale = 1.0) { // scale accumulation result by scale. e.g. used for computing mean from sum over N elements with scale 1/N
int outLength = out.shape().elements();
bool same = outLength == full.elements();
for(int i = 0; i < K; ++i)
@ -32,10 +35,10 @@ __global__ void gAggregateGeneric(Functor functor, AccType aggInit, AggFunctor a
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < outLength) {
if(same) {
out[index] = aggFunctor(out[index], functional::apply(functor, ins, index) * (T)scale);
out[index] = (T)aggFunctor((AccType)out[index], functional::applyWithCast<AccType>(functor, ins, index) * scale); // apply functors to with arguments cast to AccType
} else {
out.shape().dims(index, dims);
out[index] = aggFunctor(out[index], (T)(functional::loops(functor, aggInit, aggFunctor, ins, len, dims) * scale));
out[index] = (T)aggFunctor((AccType)out[index], functional::loops(functor, aggInit, aggFunctor, ins, len, dims) * scale); // apply functors to with arguments cast to AccType
}
}
}
@ -62,7 +65,7 @@ __global__ void gAggregateEqual(Functor functor, AggFunctor aggFunctor,
indices[i] = ins[i].shape().bindex(dims);
}
out[index] = aggFunctor(out[index], functional::apply(functor, ins, indices) * (T)scale);
out[index] = (T)aggFunctor((AccType)out[index], functional::applyWithCast<AccType>(functor, ins, indices) * scale);
}
}
}
@ -76,7 +79,7 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
int rows = full.elements() / full.back();
int cols = full.back();
bool same = true;
bool same = true; // do all inputs have the same number of elements?
for(int i = 0; i < K; ++i)
same = same && ins[i].shape().elements() == full.elements();
@ -93,7 +96,7 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols)
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], (AccType)functional::apply(functor, ins, j * cols + id));
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], functional::applyWithCast<AccType>(functor, ins, j * cols + id)); // casts to AccType before applying functor which then performs operation in AccType
}
} else {
functional::Array<int, functional::Shape::size()> dims;
@ -106,7 +109,7 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
functional::Array<int, K> indices;
for(int i = 0; i < K; ++i)
indices[i] = ins[i].shape().bindex(dims);
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], (AccType)functional::apply(functor, ins, indices));
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], functional::applyWithCast<AccType>(functor, ins, indices));// casts to AccType before applying functor which then performs operation in AccType
}
}
}
@ -121,7 +124,8 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
len = (len + 1) >> 1;
}
__syncthreads();
out[j] = aggFunctor(out[j], (T)(_sum[0] * scale));
if(threadIdx.x == 0) // only set value when in thread 0 in block
out[j] = aggFunctor(out[j], (T)(_sum[0] * scale));
}
__syncthreads();
}
@ -140,16 +144,16 @@ void AggregateTyped(Functor functor, AccType aggInit, AggFunctor aggFunctor, Acc
functional::Tensor<T> gOut = out;
functional::Array<functional::Tensor<T>, K> gIns = {tensors...};
if(full.back() != 1 && out->shape().back() == 1) {
size_t m = full.elements() / length;
size_t k = full.back();
if(out->shape().elements() == 1) { // reduce everything into a single element
AggregateAll<T, AccType>(nullptr, functor, aggInit, aggFunctor, scale, out, tensors...); // @TODO: pass allocator in here, currently uses cudaMalloc
} else if(full.back() != 1 && out->shape().back() == 1 && full.elements() / full.back() == length) { // element number of out and full shape on axis that are not reduced must match
size_t m = full.elements() / full.back(); // how many rows are we iterating over?
size_t k = full.back(); // how many columns are being reduced to 1 in each row?
int blocks = std::min(MAX_BLOCKS, (int)m);
int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
int shared = sizeof(AccType) * threads;
int shared = sizeof(AccType) * threads;
gAggregateReduce<K, Functor, AggFunctor, T, AccType><<<blocks, threads, shared>>>(functor, aggInit, aggFunctor, full, gOut, gIns, scale);
} else if(out->shape() == full) {
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));

128
src/tensors/gpu/add_all.h Normal file
View File

@ -0,0 +1,128 @@
#pragma once
// This header file provides wrappers around NVidia's reduce_all kernel with our custom aggregation functionality
// This kernel reduces a tensor into a single value. We have modified it to allow for different types of aggregations
// like summing or max etc.
#include "tensors/gpu/cuda_helpers.h"
#include "tensors/tensor.h"
#include "tensors/allocator.h"
#include "functional/functional.h"
#include "functional/tensor.h"
#include "tensors/tensor_operators.h"
#include "3rd_party/reduce_all.h" // only works with CUDA >9.0, we are dropping CUDA 8.0 support, also changed in CMakeLists.txt
#include <cuda.h>
namespace marian {
#if COMPILE_FP16
// local overload to determine tensor type
template <> inline Type typeId<half>() { return Type::float16; }
#endif
template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensors... tensors) {
cudaSetDevice(out->getDeviceId().no);
static_assert(CUDA_VERSION >= 9000, "Marian requires CUDA_VERSION >= 9000 (9.0)");
constexpr size_t K = sizeof...(Tensors); // obtain arity K of tensors...
functional::Array<functional::Tensor<T>, K> gIns = {tensors...}; // convert to array of K objects of type functional::Tensor<T>
functional::Shape full = marian::Shape::broadcast({tensors...}); // compute maximal broadcasted shape
int size = full.elements();
int threads = (size < MAX_THREADS * 2) ? nextPow2((size + 1) / 2) : MAX_THREADS; // suggested in NVidia example for the all_reduce kernel
int blocks = std::min(MAX_BLOCKS, (size + (threads * 2 - 1)) / (threads * 2)); // suggested in NVidia example for the all_reduce kernel
// The all_reduce kernel by nivida needs to perform multiple passes if the number of blocks needed to perform the reduction is larger than 1.
// Here we allocate the memory for the intermediate reductions for each block.
Tensor blockMem;
if(blocks > 1 || out->type() != typeId<AccType>()) { // if the out tensor does not have elementType AccType we need to allocate and convert later
MemoryPiece::PtrType temporaryMemory;
if(allocator) {
temporaryMemory = allocator->alloc<AccType>(blocks);
} else { // @TODO: get rid of this branch
uint8_t* temporaryMemoryPtr = 0;
CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType) * blocks));
temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType) * blocks); // @TODO: consider implementing MemoryPiece::cudaMalloc<T>(size) for managed memory
}
blockMem = TensorBase::New(temporaryMemory,
Shape({blocks}),
typeId<AccType>(),
out->getBackend());
blockMem->set(aggInit); // set temporary memory to aggInit
}
else { // we are reducing into a single element now and the type matches, just use out as memory
blockMem = out; // do not set final output memory as we might be summing gradients... needs to be handled outside this function
}
functional::Tensor<AccType> gBlockMem = blockMem;
reduceSinglePass<T, AccType>(functor, aggInit, aggFunctor, scale, full, /*out=*/gBlockMem, /*in=*/gIns, threads, blocks); // First pass reduction into intermediate memory
// If we actually needed more than one block to perform the first pass reduction, recursively run a second pass reduction over block memory until block memory has size 1.
if(blocks > 1) {
using namespace functional;
auto identity = _1; // transformation was done in first pass, hence only identity
AggregateAll<AccType, AccType>(allocator, identity, aggInit, aggFunctor, scale, out, /*in=*/blockMem); // Reducing AccType in AccType now (meta-reduction)
} else if(out->type() != typeId<AccType>()) { // it's only a single block, but we need to convert to different type, as mentioned above
CopyCast(out, blockMem);
}
if(blockMem != out) {
// Free temporary memory whether allocated in allocator or via cudaMalloc
if(allocator)
allocator->free(blockMem->memory());
else if(blockMem->memory()->data())
CUDA_CHECK(cudaFree(blockMem->memory()->data()));
}
}
// Aggregates all values into a single tensor and returns the value of that tensor as a float
// This does a GPU to CPU memory copy via TensorBase::scalar().
// Used currently only for L2Norm computation
template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
AccType AggregateAllAndReturn(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
const Tensors... tensors) {
MemoryPiece::PtrType temporaryMemory;
if(allocator) {
temporaryMemory = allocator->alloc<AccType>(1);
} else { // @TODO: get rid of this branch
uint8_t* temporaryMemoryPtr = 0;
CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType)));
temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType));
}
std::tuple<Tensors...> in(tensors...);
// Create a temporary tensor of size 1 to reduce into
auto out = TensorBase::New(temporaryMemory,
Shape({1}),
typeId<AccType>(),
std::get<0>(in)->getBackend());
out->set(aggInit); // init to aggInit
AggregateAll<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, tensors...);
AccType outScalar = out->template scalar<AccType>(); // convert to float also if other underlying type
if(allocator)
allocator->free(out->memory());
else if(out->memory()->data()) // @TODO: get rid of this branch
CUDA_CHECK(cudaFree(out->memory()->data()));
return outScalar;
}
}

View File

@ -1,5 +1,3 @@
//#include <thrust/transform_reduce.h>
#include "common/types.h"
#include "tensors/tensor_operators.h"
@ -9,7 +7,7 @@
#include "tensors/gpu/backend.h"
#include "tensors/gpu/cuda_helpers.h"
#include "3rd_party/reduce_all.h"
#include "tensors/gpu/add_all.h"
namespace marian {
@ -588,6 +586,8 @@ __global__ void gSoftmax(T* out,
// determine max (used below to improve numeric stability)
T* _max = _share;
// @TODO: what's going on here with fp16?
_max[threadIdx.x] = -CUDA_FLT_MAX; // mask
// find max over column indices that have the same relative column index (=threadIdx.x) across all blocks of columns
for(int tid = 0; tid < cols; tid += blockDim.x) {
@ -1661,53 +1661,27 @@ void CrossEntropyPickBackward(Tensor out, Tensor adj, Tensor a, Tensor indices)
}
}
float L2Norm(Tensor in, Ptr<Allocator> allocator) {
// computes the L2Norm of tensor and returns value as flaot on the CPU,
// this is mostly used for diagnostic purposes and gradient clipping
float L2Norm(Tensor in, Ptr<Allocator> allocator) { // @TODO: reverse order of arguments
cudaSetDevice(in->getDeviceId().no);
int size = in->shape().elements();
int threads = std::min(MAX_THREADS, size);
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
if(allocator) {
auto memoryPiece = allocator->alloc<float>(blocks);
auto blockMem = TensorBase::New(memoryPiece, Shape({1, blocks}), Type::float32, in->getBackend());
using namespace functional;
if(in->type() == Type::float32) {
ReduceAll<float, float>(_1 * _1, blockMem, in);
using namespace functional;
float l2Norm;
if(in->type() == Type::float32) {
l2Norm = std::sqrt(AggregateAllAndReturn</*ElementType=*/float, /*AccType=*/float>(allocator, /*functor=*/_1 * _1, /*aggInit=*/0.f, /*aggFunctor=*/_1 + _2, /*scale=*/1.f, in));
#if COMPILE_FP16
} else if(in->type() == Type::float16) {
ReduceAll<half, float>(_1 * _1, blockMem, in);
} else if(in->type() == Type::float16) {
l2Norm = std::sqrt(AggregateAllAndReturn</*ElementType=*/half, /*AccType=*/float>(allocator, /*functor=*/_1 * _1, /*aggInit=*/0.f, /*aggFunctor=*/_1 + _2, /*scale=*/1.f, in));
#endif
} else {
ABORT("L2Norm not implemented for type {}", in->type());
}
float dataCpu = sqrtf(blockMem->get<float>(0));
allocator->free(memoryPiece);
return dataCpu;
} else { // @TODO: this branch is to be removed with next PR, old version
uint8_t* data;
cudaMalloc(&data, blocks * sizeof(float));
Tensor out(TensorBase::New(MemoryPiece::New(data, blocks * sizeof(float)),
Shape({1, blocks}),
Type::float32,
in->getBackend()));
using namespace functional;
if(in->type() == Type::float32) {
ReduceAll<float, float>(_1 * _1, out, in);
#if COMPILE_FP16
} else if(in->type() == Type::float16) {
ReduceAll<half, float>(_1 * _1, out, in);
#endif
} else {
ABORT("L2Norm not implemented for type {}", in->type());
}
float dataCpu = sqrtf(out->get<float>(0));
out.reset();
cudaFree(data);
return dataCpu;
} else {
ABORT("L2Norm not implemented for type {}", in->type());
}
return l2Norm;
}
template <typename T, typename AccType = float>
@ -1761,22 +1735,22 @@ __global__ void gAtt(T* out,
void Att(Tensor out, Tensor va, Tensor context, Tensor state) {
cudaSetDevice(out->getDeviceId().no);
size_t m = out->shape().elements() / out->shape().back();
size_t k = context->shape()[-1];
size_t b = context->shape()[-2];
size_t t = context->shape()[-3];
size_t totalRows = out->shape().elements() / out->shape().back(); // number of rows
size_t modelDim = context->shape()[-1]; // number of cols
size_t batchDim = context->shape()[-2];
size_t contextWordsDim = context->shape()[-3];
int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
int blocks = std::min(MAX_BLOCKS, (int)totalRows);
int threads = std::min(MAX_THREADS, (int)modelDim);
int shared = sizeof(float) * threads;
if(out->type() == Type::float32) {
gAtt<float, float><<<blocks, threads, shared>>>(
out->data<float>(), va->data<float>(), context->data<float>(), state->data<float>(), m, k, b, t);
out->data<float>(), va->data<float>(), context->data<float>(), state->data<float>(), totalRows, modelDim, batchDim, contextWordsDim);
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gAtt<half, float><<<blocks, threads, shared>>>(
out->data<half>(), va->data<half>(), context->data<half>(), state->data<half>(), m, k, b, t);
out->data<half>(), va->data<half>(), context->data<half>(), state->data<half>(), totalRows, modelDim, batchDim, contextWordsDim);
#endif
} else {
ABORT("gAtt not implemented for type {}", out->type());
@ -2005,10 +1979,10 @@ __global__ void gLayerNormalizationGrad(T* gradX,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
AccType* sum_adj = shared;
AccType* sum_adj_x = shared + blockDim.x;
AccType* sum_x = shared + 2 * blockDim.x;
AccType* sum_sqr = shared + 3 * blockDim.x;
AccType* sum_adj = shared; // sum of gradient coming in
AccType* sum_adj_l = shared + blockDim.x; // sum of gradient coming in times layerNorm from value
AccType* sum_x = shared + 2 * blockDim.x; // sum of input value x
AccType* sum_sqr = shared + 3 * blockDim.x; // sum of (x - mean)^2
const T* xRow = x + j * cols;
const T* yRow = y + j * cols;
@ -2016,7 +1990,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
sum_x[threadIdx.x] = (AccType)0.0f;
sum_adj[threadIdx.x] = (AccType)0.0f;
sum_adj_x[threadIdx.x] = (AccType)0.0f;
sum_adj_l[threadIdx.x] = (AccType)0.0f;
sum_sqr[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
@ -2030,7 +2004,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
AccType lv = (yv - betav) / (gammav + eps); // go back to LN(x) from scaled and shifted version for accumulation
sum_x[threadIdx.x] += xv;
sum_adj_x[threadIdx.x] += adjv * lv;
sum_adj_l[threadIdx.x] += adjv * lv;
sum_adj[threadIdx.x] += adjv;
}
}
@ -2042,7 +2016,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
if(threadIdx.x < (len >> 1)) {
sum_x[threadIdx.x] += sum_x[threadIdx.x + skip]; // Accumulates in AccType
sum_adj[threadIdx.x] += sum_adj[threadIdx.x + skip]; // Accumulates in AccType
sum_adj_x[threadIdx.x] += sum_adj_x[threadIdx.x + skip]; // Accumulates in AccType
sum_adj_l[threadIdx.x] += sum_adj_l[threadIdx.x + skip]; // Accumulates in AccType
}
len = (len + 1) >> 1;
}
@ -2074,28 +2048,27 @@ __global__ void gLayerNormalizationGrad(T* gradX,
// Jacobian of layer norm
// J = [ \frac{1}{N\sigma} (N\delta_{ij} - l_i l_j - 1) ]_{ij}
// J * a = dC/dx_i = ( N v_i - l_i \sum_j l_j a_j - \sum_j a_j ) / (N \sigma)
// J * a = dC/dx_i = ( N a_i - l_i \sum_j l_j a_j - \sum_j a_j ) / (N \sigma)
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
AccType xv = xRow[id];
//AccType yv = yRow[id];
//AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
AccType gammav = (AccType)gamma[id];
AccType adjv = adjRow[id];
AccType lv = (xv - mean) / (sigma + eps);
AccType gradLv = N * adjv - lv * sum_adj_x[0] - sum_adj[0];
AccType gradLv = N * adjv - lv * sum_adj_l[0] - sum_adj[0];
gradLv /= N * (sigma + eps); // eps has to be inside parentheses for correct gradient
AccType gradXv = gammav * gradLv;
// Keep LN gradient between [-10, 10]
// AccType sign = functional::Ops<AccType>::sgn(gradXv);
// AccType cutoff = (AccType)10.f;
// gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv;
// Keep LN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. @TODO: to be fixed and removed.
AccType sign = functional::Ops<AccType>::sgn(gradXv);
AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option?
// or better: make obsolete.
gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv;
T* gradXRow = gradX + j * cols;
gradXRow[id] += (T)(gradXv);

View File

@ -28,18 +28,20 @@ std::string TensorBase::debug(int precision, int dispCols) {
else
strm << std::fixed << std::setprecision(0) << std::setfill(' ');
// double maxv = std::numeric_limits<double>::lowest();
// double minv = std::numeric_limits<double>::max();
// double l2Norm = 0.0;
double maxv = std::numeric_limits<double>::lowest();
double minv = std::numeric_limits<double>::max();
double l2Sum = 0.0;
for(int i = 0; i < values.size(); ++i) {
if((double)values[i] > maxv) maxv = values[i];
if((double)values[i] < minv) minv = values[i];
l2Sum += (double)values[i] * (double)values[i];
}
strm << "min: " << minv << " max: " << maxv << " l2-norm: " << sqrt(l2Sum) << std::endl;
for(int i = 0; i < values.size(); ++i) {
std::vector<int> dims;
shape().dims(i, dims);
// if((double)values[i] > maxv) maxv = values[i];
// if((double)values[i] < minv) minv = values[i];
// l2Norm += (double)values[i] * (double)values[i];
bool disp = true;
for(int j = 0; j < dims.size(); ++j)
disp = disp && (dims[j] < dispCols || dims[j] >= shape()[j] - dispCols);
@ -95,8 +97,6 @@ std::string TensorBase::debug(int precision, int dispCols) {
}
}
strm << std::endl;
//strm << "min: " << minv << " max: " << maxv << " l2-norm: " << sqrt(l2Norm);
return strm.str();
}

View File

@ -54,12 +54,12 @@ void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
gpu::Add(functor, scale, out, tensors...);
else
#endif
cpu::Aggregate(functor, 0.0f, functional::_1 + functional::_2, scale, out, tensors...);
cpu::Aggregate(functor, /*aggInit=*/0.0f, functional::_1 + functional::_2, scale, out, tensors...);
}
template <class Functor, class... Tensors>
void Add(Functor functor, marian::Tensor out, Tensors... tensors) {
Add(functor, 1, out, tensors...);
Add(functor, /*scale=*/1.f, out, tensors...);
}
template <class Functor, class AggFunctor, class... Tensors>

View File

@ -789,3 +789,59 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]")
tests<float>(DeviceType::cpu);
}
#endif
#ifdef BLAS_FOUND
#ifdef CUDA_FOUND
TEST_CASE("Compare aggregate operator", "[graph]") {
auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.001); };
Config::seed = 1234;
std::vector<float> initc;
std::vector<float> inita;
{
auto graph = New<ExpressionGraph>();
graph->setDevice({0, DeviceType::cpu});
graph->reserveWorkspaceMB(40);
auto chl = graph->param("1x10x512x2048", {1, 10, 512, 2048}, inits::normal());
auto adj = graph->param("1x1x512x2048", {1, 1, 512, 2048}, inits::normal());
graph->forward();
chl->val()->get(initc);
adj->val()->get(inita);
}
SECTION("initializing with zero (cpu)") {
std::vector<float> values1;
std::vector<float> values2;
auto graph1 = New<ExpressionGraph>();
graph1->setDevice({0, DeviceType::cpu});
graph1->reserveWorkspaceMB(40);
auto graph2 = New<ExpressionGraph>();
graph2->setDevice({0, DeviceType::gpu});
graph2->reserveWorkspaceMB(40);
auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
auto prod1 = scalar_product(chl1, adj1, -1);
graph1->forward();
auto chl2 = graph2->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
auto adj2 = graph2->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
auto prod2 = scalar_product(chl2, adj2, -1);
graph2->forward();
prod1->val()->get(values1);
prod2->val()->get(values2);
CHECK( std::equal(values1.begin(), values1.end(), values2.begin(), floatApprox) );
}
}
#endif
#endif