mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 10553: Fix multiple problems in reduce kernels that occurred during back-prop
This fixes a number of bugs in our GPU reduce-kernels that would manifest mainly for larger matrices and during back-prop. We also drop support for CUDA 8.0 to be able to take advantage of new GPU primitives introduced by NVidia in CUDA 9.0.
This commit is contained in:
parent
b19820c8ba
commit
93b7ed80fe
@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Gradient-checkpointing
|
||||
|
||||
### Fixed
|
||||
- Fixed multiple reduction kernels on GPU
|
||||
- Replace IntrusivePtr with std::uniq_ptr in FastOpt, fixes random segfaults
|
||||
due to thread-non-safty of reference counting.
|
||||
- Make sure that items are 256-byte aligned during saving
|
||||
@ -38,6 +39,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Compilation with CUDA 10.1
|
||||
|
||||
### Changed
|
||||
- Dropped support for CUDA 8.0, CUDA 9.0 is now minimal requirement
|
||||
- Removed autotuner for now, will be switched back on later
|
||||
- Boost depdendency is now optional and only required for marian_server
|
||||
- Dropped support for g++-4.9
|
||||
|
@ -191,7 +191,7 @@ if(USE_STATIC_LIBS)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(CUDA "8.0") # TODO: only enable FP16-related options for compute_70 and higher.
|
||||
find_package(CUDA "9.0") # TODO: only enable FP16-related options for compute_70 and higher.
|
||||
if(CUDA_FOUND)
|
||||
# CUDA >= 10.0 requires CMake >= 3.12.2
|
||||
if((CUDA_VERSION VERSION_EQUAL "10.0" OR CUDA_VERSION VERSION_GREATER "10.0") AND (CMAKE_VERSION VERSION_LESS "3.12.2"))
|
||||
|
555
src/3rd_party/reduce_all.h
vendored
555
src/3rd_party/reduce_all.h
vendored
@ -1,347 +1,248 @@
|
||||
// This software contains source code provided by NVIDIA Corporation.
|
||||
|
||||
/*
|
||||
* Copyright 1993-2015 NVIDIA Corporation. All rights reserved.
|
||||
/* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
||||
* with this source code for terms and conditions that govern your use of
|
||||
* this software. Any use, reproduction, disclosure, or distribution of
|
||||
* this software and related documentation outside the terms of the EULA
|
||||
* is strictly prohibited.
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions
|
||||
* are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
* contributors may be used to endorse or promote products derived
|
||||
* from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
* CONTRIBUTORS BE LIABLE FOR ANY DIRECINDIRECFunctor, T, AccTyf, pe, INCIDENTAL, SPECIAL,
|
||||
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
* OF LIABILITY, WHETHER IN CONTRACSf, TRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
/*
|
||||
MJD: Relevant text from the NVIDIA EULA:
|
||||
|
||||
2.1 Sample Source Code Modification, Ownership and Distribution
|
||||
|
||||
Subject to the terms of the SLA and this Supplement, NVIDIA hereby grants you a non-
|
||||
exclusive, non-transferable license, without the right to sublicense, during the applicable
|
||||
license term unless earlier terminated pursuant to the SLA, to have Authorized Users
|
||||
modify and create derivative works of CUDA Licensed Software that constitutes sample
|
||||
source code, when provided to you by NVIDIA in source code form. You hold all rights,
|
||||
title and interest in and to your modifications and derivative works of the sample source
|
||||
code software that you create as permitted hereunder (collective, Derivatives”), subject
|
||||
to NVIDIA’s underlying Intellectual Property Rights in and to the CUDA Licensed
|
||||
Software; provided, however that you grant NVIDIA and its Affiliates an irrevocable,
|
||||
perpetual, nonexclusive, worldwide, royalty-free paid-up license to make, have made,
|
||||
use, have used, reproduce, license, distribute, sublicense, transfer and otherwise
|
||||
commercialize Derivatives including (without limitation) with the CUDA Licensed
|
||||
Software or other NVIDIA products, technologies or materials. You may distribute the
|
||||
CUDA Supplement to Software License Agreement End User License Agreements (EULA)
|
||||
DR-06739-001_v01_v9.0 | 14 sample source code as delivered by NVIDIA and/or your Derivatives,
|
||||
provided that all NVIDIA copyright notices and trademarks are maintained and used properly
|
||||
and the sample source code includes the following notice: “This software contains source code
|
||||
provided by NVIDIA Corporation.”
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensors/tensor.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "functional/tmp.h"
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace marian {
|
||||
|
||||
template <unsigned int blockSize, typename AccType>
|
||||
__device__ void
|
||||
reduceBlock(volatile AccType *sdata, AccType mySum, const unsigned int tid)
|
||||
{
|
||||
sdata[tid] = mySum;
|
||||
__syncthreads();
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// do reduction in shared mem
|
||||
if (blockSize >= 512)
|
||||
{
|
||||
if (tid < 256)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 256];
|
||||
}
|
||||
// Utility class used to avoid linker errors with extern
|
||||
// unsized shared memory arrays with templated type
|
||||
template <class T>
|
||||
struct SharedMemory {
|
||||
__device__ inline operator T *() {
|
||||
extern __shared__ int __smem[];
|
||||
return (T *)__smem;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
__device__ inline operator const T *() const {
|
||||
extern __shared__ int __smem[];
|
||||
return (T *)__smem;
|
||||
}
|
||||
};
|
||||
|
||||
// specialize for double to avoid unaligned memory
|
||||
// access compile errors
|
||||
template <>
|
||||
struct SharedMemory<double> {
|
||||
__device__ inline operator double *() {
|
||||
extern __shared__ double __smem_d[];
|
||||
return (double *)__smem_d;
|
||||
}
|
||||
|
||||
__device__ inline operator const double *() const {
|
||||
extern __shared__ double __smem_d[];
|
||||
return (double *)__smem_d;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
This version adds multiple elements per thread sequentially. This reduces
|
||||
the overall cost of the algorithm while keeping the work complexity O(n) and
|
||||
the step complexity O(log n). (Brent's Theorem optimization)
|
||||
|
||||
Note, this kernel needs a minimum of 64*sizeof(T) bytes of shared memory.
|
||||
In other words if blockSize <= 32, allocate 64*sizeof(T) bytes.
|
||||
If blockSize > 32, allocate blockSize*sizeof(T) bytes.
|
||||
*/
|
||||
template <typename T, typename AccType, unsigned int blockSize, bool nIsPow2Greater1, size_t K, class Functor, class AggFunctor>
|
||||
__global__ void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
|
||||
const functional::Shape full,
|
||||
functional::Tensor<AccType> out,
|
||||
functional::Array<functional::Tensor<T>, K> ins) {
|
||||
int n = full.elements();
|
||||
|
||||
// Handle to thread block group
|
||||
cg::thread_block cta = cg::this_thread_block();
|
||||
AccType *sdata = SharedMemory<AccType>();
|
||||
|
||||
// perform first level of reduction,
|
||||
// reading from global memory, writing to shared memory
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x;
|
||||
unsigned int gridSize = blockSize * 2 * gridDim.x;
|
||||
|
||||
AccType mySum = aggInit;
|
||||
|
||||
// we reduceSinglePass multiple elements per thread. The number is determined by the
|
||||
// number of active thread blocks (via gridDim). More blocks will result
|
||||
// in a larger gridSize and therefore fewer elements per thread
|
||||
while (i < n) {
|
||||
mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i));
|
||||
|
||||
// ensure we don't read out of bounds -- this is optimized away for powerOf2
|
||||
// sized arrays
|
||||
if (nIsPow2Greater1 || i + blockSize < n)
|
||||
mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i + blockSize));
|
||||
|
||||
i += gridSize;
|
||||
}
|
||||
|
||||
// each thread puts its local sum into shared memory
|
||||
sdata[tid] = mySum;
|
||||
cg::sync(cta);
|
||||
|
||||
// do reduction in shared mem
|
||||
if ((blockSize >= 512) && (tid < 256)) {
|
||||
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 256]);
|
||||
}
|
||||
|
||||
cg::sync(cta);
|
||||
|
||||
if ((blockSize >= 256) && (tid < 128)) {
|
||||
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 128]);
|
||||
}
|
||||
|
||||
cg::sync(cta);
|
||||
|
||||
if ((blockSize >= 128) && (tid < 64)) {
|
||||
sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 64]);
|
||||
}
|
||||
|
||||
cg::sync(cta);
|
||||
|
||||
cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
|
||||
|
||||
if (cta.thread_rank() < 32) {
|
||||
// Fetch final intermediate sum from 2nd warp
|
||||
if (blockSize >= 64)
|
||||
mySum = aggFunctor(mySum, sdata[tid + 32]);
|
||||
// reduce final warp using shuffle
|
||||
for (int offset = tile32.size() / 2; offset > 0; offset /= 2) {
|
||||
mySum = aggFunctor(mySum, tile32.shfl_down(mySum, offset));
|
||||
}
|
||||
}
|
||||
|
||||
if (blockSize >= 256)
|
||||
{
|
||||
if (tid < 128)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 128];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (blockSize >= 128)
|
||||
{
|
||||
if (tid < 64)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 64];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 32];
|
||||
}
|
||||
|
||||
if (blockSize >= 32)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 16];
|
||||
}
|
||||
|
||||
if (blockSize >= 16)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 8];
|
||||
}
|
||||
|
||||
if (blockSize >= 8)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 4];
|
||||
}
|
||||
|
||||
if (blockSize >= 4)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 2];
|
||||
}
|
||||
|
||||
if (blockSize >= 2)
|
||||
{
|
||||
sdata[tid] = mySum = mySum + sdata[tid + 1];
|
||||
}
|
||||
}
|
||||
// write result for this block to global mem
|
||||
if (cta.thread_rank() == 0)
|
||||
out[blockIdx.x] = aggFunctor(out[blockIdx.x], mySum * scale); // aggFunctor?
|
||||
}
|
||||
|
||||
template <unsigned int blockSize, bool nIsPow2, typename T, typename AccType, class Functor>
|
||||
__device__ void
|
||||
reduceBlocks(Functor f, T *g_idata, AccType *g_odata, unsigned int n)
|
||||
{
|
||||
extern __shared__ AccType sdata[];
|
||||
static inline bool isPow2Greater1(unsigned int x) { // is power of two but also larger than 1, otherwise an out-of-bounds read occurs
|
||||
return x > 1 && ((x & (x - 1)) == 0);
|
||||
}
|
||||
|
||||
// perform first level of reduction,
|
||||
// reading from global memory, writing to shared memory
|
||||
unsigned int tid = threadIdx.x;
|
||||
unsigned int i = blockIdx.x*(blockSize*2) + threadIdx.x;
|
||||
unsigned int gridSize = blockSize*2*gridDim.x;
|
||||
AccType mySum = 0;
|
||||
static inline unsigned int nextPow2(unsigned int x) {
|
||||
--x;
|
||||
x |= x >> 1;
|
||||
x |= x >> 2;
|
||||
x |= x >> 4;
|
||||
x |= x >> 8;
|
||||
x |= x >> 16;
|
||||
return ++x;
|
||||
}
|
||||
|
||||
// we reduce multiple elements per thread. The number is determined by the
|
||||
// number of active thread blocks (via gridDim). More blocks will result
|
||||
// in a larger gridSize and therefore fewer elements per thread
|
||||
while (i < n)
|
||||
{
|
||||
mySum += f((AccType)g_idata[i]);
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Wrapper function for kernel launch
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T, typename AccType, size_t K, class Functor, class AggFunctor>
|
||||
void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
|
||||
const functional::Shape full,
|
||||
functional::Tensor<AccType> out,
|
||||
functional::Array<functional::Tensor<T>, K> ins,
|
||||
int threads, int blocks) {
|
||||
int size = full.elements();
|
||||
// when there is only one warp per block, we need to allocate two warps
|
||||
// worth of shared memory so that we don't index shared memory out of bounds
|
||||
int smemSize = (threads <= 32) ? 2 * threads * sizeof(AccType) : threads * sizeof(AccType);
|
||||
dim3 dimBlock(threads, 1, 1);
|
||||
dim3 dimGrid(blocks, 1, 1);
|
||||
|
||||
// ensure we don't read out of bounds -- this is optimized away for powerOf2 sized arrays
|
||||
if (nIsPow2 || i + blockSize < n)
|
||||
mySum += f((AccType)g_idata[i+blockSize]);
|
||||
|
||||
i += gridSize;
|
||||
if (isPow2Greater1(size)) {
|
||||
switch (threads) {
|
||||
case 512:
|
||||
reduceSinglePass<T, AccType, 512, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 256:
|
||||
reduceSinglePass<T, AccType, 256, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 128:
|
||||
reduceSinglePass<T, AccType, 128, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 64:
|
||||
reduceSinglePass<T, AccType, 64, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 32:
|
||||
reduceSinglePass<T, AccType, 32, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 16:
|
||||
reduceSinglePass<T, AccType, 16, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 8:
|
||||
reduceSinglePass<T, AccType, 8, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 4:
|
||||
reduceSinglePass<T, AccType, 4, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 2:
|
||||
reduceSinglePass<T, AccType, 2, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 1:
|
||||
reduceSinglePass<T, AccType, 1, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
}
|
||||
|
||||
// do reduction in shared mem
|
||||
reduceBlock<blockSize>(sdata, mySum, tid);
|
||||
|
||||
// write result for this block to global mem
|
||||
if (tid == 0) g_odata[blockIdx.x] = sdata[0];
|
||||
}
|
||||
|
||||
// Global variable used by reduceSinglePass to count how many blocks have finished
|
||||
__device__ unsigned int retirementCount = 0;
|
||||
|
||||
cudaError_t setRetirementCount(int retCnt)
|
||||
{
|
||||
return cudaMemcpyToSymbol(retirementCount, &retCnt, sizeof(unsigned int), 0, cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
// This reduction kernel reduces an arbitrary size array in a single kernel invocation
|
||||
// It does so by keeping track of how many blocks have finished. After each thread
|
||||
// block completes the reduction of its own block of data, it "takes a ticket" by
|
||||
// atomically incrementing a global counter. If the ticket value is equal to the number
|
||||
// of thread blocks, then the block holding the ticket knows that it is the last block
|
||||
// to finish. This last block is responsible for summing the results of all the other
|
||||
// blocks.
|
||||
//
|
||||
// In order for this to work, we must be sure that before a block takes a ticket, all
|
||||
// of its memory transactions have completed. This is what __threadfence() does -- it
|
||||
// blocks until the results of all outstanding memory transactions within the
|
||||
// calling thread are visible to all other threads.
|
||||
//
|
||||
// For more details on the reduction algorithm (notably the multi-pass approach), see
|
||||
// the "reduction" sample in the CUDA SDK.
|
||||
|
||||
template <unsigned int blockSize, bool nIsPow2, typename T, typename AccType, class Functor>
|
||||
__global__ void reduceSinglePass(Functor f, T *g_idata, AccType *g_odata, unsigned int n)
|
||||
{
|
||||
|
||||
//
|
||||
// PHASE 1: Process all inputs assigned to this block
|
||||
//
|
||||
|
||||
reduceBlocks<blockSize, nIsPow2, T, AccType>(f, g_idata, g_odata, n);
|
||||
|
||||
//
|
||||
// PHASE 2: Last block finished will process all partial sums
|
||||
//
|
||||
|
||||
if (gridDim.x > 1)
|
||||
{
|
||||
const unsigned int tid = threadIdx.x;
|
||||
__shared__ bool amLast;
|
||||
extern AccType __shared__ smem[];
|
||||
|
||||
// wait until all outstanding memory instructions in this thread are finished
|
||||
__threadfence();
|
||||
|
||||
// Thread 0 takes a ticket
|
||||
if (tid==0)
|
||||
{
|
||||
unsigned int ticket = atomicInc(&retirementCount, gridDim.x);
|
||||
// If the ticket ID is equal to the number of blocks, we are the last block!
|
||||
amLast = (ticket == gridDim.x-1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// The last block sums the results of all other blocks
|
||||
if (amLast)
|
||||
{
|
||||
int i = tid;
|
||||
AccType mySum = 0;
|
||||
|
||||
while (i < gridDim.x)
|
||||
{
|
||||
mySum += g_odata[i];
|
||||
i += blockSize;
|
||||
}
|
||||
|
||||
reduceBlock<blockSize>(smem, mySum, tid);
|
||||
|
||||
if (tid==0)
|
||||
{
|
||||
g_odata[0] = smem[0];
|
||||
|
||||
// reset retirement count so that next run succeeds
|
||||
retirementCount = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (threads) {
|
||||
case 512:
|
||||
reduceSinglePass<T, AccType, 512, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 256:
|
||||
reduceSinglePass<T, AccType, 256, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 128:
|
||||
reduceSinglePass<T, AccType, 128, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 64:
|
||||
reduceSinglePass<T, AccType, 64, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 32:
|
||||
reduceSinglePass<T, AccType, 32, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 16:
|
||||
reduceSinglePass<T, AccType, 16, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 8:
|
||||
reduceSinglePass<T, AccType, 8, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 4:
|
||||
reduceSinglePass<T, AccType, 4, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 2:
|
||||
reduceSinglePass<T, AccType, 2, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
case 1:
|
||||
reduceSinglePass<T, AccType, 1, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isPow2(unsigned int x)
|
||||
{
|
||||
return ((x&(x-1))==0);
|
||||
}
|
||||
|
||||
template <typename T, typename AccType, class Functor>
|
||||
void ReduceAll(Functor f, Tensor blockMem, Tensor in)
|
||||
{
|
||||
cudaSetDevice(in->getDeviceId().no);
|
||||
int size = in->shape().elements();
|
||||
int threads = std::min(MAX_THREADS, size);
|
||||
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
|
||||
|
||||
dim3 dimBlock(threads, 1, 1);
|
||||
dim3 dimGrid(blocks, 1, 1);
|
||||
int smemSize = threads * sizeof(AccType);
|
||||
|
||||
T* d_idata = in->data<T>();
|
||||
AccType* d_odata = blockMem->data<AccType>();
|
||||
|
||||
// choose which of the optimized versions of reduction to launch
|
||||
if (isPow2(size))
|
||||
{
|
||||
switch (threads)
|
||||
{
|
||||
case 512:
|
||||
reduceSinglePass<512, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 256:
|
||||
reduceSinglePass<256, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 128:
|
||||
reduceSinglePass<128, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 64:
|
||||
reduceSinglePass< 64, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 32:
|
||||
reduceSinglePass< 32, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 16:
|
||||
reduceSinglePass< 16, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 8:
|
||||
reduceSinglePass< 8, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 4:
|
||||
reduceSinglePass< 4, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
reduceSinglePass< 2, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 1:
|
||||
reduceSinglePass< 1, true, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (threads)
|
||||
{
|
||||
case 512:
|
||||
reduceSinglePass<512, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 256:
|
||||
reduceSinglePass<256, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 128:
|
||||
reduceSinglePass<128, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 64:
|
||||
reduceSinglePass< 64, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 32:
|
||||
reduceSinglePass< 32, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 16:
|
||||
reduceSinglePass< 16, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 8:
|
||||
reduceSinglePass< 8, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 4:
|
||||
reduceSinglePass< 4, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
reduceSinglePass< 2, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 1:
|
||||
reduceSinglePass< 1, false, T, AccType><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -9,136 +9,155 @@
|
||||
namespace marian {
|
||||
namespace functional {
|
||||
|
||||
template <size_t K, class Functor>
|
||||
// This struct and its specializations are never used directly, only through apply and applyWithCast below.
|
||||
template <size_t K, class Functor, typename AccType> // K-ary application of Functor, elements are cast to AccType before application of Functor
|
||||
struct FApply {};
|
||||
|
||||
template <class Functor>
|
||||
struct FApply<1, Functor> {
|
||||
template <class Functor, typename AccType>
|
||||
struct FApply<1, Functor, AccType> {
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 1>& in,
|
||||
const functional::Array<int, 1>& indices) {
|
||||
return functor(in[0].data()[indices[0]]);
|
||||
return functor((AccType)in[0].data()[indices[0]]); // indices is an array of offsets into multiple tensors, index[i] corresponds in[i] based on up to arity K
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 1>& in,
|
||||
int index) {
|
||||
return functor(in[0].data()[index]);
|
||||
return functor((AccType)in[0].data()[index]);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Functor>
|
||||
struct FApply<2, Functor> {
|
||||
template <class Functor, typename AccType>
|
||||
struct FApply<2, Functor, AccType> {
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 2>& in,
|
||||
const functional::Array<int, 2>& indices) {
|
||||
return functor(in[0].data()[indices[0]],
|
||||
in[1].data()[indices[1]]);
|
||||
return functor((AccType)in[0].data()[indices[0]],
|
||||
(AccType)in[1].data()[indices[1]]);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 2>& in,
|
||||
int index) {
|
||||
return functor(in[0].data()[index],
|
||||
in[1].data()[index]);
|
||||
return functor((AccType)in[0].data()[index],
|
||||
(AccType)in[1].data()[index]);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Functor>
|
||||
struct FApply<3, Functor> {
|
||||
template <class Functor, typename AccType>
|
||||
struct FApply<3, Functor, AccType> {
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 3>& in,
|
||||
const functional::Array<int, 3>& indices) {
|
||||
return functor(in[0].data()[indices[0]],
|
||||
in[1].data()[indices[1]],
|
||||
in[2].data()[indices[2]]);
|
||||
return functor((AccType)in[0].data()[indices[0]],
|
||||
(AccType)in[1].data()[indices[1]],
|
||||
(AccType)in[2].data()[indices[2]]);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 3>& in,
|
||||
int index) {
|
||||
return functor(in[0].data()[index],
|
||||
in[1].data()[index],
|
||||
in[2].data()[index]);
|
||||
return functor((AccType)in[0].data()[index],
|
||||
(AccType)in[1].data()[index],
|
||||
(AccType)in[2].data()[index]);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Functor>
|
||||
struct FApply<4, Functor> {
|
||||
template <class Functor, typename AccType>
|
||||
struct FApply<4, Functor, AccType> {
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 4>& in,
|
||||
const functional::Array<int, 4>& indices) {
|
||||
return functor(in[0].data()[indices[0]],
|
||||
in[1].data()[indices[1]],
|
||||
in[2].data()[indices[2]],
|
||||
in[3].data()[indices[3]]);
|
||||
return functor((AccType)in[0].data()[indices[0]],
|
||||
(AccType)in[1].data()[indices[1]],
|
||||
(AccType)in[2].data()[indices[2]],
|
||||
(AccType)in[3].data()[indices[3]]);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 4>& in,
|
||||
int index) {
|
||||
return functor(in[0].data()[index],
|
||||
in[1].data()[index],
|
||||
in[2].data()[index],
|
||||
in[3].data()[index]);
|
||||
return functor((AccType)in[0].data()[index],
|
||||
(AccType)in[1].data()[index],
|
||||
(AccType)in[2].data()[index],
|
||||
(AccType)in[3].data()[index]);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Functor>
|
||||
struct FApply<5, Functor> {
|
||||
template <class Functor, typename AccType>
|
||||
struct FApply<5, Functor, AccType> {
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 5>& in,
|
||||
const functional::Array<int, 5>& indices) {
|
||||
return functor(in[0].data()[indices[0]],
|
||||
in[1].data()[indices[1]],
|
||||
in[2].data()[indices[2]],
|
||||
in[3].data()[indices[3]],
|
||||
in[4].data()[indices[4]]);
|
||||
return functor((AccType)in[0].data()[indices[0]],
|
||||
(AccType)in[1].data()[indices[1]],
|
||||
(AccType)in[2].data()[indices[2]],
|
||||
(AccType)in[3].data()[indices[3]],
|
||||
(AccType)in[4].data()[indices[4]]);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
HOST_DEVICE_INLINE static ElementType apply(
|
||||
HOST_DEVICE_INLINE static AccType apply(
|
||||
Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, 5>& in,
|
||||
int index) {
|
||||
return functor(in[0].data()[index],
|
||||
in[1].data()[index],
|
||||
in[2].data()[index],
|
||||
in[3].data()[index],
|
||||
in[4].data()[index]);
|
||||
return functor((AccType)in[0].data()[index],
|
||||
(AccType)in[1].data()[index],
|
||||
(AccType)in[2].data()[index],
|
||||
(AccType)in[3].data()[index],
|
||||
(AccType)in[4].data()[index]);
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t K, class Functor, typename ElementType>
|
||||
/******************************************************************************/
|
||||
// Applying functor to sets of K tensors
|
||||
template <typename ElementType, size_t K, class Functor>
|
||||
HOST_DEVICE_INLINE ElementType apply(Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, K>& in,
|
||||
const functional::Array<int, K>& indices) {
|
||||
return FApply<K, Functor>::apply(functor, in, indices);
|
||||
return FApply<K, Functor, ElementType>::apply(functor, in, indices); // functor is applied to same type as input ElementType, no casting required
|
||||
}
|
||||
|
||||
template <size_t K, class Functor, typename ElementType>
|
||||
template <typename ElementType, size_t K, class Functor>
|
||||
HOST_DEVICE_INLINE ElementType apply(Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, K>& in,
|
||||
int index) {
|
||||
return FApply<K, Functor>::apply(functor, in, index);
|
||||
return FApply<K, Functor, ElementType>::apply(functor, in, index); // functor is applied to same type as input ElementType, no casting required
|
||||
}
|
||||
|
||||
template <typename AccType, typename ElementType, size_t K, class Functor>
|
||||
HOST_DEVICE_INLINE AccType applyWithCast(Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, K>& in,
|
||||
const functional::Array<int, K>& indices) {
|
||||
return FApply<K, Functor, AccType>::apply(functor, in, indices); // ElementType and AccType are potentially different, cast to AccType before applying functor.
|
||||
// This is useful when accumulating e.g. 16-bit into 32-bit and we want to case to 32-bit before
|
||||
// the functor is applied. L2-Norm is a good use-case since the square can be large.
|
||||
}
|
||||
|
||||
template <typename AccType, typename ElementType, size_t K, class Functor>
|
||||
HOST_DEVICE_INLINE AccType applyWithCast(Functor functor,
|
||||
functional::Array<functional::Tensor<ElementType>, K>& in,
|
||||
int index) {
|
||||
return FApply<K, Functor, AccType>::apply(functor, in, index); // ElementType and AccType are potentially different, cast to AccType before applying functor
|
||||
}
|
||||
|
||||
/******************************************************************************/
|
||||
@ -180,7 +199,7 @@ struct Loop<1, N, K> {
|
||||
for(size_t j = 0; j < K; ++j) {
|
||||
acc[j] = pAcc[j] + (dim[N - 1] + i) * in[j].shape().bstride(N - 1);
|
||||
}
|
||||
agg = aggFunctor(agg, (AccType)apply<K>(functor, in, acc));
|
||||
agg = aggFunctor(agg, applyWithCast<AccType>(functor, in, acc));
|
||||
}
|
||||
return agg;
|
||||
}
|
||||
|
@ -355,35 +355,51 @@ Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but
|
||||
}
|
||||
|
||||
Expr sum(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, sum of itself is a
|
||||
return a;
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::sum);
|
||||
}
|
||||
|
||||
Expr mean(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, mean of itself is a
|
||||
return a;
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::mean);
|
||||
}
|
||||
|
||||
Expr std(Expr a, int ax) {
|
||||
return Expression<ReduceNodeOp>(a - mean(a,ax), ax, ReduceNodeOpCode::rms);
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, std(a) = 0
|
||||
return a - a;
|
||||
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
|
||||
}
|
||||
|
||||
Expr var(Expr a, int ax) {
|
||||
Expr var(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
|
||||
return a - a;
|
||||
return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
|
||||
}
|
||||
|
||||
Expr max(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, max of itself is a
|
||||
return a;
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::max);
|
||||
}
|
||||
|
||||
Expr min(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, min of itself is a
|
||||
return a;
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::min);
|
||||
}
|
||||
|
||||
Expr prod(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, prod of itself is a
|
||||
return a;
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::prod);
|
||||
}
|
||||
|
||||
// log(sum(exp(a)))
|
||||
Expr logsumexp(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) // nothing to reduce, log(sum(exp(a))) = log(exp(a)) = a
|
||||
return a;
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::logSumExp);
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "tensors/gpu/add.h"
|
||||
#include "tensors/gpu/add_all.h"
|
||||
|
||||
#include "tensors/gpu/cuda_helpers.h"
|
||||
|
||||
@ -12,11 +13,13 @@ namespace marian {
|
||||
namespace gpu {
|
||||
|
||||
template <size_t K, class Functor, class AggFunctor, typename T, typename AccType>
|
||||
__global__ void gAggregateGeneric(Functor functor, AccType aggInit, AggFunctor aggFunctor,
|
||||
const functional::Shape full,
|
||||
functional::Tensor<T> out,
|
||||
functional::Array<functional::Tensor<T>, K> ins,
|
||||
AccType scale = 1.0) {
|
||||
__global__ void gAggregateGeneric(Functor functor, // functor applied to single corresponding elements in tensors (via broadcasting),
|
||||
AccType aggInit, // aggInit is starting value of accumulation (e.g. 0 for sum),
|
||||
AggFunctor aggFunctor, // aggFunctor is used to accumulate values (e.g. sum),
|
||||
const functional::Shape full, // maximal combined shape of all tensors via broadcasting
|
||||
functional::Tensor<T> out, // output tensor
|
||||
functional::Array<functional::Tensor<T>, K> ins, // input tensors
|
||||
AccType scale = 1.0) { // scale accumulation result by scale. e.g. used for computing mean from sum over N elements with scale 1/N
|
||||
int outLength = out.shape().elements();
|
||||
bool same = outLength == full.elements();
|
||||
for(int i = 0; i < K; ++i)
|
||||
@ -32,10 +35,10 @@ __global__ void gAggregateGeneric(Functor functor, AccType aggInit, AggFunctor a
|
||||
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if(index < outLength) {
|
||||
if(same) {
|
||||
out[index] = aggFunctor(out[index], functional::apply(functor, ins, index) * (T)scale);
|
||||
out[index] = (T)aggFunctor((AccType)out[index], functional::applyWithCast<AccType>(functor, ins, index) * scale); // apply functors to with arguments cast to AccType
|
||||
} else {
|
||||
out.shape().dims(index, dims);
|
||||
out[index] = aggFunctor(out[index], (T)(functional::loops(functor, aggInit, aggFunctor, ins, len, dims) * scale));
|
||||
out[index] = (T)aggFunctor((AccType)out[index], functional::loops(functor, aggInit, aggFunctor, ins, len, dims) * scale); // apply functors to with arguments cast to AccType
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -62,7 +65,7 @@ __global__ void gAggregateEqual(Functor functor, AggFunctor aggFunctor,
|
||||
indices[i] = ins[i].shape().bindex(dims);
|
||||
}
|
||||
|
||||
out[index] = aggFunctor(out[index], functional::apply(functor, ins, indices) * (T)scale);
|
||||
out[index] = (T)aggFunctor((AccType)out[index], functional::applyWithCast<AccType>(functor, ins, indices) * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -76,7 +79,7 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
|
||||
int rows = full.elements() / full.back();
|
||||
int cols = full.back();
|
||||
|
||||
bool same = true;
|
||||
bool same = true; // do all inputs have the same number of elements?
|
||||
for(int i = 0; i < K; ++i)
|
||||
same = same && ins[i].shape().elements() == full.elements();
|
||||
|
||||
@ -93,7 +96,7 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols)
|
||||
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], (AccType)functional::apply(functor, ins, j * cols + id));
|
||||
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], functional::applyWithCast<AccType>(functor, ins, j * cols + id)); // casts to AccType before applying functor which then performs operation in AccType
|
||||
}
|
||||
} else {
|
||||
functional::Array<int, functional::Shape::size()> dims;
|
||||
@ -106,7 +109,7 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
|
||||
functional::Array<int, K> indices;
|
||||
for(int i = 0; i < K; ++i)
|
||||
indices[i] = ins[i].shape().bindex(dims);
|
||||
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], (AccType)functional::apply(functor, ins, indices));
|
||||
_sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], functional::applyWithCast<AccType>(functor, ins, indices));// casts to AccType before applying functor which then performs operation in AccType
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -121,7 +124,8 @@ __global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor ag
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
__syncthreads();
|
||||
out[j] = aggFunctor(out[j], (T)(_sum[0] * scale));
|
||||
if(threadIdx.x == 0) // only set value when in thread 0 in block
|
||||
out[j] = aggFunctor(out[j], (T)(_sum[0] * scale));
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
@ -140,16 +144,16 @@ void AggregateTyped(Functor functor, AccType aggInit, AggFunctor aggFunctor, Acc
|
||||
functional::Tensor<T> gOut = out;
|
||||
functional::Array<functional::Tensor<T>, K> gIns = {tensors...};
|
||||
|
||||
if(full.back() != 1 && out->shape().back() == 1) {
|
||||
size_t m = full.elements() / length;
|
||||
size_t k = full.back();
|
||||
if(out->shape().elements() == 1) { // reduce everything into a single element
|
||||
AggregateAll<T, AccType>(nullptr, functor, aggInit, aggFunctor, scale, out, tensors...); // @TODO: pass allocator in here, currently uses cudaMalloc
|
||||
} else if(full.back() != 1 && out->shape().back() == 1 && full.elements() / full.back() == length) { // element number of out and full shape on axis that are not reduced must match
|
||||
size_t m = full.elements() / full.back(); // how many rows are we iterating over?
|
||||
size_t k = full.back(); // how many columns are being reduced to 1 in each row?
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)m);
|
||||
int blocks = std::min(MAX_BLOCKS, (int)m);
|
||||
int threads = std::min(MAX_THREADS, (int)k);
|
||||
int shared = sizeof(AccType) * threads;
|
||||
|
||||
int shared = sizeof(AccType) * threads;
|
||||
gAggregateReduce<K, Functor, AggFunctor, T, AccType><<<blocks, threads, shared>>>(functor, aggInit, aggFunctor, full, gOut, gIns, scale);
|
||||
|
||||
} else if(out->shape() == full) {
|
||||
int threads = std::min(MAX_THREADS, length);
|
||||
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
|
||||
|
128
src/tensors/gpu/add_all.h
Normal file
128
src/tensors/gpu/add_all.h
Normal file
@ -0,0 +1,128 @@
|
||||
#pragma once
|
||||
|
||||
// This header file provides wrappers around NVidia's reduce_all kernel with our custom aggregation functionality
|
||||
// This kernel reduces a tensor into a single value. We have modified it to allow for different types of aggregations
|
||||
// like summing or max etc.
|
||||
|
||||
#include "tensors/gpu/cuda_helpers.h"
|
||||
#include "tensors/tensor.h"
|
||||
#include "tensors/allocator.h"
|
||||
#include "functional/functional.h"
|
||||
#include "functional/tensor.h"
|
||||
#include "tensors/tensor_operators.h"
|
||||
#include "3rd_party/reduce_all.h" // only works with CUDA >9.0, we are dropping CUDA 8.0 support, also changed in CMakeLists.txt
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
|
||||
namespace marian {
|
||||
|
||||
#if COMPILE_FP16
|
||||
// local overload to determine tensor type
|
||||
template <> inline Type typeId<half>() { return Type::float16; }
|
||||
#endif
|
||||
|
||||
template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
|
||||
void AggregateAll(Ptr<Allocator> allocator,
|
||||
Functor functor,
|
||||
AccType aggInit,
|
||||
AggFunctor aggFunctor,
|
||||
AccType scale,
|
||||
Tensor out,
|
||||
const Tensors... tensors) {
|
||||
cudaSetDevice(out->getDeviceId().no);
|
||||
|
||||
static_assert(CUDA_VERSION >= 9000, "Marian requires CUDA_VERSION >= 9000 (9.0)");
|
||||
|
||||
constexpr size_t K = sizeof...(Tensors); // obtain arity K of tensors...
|
||||
functional::Array<functional::Tensor<T>, K> gIns = {tensors...}; // convert to array of K objects of type functional::Tensor<T>
|
||||
functional::Shape full = marian::Shape::broadcast({tensors...}); // compute maximal broadcasted shape
|
||||
|
||||
int size = full.elements();
|
||||
int threads = (size < MAX_THREADS * 2) ? nextPow2((size + 1) / 2) : MAX_THREADS; // suggested in NVidia example for the all_reduce kernel
|
||||
int blocks = std::min(MAX_BLOCKS, (size + (threads * 2 - 1)) / (threads * 2)); // suggested in NVidia example for the all_reduce kernel
|
||||
|
||||
// The all_reduce kernel by nivida needs to perform multiple passes if the number of blocks needed to perform the reduction is larger than 1.
|
||||
// Here we allocate the memory for the intermediate reductions for each block.
|
||||
Tensor blockMem;
|
||||
if(blocks > 1 || out->type() != typeId<AccType>()) { // if the out tensor does not have elementType AccType we need to allocate and convert later
|
||||
MemoryPiece::PtrType temporaryMemory;
|
||||
if(allocator) {
|
||||
temporaryMemory = allocator->alloc<AccType>(blocks);
|
||||
} else { // @TODO: get rid of this branch
|
||||
uint8_t* temporaryMemoryPtr = 0;
|
||||
CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType) * blocks));
|
||||
temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType) * blocks); // @TODO: consider implementing MemoryPiece::cudaMalloc<T>(size) for managed memory
|
||||
}
|
||||
blockMem = TensorBase::New(temporaryMemory,
|
||||
Shape({blocks}),
|
||||
typeId<AccType>(),
|
||||
out->getBackend());
|
||||
blockMem->set(aggInit); // set temporary memory to aggInit
|
||||
}
|
||||
else { // we are reducing into a single element now and the type matches, just use out as memory
|
||||
blockMem = out; // do not set final output memory as we might be summing gradients... needs to be handled outside this function
|
||||
}
|
||||
|
||||
functional::Tensor<AccType> gBlockMem = blockMem;
|
||||
reduceSinglePass<T, AccType>(functor, aggInit, aggFunctor, scale, full, /*out=*/gBlockMem, /*in=*/gIns, threads, blocks); // First pass reduction into intermediate memory
|
||||
|
||||
// If we actually needed more than one block to perform the first pass reduction, recursively run a second pass reduction over block memory until block memory has size 1.
|
||||
if(blocks > 1) {
|
||||
using namespace functional;
|
||||
auto identity = _1; // transformation was done in first pass, hence only identity
|
||||
AggregateAll<AccType, AccType>(allocator, identity, aggInit, aggFunctor, scale, out, /*in=*/blockMem); // Reducing AccType in AccType now (meta-reduction)
|
||||
} else if(out->type() != typeId<AccType>()) { // it's only a single block, but we need to convert to different type, as mentioned above
|
||||
CopyCast(out, blockMem);
|
||||
}
|
||||
|
||||
if(blockMem != out) {
|
||||
// Free temporary memory whether allocated in allocator or via cudaMalloc
|
||||
if(allocator)
|
||||
allocator->free(blockMem->memory());
|
||||
else if(blockMem->memory()->data())
|
||||
CUDA_CHECK(cudaFree(blockMem->memory()->data()));
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregates all values into a single tensor and returns the value of that tensor as a float
|
||||
// This does a GPU to CPU memory copy via TensorBase::scalar().
|
||||
// Used currently only for L2Norm computation
|
||||
template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
|
||||
AccType AggregateAllAndReturn(Ptr<Allocator> allocator,
|
||||
Functor functor,
|
||||
AccType aggInit,
|
||||
AggFunctor aggFunctor,
|
||||
AccType scale,
|
||||
const Tensors... tensors) {
|
||||
MemoryPiece::PtrType temporaryMemory;
|
||||
if(allocator) {
|
||||
temporaryMemory = allocator->alloc<AccType>(1);
|
||||
} else { // @TODO: get rid of this branch
|
||||
uint8_t* temporaryMemoryPtr = 0;
|
||||
CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType)));
|
||||
temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType));
|
||||
}
|
||||
|
||||
std::tuple<Tensors...> in(tensors...);
|
||||
|
||||
// Create a temporary tensor of size 1 to reduce into
|
||||
auto out = TensorBase::New(temporaryMemory,
|
||||
Shape({1}),
|
||||
typeId<AccType>(),
|
||||
std::get<0>(in)->getBackend());
|
||||
out->set(aggInit); // init to aggInit
|
||||
|
||||
AggregateAll<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, tensors...);
|
||||
|
||||
AccType outScalar = out->template scalar<AccType>(); // convert to float also if other underlying type
|
||||
|
||||
if(allocator)
|
||||
allocator->free(out->memory());
|
||||
else if(out->memory()->data()) // @TODO: get rid of this branch
|
||||
CUDA_CHECK(cudaFree(out->memory()->data()));
|
||||
|
||||
return outScalar;
|
||||
}
|
||||
|
||||
}
|
@ -1,5 +1,3 @@
|
||||
//#include <thrust/transform_reduce.h>
|
||||
|
||||
#include "common/types.h"
|
||||
#include "tensors/tensor_operators.h"
|
||||
|
||||
@ -9,7 +7,7 @@
|
||||
#include "tensors/gpu/backend.h"
|
||||
#include "tensors/gpu/cuda_helpers.h"
|
||||
|
||||
#include "3rd_party/reduce_all.h"
|
||||
#include "tensors/gpu/add_all.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -588,6 +586,8 @@ __global__ void gSoftmax(T* out,
|
||||
|
||||
// determine max (used below to improve numeric stability)
|
||||
T* _max = _share;
|
||||
|
||||
// @TODO: what's going on here with fp16?
|
||||
_max[threadIdx.x] = -CUDA_FLT_MAX; // mask
|
||||
// find max over column indices that have the same relative column index (=threadIdx.x) across all blocks of columns
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
@ -1661,53 +1661,27 @@ void CrossEntropyPickBackward(Tensor out, Tensor adj, Tensor a, Tensor indices)
|
||||
}
|
||||
}
|
||||
|
||||
float L2Norm(Tensor in, Ptr<Allocator> allocator) {
|
||||
// computes the L2Norm of tensor and returns value as flaot on the CPU,
|
||||
// this is mostly used for diagnostic purposes and gradient clipping
|
||||
float L2Norm(Tensor in, Ptr<Allocator> allocator) { // @TODO: reverse order of arguments
|
||||
cudaSetDevice(in->getDeviceId().no);
|
||||
|
||||
int size = in->shape().elements();
|
||||
int threads = std::min(MAX_THREADS, size);
|
||||
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
|
||||
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
|
||||
|
||||
if(allocator) {
|
||||
auto memoryPiece = allocator->alloc<float>(blocks);
|
||||
auto blockMem = TensorBase::New(memoryPiece, Shape({1, blocks}), Type::float32, in->getBackend());
|
||||
|
||||
using namespace functional;
|
||||
if(in->type() == Type::float32) {
|
||||
ReduceAll<float, float>(_1 * _1, blockMem, in);
|
||||
using namespace functional;
|
||||
float l2Norm;
|
||||
if(in->type() == Type::float32) {
|
||||
l2Norm = std::sqrt(AggregateAllAndReturn</*ElementType=*/float, /*AccType=*/float>(allocator, /*functor=*/_1 * _1, /*aggInit=*/0.f, /*aggFunctor=*/_1 + _2, /*scale=*/1.f, in));
|
||||
#if COMPILE_FP16
|
||||
} else if(in->type() == Type::float16) {
|
||||
ReduceAll<half, float>(_1 * _1, blockMem, in);
|
||||
} else if(in->type() == Type::float16) {
|
||||
l2Norm = std::sqrt(AggregateAllAndReturn</*ElementType=*/half, /*AccType=*/float>(allocator, /*functor=*/_1 * _1, /*aggInit=*/0.f, /*aggFunctor=*/_1 + _2, /*scale=*/1.f, in));
|
||||
#endif
|
||||
} else {
|
||||
ABORT("L2Norm not implemented for type {}", in->type());
|
||||
}
|
||||
float dataCpu = sqrtf(blockMem->get<float>(0));
|
||||
allocator->free(memoryPiece);
|
||||
return dataCpu;
|
||||
} else { // @TODO: this branch is to be removed with next PR, old version
|
||||
uint8_t* data;
|
||||
cudaMalloc(&data, blocks * sizeof(float));
|
||||
Tensor out(TensorBase::New(MemoryPiece::New(data, blocks * sizeof(float)),
|
||||
Shape({1, blocks}),
|
||||
Type::float32,
|
||||
in->getBackend()));
|
||||
|
||||
using namespace functional;
|
||||
if(in->type() == Type::float32) {
|
||||
ReduceAll<float, float>(_1 * _1, out, in);
|
||||
#if COMPILE_FP16
|
||||
} else if(in->type() == Type::float16) {
|
||||
ReduceAll<half, float>(_1 * _1, out, in);
|
||||
#endif
|
||||
} else {
|
||||
ABORT("L2Norm not implemented for type {}", in->type());
|
||||
}
|
||||
float dataCpu = sqrtf(out->get<float>(0));
|
||||
out.reset();
|
||||
cudaFree(data);
|
||||
return dataCpu;
|
||||
} else {
|
||||
ABORT("L2Norm not implemented for type {}", in->type());
|
||||
}
|
||||
return l2Norm;
|
||||
}
|
||||
|
||||
template <typename T, typename AccType = float>
|
||||
@ -1761,22 +1735,22 @@ __global__ void gAtt(T* out,
|
||||
void Att(Tensor out, Tensor va, Tensor context, Tensor state) {
|
||||
cudaSetDevice(out->getDeviceId().no);
|
||||
|
||||
size_t m = out->shape().elements() / out->shape().back();
|
||||
size_t k = context->shape()[-1];
|
||||
size_t b = context->shape()[-2];
|
||||
size_t t = context->shape()[-3];
|
||||
size_t totalRows = out->shape().elements() / out->shape().back(); // number of rows
|
||||
size_t modelDim = context->shape()[-1]; // number of cols
|
||||
size_t batchDim = context->shape()[-2];
|
||||
size_t contextWordsDim = context->shape()[-3];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)m);
|
||||
int threads = std::min(MAX_THREADS, (int)k);
|
||||
int blocks = std::min(MAX_BLOCKS, (int)totalRows);
|
||||
int threads = std::min(MAX_THREADS, (int)modelDim);
|
||||
int shared = sizeof(float) * threads;
|
||||
|
||||
if(out->type() == Type::float32) {
|
||||
gAtt<float, float><<<blocks, threads, shared>>>(
|
||||
out->data<float>(), va->data<float>(), context->data<float>(), state->data<float>(), m, k, b, t);
|
||||
out->data<float>(), va->data<float>(), context->data<float>(), state->data<float>(), totalRows, modelDim, batchDim, contextWordsDim);
|
||||
#if COMPILE_FP16
|
||||
} else if (out->type() == Type::float16) {
|
||||
gAtt<half, float><<<blocks, threads, shared>>>(
|
||||
out->data<half>(), va->data<half>(), context->data<half>(), state->data<half>(), m, k, b, t);
|
||||
out->data<half>(), va->data<half>(), context->data<half>(), state->data<half>(), totalRows, modelDim, batchDim, contextWordsDim);
|
||||
#endif
|
||||
} else {
|
||||
ABORT("gAtt not implemented for type {}", out->type());
|
||||
@ -2005,10 +1979,10 @@ __global__ void gLayerNormalizationGrad(T* gradX,
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
AccType* sum_adj = shared;
|
||||
AccType* sum_adj_x = shared + blockDim.x;
|
||||
AccType* sum_x = shared + 2 * blockDim.x;
|
||||
AccType* sum_sqr = shared + 3 * blockDim.x;
|
||||
AccType* sum_adj = shared; // sum of gradient coming in
|
||||
AccType* sum_adj_l = shared + blockDim.x; // sum of gradient coming in times layerNorm from value
|
||||
AccType* sum_x = shared + 2 * blockDim.x; // sum of input value x
|
||||
AccType* sum_sqr = shared + 3 * blockDim.x; // sum of (x - mean)^2
|
||||
|
||||
const T* xRow = x + j * cols;
|
||||
const T* yRow = y + j * cols;
|
||||
@ -2016,7 +1990,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
|
||||
|
||||
sum_x[threadIdx.x] = (AccType)0.0f;
|
||||
sum_adj[threadIdx.x] = (AccType)0.0f;
|
||||
sum_adj_x[threadIdx.x] = (AccType)0.0f;
|
||||
sum_adj_l[threadIdx.x] = (AccType)0.0f;
|
||||
sum_sqr[threadIdx.x] = (AccType)0.0f;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
@ -2030,7 +2004,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
|
||||
AccType lv = (yv - betav) / (gammav + eps); // go back to LN(x) from scaled and shifted version for accumulation
|
||||
|
||||
sum_x[threadIdx.x] += xv;
|
||||
sum_adj_x[threadIdx.x] += adjv * lv;
|
||||
sum_adj_l[threadIdx.x] += adjv * lv;
|
||||
sum_adj[threadIdx.x] += adjv;
|
||||
}
|
||||
}
|
||||
@ -2042,7 +2016,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
|
||||
if(threadIdx.x < (len >> 1)) {
|
||||
sum_x[threadIdx.x] += sum_x[threadIdx.x + skip]; // Accumulates in AccType
|
||||
sum_adj[threadIdx.x] += sum_adj[threadIdx.x + skip]; // Accumulates in AccType
|
||||
sum_adj_x[threadIdx.x] += sum_adj_x[threadIdx.x + skip]; // Accumulates in AccType
|
||||
sum_adj_l[threadIdx.x] += sum_adj_l[threadIdx.x + skip]; // Accumulates in AccType
|
||||
}
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
@ -2074,28 +2048,27 @@ __global__ void gLayerNormalizationGrad(T* gradX,
|
||||
|
||||
// Jacobian of layer norm
|
||||
// J = [ \frac{1}{N\sigma} (N\delta_{ij} - l_i l_j - 1) ]_{ij}
|
||||
// J * a = dC/dx_i = ( N v_i - l_i \sum_j l_j a_j - \sum_j a_j ) / (N \sigma)
|
||||
// J * a = dC/dx_i = ( N a_i - l_i \sum_j l_j a_j - \sum_j a_j ) / (N \sigma)
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
|
||||
AccType xv = xRow[id];
|
||||
//AccType yv = yRow[id];
|
||||
//AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
|
||||
AccType gammav = (AccType)gamma[id];
|
||||
AccType adjv = adjRow[id];
|
||||
AccType lv = (xv - mean) / (sigma + eps);
|
||||
|
||||
AccType gradLv = N * adjv - lv * sum_adj_x[0] - sum_adj[0];
|
||||
AccType gradLv = N * adjv - lv * sum_adj_l[0] - sum_adj[0];
|
||||
gradLv /= N * (sigma + eps); // eps has to be inside parentheses for correct gradient
|
||||
|
||||
AccType gradXv = gammav * gradLv;
|
||||
|
||||
// Keep LN gradient between [-10, 10]
|
||||
// AccType sign = functional::Ops<AccType>::sgn(gradXv);
|
||||
// AccType cutoff = (AccType)10.f;
|
||||
// gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv;
|
||||
// Keep LN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. @TODO: to be fixed and removed.
|
||||
AccType sign = functional::Ops<AccType>::sgn(gradXv);
|
||||
AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option?
|
||||
// or better: make obsolete.
|
||||
gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv;
|
||||
|
||||
T* gradXRow = gradX + j * cols;
|
||||
gradXRow[id] += (T)(gradXv);
|
||||
|
@ -28,18 +28,20 @@ std::string TensorBase::debug(int precision, int dispCols) {
|
||||
else
|
||||
strm << std::fixed << std::setprecision(0) << std::setfill(' ');
|
||||
|
||||
// double maxv = std::numeric_limits<double>::lowest();
|
||||
// double minv = std::numeric_limits<double>::max();
|
||||
// double l2Norm = 0.0;
|
||||
double maxv = std::numeric_limits<double>::lowest();
|
||||
double minv = std::numeric_limits<double>::max();
|
||||
double l2Sum = 0.0;
|
||||
for(int i = 0; i < values.size(); ++i) {
|
||||
if((double)values[i] > maxv) maxv = values[i];
|
||||
if((double)values[i] < minv) minv = values[i];
|
||||
l2Sum += (double)values[i] * (double)values[i];
|
||||
}
|
||||
strm << "min: " << minv << " max: " << maxv << " l2-norm: " << sqrt(l2Sum) << std::endl;
|
||||
|
||||
for(int i = 0; i < values.size(); ++i) {
|
||||
std::vector<int> dims;
|
||||
shape().dims(i, dims);
|
||||
|
||||
// if((double)values[i] > maxv) maxv = values[i];
|
||||
// if((double)values[i] < minv) minv = values[i];
|
||||
// l2Norm += (double)values[i] * (double)values[i];
|
||||
|
||||
bool disp = true;
|
||||
for(int j = 0; j < dims.size(); ++j)
|
||||
disp = disp && (dims[j] < dispCols || dims[j] >= shape()[j] - dispCols);
|
||||
@ -95,8 +97,6 @@ std::string TensorBase::debug(int precision, int dispCols) {
|
||||
}
|
||||
}
|
||||
strm << std::endl;
|
||||
//strm << "min: " << minv << " max: " << maxv << " l2-norm: " << sqrt(l2Norm);
|
||||
|
||||
return strm.str();
|
||||
}
|
||||
|
||||
|
@ -54,12 +54,12 @@ void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
|
||||
gpu::Add(functor, scale, out, tensors...);
|
||||
else
|
||||
#endif
|
||||
cpu::Aggregate(functor, 0.0f, functional::_1 + functional::_2, scale, out, tensors...);
|
||||
cpu::Aggregate(functor, /*aggInit=*/0.0f, functional::_1 + functional::_2, scale, out, tensors...);
|
||||
}
|
||||
|
||||
template <class Functor, class... Tensors>
|
||||
void Add(Functor functor, marian::Tensor out, Tensors... tensors) {
|
||||
Add(functor, 1, out, tensors...);
|
||||
Add(functor, /*scale=*/1.f, out, tensors...);
|
||||
}
|
||||
|
||||
template <class Functor, class AggFunctor, class... Tensors>
|
||||
|
@ -789,3 +789,59 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]")
|
||||
tests<float>(DeviceType::cpu);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef BLAS_FOUND
|
||||
#ifdef CUDA_FOUND
|
||||
|
||||
TEST_CASE("Compare aggregate operator", "[graph]") {
|
||||
auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.001); };
|
||||
|
||||
Config::seed = 1234;
|
||||
|
||||
std::vector<float> initc;
|
||||
std::vector<float> inita;
|
||||
|
||||
{
|
||||
auto graph = New<ExpressionGraph>();
|
||||
graph->setDevice({0, DeviceType::cpu});
|
||||
graph->reserveWorkspaceMB(40);
|
||||
|
||||
auto chl = graph->param("1x10x512x2048", {1, 10, 512, 2048}, inits::normal());
|
||||
auto adj = graph->param("1x1x512x2048", {1, 1, 512, 2048}, inits::normal());
|
||||
graph->forward();
|
||||
|
||||
chl->val()->get(initc);
|
||||
adj->val()->get(inita);
|
||||
}
|
||||
|
||||
SECTION("initializing with zero (cpu)") {
|
||||
std::vector<float> values1;
|
||||
std::vector<float> values2;
|
||||
|
||||
auto graph1 = New<ExpressionGraph>();
|
||||
graph1->setDevice({0, DeviceType::cpu});
|
||||
graph1->reserveWorkspaceMB(40);
|
||||
|
||||
auto graph2 = New<ExpressionGraph>();
|
||||
graph2->setDevice({0, DeviceType::gpu});
|
||||
graph2->reserveWorkspaceMB(40);
|
||||
|
||||
auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
|
||||
auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
|
||||
auto prod1 = scalar_product(chl1, adj1, -1);
|
||||
graph1->forward();
|
||||
|
||||
auto chl2 = graph2->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
|
||||
auto adj2 = graph2->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
|
||||
auto prod2 = scalar_product(chl2, adj2, -1);
|
||||
graph2->forward();
|
||||
|
||||
prod1->val()->get(values1);
|
||||
prod2->val()->get(values2);
|
||||
|
||||
CHECK( std::equal(values1.begin(), values1.end(), values2.begin(), floatApprox) );
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user