combined elementwise functions for softmax gradient into single kernel call, no temporary objects needed, slightly faster

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-18 16:21:15 +02:00
parent 089f76751b
commit ed684e25b6
6 changed files with 57 additions and 64 deletions

View File

@ -13,6 +13,7 @@ struct Chainable {
virtual ~Chainable() { } virtual ~Chainable() { }
virtual void forward() { } virtual void forward() { }
virtual void backward() { } virtual void backward() { }
virtual void check() { }
virtual void init_dependent() { } virtual void init_dependent() { }
virtual void set_zero_adjoint() { } virtual void set_zero_adjoint() { }

View File

@ -40,7 +40,6 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
biases.emplace_back( biases.emplace_back(
g.param(shape={1, out}, g.param(shape={1, out},
init=normal())); init=normal()));
} }
auto probs = named( auto probs = named(

View File

@ -92,8 +92,7 @@ struct UnaryNodeOp : public Node {
template <typename ...Args> template <typename ...Args>
UnaryNodeOp(ChainPtr a, Args ...args) UnaryNodeOp(ChainPtr a, Args ...args)
: Node(keywords::shape=a->shape(), //@TODO: Check keywords? : Node(keywords::shape=a->shape(), //@TODO: Check keywords?
args...), args...), a_(a) {}
a_(a) {}
}; };
struct LogitNodeOp : public UnaryNodeOp { struct LogitNodeOp : public UnaryNodeOp {
@ -111,6 +110,10 @@ struct LogitNodeOp : public UnaryNodeOp {
a_->grad(), adj_, val_); a_->grad(), adj_, val_);
} }
void check() {
}
virtual std::string graphviz() { virtual std::string graphviz() {
std::stringstream ss; std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
@ -171,10 +174,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
// Classification." ICML 2016. // Classification." ICML 2016.
// http://jmlr.org/proceedings/papers/v48/martins16.pdf // http://jmlr.org/proceedings/papers/v48/martins16.pdf
Tensor result(adj_.shape()); SoftmaxGrad(a_->grad(), adj_, val_);
thrust::copy(adj_.begin(), adj_.end(), result.begin());
SubtractMean(&result, val_);
Element(_1 += _2 * _3, a_->grad(), val_, result);
} }
virtual std::string graphviz() { virtual std::string graphviz() {
@ -183,7 +183,6 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str(); return ss.str();
}; };
}; };
struct ArgmaxNodeOp : public UnaryNodeOp { struct ArgmaxNodeOp : public UnaryNodeOp {

View File

@ -12,20 +12,22 @@ static cublasHandle_t create_handle() {
} }
cublasHandle_t cublasHandle = create_handle(); cublasHandle_t cublasHandle = create_handle();
__global__ void gSubtractMean(float* out, float* weights, __global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
size_t rows, size_t cols) { const int rows, const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) { for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x; int j = bid + blockIdx.x;
if(j < rows) { if(j < rows) {
extern __shared__ float _share[]; extern __shared__ float _share[];
float* _sum = _share + blockDim.x; float* _sum = _share + blockDim.x;
float* sp = out + j * cols;
float* w = weights + j * cols; float* gradRow = grad + j * cols;
const float* adjRow = adj + j * cols;
const float* valRow = val + j * cols;
_sum[threadIdx.x] = 0.0; _sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) { for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x; int id = tid + threadIdx.x;
if(id < cols) { if(id < cols) {
_sum[threadIdx.x] += w[id] * sp[id]; _sum[threadIdx.x] += valRow[id] * adjRow[id];
} }
} }
__syncthreads(); __syncthreads();
@ -41,25 +43,25 @@ __global__ void gSubtractMean(float* out, float* weights,
for(int tid = 0; tid < cols; tid += blockDim.x){ for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x; int id = tid + threadIdx.x;
if(id < cols) if(id < cols)
sp[id] -= _sum[0]; gradRow[id] += valRow[id] * (adjRow[id] - _sum[0]);
} }
} }
} }
} }
void SubtractMean(Tensor* Out, Tensor &Weights) { void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
// Out and Weights are both m-by-k matrices, passed as input. // grad and val are both m-by-k matrices, passed as input.
// A weighted average of each row of Out (according to the weights // A weighted average of each row of grad (according to the weights
// specified in Weights) is computed and subtracted from Out. // specified in val) is computed and subtracted from Out.
// Out is both input and output. // adj is multiplied for each element to get backward step in autodiff
size_t m = Out->shape()[0]; int m = grad.shape()[0];
size_t k = Out->shape()[1]; int k = grad.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m); int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, (int) k); int threads = std::min(MAX_THREADS, k);
int shared = sizeof(float) * threads * 2; int shared = sizeof(float) * threads * 2;
gSubtractMean<<<blocks, threads, shared>>>(Out->data(), Weights.data(), gSoftmaxGrad<<<blocks, threads, shared>>>(grad.data(), adj.data(), val.data(),
m, k); m, k);
cudaStreamSynchronize(0); cudaStreamSynchronize(0);
} }
@ -158,8 +160,9 @@ void Softmax(Tensor* Out) {
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k); gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0); cudaStreamSynchronize(0);
} }
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) { __global__ void gArgmax(float *out, const float *data, size_t rows, size_t cols) {
size_t row = blockIdx.x; size_t row = blockIdx.x;
size_t startInd = row * cols; size_t startInd = row * cols;
float maxScore = -99999; float maxScore = -99999;
@ -182,7 +185,7 @@ void Argmax(Tensor* Out, const Tensor* In) {
int blocks = m; //std::min(MAX_BLOCKS, (int) m); int blocks = m; //std::min(MAX_BLOCKS, (int) m);
int threads = k; //std::min(MAX_THREADS, (int) k); int threads = k; //std::min(MAX_THREADS, (int) k);
//int shared = sizeof(float) * threads * 2; //int shared = sizeof(float) * threads * 2;
gArgMax<<<blocks, threads>>>(Out->data(), In->data(), m, k); gArgmax<<<blocks, threads>>>(Out->data(), In->data(), m, k);
cudaStreamSynchronize(0); cudaStreamSynchronize(0);
} }

View File

@ -142,20 +142,11 @@ void Element(Functor functor,
cudaStreamSynchronize(0); cudaStreamSynchronize(0);
} }
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols);
void SubtractMean(Tensor* Out, Tensor &Weights);
__global__ void gSubtractMax(float* out, size_t rows, size_t cols);
void SubtractMax(Tensor* Out); void SubtractMax(Tensor* Out);
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
void Softmax(Tensor* Out); void Softmax(Tensor* Out);
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols); void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void Argmax(Tensor* Out, const Tensor* In); void Argmax(Tensor* Out, const Tensor* In);

