mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Add graph operations documentation (#801)
* Doxygen structure for expression graph operators * Document arithmetic expression operations * Document comparison expression operations * Document exp/log and trig operations * Add missing implementation for cos/tan * Document expression manipulation operations * Document misc math operations * Overview of operators * Document activation functions * Document element-wise min/max * Document debugging/checkpoint operators * Document topk/argmin/argmax operations * Document index-based operations * Document reduction operations * Document lambda expression operators * Document product operations * Document softmax, cross-entropy, unlikelihood operations * Document dropout operations * Document scalar product and weighted average operations * Document layer normalization, highway and pooling operations * Document shift expression operator * Extra details on rules for adding specializations to .inc files * Add SinNodeOp example for specialization documentation * Additional details in tensor operator documentation * Remove brief command from doxygen comments * Prefer @ style doxygen functions to \ * Document n-ary function macros * Enable .cu and .inc files in documentation * Add a comment about ONNX mapping * Remove empty lines in doxygen * Update CHANGELOG Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
This commit is contained in:
parent
2a9c0bb377
commit
ac71ee8518
@ -11,8 +11,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
### Added
|
||||
- Developer documentation framework based on Sphinx+Doxygen+Breathe+Exhale
|
||||
|
||||
|
||||
### Fixed
|
||||
- Fix building server with Boost 1.75
|
||||
- Missing implementation for cos/tan expression operator
|
||||
|
||||
## [1.10.0] - 2021-02-06
|
||||
|
||||
|
@ -85,7 +85,9 @@ doxygen_config = """
|
||||
INPUT = ../src
|
||||
EXCLUDE += ../src/3rd_party
|
||||
EXCLUDE += ../src/tests
|
||||
EXCLUDE_PATTERNS = *.inc *.md *.txt
|
||||
EXCLUDE_PATTERNS = *.md *.txt
|
||||
FILE_PATTERNS += *.cu
|
||||
EXTENSION_MAPPING += cu=C++ inc=C++
|
||||
ENABLE_PREPROCESSING = YES
|
||||
JAVADOC_AUTOBRIEF = YES
|
||||
WARN_IF_UNDOCUMENTED = NO
|
||||
|
554
doc/operators.md
554
doc/operators.md
@ -1,3 +1,553 @@
|
||||
# Graph operations
|
||||
# Operations in the Expression Graph
|
||||
|
||||
...
|
||||
Operations are responsible for manipulating the elements of an expression graph.
|
||||
In Marian, many useful operations have already been implemented and can be found
|
||||
the code documentation. The provided operations cover simple arithmetic, logical
|
||||
comparisons and common mathematical functions; as well as tensor manipulation,
|
||||
for example `slice` or `reshape`, and aggregations such as `sum` or `minimum`.
|
||||
Finally, other routines, such as activation functions, useful in building
|
||||
neutral networks are also available.
|
||||
|
||||
There are several necessary components required to implement an operation in
|
||||
Marian's expression graph. The highest-level component is the Expression
|
||||
Operator, responsible for setting up the Node Operator and adding it to the
|
||||
graph. Next, this Node Operator describes the nature of the forward and backward
|
||||
operation to be performed. These operations are implemented using some
|
||||
combination of Functional Operators (element wise), and Tensor Operators.
|
||||
|
||||
This overview aims to provide information about what each of the different
|
||||
operator components does, how they fit together and where to go to make changes.
|
||||
Then, equipped with this knowledge, to be able to add new functionality to
|
||||
Marian.
|
||||
|
||||
## Operator Structure
|
||||
|
||||
The central component in the graph is the `Chainable<Tensor>` object. This
|
||||
object provides the abstract interface necessary to interact with elements in
|
||||
the computation graph. The details of this interface can be found in
|
||||
[/src/graph/chainable.h](api/file_src_graph_chainable.h.html). Note that the
|
||||
template parameter corresponds to the underlying data structure, which in Marian
|
||||
is the `Tensor`. Therefore, for convenience, the type `Expr` is defined:
|
||||
|
||||
```cpp
|
||||
typedef IPtr<Chainable<Tensor>> Expr;
|
||||
```
|
||||
|
||||
The implementation of the different operator components are divided across
|
||||
several files:
|
||||
|
||||
- Expression Operator
|
||||
- [/src/graph/expression_operators.h](api/file_src_graph_expression_operators.h.html)
|
||||
- [/src/graph/expression_operators.cpp](api/file_src_graph_expression_operators.cpp.html)
|
||||
- Node Operator
|
||||
- [/src/graph/node_operators_unary.h](api/file_src_graph_node_operators_unary.h.html)
|
||||
- [/src/graph/node_operators_binary.h](api/file_src_graph_node_operators_binary.h.html)
|
||||
- [/src/graph/node_operators_tuple.h](api/file_src_graph_node_operators_tuple.h.html)
|
||||
- Functional Operator
|
||||
- [/src/functional/operators.h](api/file_src_functional_operators.h.html)
|
||||
- Tensor operation
|
||||
- [/src/tensors/tensor_operators.h](api/file_src_tensors_tensor_operators.h.html)
|
||||
- [/src/tensors/cpu/tensor_operators.cpp](api/file_src_tensors_cpu_tensor_operators.cpp.html)
|
||||
- [/src/tensors/gpu/tensor_operators.cu](api/file_src_tensors_gpu_tensor_operators.cu.html)
|
||||
- Declared Specialization
|
||||
- [/src/tensors/gpu/element.inc](api/program_listing_file_src_tensors_gpu_element.inc.html)
|
||||
- [/src/tensors/gpu/add.inc](api/program_listing_file_src_tensors_gpu_add.inc.html)
|
||||
- [/src/tensors/gpu/add_all.inc](api/program_listing_file_src_tensors_gpu_add_all.inc.html)
|
||||
|
||||
To understand how the different components are inter-linked, we'll look at each
|
||||
of them in turn.
|
||||
|
||||
|
||||
## Expression Operator
|
||||
|
||||
The expression operator is the user-facing method used when building a graph. It
|
||||
is responsible for constructing the corresponding Node Operation and inserting
|
||||
it into the expression graph. To accommodate these core requirements, the
|
||||
function `Expression` is able to perform both actions in generality:
|
||||
|
||||
```cpp
|
||||
template <class T, typename... Args>
|
||||
Expr Expression(Args&&... args) {
|
||||
auto e = Expr(new T(std::forward<Args>(args)...));
|
||||
return e->graph()->add(e);
|
||||
}
|
||||
```
|
||||
|
||||
This helper-function simplifies the definition of many expression operators. For
|
||||
example, the implementation of the expression operator `sin(x)` is simply:
|
||||
|
||||
```cpp
|
||||
// src/graph/expression_operators.h
|
||||
Expr sin(Expr x);
|
||||
|
||||
// src/graph/expression_operators.cpp
|
||||
Expr sin(Expr x) {
|
||||
return Expression<SinNodeOp>(x);
|
||||
}
|
||||
```
|
||||
|
||||
However, implementations may perform actions beyond the core functionality
|
||||
alone. Taking `sum` as an example
|
||||
|
||||
```cpp
|
||||
Expr sum(Expr a, int ax) {
|
||||
if(a->shape()[ax] == 1) {
|
||||
return a;
|
||||
}
|
||||
return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::sum);
|
||||
}
|
||||
```
|
||||
|
||||
The trivial operation is handled without needing to construct a node operation.
|
||||
This example also demonstrates a non-trivial construction of `ReduceNodeOp`,
|
||||
which is capable of performing differing reduction operations depending on
|
||||
instantiation.
|
||||
|
||||
Going further, an expression operator may be defined in terms of existing
|
||||
expressions. Operators such as `weighted_average` are composed of three
|
||||
different expression operator calls: `scalar_product`, `sum`, and `operator/`.
|
||||
|
||||
```cpp
|
||||
Expr weighted_average(Expr in, Expr weights, int ax) {
|
||||
auto p = scalar_product(in, weights, ax);
|
||||
auto s = sum(weights, ax);
|
||||
return p / s;
|
||||
}
|
||||
```
|
||||
|
||||
While useful, composition at this level may be less efficient than lower-level
|
||||
implementations.
|
||||
|
||||
|
||||
## Node Operator
|
||||
|
||||
The `Node` subclass of `Chainable<Tensor>` provides concrete implementations for
|
||||
much of the abstract interface, while subclasses of `Node` enable different node
|
||||
behaviours. In the context of operations, the relevant derived class is
|
||||
`NaryNodeOp` and is base class used for Node Operators. This subclass provides
|
||||
implementation focused on performing general N-arity operations. However, many
|
||||
common operations are unary and, for convenience, a further specialization,
|
||||
`UnaryNodeOp`, exists to simplify their definition.
|
||||
|
||||
The purpose of the Node Operator is to define the forward and backward behaviour
|
||||
of the operation. The forward operation performs the desired operation while the
|
||||
backward operation updates the gradients. These behaviours are written in terms
|
||||
of `NodeOps`, where a `NodeOp` is a wrapper to define a capturing lambda
|
||||
function. Explicitly these are defined as:
|
||||
|
||||
```cpp
|
||||
// src/graph/chainable.h
|
||||
#define NodeOp(op) [=]() { op; }
|
||||
typedef std::vector<std::function<void()>> NodeOps;
|
||||
```
|
||||
|
||||
Each `NodeOp` is written as a function in terms of the value (`val_`), gradient
|
||||
(`adj_`) of the current node, and its children, via `child()`. The values and
|
||||
gradients the n<sup>th</sup> child node are accessed via the interfaces
|
||||
`child(n)->val()` and `child(n)->grad()`, respectively. NodeOps are executed in
|
||||
order when running the graph forwards and backwards, as this snippet from `Node`
|
||||
demonstrates
|
||||
|
||||
```cpp
|
||||
// Node in src/graph/node.h
|
||||
virtual void runForward(const NodeOps& ops) {
|
||||
for(auto&& op : ops)
|
||||
op();
|
||||
}
|
||||
|
||||
virtual void runBackward(const NodeOps& ops) {
|
||||
size_t i = 0;
|
||||
for(auto&& op : ops)
|
||||
if(child(i++)->trainable())
|
||||
op();
|
||||
}
|
||||
```
|
||||
|
||||
In backwards operation it is **crucial** that the `NopeOp` responsible for
|
||||
propagating a gradient to `child(i)` is the i<sup>th</sup> element of the
|
||||
NodeOps vector. The requirement that the child associated with the NodeOp be
|
||||
trainable means that an out-of-position NodeOp may not be run. To represent no
|
||||
operation a `nullptr` can be passed as a NodeOp.
|
||||
|
||||
A typical node operator has the functionality demonstrated in the following
|
||||
snippet.
|
||||
|
||||
```cpp
|
||||
// outline of a node op
|
||||
struct MyNodeOp : public NaryNodeOp {
|
||||
MyNodeOp(Expr a)
|
||||
: NaryNodeOp({a}, newShape(...), newType(...)) {}
|
||||
|
||||
Shape newShape(...) {} // optional
|
||||
Type newType(...) {} // optional
|
||||
|
||||
const std::string type() override { return "my_node_op"; }
|
||||
virtual size_t hash() override {} // potentially required
|
||||
virtual bool equal(Expr node) override {} // potentially required
|
||||
|
||||
NodeOps forwardOps() override {}
|
||||
NodeOps backwardOps() override {}
|
||||
```
|
||||
|
||||
This outline describes a node operator that takes a single argument `a`. The
|
||||
shape and type of the node would be determined by the result of `newShape` and
|
||||
`newType` when constructing the `NaryNodeOp`. These functions represent any
|
||||
custom logic used to determine the shape and type of the node. As indicated in
|
||||
this example code, these are optional and, when omitted, calling
|
||||
`NaryNodeOp({a})` would result in a node with the same shape and type as `a`.
|
||||
The `type()` method returns the friendly name for the node. Note that the
|
||||
[ONNX](https://onnx.ai)
|
||||
[interface](api/program_listing_file_src_onnx_expression_graph_onnx_serialization.cpp.html)
|
||||
maintains a mapping of these friendly names to their ONNX representation. In the
|
||||
absence of any member variables the `hash()` and `equal()` methods can be
|
||||
omitted, and defer to their `NaryNodeOp` definition. However, if such variables
|
||||
exist then `hash()` should implement a hashed representation and `equal()`
|
||||
should provide the necessary conditions to consider nodes equivalent. Finally,
|
||||
the operations of the node are defined in `forwardOps()` and `backwardOps()`.
|
||||
|
||||
Continuing with the example of `sin(x)`, the code responsible for implementing
|
||||
the behaviour is
|
||||
|
||||
```cpp
|
||||
// src/graph/node_operators_unary.h
|
||||
struct SinNodeOp : public UnaryNodeOp {
|
||||
SinNodeOp(Expr x) : UnaryNodeOp(x) {}
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
using namespace functional;
|
||||
return {NodeOp(Element(_1 = sin(_2), val_, child(0)->val()))};
|
||||
}
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
using namespace functional;
|
||||
return {NodeOp(Add(_1 * cos(_2), child(0)->grad(), adj_, child(0)->val()))};
|
||||
}
|
||||
|
||||
const std::string type() override { return "sin"; }
|
||||
};
|
||||
```
|
||||
|
||||
In this code, the constructor trivially initialises the `UnaryNodeOp`, passing
|
||||
the expression `x` as its input. This propagates up to `NaryNodeOp` and becomes
|
||||
`child(0)` of the node. The size and type of the SinNodeOp are equivalent to
|
||||
that of `x`. The lack of any member variables allows the `hash()` and `equal()`
|
||||
methods to be omitted. The friendly name for this node is the string `sin`. The
|
||||
forward and backward implementation are accomplished using a single NodeOp each.
|
||||
|
||||
### Forward operation
|
||||
|
||||
The forward NodeOp calls the tensor operation Element, that execute the
|
||||
element-wise operation described by the functor:
|
||||
|
||||
```cpp
|
||||
_1 = sin(_2)
|
||||
```
|
||||
|
||||
The placeholders `_1`, `_2` are enabled by code in
|
||||
[/src/functional](api/dir_src_functional.html) and interoperate with the
|
||||
functional operators. In the call to `Element`, `val_` is assigned to `_1` and
|
||||
`child(0)->val()` to `_2`. Therefore, this has the action of setting the
|
||||
elements of this node to the result obtained by applying `sin` to the elements
|
||||
of `child(0)`.
|
||||
|
||||
### Backward Operation
|
||||
|
||||
The backward NodeOp is responsible for backpropagation of the gradients via
|
||||
reverse-mode automatic differentiation. In this example, where `y = sin(x)`,
|
||||
this corresponds to evaluating
|
||||
|
||||
```
|
||||
dJ/dx += dJ/dy * dy/dx, dy/dx = cos(x)
|
||||
```
|
||||
|
||||
This is realised using the tensor operator `Add` with the functor
|
||||
|
||||
```cpp
|
||||
_1 * cos(_2)
|
||||
```
|
||||
|
||||
In the call to `Add`, `adj_` is assigned to `_1` and `child(0)->val()` to `_2`.
|
||||
Therefore, this functor represents `dJ/dy * dy/dx`: the product of the gradient
|
||||
at the current node and the gradient of the operation. This value is then added
|
||||
to the gradient of the child `child(0)->grad()` as required.
|
||||
|
||||
### Shape and Type Changes
|
||||
|
||||
The `newShape` and `newType` methods are just a suggestion of how custom logic
|
||||
may be encapsulated where needed. However, in practice, many operations do not
|
||||
require a change in shape or type. In these instances, the node inherits the
|
||||
broadcasted shape of its children as well as their common type. An important
|
||||
feature of the type deduction in `NaryNodeOp::commonType()` is that it
|
||||
guarantees that all child nodes are of the same type.
|
||||
|
||||
There are few operations in Marian that require a type specification. Where they
|
||||
do exist, they are often simple as the desired type is explicitly provided, or
|
||||
is trivially deduced. An example of this is `CastNodeOp`
|
||||
|
||||
```cpp
|
||||
// CastNodeOp in src/graph/node_operators_unary.h
|
||||
CastNodeOp(Expr a, Type type) : UnaryNodeOp(a, type) {}
|
||||
```
|
||||
|
||||
The desired type is set explicitly in construction. A slightly different example
|
||||
is that of `CSRDotNodeOp`. It has several child nodes which are a mixture of
|
||||
`DataType` and `IndexType` and therefore do not share a common type. The
|
||||
solution is to explicitly specify the relevant children to
|
||||
`NaryNodeOp::commonType({...})`.
|
||||
|
||||
Shape modifying operations are more common. A simple example is the class of
|
||||
operations performed by `ReduceNodeOp` which involve an aggregation process
|
||||
along one axis of the Tensor. The output shape is determined by
|
||||
|
||||
```cpp
|
||||
// ReduceNodeOp in src/graph/node_operators_unary.h
|
||||
Shape newShape(Expr a, int axis) {
|
||||
Shape shape = a->shape();
|
||||
axis_ = shape.axis(axis);
|
||||
|
||||
shape.set(axis_, 1);
|
||||
return shape;
|
||||
}
|
||||
```
|
||||
|
||||
The output shape is the same as the input but with the processed axis is reduced
|
||||
to a single element. Other use cases include transpose and slicing operations,
|
||||
as well as tensor products.
|
||||
|
||||
|
||||
## Functional Operator
|
||||
|
||||
As the NodeOp are evaluated, they encounter the underlying datatype of the
|
||||
`Tensor`. At this stage, type-specific intrinsic functions are required. These
|
||||
intrinsics are implemented in the templated struct `Ops<ElementType>`, with a
|
||||
specialization required for each type. The current required types are:
|
||||
- float
|
||||
- double
|
||||
- float32x4 (see `src/3rd_party/sse_mathfun.h`)
|
||||
- float32x8 (see `src/3rd_party/avx_mathfun.h`)
|
||||
- half (see `cuda_fp16.h` in the CUDA Math API)
|
||||
|
||||
Further details are available in
|
||||
[/src/common/types.h](api/file_src_common_types.h.html).
|
||||
|
||||
Returning to the example of `sin(x)`, the specialization for `float` and
|
||||
`double` requires
|
||||
|
||||
```cpp
|
||||
// src/functional/operators.h
|
||||
// in namespace marian::functional
|
||||
template <typename T>
|
||||
struct Ops {
|
||||
static HOST_DEVICE_INLINE T sin(const T&) { ABORT("Unknown type"); }
|
||||
};
|
||||
|
||||
// Specialization for float
|
||||
template <>
|
||||
struct Ops<float> {
|
||||
static HOST_DEVICE_INLINE float sin(const float& x) { return sinf(x); }
|
||||
};
|
||||
|
||||
// Specialization for double
|
||||
template <>
|
||||
struct Ops<double> {
|
||||
static HOST_DEVICE_INLINE double sin(const double& x) { return std::sin(x); }
|
||||
};
|
||||
```
|
||||
|
||||
The remaining specializations can be seen in
|
||||
[/src/functional/operators.h](api/file_src_functional_operators.h.html). Note
|
||||
that the general template must produce a runtime abort.
|
||||
|
||||
The final component of the functional operator is to call the macro that enables
|
||||
interoperability with the framework of
|
||||
[/src/functional](api/dir_src_functional.html). For a unary operator, this is
|
||||
the macro `UNARY`.
|
||||
|
||||
```cpp
|
||||
UNARY(Sin, sin, Ops<ElementType>::sin(x));
|
||||
```
|
||||
|
||||
where template parameter `ElementType` **must** be used. There are equivalent
|
||||
macros for `BINARY` and `TERNARY` Ops.
|
||||
|
||||
|
||||
## Tensor Operator
|
||||
|
||||
Tensor operations use less abstracted interfaces to interact with the Tensors,
|
||||
often working with the Tensor data directly. They also rely on BLAS (Basic
|
||||
Linear Algebra Subprograms) libraries to accelerate these operations. As well as
|
||||
libraries containing device-specific optimisations. These libraries include:
|
||||
|
||||
- CPU
|
||||
- CBLAS / OpenBLAS
|
||||
- FBGEMM
|
||||
- INTGEMM
|
||||
- MKL
|
||||
- GPU
|
||||
- CUDA (cuBLAS)
|
||||
|
||||
An important subtlety is that while the CPU focused libraries use a row-major
|
||||
representation, the cuBLAS library (GPU) instead uses a column-major
|
||||
representation.
|
||||
|
||||
Furthermore, the OpenMPI and OpenMP libraries are employed for parallelisation.
|
||||
While macros provided in
|
||||
[/src/common/definitions.h](api/file_src_common_definitions.h.html) locally
|
||||
enable faster floating-point math in supported compilers.
|
||||
|
||||
```cpp
|
||||
MARIAN_FFAST_MATH_BEGIN
|
||||
// ffmath code
|
||||
MARIAN_FFAST_MATH_END
|
||||
```
|
||||
|
||||
The usual caveats apply when enabling `fast_math`, and can be found in
|
||||
[/src/common/definitions.h](api/file_src_common_definitions.h.html)
|
||||
|
||||
Tensor operators are declared in
|
||||
[/src/tensors/tensor_operators.h](api/file_src_tensors_tensor_operators.h.html),
|
||||
these are device-agnostic function that call the relevant device-specific
|
||||
implementation. The CPU- and GPU-specific implementation are defined in `cpu`
|
||||
namespace in [/src/tensors/cpu/](api/dir_src_tensors_cpu.html) and the `gpu`
|
||||
namespace [/src/tensors/gpu/](api/dir_src_tensors_gpu.html). Therefore a typical
|
||||
operator defers to an implementation in the device-specific namespace.
|
||||
|
||||
```cpp
|
||||
void TensorOp(marian::Tensor out, marian::Tensor in) {
|
||||
#ifdef CUDA_FOUND
|
||||
if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
|
||||
gpu::TensorOp(out, in);
|
||||
else
|
||||
#endif
|
||||
cpu::TensorOp(out, in);
|
||||
}
|
||||
```
|
||||
|
||||
When compiled with GPU support, this function dispatches a call to the
|
||||
implementation that corresponds to the backend device type configured in the
|
||||
graph (either GPU or CPU). Without GPU support, only the CPU implementation is
|
||||
available.
|
||||
|
||||
Many operations are covered by three general tensor operators: `Element`,
|
||||
`Aggregate` and `Prod`. The `Element` operator applies a function element-wise
|
||||
across an arbitrary number of input tensors and stores the result in the output
|
||||
tensor. The `Aggregate` operator also applies a function element-wise across its
|
||||
inputs, but instead aggregates the results in the output via a given aggregation
|
||||
function. A common aggregation function used is addition, which is the basis of
|
||||
the `Add` and `Reduce` operators. Finally, `Prod` deals with products of
|
||||
tensors. This operator performs a general matrix multiplication with the
|
||||
underlying implementation relying on the libraries mentioned above.
|
||||
|
||||
Specialized operators exist to manipulation tensors beyond the cases covered
|
||||
above; such as under transposition and concatenation. These operators may even
|
||||
be expressed in terms of existing tensor operators.
|
||||
|
||||
Furthermore, for complicated multi-operation computations, performance gains and
|
||||
memory improvements may be realised by implementing a tensor operator for that
|
||||
specific purpose. An example of this is `softmax`, which could be implemented
|
||||
using multiple expression operators (`exp`, `sum`), but is instead implemented
|
||||
directly as a tensor operator. These optimized implementations may be device
|
||||
specific.
|
||||
|
||||
## Declared Specialization
|
||||
|
||||
The operations performed in the forward and backward methods of NodeOp require
|
||||
their GPU templates to be explicitly declared. When a new specialization is
|
||||
introduced without being explicitly instantiated it will cause a link error on
|
||||
compilation:
|
||||
|
||||
```
|
||||
.../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'
|
||||
```
|
||||
|
||||
To fix these undefined references, we must explicitly add the specialization to
|
||||
the `.inc` files of [/src/tensors/gpu/](api/dir_src_tensors_gpu.html). Each
|
||||
`.inc` file is included at the end of its corresponding `.cu` file, ensuring
|
||||
that the specialization is compiled.
|
||||
|
||||
The undefined references should be added to the `.inc` file that corresponds to
|
||||
the header file in which contains the declaration of the missing functions.
|
||||
|
||||
The file [element.inc](api/file_src_tensors_gpu_element.inc.html) contains the
|
||||
specializations of the function defined in
|
||||
[element.h](api/file_src_tensors_gpu_element.h.html):
|
||||
|
||||
```cpp
|
||||
// src/tensors/gpu/element.h
|
||||
template <class Functor, class... Tensors>
|
||||
void Element(Functor functor, Tensor out, Tensors... tensors);
|
||||
```
|
||||
|
||||
Similarly, [add.inc](api/file_src_tensors_gpu_add.inc.html) contains the
|
||||
specializations for functions matching either of the two signatures in
|
||||
[add.h](api/file_src_tensors_gpu_add.h.html):
|
||||
|
||||
```cpp
|
||||
// src/tensors/gpu/add.h
|
||||
template <class Functor, class... Tensors>
|
||||
void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors);
|
||||
|
||||
template <class Functor, class AggFunctor, class... Tensors>
|
||||
void Aggregate(Functor functor, float initAgg, AggFunctor aggFunctor, float scale, marian::Tensor out, Tensors... tensors);
|
||||
```
|
||||
|
||||
Finally [add_all.inc](api/file_src_tensors_gpu_add_all.inc.html) contains the
|
||||
specializations for [add_all.h](api/file_src_tensors_gpu_add_all.h.html), which
|
||||
are several versions of:
|
||||
|
||||
```cpp
|
||||
// src/tensors/gpu/add_all.h
|
||||
template <typename T, typename AccType, class Functor, class AggFunctor>
|
||||
void AggregateAll(Ptr<Allocator> allocator,
|
||||
Functor functor,
|
||||
AccType aggInit,
|
||||
AggFunctor aggFunctor,
|
||||
AccType scale,
|
||||
Tensor out,
|
||||
const Tensor in1);
|
||||
```
|
||||
|
||||
However, for [add_all.h](api/file_src_tensors_gpu_add_all.h.html), there is an
|
||||
additional type dependence in the first template parameter, which requires two
|
||||
entries:
|
||||
|
||||
```cpp
|
||||
marian::gpu::AggregateAll< float, ... >( ... );
|
||||
marian::gpu::AggregateAll< __half, ... >( ... ); // for COMPILE_FP16
|
||||
```
|
||||
|
||||
where the `__half` specialization is related to half-precision floats and should
|
||||
be added to the `COMPILE_FP16` preprocessor block.
|
||||
|
||||
The simplest method to add the correct specialization is to take the compilation
|
||||
error output and extract the needed signature. To extract the signature:
|
||||
|
||||
1. Replace up to, and including, "undefined reference to `" with "template"
|
||||
2. Replace the final ' with a semi-colon
|
||||
|
||||
To conform with definitions in the codebase, we should replace
|
||||
`IntrusivePtr<marian::TensorBase>` with its typedef `marian::Tensor`. Note that
|
||||
as these files are included in `marian::gpu` namespace, and explicitly use
|
||||
`marian::functional` namespace it is also possible to omit both of these
|
||||
prefixes. Typically, the namespace prefix of the specialized function is removed
|
||||
as well. Following these rules for the example of `SinNodeOp` results in the
|
||||
following entries:
|
||||
|
||||
**element**
|
||||
```cpp
|
||||
template void Element<Assign<Var<1>, UnaryFunctor<elem::Sin, Assignee<2> > >, marian::Tensor >(Assign<Var<1>, UnaryFunctor<elem::Sin, Assignee<2> > >, marian::Tensor, marian::Tensor);
|
||||
```
|
||||
|
||||
**add**
|
||||
```cpp
|
||||
template void Add<BinaryFunctor<elem::Mult,Assignee<1>,UnaryFunctor<elem::Cos,Assignee<2> > >,class marian::Tensor,class marian::Tensor >(BinaryFunctor<elem::Mult,Assignee<1>,UnaryFunctor<elem::Cos,Assignee<2> > >,float,class marian::Tensor,class marian::Tensor,class marian::Tensor);
|
||||
```
|
||||
|
||||
**add_all**
|
||||
```cpp
|
||||
template void AggregateAll<float,float,BinaryFunctor<elem::Mult,Assignee<1>,UnaryFunctor<elem::Cos,Assignee<2> > >,BinaryFunctor<elem::Plus,Assignee<1>,Assignee<2> > >(std::shared_ptr<marian::Allocator>,BinaryFunctor<elem::Mult,Assignee<1>,UnaryFunctor<elem::Cos,Assignee<2> > >,float,BinaryFunctor<elem::Plus,Assignee<1>,Assignee<2> >,float,marian::Tensor,marian::Tensor,marian::Tensor);
|
||||
|
||||
#if COMPILE_FP16
|
||||
template void AggregateAll<__half,float,BinaryFunctor<elem::Mult,Assignee<1>,UnaryFunctor<elem::Cos,Assignee<2> > >,BinaryFunctor<elem::Plus,Assignee<1>,Assignee<2> > >(std::shared_ptr<marian::Allocator>,BinaryFunctor<elem::Mult,Assignee<1>,UnaryFunctor<elem::Cos,Assignee<2> > >,float,BinaryFunctor<elem::Plus,Assignee<1>,Assignee<2> >,float,marian::Tensor,marian::Tensor,marian::Tensor);
|
||||
#endif
|
||||
```
|
||||
|
@ -39,6 +39,12 @@ struct BinaryFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Macro to set up unary-functions from marian::functional::Ops.
|
||||
* @param name name for the struct
|
||||
* @param name2 callable typedef
|
||||
* @param func function wrapped
|
||||
*/
|
||||
#define UNARY(name, name2, func) \
|
||||
namespace elem { \
|
||||
struct name { \
|
||||
@ -55,6 +61,12 @@ struct BinaryFunctor {
|
||||
} \
|
||||
static inline name<Capture> name2(Capture x) { return name<Capture>(x); }
|
||||
|
||||
/**
|
||||
* Macro to set up binary-functions from marian::functional::Ops.
|
||||
* @param name name for the struct
|
||||
* @param name2 callable typedef
|
||||
* @param func function wrapped
|
||||
*/
|
||||
#define BINARY(name, name2, func) \
|
||||
namespace elem { \
|
||||
struct name { \
|
||||
@ -95,6 +107,12 @@ struct TernaryFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Macro to set up ternary-functions from marian::functional::Ops.
|
||||
* @param name name for the struct
|
||||
* @param name2 callable typedef
|
||||
* @param func function wrapped
|
||||
*/
|
||||
#define TERNARY(name, name2, func) \
|
||||
namespace elem { \
|
||||
struct name { \
|
||||
|
@ -72,6 +72,14 @@ Expr sin(Expr a) {
|
||||
return Expression<SinNodeOp>(a);
|
||||
};
|
||||
|
||||
Expr cos(Expr a) {
|
||||
return Expression<CosNodeOp>(a);
|
||||
};
|
||||
|
||||
Expr tan(Expr a) {
|
||||
return Expression<TanNodeOp>(a);
|
||||
};
|
||||
|
||||
Expr swish(Expr a) {
|
||||
return Expression<SwishNodeOp>(a);
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -646,7 +646,7 @@ struct CosNodeOp : public UnaryNodeOp {
|
||||
return {NodeOp(Add(_1 * -sin(_2), child(0)->grad(), adj_, child(0)->val()))};
|
||||
}
|
||||
|
||||
const std::string type() override { return "sin"; }
|
||||
const std::string type() override { return "cos"; }
|
||||
};
|
||||
|
||||
struct TanNodeOp : public UnaryNodeOp {
|
||||
@ -662,7 +662,7 @@ struct TanNodeOp : public UnaryNodeOp {
|
||||
return {NodeOp(Add(_1 / sqr(cos(_2)), child(0)->grad(), adj_, child(0)->val()))};
|
||||
}
|
||||
|
||||
const std::string type() override { return "sin"; }
|
||||
const std::string type() override { return "tan"; }
|
||||
};
|
||||
|
||||
struct SqrtNodeOp : public UnaryNodeOp {
|
||||
|
@ -37,3 +37,5 @@ template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functio
|
||||
template void marian::gpu::Add<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::gpu::Aggregate<marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, IntrusivePtr<marian::TensorBase> >(marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>,class IntrusivePtr<class marian::TensorBase>);
|
||||
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::Tensor, marian::Tensor >(marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
|
@ -37,6 +37,9 @@ template void marian::AggregateAll<float, float, marian::functional::BinaryFunct
|
||||
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::AggregateAll<float,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::AggregateAll<float, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void marian::AggregateAll<float, float, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
|
||||
#if COMPILE_FP16
|
||||
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
@ -75,4 +78,6 @@ template void marian::AggregateAll<__half, float, marian::functional::BinaryFunc
|
||||
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>,marian::functional::BinaryFunctor<marian::functional::elem::Mult,marian::functional::Assignee<1>,marian::functional::UnaryFunctor<marian::functional::elem::Cos,marian::functional::Assignee<2> > >,float,marian::functional::BinaryFunctor<marian::functional::elem::Plus,marian::functional::Assignee<1>,marian::functional::Assignee<2> >,float,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>,IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::Assignee<1> >, float, marian::functional::BinaryFunctor<marian::functional::elem::Max, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Neg, marian::functional::UnaryFunctor<marian::functional::elem::Sin, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> > >(std::shared_ptr<marian::Allocator>, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqr, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > > >, float, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
#endif
|
||||
|
@ -68,6 +68,8 @@ template void marian::gpu::Element<marian::functional::Assign<marian::functional
|
||||
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>);
|
||||
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > >, marian::Tensor, marian::Tensor);
|
||||
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Tan, marian::functional::Assignee<2> > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Tan, marian::functional::Assignee<2> > >, marian::Tensor, marian::Tensor);
|
||||
// How to add new specializations:
|
||||
// When you use a new specialization, it will cause a link error of this form (example):
|
||||
// .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'
|
||||
|
Loading…
Reference in New Issue
Block a user