mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-19 02:37:14 +03:00
first attempts at re-implementing nematus, encoder
This commit is contained in:
parent
7adade04e2
commit
dd2475bce6
@ -26,6 +26,8 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <cuda.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "shape.h"
|
||||
|
||||
@ -36,9 +38,14 @@ namespace marian {
|
||||
return std::shared_ptr<T>(new T(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
|
||||
typedef float Float;
|
||||
|
||||
template<class T>
|
||||
using DeviceVector = thrust::device_vector<T>;
|
||||
|
||||
template<class T>
|
||||
using HostVector = thrust::host_vector<T>;
|
||||
|
||||
/** @brief A placeholder that represents the size of a dimension, the actual value of which is to be specified at some later point.
|
||||
*
|
||||
* For example, in certain cases the value of one dimension in a Shape object may be used to represent batch size.
|
||||
|
@ -112,6 +112,18 @@ class ExpressionGraph : public std::enable_shared_from_this<ExpressionGraph> {
|
||||
v->forward();
|
||||
}
|
||||
|
||||
void forward() {
|
||||
params_.allocateForward();
|
||||
|
||||
for(auto&& v : tape_)
|
||||
if(!v->skipped_training())
|
||||
v->allocate(0);
|
||||
|
||||
for(auto&& v : tape_)
|
||||
if(!v->skipped_training())
|
||||
v->forward();
|
||||
}
|
||||
|
||||
void inference(data::BatchPtr batch) {
|
||||
for(auto&& v : tape_)
|
||||
if(!v->skipped_inference())
|
||||
|
@ -40,6 +40,10 @@ Expr name(Expr a, const std::string& name) {
|
||||
return a;
|
||||
}
|
||||
|
||||
Expr rows(Expr a, const DeviceVector<size_t>& indeces) {
|
||||
return Expression<RowsNodeOp>(a, indeces);
|
||||
}
|
||||
|
||||
Expr logit(Expr a) {
|
||||
return Expression<LogitNodeOp>(a);
|
||||
}
|
||||
|
@ -38,6 +38,8 @@ Expr inference(Expr a);
|
||||
*/
|
||||
Expr name(Expr a, const std::string& name);
|
||||
|
||||
Expr rows(Expr a, const DeviceVector<size_t>& indeces);
|
||||
|
||||
Expr logit(Expr a);
|
||||
|
||||
Expr tanh(Expr a);
|
||||
|
@ -13,53 +13,89 @@ using namespace marian;
|
||||
using namespace keywords;
|
||||
using namespace data;
|
||||
|
||||
void construct(ExpressionGraphPtr g, size_t length) {
|
||||
typedef DeviceVector<size_t> WordBatch;
|
||||
typedef std::vector<WordBatch> SentBatch;
|
||||
|
||||
void construct(ExpressionGraphPtr g,
|
||||
const SentBatch& srcSentenceBatch) {
|
||||
g->clear();
|
||||
|
||||
int dim_i = 500;
|
||||
int dim_h = 1024;
|
||||
int dimSrcVoc = 30000;
|
||||
int dimSrcEmb = 512;
|
||||
int dimEncState = 1024;
|
||||
int dimBatch = 1;
|
||||
|
||||
ParametersGRU pGRU;
|
||||
pGRU.Uz = g->param("Uz", {dim_h, dim_h}, init=uniform());
|
||||
pGRU.Wz = g->param("Wz", {dim_i, dim_h}, init=uniform());
|
||||
//pGRU.bz = nullptr; // g->param("bz", {1, dim_h}, init=zeros);
|
||||
auto Wemb = g->param("Wemb", {dimSrcVoc, dimSrcEmb}, init=uniform());
|
||||
|
||||
pGRU.Ur = g->param("Ur", {dim_h, dim_h}, init=uniform());
|
||||
pGRU.Wr = g->param("Wr", {dim_i, dim_h}, init=uniform());
|
||||
//pGRU.br = nullptr; //g->param("br", {1, dim_h}, init=zeros);
|
||||
|
||||
pGRU.Uh = g->param("Uh", {dim_h, dim_h}, init=uniform());
|
||||
pGRU.Wh = g->param("Wh", {dim_i, dim_h}, init=uniform());
|
||||
//pGRU.bh = nullptr; //g->param("bh", {1, dim_h}, init=zeros);
|
||||
|
||||
pGRU.dropout = 0.2;
|
||||
|
||||
auto start = name(g->zeros(shape={whatevs, dim_h}), "s_0");
|
||||
std::vector<Expr> inputs;
|
||||
for(int i = 0; i < length; ++i) {
|
||||
auto x = name(g->input(shape={whatevs, dim_i}),
|
||||
"x_" + std::to_string(i));
|
||||
for(auto& srcWordBatch : srcSentenceBatch) {
|
||||
auto x = rows(Wemb, srcWordBatch);
|
||||
inputs.push_back(x);
|
||||
dimBatch = srcWordBatch.size();
|
||||
}
|
||||
|
||||
RNN<GRU> gru(pGRU);
|
||||
auto outputs = gru.apply(inputs, start);
|
||||
auto encoder = [=](const std::string& prefix){
|
||||
ParametersGRU encParams;
|
||||
encParams.Uz = g->param(prefix + "_Uz", {dimEncState, dimEncState},
|
||||
init=uniform());
|
||||
encParams.Ur = g->param(prefix + "_Ur", {dimEncState, dimEncState},
|
||||
init=uniform());
|
||||
|
||||
encParams.Wz = g->param(prefix + "_Wz", {dimSrcEmb, dimEncState},
|
||||
init=uniform());
|
||||
encParams.Wr = g->param(prefix + "_Wr", {dimSrcEmb, dimEncState},
|
||||
init=uniform());
|
||||
|
||||
encParams.bz = g->param(prefix + "_bz", {1, dimEncState}, init=zeros);
|
||||
encParams.br = g->param(prefix + "_br", {1, dimEncState}, init=zeros);
|
||||
|
||||
encParams.Ux = g->param(prefix + "_Ux", {dimEncState, dimEncState},
|
||||
init=uniform());
|
||||
encParams.Wx = g->param(prefix + "_Wx", {dimSrcEmb, dimEncState},
|
||||
init=uniform());
|
||||
encParams.bx = g->param(prefix + "_bx", {1, dimEncState}, init=zeros);
|
||||
|
||||
return RNN<GRU>(encParams);
|
||||
};
|
||||
|
||||
auto encStartState = g->zeros(shape={dimBatch, dimEncState});
|
||||
|
||||
auto encForward = encoder("encoder");
|
||||
auto statesForward = encForward.apply(inputs.begin(), inputs.end(),
|
||||
encStartState);
|
||||
|
||||
/*
|
||||
auto encBackward = encoder("encoder_r");
|
||||
auto statesBackward = encBackward.apply(inputs.rbegin(), inputs.rend(),
|
||||
encStartState);
|
||||
|
||||
std::vector<Expr> joinedStates;
|
||||
for(auto itFw = statesForward.begin(), auto itBw = statesBackward.rbegin();
|
||||
itFw != statesForward.end(); itFw++, itBw++)
|
||||
joinedStates.push_back(concatenate({*itFw, *itBw}, axis=1));
|
||||
|
||||
auto encoder = concatenate(joinedStates, axis=2)
|
||||
auto decStartState = mean(encoder, axis=2);
|
||||
*/
|
||||
}
|
||||
|
||||
SentBatch generateBatch(size_t batchSize) {
|
||||
size_t length = rand() % 40 + 10;
|
||||
return SentBatch(length, WordBatch(batchSize));
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
auto g = New<ExpressionGraph>();
|
||||
|
||||
size_t batchSize = 80;
|
||||
|
||||
boost::timer::cpu_timer timer;
|
||||
for(int i = 1; i <= 1000; ++i) {
|
||||
size_t length = rand() % 40 + 10; // random int from [10,50]
|
||||
g->clear();
|
||||
construct(g, length);
|
||||
auto batch = generateBatch(batchSize);
|
||||
construct(g, batch);
|
||||
|
||||
BatchPtr batch(new Batch());
|
||||
for(int j = 0; j < length; ++j)
|
||||
batch->push_back(Input({80, 500}));
|
||||
|
||||
g->forward(batch);
|
||||
g->forward();
|
||||
if(i % 100 == 0)
|
||||
std::cout << i << std::endl;
|
||||
}
|
||||
|
@ -111,10 +111,10 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
*
|
||||
* This node implements the <a href="https://en.wikipedia.org/wiki/Activation_function">activation function</a>
|
||||
* \f$f(x) = \max(0, x)\f$ and its derivative:
|
||||
*
|
||||
*
|
||||
\f[
|
||||
f^\prime(x) =
|
||||
\begin{cases}
|
||||
f^\prime(x) =
|
||||
\begin{cases}
|
||||
0 & \text{if } x \leq 0 \\
|
||||
1 & \text{if } x > 0
|
||||
\end{cases}
|
||||
@ -145,10 +145,10 @@ struct ReLUNodeOp : public UnaryNodeOp {
|
||||
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Represents a <a href="https://en.wikipedia.org/wiki/Dropout_(neural_networks)">dropout</a> node
|
||||
/**
|
||||
* @brief Represents a <a href="https://en.wikipedia.org/wiki/Dropout_(neural_networks)">dropout</a> node
|
||||
* in an expression graph.
|
||||
*
|
||||
*
|
||||
* @see \cite dropout
|
||||
* @see \cite cudnn
|
||||
*/
|
||||
@ -366,8 +366,7 @@ struct MeanNodeOp : public UnaryNodeOp {
|
||||
<< label("mean") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -439,8 +438,39 @@ struct NegNodeOp : public UnaryNodeOp {
|
||||
<< label("-") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct RowsNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
RowsNodeOp(Expr a, const DeviceVector<size_t>& indeces, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a, indeces), args...),
|
||||
indeces_(indeces) { }
|
||||
|
||||
void forward() {
|
||||
CopyRows(val_, a_->val(), indeces_);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
PasteRows(a_->grad(), adj_, indeces_);
|
||||
}
|
||||
|
||||
template <class ...Args>
|
||||
Shape newShape(Expr a, const DeviceVector<size_t>& indeces) {
|
||||
Shape shape = a->shape();
|
||||
shape[0] = indeces.size();
|
||||
return shape;
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label="
|
||||
<< label("rows") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const DeviceVector<size_t> &indeces_;
|
||||
};
|
||||
|
||||
|
||||
|
21
src/rnn.h
21
src/rnn.h
@ -21,7 +21,7 @@ class Tanh {
|
||||
|
||||
Expr apply(Expr input, Expr state) {
|
||||
using namespace keywords;
|
||||
|
||||
|
||||
Expr output = dot(input, params_.W) + dot(state, params_.U);
|
||||
if(params_.b)
|
||||
output += params_.b;
|
||||
@ -40,7 +40,7 @@ class Tanh {
|
||||
struct ParametersGRU {
|
||||
Expr Uz, Wz, bz;
|
||||
Expr Ur, Wr, br;
|
||||
Expr Uh, Wh, bh;
|
||||
Expr Ux, Wx, bx;
|
||||
float dropout = 0;
|
||||
};
|
||||
|
||||
@ -62,9 +62,9 @@ class GRU {
|
||||
r += params_.br;
|
||||
r = logit(r);
|
||||
|
||||
Expr h = dot(input, params_.Wh) + dot(state, params_.Uh) * r;
|
||||
if(params_.bh)
|
||||
h += params_.bh;
|
||||
Expr h = dot(input, params_.Wx) + dot(state, params_.Ux) * r;
|
||||
if(params_.bx)
|
||||
h += params_.bx;
|
||||
h = tanh(h);
|
||||
|
||||
// constant 1 in (1-z)*h+z*s
|
||||
@ -92,10 +92,17 @@ class RNN {
|
||||
|
||||
std::vector<Expr> apply(const std::vector<Expr>& inputs,
|
||||
const Expr initialState) {
|
||||
return apply(inputs.begin(), inputs.end(),
|
||||
initialState);
|
||||
}
|
||||
|
||||
template <class Iterator>
|
||||
std::vector<Expr> apply(Iterator it, Iterator end,
|
||||
const Expr initialState) {
|
||||
std::vector<Expr> outputs;
|
||||
auto state = initialState;
|
||||
for(auto input : inputs) {
|
||||
state = cell_.apply(input, state);
|
||||
while(it != end) {
|
||||
state = cell_.apply(*it++, state);
|
||||
outputs.push_back(state);
|
||||
}
|
||||
return outputs;
|
||||
|
@ -39,7 +39,7 @@ static cudnnHandle_t create_handle_dnn() {
|
||||
cublasHandle_t cublasHandle = create_handle();
|
||||
cudnnHandle_t cudnnHandle = create_handle_dnn();
|
||||
|
||||
void CudnnSoftmax(Tensor& out, Tensor& in) {
|
||||
void CudnnSoftmax(Tensor out, Tensor in) {
|
||||
float alpha = 1, beta = 0;
|
||||
auto inGpu = static_cast<TensorGPU*>(in.get());
|
||||
auto outGpu = static_cast<TensorGPU*>(out.get());
|
||||
@ -55,7 +55,7 @@ void CudnnSoftmax(Tensor& out, Tensor& in) {
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void CudnnLogSoftmax(Tensor& out, Tensor& in) {
|
||||
void CudnnLogSoftmax(Tensor out, Tensor in) {
|
||||
float alpha = 1, beta = 0;
|
||||
auto inGpu = static_cast<TensorGPU*>(in.get());
|
||||
auto outGpu = static_cast<TensorGPU*>(out.get());
|
||||
@ -71,7 +71,7 @@ void CudnnLogSoftmax(Tensor& out, Tensor& in) {
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void CudnnSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) {
|
||||
void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
float alpha = 1, beta = 0;
|
||||
auto valGpu = static_cast<TensorGPU*>(val.get());
|
||||
auto adjGpu = static_cast<TensorGPU*>(adj.get());
|
||||
@ -90,7 +90,7 @@ void CudnnSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) {
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
void CudnnLogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) {
|
||||
void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
float alpha = 1, beta = 0;
|
||||
auto valGpu = static_cast<TensorGPU*>(val.get());
|
||||
auto adjGpu = static_cast<TensorGPU*>(adj.get());
|
||||
@ -147,7 +147,7 @@ __global__ void gSubtractMax(float* out, const float* in,
|
||||
}
|
||||
}
|
||||
|
||||
void SubtractMax(Tensor& out, Tensor& in) {
|
||||
void SubtractMax(Tensor out, Tensor in) {
|
||||
// Out is a m-by-k matrix, passed as input.
|
||||
// The max element of each row of Out is computed and subtracted from Out.
|
||||
// Out is both input and output.
|
||||
@ -196,7 +196,7 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||
}
|
||||
}
|
||||
|
||||
void Softmax(Tensor& out, Tensor& in) {
|
||||
void Softmax(Tensor out, Tensor in) {
|
||||
size_t m = out->shape()[0];
|
||||
size_t k = out->shape()[1];
|
||||
|
||||
@ -250,7 +250,7 @@ __global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
|
||||
}
|
||||
}
|
||||
|
||||
void SoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) {
|
||||
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
// grad and val are both m-by-k matrices, passed as input.
|
||||
// A weighted average of each row of grad (according to the weights
|
||||
// specified in val) is computed and subtracted from Out.
|
||||
@ -303,7 +303,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
|
||||
}
|
||||
}
|
||||
|
||||
void LogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) {
|
||||
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
// grad and val are both m-by-k matrices, passed as input.
|
||||
// A weighted average of each row of grad (according to the weights
|
||||
// specified in val) is computed and subtracted from Out.
|
||||
@ -350,7 +350,7 @@ __global__ void gArgmax(float *out, const float *data, size_t rows, size_t cols)
|
||||
|
||||
///////////////////////////////////////////////////////
|
||||
|
||||
void Prod(cublasHandle_t handle, Tensor& C, const Tensor& A, const Tensor& B,
|
||||
void Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta) {
|
||||
Float alpha = 1.0;
|
||||
|
||||
@ -378,13 +378,13 @@ void Prod(cublasHandle_t handle, Tensor& C, const Tensor& A, const Tensor& B,
|
||||
n, m, k, &alpha, B->data(), ldb, A->data(), lda, &beta, C->data(), ldc);
|
||||
}
|
||||
|
||||
void Prod(Tensor& C, const Tensor& A, const Tensor& B,
|
||||
void Prod(Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta) {
|
||||
|
||||
Prod(cublasHandle, C, A, B, transA, transB, beta);
|
||||
}
|
||||
|
||||
void Sum(Tensor& out, const Tensor& in, int axis, bool mean) {
|
||||
void Sum(Tensor out, const Tensor in, int axis, bool mean) {
|
||||
int rows = in->shape()[0];
|
||||
int cols = in->shape()[1];
|
||||
|
||||
@ -430,7 +430,7 @@ void Sum(Tensor& out, const Tensor& in, int axis, bool mean) {
|
||||
}
|
||||
}
|
||||
|
||||
void SumBackward(Tensor& out, const Tensor& in, int axis, bool mean) {
|
||||
void SumBackward(Tensor out, const Tensor in, int axis, bool mean) {
|
||||
int rows = out->shape()[0];
|
||||
int cols = out->shape()[1];
|
||||
|
||||
@ -476,7 +476,7 @@ void SumBackward(Tensor& out, const Tensor& in, int axis, bool mean) {
|
||||
}
|
||||
}
|
||||
|
||||
void CudnnDropoutPrepare(Tensor& in, float p,
|
||||
void CudnnDropoutPrepare(Tensor in, float p,
|
||||
cudnnDropoutDescriptor_t* dropDesc,
|
||||
void** space, size_t* spaceSize,
|
||||
void** states, size_t seed) {
|
||||
@ -506,7 +506,7 @@ void CudnnDropoutDestroy(cudnnDropoutDescriptor_t dropDesc,
|
||||
|
||||
void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc,
|
||||
void* space, size_t spaceSize,
|
||||
Tensor& out, Tensor& in) {
|
||||
Tensor out, Tensor in) {
|
||||
auto inGpu = static_cast<TensorGPU*>(in.get());
|
||||
auto outGpu = static_cast<TensorGPU*>(out.get());
|
||||
cudnnDropoutForward(cudnnHandle,
|
||||
@ -521,7 +521,7 @@ void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc,
|
||||
|
||||
void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc,
|
||||
void* space, size_t spaceSize,
|
||||
Tensor& out, Tensor& in) {
|
||||
Tensor out, Tensor in) {
|
||||
auto inGpu = static_cast<TensorGPU*>(in.get());
|
||||
auto outGpu = static_cast<TensorGPU*>(out.get());
|
||||
cudnnDropoutBackward(cudnnHandle,
|
||||
@ -534,4 +534,72 @@ void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc,
|
||||
spaceSize);
|
||||
}
|
||||
|
||||
__global__ void gCopyRows(float* out, const float* in, size_t cols,
|
||||
const size_t* sourceRowIdx, size_t rows) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
size_t dstId = j;
|
||||
size_t srcId = sourceRowIdx[j];
|
||||
|
||||
float* rowOut = out + dstId * cols;
|
||||
const float* rowIn = in + srcId * cols;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols)
|
||||
rowOut[i] = rowIn[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CopyRows(Tensor out, const Tensor in, const DeviceVector<size_t>& indeces) {
|
||||
size_t cols = in->shape()[1];
|
||||
size_t rowsToCopy = indeces.size();
|
||||
|
||||
int threads = std::min(MAX_THREADS, (int)cols);
|
||||
int blocks = std::min(MAX_BLOCKS, (int)rowsToCopy);
|
||||
|
||||
gCopyRows<<<blocks, threads>>>(out->data(), in->data(), cols,
|
||||
thrust::raw_pointer_cast(indeces.data()),
|
||||
rowsToCopy);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
__global__ void gPasteRows(float* out, const float* in, size_t cols,
|
||||
const size_t* targetRowIdx, size_t rows) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
size_t dstId = targetRowIdx[j];
|
||||
size_t srcId = j;
|
||||
|
||||
float* rowOut = out + dstId * cols;
|
||||
const float* rowIn = in + srcId * cols;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols)
|
||||
rowOut[i] = rowIn[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PasteRows(Tensor out, const Tensor in, const DeviceVector<size_t>& indeces) {
|
||||
size_t cols = in->shape()[1];
|
||||
size_t rowsToCopy = indeces.size();
|
||||
|
||||
int threads = std::min(MAX_THREADS, (int)cols);
|
||||
int blocks = std::min(MAX_BLOCKS, (int)rowsToCopy);
|
||||
|
||||
gPasteRows<<<blocks, threads>>>(out->data(), in->data(), cols,
|
||||
thrust::raw_pointer_cast(indeces.data()),
|
||||
rowsToCopy);
|
||||
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -23,6 +23,9 @@
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/pair.h>
|
||||
|
||||
#include "tensors/tensor_gpu.h"
|
||||
|
||||
@ -257,33 +260,40 @@ void Element(Functor functor,
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
void ClipNorm(Tensor& out, float threshold);
|
||||
void ClipNorm(Tensor out, float threshold);
|
||||
|
||||
void SubtractMax(Tensor& out, Tensor& in);
|
||||
void SubtractMax(Tensor out, Tensor in);
|
||||
|
||||
void Softmax(Tensor& out, Tensor& in);
|
||||
void Softmax(Tensor out, Tensor in);
|
||||
|
||||
void SoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val);
|
||||
void LogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val);
|
||||
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
|
||||
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
|
||||
|
||||
void CudnnSoftmax(Tensor& out, Tensor& in);
|
||||
void CudnnSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val);
|
||||
void CudnnSoftmax(Tensor out, Tensor in);
|
||||
void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
|
||||
|
||||
void CudnnLogSoftmax(Tensor& out, Tensor& in);
|
||||
void CudnnLogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val);
|
||||
void CudnnLogSoftmax(Tensor out, Tensor in);
|
||||
void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
|
||||
|
||||
void Argmax(Tensor& Out, const Tensor& In);
|
||||
void Argmax(Tensor Out, const Tensor In);
|
||||
|
||||
void Prod(cublasHandle_t handle, Tensor& C, const Tensor& A, const Tensor& B,
|
||||
void Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta);
|
||||
|
||||
void Prod(Tensor& C, const Tensor& A, const Tensor& B,
|
||||
void Prod(Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta = 0);
|
||||
|
||||
void Sum(Tensor& out, const Tensor& in, int axis=-1, bool mean=false);
|
||||
void SumBackward(Tensor& out, const Tensor& in, int axis=-1, bool mean=false);
|
||||
void Sum(Tensor out, const Tensor in, int axis=-1, bool mean=false);
|
||||
void SumBackward(Tensor out, const Tensor in, int axis=-1, bool mean=false);
|
||||
|
||||
void CudnnDropoutPrepare(Tensor& in, float p,
|
||||
void CopyRowsByIndex(Tensor out, const Tensor in,
|
||||
thrust::pair<size_t, size_t>* ipair, size_t length);
|
||||
|
||||
void CopyRows(Tensor out, const Tensor in, const DeviceVector<size_t>& indeces);
|
||||
|
||||
void PasteRows(Tensor out, const Tensor in, const DeviceVector<size_t>& indeces);
|
||||
|
||||
void CudnnDropoutPrepare(Tensor in, float p,
|
||||
cudnnDropoutDescriptor_t* dropDesc,
|
||||
void** space, size_t* spaceSize,
|
||||
void** states, size_t seed);
|
||||
@ -293,11 +303,11 @@ void CudnnDropoutDestroy(cudnnDropoutDescriptor_t dropDesc,
|
||||
|
||||
void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc,
|
||||
void* space, size_t spaceSize,
|
||||
Tensor& out, Tensor& in);
|
||||
Tensor out, Tensor in);
|
||||
|
||||
void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc,
|
||||
void* space, size_t spaceSize,
|
||||
Tensor& out, Tensor& in);
|
||||
Tensor out, Tensor in);
|
||||
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user