This commit is contained in:
Hieu Hoang 2016-09-18 19:15:12 +01:00
commit 1c0fdfca8d
11 changed files with 260 additions and 92 deletions

View File

@ -17,7 +17,7 @@
</extensions>
</storageModule>
<storageModule moduleId="cdtBuildSystem" version="4.0.0">
<configuration artifactName="${ProjName}" buildArtefactType="org.eclipse.cdt.build.core.buildArtefactType.exe" buildProperties="org.eclipse.cdt.build.core.buildArtefactType=org.eclipse.cdt.build.core.buildArtefactType.exe,org.eclipse.cdt.build.core.buildType=org.eclipse.cdt.build.core.buildType.debug" cleanCommand="rm -rf" description="" id="com.nvidia.cuda.ide.seven_five.configuration.debug.1479727693" name="Debug" parent="com.nvidia.cuda.ide.seven_five.configuration.debug">
<configuration artifactName="${ProjName}" buildArtefactType="org.eclipse.cdt.build.core.buildArtefactType.exe" buildProperties="org.eclipse.cdt.build.core.buildType=org.eclipse.cdt.build.core.buildType.debug,org.eclipse.cdt.build.core.buildArtefactType=org.eclipse.cdt.build.core.buildArtefactType.exe" cleanCommand="rm -rf" description="" id="com.nvidia.cuda.ide.seven_five.configuration.debug.1479727693" name="Debug" parent="com.nvidia.cuda.ide.seven_five.configuration.debug">
<folderInfo id="com.nvidia.cuda.ide.seven_five.configuration.debug.1479727693." name="/" resourcePath="">
<toolChain id="com.nvidia.cuda.tools.toolchain.seven_five.exe.debug.1735809242" name="CUDA Toolkit 8.0" superClass="com.nvidia.cuda.tools.toolchain.seven_five.exe.debug">
<targetPlatform archList="all" binaryParser="com.nvidia.cuda.ide.elf;com.nvidia.cuda.ide.macho;com.nvidia.cuda.ide.cubin" id="com.nvidia.cuda.ide.targetPlatform.1814841241" isAbstract="false" name="Debug Platform" osList="linux,macosx" superClass="com.nvidia.cuda.ide.targetPlatform"/>
@ -37,16 +37,12 @@
</tool>
<tool id="nvcc.linker.base.635344589" name="NVCC Linker" superClass="nvcc.linker.base">
<option id="nvcc.linker.option.libs.1878015233" name="Libraries (-l)" superClass="nvcc.linker.option.libs" valueType="libs">
<listOptionValue builtIn="false" value="boost_chrono"/>
<listOptionValue builtIn="false" value="boost_system"/>
<listOptionValue builtIn="false" value="boost_timer"/>
<listOptionValue builtIn="false" value="cudnn"/>
<listOptionValue builtIn="false" value="cuda"/>
<listOptionValue builtIn="false" value="cublas"/>
</option>
<option id="nvcc.linker.option.paths.1326041662" name="Library search path (-L)" superClass="nvcc.linker.option.paths" valueType="libPaths">
<listOptionValue builtIn="false" value="/usr/local/cuda/lib"/>
<listOptionValue builtIn="false" value="&quot;${workspace_loc:/}/boost/lib64&quot;"/>
<listOptionValue builtIn="false" value="/usr/lib"/>
</option>
<inputType id="nvcc.linker.input.1742167733" superClass="nvcc.linker.input">
@ -60,11 +56,11 @@
</tool>
</toolChain>
</folderInfo>
<fileInfo id="com.nvidia.cuda.ide.seven_five.configuration.debug.1479727693.1232903988" name="mnist_benchmark.cu" rcbsApplicability="disable" resourcePath="src/mnist_benchmark.cu" toolsToInvoke="nvcc.compiler.base.1979453423.1712466877">
<tool id="nvcc.compiler.base.1979453423.1712466877" name="NVCC Compiler" superClass="nvcc.compiler.base.1979453423"/>
<fileInfo id="com.nvidia.cuda.ide.seven_five.configuration.debug.1479727693.1303731853" name="test.cu" rcbsApplicability="disable" resourcePath="src/test.cu" toolsToInvoke="nvcc.compiler.base.1979453423.1311284147">
<tool id="nvcc.compiler.base.1979453423.1311284147" name="NVCC Compiler" superClass="nvcc.compiler.base.1979453423"/>
</fileInfo>
<sourceEntries>
<entry excluding="src/mnist_benchmark.cu|src/validate_encoder_decoder.cu|src/validate_mnist_batch.cu|src/train_mnist.cu|src/validate_mnist.cu|src/npz_converter.cpp" flags="VALUE_WORKSPACE_PATH|RESOLVED" kind="sourcePath" name=""/>
<entry excluding="src/test.cu|src/validate_mnist_batch.cu|src/train_mnist.cu|src/validate_mnist.cu|src/npz_converter.cpp" flags="VALUE_WORKSPACE_PATH|RESOLVED" kind="sourcePath" name=""/>
</sourceEntries>
</configuration>
</storageModule>
@ -147,10 +143,10 @@
</storageModule>
<storageModule moduleId="org.eclipse.cdt.core.LanguageSettingsProviders"/>
<storageModule moduleId="refreshScope" versionNumber="2">
<configuration configurationName="Debug">
<configuration configurationName="Release">
<resource resourceType="PROJECT" workspacePath="/marian"/>
</configuration>
<configuration configurationName="Release">
<configuration configurationName="Debug">
<resource resourceType="PROJECT" workspacePath="/marian"/>
</configuration>
</storageModule>