View File

@ -17,33 +17,33 @@ string output(const std::vector<float> &vec)
return strm.str(); return strm.str();
} }
void testArgMax() //void testArgMax()
{ //{
using namespace std; // using namespace std;
using namespace marian; // using namespace marian;
//
std::vector<float> hVec({29,19, 49,39, 79,99, 79,39}); // std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
cerr << "hVec =" << output(hVec) << endl; // cerr << "hVec =" << output(hVec) << endl;
//
thrust::device_vector<float> dVec(8); // thrust::device_vector<float> dVec(8);
thrust::copy(hVec.begin(), hVec.end(), dVec.begin()); // thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
float *data = thrust::raw_pointer_cast(dVec.data()); // float *data = thrust::raw_pointer_cast(dVec.data());
//
thrust::device_vector<float> dLabel(4); // thrust::device_vector<float> dLabel(4);
float *labelPtr = thrust::raw_pointer_cast(dLabel.data()); // float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
//
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2); // gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
//
std::vector<float> hVec2(8); // std::vector<float> hVec2(8);
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin()); // thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
cerr << "hVec2=" << output(hVec2) << endl; // cerr << "hVec2=" << output(hVec2) << endl;
//
std::vector<float> hLabel(4); // std::vector<float> hLabel(4);
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin()); // thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
cerr << "hLabel=" << output(hLabel) << endl; // cerr << "hLabel=" << output(hLabel) << endl;
//
exit(0); // exit(0);
} //}
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
int main(int argc, char** argv) { int main(int argc, char** argv) {