Merged PR 10999: Splitting up add_all.h into *.h, *.cu and *.inc

Splitting up header file into header and *.cu, comes with the price of having to include specializations for combinations of types as for element.inc and add.inc. No code changes otherwise.

Add CMake options to disable specific compute capabilities.

When run with `make -j16` this compiles in about 6 minutes instead of 7 minutes. Selecting only SM70 during compilation brings down the time to 3 minutes.
This commit is contained in:
Martin Junczys-Dowmunt 2020-01-06 19:14:00 +00:00
parent 88d9980589
commit 164d26cc36
9 changed files with 261 additions and 87 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Add CMAKE options to disable compilation for specific GPU SM types
- An option to print word-level translation scores
- An option to turn off automatic detokenization from SentencePiece
- Separate quantization types for 8-bit FBGEMM for AVX2 and AVX512

View File

@ -13,6 +13,10 @@ set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.")
# Custom CMake options
option(COMPILE_CPU "Compile CPU version" ON)
option(COMPILE_CUDA "Compile GPU version" ON)
option(COMPILE_CUDA_SM35 "Compile GPU version with SM35 support" ON)
option(COMPILE_CUDA_SM50 "Compile GPU version with SM50 support" ON)
option(COMPILE_CUDA_SM60 "Compile GPU version with SM60 support" ON)
option(COMPILE_CUDA_SM70 "Compile GPU version with SM70 support" ON)
option(COMPILE_EXAMPLES "Compile examples" OFF)
option(COMPILE_SERVER "Compile marian-server" OFF)
option(COMPILE_TESTS "Compile tests" OFF)
@ -181,8 +185,6 @@ set(EXT_LIBS ${EXT_LIBS} ${CMAKE_DL_LIBS})
if(COMPILE_CUDA)
LIST(APPEND COMPUTE -arch=sm_35; -gencode=arch=compute_35,code=sm_35; -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52; -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61;)
if(USE_STATIC_LIBS)
# link statically to stdlib libraries
set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++")
@ -202,10 +204,19 @@ if(CUDA_FOUND)
if((CUDA_VERSION VERSION_EQUAL "10.0" OR CUDA_VERSION VERSION_GREATER "10.0") AND (CMAKE_VERSION VERSION_LESS "3.12.2"))
message(WARNING "On some Unix systems CUDA 10.0+ requires CMake 3.12.2+; you use CMake ${CMAKE_VERSION}")
endif()
if(CUDA_VERSION VERSION_GREATER "8.0")
LIST(APPEND COMPUTE -gencode=arch=compute_70,code=sm_70; -gencode=arch=compute_70,code=compute_70)
endif()
if(COMPILE_CUDA_SM35)
LIST(APPEND COMPUTE -arch=sm_35; -gencode=arch=compute_35,code=sm_35;) # Tesla K40 and above
endif(COMPILE_CUDA_SM35)
if(COMPILE_CUDA_SM50)
LIST(APPEND COMPUTE -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52;) # Maxwell GPUs
endif(COMPILE_CUDA_SM50)
if(COMPILE_CUDA_SM60)
LIST(APPEND COMPUTE -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61;) # Pascal GPUs
endif(COMPILE_CUDA_SM60)
if(COMPILE_CUDA_SM70)
LIST(APPEND COMPUTE -gencode=arch=compute_70,code=sm_70; -gencode=arch=compute_70,code=compute_70) # Volta GPUs
endif(COMPILE_CUDA_SM70)
if(USE_STATIC_LIBS)
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)

View File

@ -1 +1 @@
v1.8.37
v1.8.38

View File