View File

@ -18,8 +18,6 @@ cuda_add_executable(
test.cu
)
target_link_libraries(marian marian_lib)
cuda_add_executable(
mnist_benchmark
mnist_benchmark.cu
@ -35,11 +33,18 @@ cuda_add_executable(
validate_encoder_decoder.cu
)
cuda_add_executable(
test_nodes
test_nodes.cu
)
target_link_libraries(marian marian_lib)
target_link_libraries(mnist_benchmark marian_lib)
target_link_libraries(validate_mnist_batch marian_lib)
target_link_libraries(validate_encoder_decoder marian_lib)
target_link_libraries(test_nodes marian_lib)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes)
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")

View File

@ -13,6 +13,7 @@ struct Chainable {
virtual ~Chainable() { }
virtual void forward() { }
virtual void backward() { }
virtual void check() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }

View File

@ -125,4 +125,8 @@ Expr dot(Expr a, Expr b) {
return Expr(a.graph(), new DotNodeOp(a, b));
}
Expr cross_entropy(Expr a, Expr b) {
return Expr(a.graph(), new CrossEntropyNodeOp(a, b));
}
}

View File

@ -112,4 +112,6 @@ inline Expr mean(Expr a, Args ...args) {
}
}
Expr cross_entropy(Expr a, Expr b);
}

View File

@ -40,19 +40,20 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
biases.emplace_back(
g.param(shape={1, out},
init=normal()));
}
auto probs = named(
softmax(dot(layers.back(), weights.back()) + biases.back()),
"probs"
);
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
"scores");
auto cost = mean(cross_entropy(scores, y), axis=0);
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
auto costreg = named(
cost, "cost"
);
// If we uncomment the line below, this will just horribly diverge.
// auto dummy_probs = named(softmax(scores), "dummy_probs");
std::cerr << timer.format(5, "%ws") << std::endl;
return g;
}
@ -142,7 +143,7 @@ int main(int argc, char** argv) {
g.forward(BATCH_SIZE);
std::vector<float> bResults;
bResults << g["probs"].val();
bResults << g["scores"].val();
results.insert(results.end(), bResults.begin(), bResults.end());
}

View File

@ -92,8 +92,7 @@ struct UnaryNodeOp : public Node {
template <typename ...Args>
UnaryNodeOp(ChainPtr a, Args ...args)
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
args...),
a_(a) {}
args...), a_(a) {}
};
struct LogitNodeOp : public UnaryNodeOp {
@ -111,6 +110,10 @@ struct LogitNodeOp : public UnaryNodeOp {
a_->grad(), adj_, val_);
}
void check() {
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
@ -144,10 +147,6 @@ struct TanhNodeOp : public UnaryNodeOp {
};
// @TODO, make this numerically safe(r):
// softmax(X) = softmax_safe(X - max(X, axis=1))
// Probably best to do this directly in Softmax
// function.
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(Args ...args)
@ -155,8 +154,8 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
void forward() {
// B = softmax(A).
val_ = a_->val();
SubtractMax(&val_); // Safe version of softmax.
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
// Safe version of softmax.
Softmax(&val_);
}
@ -171,10 +170,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
// Classification." ICML 2016.
// http://jmlr.org/proceedings/papers/v48/martins16.pdf
Tensor result(adj_.shape());
thrust::copy(adj_.begin(), adj_.end(), result.begin());
SubtractMean(&result, val_);
Element(_1 += _2 * _3, a_->grad(), val_, result);
SoftmaxGrad(a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
@ -183,7 +179,6 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ArgmaxNodeOp : public UnaryNodeOp {
@ -445,4 +440,71 @@ struct DivNodeOp : public BinaryNodeOp {
};
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
struct CrossEntropyNodeOp : public BinaryNodeOp {
template <typename ...Args>
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b,
keywords::shape=newShape(a, b),
args...) { }
Shape newShape(ChainPtr a, ChainPtr b) {
Shape shape1 = a->shape();
Shape shape2 = b->shape();
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
"cross entropy requires dimensions to match");
shape1[1] = 1;
return shape1;
}
// We're caching the softmax probabilities here because we'll need them for
// the backward computation.
void forward() {
// C = -dot(B, log(softmax(A))).
if (probs_) {
probs_.set(0.0);
} else {
probs_.allocate(a_->val().shape(), 0.0);
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
Tensor result(a_->val().shape());
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
SumRowwise(result, val_);
}
// @TODO: In most cases it's wasteful to compute the derivative with respect
// to the second input which is typically an input node in the computation
// graph. In general the backward functions can skip the computation of
// gradients wrt input nodes.
void backward() {
// For each row, the first input derivative is given by adj * (p - y),
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
Tensor result(probs_.shape());
// Compute first input derivative.
Element(_1 = _2 - _3, result, probs_, b_->val());
ScaleRowwise(result, adj_);
Element(_1 += _2, a_->grad(), result);
// Compute second input derivative.
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
ScaleRowwise(result, adj_);
Element(_1 += _2, b_->grad(), result);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
protected:
Tensor probs_;
};
}

View File

@ -12,20 +12,22 @@ static cublasHandle_t create_handle() {
}
cublasHandle_t cublasHandle = create_handle();
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols) {
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = out + j * cols;
float* w = weights + j * cols;
float* gradRow = grad + j * cols;
const float* adjRow = adj + j * cols;
const float* valRow = val + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += w[id] * sp[id];
_sum[threadIdx.x] += valRow[id] * adjRow[id];
}
}
__syncthreads();
@ -41,25 +43,25 @@ __global__ void gSubtractMean(float* out, float* weights,
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _sum[0];
gradRow[id] += valRow[id] * (adjRow[id] - _sum[0]);
}
}
}
}
void SubtractMean(Tensor* Out, Tensor &Weights) {
// Out and Weights are both m-by-k matrices, passed as input.
// A weighted average of each row of Out (according to the weights
// specified in Weights) is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
// grad and val are both m-by-k matrices, passed as input.
// A weighted average of each row of grad (according to the weights
// specified in val) is computed and subtracted from Out.
// adj is multiplied for each element to get backward step in autodiff
int m = grad.shape()[0];
int k = grad.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, k);
int shared = sizeof(float) * threads * 2;
gSubtractMean<<<blocks, threads, shared>>>(Out->data(), Weights.data(),
m, k);
gSoftmaxGrad<<<blocks, threads, shared>>>(grad.data(), adj.data(), val.data(),
m, k);
cudaStreamSynchronize(0);
}
@ -155,11 +157,15 @@ void Softmax(Tensor* Out) {
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) {
__global__ void gArgmax(float *out, const float *data, size_t rows, size_t cols) {
size_t row = blockIdx.x;
size_t startInd = row * cols;
float maxScore = -99999;
@ -182,7 +188,7 @@ void Argmax(Tensor* Out, const Tensor* In) {
int blocks = m; //std::min(MAX_BLOCKS, (int) m);
int threads = k; //std::min(MAX_THREADS, (int) k);
//int shared = sizeof(float) * threads * 2;
gArgMax<<<blocks, threads>>>(Out->data(), In->data(), m, k);
gArgmax<<<blocks, threads>>>(Out->data(), In->data(), m, k);
cudaStreamSynchronize(0);
}
@ -224,4 +230,48 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B,
return temp;
}
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result) {
size_t rows = A.shape()[0];
size_t cols = A.shape()[1];
thrust::device_vector<float> d_ones(cols, 1.f);
Float alpha = 1.f;
Float beta = 0.f;
cublasSgemv(handle, CUBLAS_OP_T, cols, rows, &alpha,
A.data(), cols,
thrust::raw_pointer_cast(d_ones.data()), 1, &beta,
result.data(), 1);
return result;
}
Tensor SumRowwise(const Tensor A, Tensor result) {
Tensor temp = SumRowwise(cublasHandle, A, result);
return temp;
}
// @TODO: replace this by something else when broadcast elementwise operations
// are ready.
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) rowOut[i] *= scalingFactors[j];
}
}
}
}
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
Float* d_out = Out.data();
const Float* d_in = ScalingFactors.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gScaleRowwise<<<blocks, threads>>>(d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
}