@ -71,8 +71,23 @@ set(INSTALLS "") # this will contain a list of 3rd part dependencies that we ins
if(CUDA_FOUND)
if(USE_NCCL)
# disables compilation for sm_30 to avoid ptxas warning... that's general Kepler support. But K80s are supported for instance by sm_35
set(GENCODE "-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61")
# disables compilation for sm_30 to avoid ptxas warning... that is general Kepler support. But K80s are supported for instance by sm_35
set(GENCODE "")
if(COMPILE_CUDA_SM35)
set(GENCODE "${GENCODE} -gencode=arch=compute_35,code=sm_35")
endif(COMPILE_CUDA_SM35)
if(COMPILE_CUDA_SM50)
set(GENCODE "${GENCODE} -gencode=arch=compute_50,code=sm_50")
endif(COMPILE_CUDA_SM50)
if(COMPILE_CUDA_SM60)
set(GENCODE "${GENCODE} -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61")
endif(COMPILE_CUDA_SM60)
if(COMPILE_CUDA_SM70)
set(GENCODE "${GENCODE} -gencode=arch=compute_70,code=sm_70")
endif(COMPILE_CUDA_SM70)
message(${GENCODE})
# install nccl in ${CMAKE_BINARY_DIR}/local similar to /usr/local linux installation
ExternalProject_Add(nccl_install

View File

@ -138,7 +138,8 @@ cuda_add_library(marian_cuda
tensors/gpu/algorithm.cu
tensors/gpu/prod.cpp
tensors/gpu/element.cu
tensors/gpu/add.cu
tensors/gpu/add.cu
tensors/gpu/add_all.cu
tensors/gpu/tensor_operators.cu
tensors/gpu/cudnn_wrappers.cu
translator/nth_element.cu

View File

@ -15,21 +15,21 @@ template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Div, Assignee<1>, Assignee<2> >, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Div, Assignee<1>, Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<1> >, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<1> >, float, marian::Tensor, marian::Tensor);
template void marian::gpu::Aggregate<marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Min, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::Tensor >(marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor<marian::functional::elem::Min, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor);
template void marian::gpu::Aggregate<marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::Tensor >(marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor);
template void marian::gpu::Aggregate<marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::Tensor >(marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor);
template void marian::gpu::Aggregate<marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::LogAddExp, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::Tensor >(marian::functional::Assignee<1>, float, marian::functional::BinaryFunctor<marian::functional::elem::LogAddExp, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, marian::Tensor, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Exp, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<2>, marian::functional::Assignee<3> > > >, marian::Tensor, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Exp, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<2>, marian::functional::Assignee<3> > > >, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, marian::Tensor, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Assignee<2> >, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> > > > > >, marian::Tensor, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> > > > > >, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, marian::Tensor, marian::Tensor);
template void Aggregate<Assignee<1>, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void Aggregate<Assignee<1>, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void Aggregate<Assignee<1>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void Aggregate<Assignee<1>, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);

116
src/tensors/gpu/add_all.cu Normal file
View File