View File

@ -142,20 +142,11 @@ void Element(Functor functor,
cudaStreamSynchronize(0);
}
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols);
void SubtractMean(Tensor* Out, Tensor &Weights);
__global__ void gSubtractMax(float* out, size_t rows, size_t cols);
void SubtractMax(Tensor* Out);
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
void Softmax(Tensor* Out);
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols);
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void Argmax(Tensor* Out, const Tensor* In);
@ -165,4 +156,13 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta = 0);
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result);
Tensor SumRowwise(const Tensor A, Tensor result);
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols);
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors);
}

View File

@ -17,33 +17,33 @@ string output(const std::vector<float> &vec)
return strm.str();
}
void testArgMax()
{
using namespace std;
using namespace marian;
std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
cerr << "hVec =" << output(hVec) << endl;
thrust::device_vector<float> dVec(8);
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
float *data = thrust::raw_pointer_cast(dVec.data());
thrust::device_vector<float> dLabel(4);
float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
std::vector<float> hVec2(8);
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
cerr << "hVec2=" << output(hVec2) << endl;
std::vector<float> hLabel(4);
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
cerr << "hLabel=" << output(hLabel) << endl;
exit(0);
}
//void testArgMax()
//{
// using namespace std;
// using namespace marian;
//
// std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
// cerr << "hVec =" << output(hVec) << endl;
//
// thrust::device_vector<float> dVec(8);
// thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
// float *data = thrust::raw_pointer_cast(dVec.data());
//
// thrust::device_vector<float> dLabel(4);
// float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
//
// gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
//
// std::vector<float> hVec2(8);
// thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
// cerr << "hVec2=" << output(hVec2) << endl;
//
// std::vector<float> hLabel(4);
// thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
// cerr << "hLabel=" << output(hLabel) << endl;
//
// exit(0);
//}
///////////////////////////////////////////////////////
int main(int argc, char** argv) {
@ -106,7 +106,7 @@ int main(int argc, char** argv) {
Yp.emplace_back(softmax(dot(H[t], Why) + by));
cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1);
}
auto graph = -mean(cross_entropy, axis=0, name="cost");
Expr graph = -mean(cross_entropy, axis=0, name="cost");
for (int t = 0; t < num_inputs; ++t) {
Tensor Xt({batch_size, input_size});

47
src/test_nodes.cu Normal file
View File

@ -0,0 +1,47 @@
#include <vector>
#include <random>
#include "marian.h"
#include "expression_graph.h"
#include "keywords.h"
#include "definitions.h"
int main(int argc, char** argv)
{
using namespace std;
using namespace marian;
using namespace keywords;
int input_size = 10;
int batch_size = 25;
// define graph
ExpressionGraph g;
Expr inputExpr = g.input(shape={batch_size, input_size});
// create data
random_device rnd_device;
mt19937 mersenne_engine(rnd_device());
uniform_real_distribution<float> dist(1, 52);
auto gen = std::bind(dist, mersenne_engine);
std::vector<float> values(batch_size * input_size);
generate(begin(values), end(values), gen);
Tensor inputTensor({batch_size, input_size});
thrust::copy(values.begin(), values.end(), inputTensor.begin());
inputExpr = inputTensor;
Expr softMaxExpr = softmax(inputExpr);
g.forward(batch_size);
g.backward();
std::cout << g.graphviz() << std::endl;
std::cerr << "inputTensor=" << inputTensor.Debug() << std::endl;
Tensor softMaxTensor = softMaxExpr.val();
std::cerr << "softMaxTensor=" << softMaxTensor.Debug() << std::endl;
}