@ -0,0 +1,116 @@
#include "tensors/gpu/add_all.h"
#include "tensors/gpu/cuda_helpers.h"
#include "functional/functional.h"
#include "tensors/tensor_operators.h"
#include "3rd_party/reduce_all.h" // only works with CUDA >9.0, we are dropping CUDA 8.0 support, also changed in CMakeLists.txt
namespace marian {
#if COMPILE_FP16
// local overload to determine tensor type
template <> inline Type typeId<half>() { return Type::float16; }
#endif
// Version with variadic template arguments, called by version with explicit arguments below
template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
void AggregateAllVar(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensors... tensors) {
cudaSetDevice(out->getDeviceId().no);
static_assert(CUDA_VERSION >= 9000, "Marian requires CUDA_VERSION >= 9000 (9.0)");
constexpr size_t K = sizeof...(Tensors); // obtain arity K of tensors...
functional::Array<functional::Tensor<T>, K> gIns = {tensors...}; // convert to array of K objects of type functional::Tensor<T>
functional::Shape full = marian::Shape::broadcast({tensors...}); // compute maximal broadcasted shape
int size = full.elements();
int threads = (size < MAX_THREADS * 2) ? nextPow2((size + 1) / 2) : MAX_THREADS; // suggested in NVidia example for the all_reduce kernel
int blocks = std::min(MAX_BLOCKS, (size + (threads * 2 - 1)) / (threads * 2)); // suggested in NVidia example for the all_reduce kernel
// The all_reduce kernel by nivida needs to perform multiple passes if the number of blocks needed to perform the reduction is larger than 1.
// Here we allocate the memory for the intermediate reductions for each block.
Tensor blockMem;
if(blocks > 1 || out->type() != typeId<AccType>()) { // if the out tensor does not have elementType AccType we need to allocate and convert later
MemoryPiece::PtrType temporaryMemory;
if(allocator) {
temporaryMemory = allocator->alloc<AccType>(blocks);
} else { // @TODO: get rid of this branch
uint8_t* temporaryMemoryPtr = 0;
CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType) * blocks));
temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType) * blocks); // @TODO: consider implementing MemoryPiece::cudaMalloc<T>(size) for managed memory
}
blockMem = TensorBase::New(temporaryMemory,
Shape({blocks}),
typeId<AccType>(),
out->getBackend());
blockMem->set(aggInit); // set temporary memory to aggInit
}
else { // we are reducing into a single element now and the type matches, just use out as memory
blockMem = out; // do not set final output memory as we might be summing gradients... needs to be handled outside this function
}
functional::Tensor<AccType> gBlockMem = blockMem;
reduceSinglePass<T, AccType>(functor, aggInit, aggFunctor, scale, full, /*out=*/gBlockMem, /*in=*/gIns, threads, blocks); // First pass reduction into intermediate memory
// If we actually needed more than one block to perform the first pass reduction, recursively run a second pass reduction over block memory until block memory has size 1.
if(blocks > 1) {
using namespace functional;
auto identity = _1; // transformation was done in first pass, hence only identity
AggregateAll<AccType, AccType>(allocator, identity, aggInit, aggFunctor, scale, out, /*in=*/blockMem); // Reducing AccType in AccType now (meta-reduction)
} else if(out->type() != typeId<AccType>()) { // it's only a single block, but we need to convert to different type, as mentioned above
CopyCast(out, blockMem);
}
if(blockMem != out) {
// Free temporary memory whether allocated in allocator or via cudaMalloc
if(allocator)
allocator->free(blockMem->memory());
else if(blockMem->memory()->data())
CUDA_CHECK(cudaFree(blockMem->memory()->data()));
}
}
template <typename T, typename AccType, class Functor, class AggFunctor>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensor in1) {
AggregateAllVar<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, in1);
}
template <typename T, typename AccType, class Functor, class AggFunctor>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensor in1,
const Tensor in2) {
AggregateAllVar<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, in1, in2);
}
template <typename T, typename AccType, class Functor, class AggFunctor>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensor in1,
const Tensor in2,
const Tensor in3) {
AggregateAllVar<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, in1, in2, in3);
}
#include "tensors/gpu/add_all.inc"
}

View File

@ -3,87 +3,46 @@
// This header file provides wrappers around NVidia's reduce_all kernel with our custom aggregation functionality
// This kernel reduces a tensor into a single value. We have modified it to allow for different types of aggregations
// like summing or max etc.
#include "tensors/gpu/cuda_helpers.h"
#include "tensors/tensor.h"
#include "tensors/allocator.h"
#include "functional/functional.h"
#include "functional/tensor.h"
#include "tensors/tensor_operators.h"
#include "3rd_party/reduce_all.h" // only works with CUDA >9.0, we are dropping CUDA 8.0 support, also changed in CMakeLists.txt
#include <cuda.h>
namespace marian {
#if COMPILE_FP16
// local overload to determine tensor type
template <> inline Type typeId<half>() { return Type::float16; }
#endif
template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
// These function declarations are repeated as template specialization with variadic template arguments does not seem to work.
// Here I am just creating version for 1, 2, and 3 arguments. To be extended if required.
template <typename T, typename AccType, class Functor, class AggFunctor>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensors... tensors) {
cudaSetDevice(out->getDeviceId().no);
const Tensor in1);
static_assert(CUDA_VERSION >= 9000, "Marian requires CUDA_VERSION >= 9000 (9.0)");
template <typename T, typename AccType, class Functor, class AggFunctor>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensor in1,
const Tensor in2);
constexpr size_t K = sizeof...(Tensors); // obtain arity K of tensors...
functional::Array<functional::Tensor<T>, K> gIns = {tensors...}; // convert to array of K objects of type functional::Tensor<T>
functional::Shape full = marian::Shape::broadcast({tensors...}); // compute maximal broadcasted shape
int size = full.elements();
int threads = (size < MAX_THREADS * 2) ? nextPow2((size + 1) / 2) : MAX_THREADS; // suggested in NVidia example for the all_reduce kernel
int blocks = std::min(MAX_BLOCKS, (size + (threads * 2 - 1)) / (threads * 2)); // suggested in NVidia example for the all_reduce kernel
// The all_reduce kernel by nivida needs to perform multiple passes if the number of blocks needed to perform the reduction is larger than 1.
// Here we allocate the memory for the intermediate reductions for each block.
Tensor blockMem;
if(blocks > 1 || out->type() != typeId<AccType>()) { // if the out tensor does not have elementType AccType we need to allocate and convert later
MemoryPiece::PtrType temporaryMemory;
if(allocator) {
temporaryMemory = allocator->alloc<AccType>(blocks);
} else { // @TODO: get rid of this branch
uint8_t* temporaryMemoryPtr = 0;
CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType) * blocks));
temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType) * blocks); // @TODO: consider implementing MemoryPiece::cudaMalloc<T>(size) for managed memory
}
blockMem = TensorBase::New(temporaryMemory,
Shape({blocks}),
typeId<AccType>(),
out->getBackend());
blockMem->set(aggInit); // set temporary memory to aggInit
}
else { // we are reducing into a single element now and the type matches, just use out as memory
blockMem = out; // do not set final output memory as we might be summing gradients... needs to be handled outside this function
}
functional::Tensor<AccType> gBlockMem = blockMem;
reduceSinglePass<T, AccType>(functor, aggInit, aggFunctor, scale, full, /*out=*/gBlockMem, /*in=*/gIns, threads, blocks); // First pass reduction into intermediate memory
// If we actually needed more than one block to perform the first pass reduction, recursively run a second pass reduction over block memory until block memory has size 1.
if(blocks > 1) {
using namespace functional;
auto identity = _1; // transformation was done in first pass, hence only identity
AggregateAll<AccType, AccType>(allocator, identity, aggInit, aggFunctor, scale, out, /*in=*/blockMem); // Reducing AccType in AccType now (meta-reduction)
} else if(out->type() != typeId<AccType>()) { // it's only a single block, but we need to convert to different type, as mentioned above
CopyCast(out, blockMem);
}
if(blockMem != out) {
// Free temporary memory whether allocated in allocator or via cudaMalloc
if(allocator)
allocator->free(blockMem->memory());
else if(blockMem->memory()->data())
CUDA_CHECK(cudaFree(blockMem->memory()->data()));
}
}
template <typename T, typename AccType, class Functor, class AggFunctor>
void AggregateAll(Ptr<Allocator> allocator,
Functor functor,
AccType aggInit,
AggFunctor aggFunctor,
AccType scale,
Tensor out,
const Tensor in1,
const Tensor in2,
const Tensor in3);
// Aggregates all values into a single tensor and returns the value of that tensor as a float
// This does a GPU to CPU memory copy via TensorBase::scalar().

View File

@ -0,0 +1,71 @@
// see element.inc for instructions on how to maintain this
using namespace functional;
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, UnaryFunctor<elem::Neg, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, UnaryFunctor<elem::Neg, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
#if COMPILE_FP16
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, UnaryFunctor<elem::Neg, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, UnaryFunctor<elem::Neg, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
#endif