Merge ../Marian

This commit is contained in:
Hieu Hoang 2016-09-21 18:28:48 +01:00
commit 3a8659c924
12 changed files with 414 additions and 183 deletions

View File

@ -37,3 +37,10 @@ Compilation with `cmake > 3.5`:
To compile API documentation using Doxygen, first cd to the build directory, and then:
make doc
To test, first compile, then:
cd examples/mnist
make
cd ../../build
./mnist_benchmark

View File

@ -20,6 +20,11 @@ cuda_add_executable(
test.cu
)
cuda_add_executable(
test_cudnn
cudnn.cu
)
cuda_add_executable(
mnist_benchmark
mnist_benchmark.cu
@ -41,12 +46,13 @@ cuda_add_executable(
)
target_link_libraries(marian marian_lib)
target_link_libraries(test_cudnn marian_lib)
target_link_libraries(mnist_benchmark marian_lib)
target_link_libraries(validate_mnist_batch marian_lib)
target_link_libraries(validate_encoder_decoder marian_lib)
target_link_libraries(test_nodes marian_lib)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes test_cudnn)
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")

53
src/cudnn.cu Normal file
View File

@ -0,0 +1,53 @@
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <cudnn.h>
#include <boost/timer/timer.hpp>
#include "tensor.h"
#include "tensor_operators.h"
#include "param_initializers.h"
using namespace marian;
int main() {
int d = 4;
Tensor in({d, d});
Tensor out({d, d});
Tensor adj({d, d}, 1);
auto f = uniform(-5, 5);
f(in);
std::cerr << in.Debug() << std::endl;
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
CudnnLogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
LogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
return 0;
}

View File

@ -58,6 +58,10 @@ Expr softmax(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a));
}
Expr logsoftmax(Expr a) {
return Expr(a.graph(), new LogSoftmaxNodeOp(a));
}
Expr argmax(Expr a) {
return Expr(a.graph(), new ArgmaxNodeOp(a));
}

View File

@ -106,6 +106,8 @@ Expr softmax_slow(Expr a, Args ...args) {
Expr softmax(Expr a);
Expr logsoftmax(Expr a);
Expr argmax(Expr a);
// inefficient

View File

@ -45,8 +45,8 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
"scores");
auto cost = mean(cross_entropy(scores, y), axis=0);
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
//auto cost = mean(cross_entropy(scores, y), axis=0);
auto cost = mean(-sum(y * logsoftmax(scores), axis=1), axis=0);
auto costreg = named(
cost, "cost"
);
@ -115,14 +115,14 @@ int main(int argc, char** argv) {
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, LABEL_SIZE});
std::cout << g.graphviz() << std::endl;
//std::cout << g.graphviz() << std::endl;
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});
boost::timer::cpu_timer total;
Adam opt(0.0002);
for(int i = 1; i <= 30; ++i) {
for(int i = 1; i <= 50; ++i) {
boost::timer::cpu_timer timer;
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
float cost = 0;

View File

@ -239,20 +239,20 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
return shape1;
}
// We're caching the softmax probabilities here because we'll need them for
// We're caching the logsoftmax probabilities here because we'll need them for
// the backward computation.
void forward() {
// C = -dot(B, log(softmax(A))).
// C = -dot(B, logsoftmax(A)).
if (probs_) {
probs_.set(0.0);
} else {
probs_.allocate(a_->val().shape(), 0.0);
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
CudnnLogSoftmax(probs_, a_->val());
if(!result_)
result_.allocate(a_->val().shape());
Element(_1 = -_2 * Log(_3), result_, b_->val(), probs_);
Element(_1 = -_2 * _3, result_, b_->val(), probs_);
SumRowwise(result_, val_);
}
@ -261,22 +261,19 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
// graph. In general the backward functions can skip the computation of
// gradients wrt input nodes.
void backward() {
// For each row, the first input derivative is given by adj * (p - y),
// We are using logsoftmax for this and cached probs are logs.
// For each row, the first input derivative is given by adj * (exp(p) - y),
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
if(!result_)
result_.allocate(probs_.shape());
// The second input derivative is -adj*p.
// Compute first input derivative.
Element(_1 = _2 - _3, result_, probs_, b_->val());
ScaleRowwise(result_, adj_);
Element(_1 += _2, a_->grad(), result_);
Element(_1 += _2 * (Exp(_3) - _4),
a_->grad(), adj_, probs_, b_->val());
// Compute second input derivative.
Element(_1 = -Log(_2), result_, probs_); // @TODO: use a cached log here.
ScaleRowwise(result_, adj_);
Element(_1 += _2, b_->grad(), result_);
Element(_1 -= _2 * _3, b_->grad(),
adj_, probs_);
}
virtual std::string graphviz() {

View File

@ -22,7 +22,6 @@ struct UnaryNodeOp : public Node {
// use df/dx to calc grad
backward();
//cerr << "orig a_->val()=" << a_->val().Debug() << endl;
//cerr << "orig a_->grad()=" << a_->grad().Debug() << endl;
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
@ -167,10 +166,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
: UnaryNodeOp(args...) { }
void forward() {
// B = softmax(A).
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
// Safe version of softmax.
Softmax(&val_);
CudnnSoftmax(val_, a_->val());
}
void backward() {
@ -196,6 +192,33 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
};
};
struct LogSoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogSoftmaxNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
CudnnLogSoftmax(val_, a_->val());
}
void backward() {
// Based on the description for softmax, we have logsoftmax:
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
CudnnLogSoftmaxGrad(a_->grad(), adj_, val_);
//LogSoftmaxGrad(a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("log-softmax")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ArgmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxNodeOp(ChainPtr a, Args ...args)
@ -285,8 +308,6 @@ struct NegNodeOp : public UnaryNodeOp {
void backward() {
Element(_1 += -_2, a_->grad(), adj_);
//std::cerr << "a_->grad=" << a_->grad().Debug() << std::endl;
}
virtual std::string graphviz() {

View File

@ -22,6 +22,8 @@
// SOFTWARE.
#include <cublas_v2.h>
#include <cudnn.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <numeric>
@ -77,6 +79,9 @@ class TensorImpl {
thrust::device_vector<Float> data_; /*< Vector of data that Tensor is managing on GPU. */
size_t tno_; /*< Tensor number */
static size_t tensorCounter; /*< Static counter of created Tensors */
// cuDNN stuff
cudnnTensorDescriptor_t cudnnDesc_;
public:
typedef Float value_type; /*< Tensor value type */
@ -100,11 +105,20 @@ class TensorImpl {
int size = GetTotalSize(shape_);
data_.resize(size, value);
cudnnCreateTensorDescriptor(&cudnnDesc_);
cudnnSetTensor4dDescriptorEx(cudnnDesc_, CUDNN_DATA_FLOAT,
shape_[0], shape_[1], 1, 1,
shape_[1], 1, 1, 1);
}
TensorImpl(const TensorImpl&) = delete;
TensorImpl(TensorImpl&&) = delete;
~TensorImpl() {
cudnnDestroyTensorDescriptor(cudnnDesc_);
}
/**
* @brief Get the i-th element of Tensor vector.
*
@ -243,6 +257,11 @@ class TensorImpl {
}
return strm.str();
}
cudnnTensorDescriptor_t cudnn() {
return cudnnDesc_;
}
};
template <typename Type>
@ -483,6 +502,10 @@ class Tensor {
TensorView gpu() {
return TensorView(*this);
}
cudnnTensorDescriptor_t cudnn() {
return pimpl_->cudnn();
}
};
/**

View File

@ -29,7 +29,175 @@ static cublasHandle_t create_handle() {
cublasCreate(&cublasHandle);
return cublasHandle;
}
static cudnnHandle_t create_handle_dnn() {
cudnnHandle_t cudnnHandle;
cudnnCreate(&cudnnHandle);
return cudnnHandle;
}
cublasHandle_t cublasHandle = create_handle();
cudnnHandle_t cudnnHandle = create_handle_dnn();
void CudnnSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnLogSoftmax(Tensor out, Tensor in) {
float alpha = 1, beta = 0;
cudnnSoftmaxForward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
in.cudnn(),
in.data(),
&beta,
out.cudnn(),
out.data());
cudaDeviceSynchronize();
}
void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
float alpha = 1, beta = 0;
cudnnSoftmaxBackward(cudnnHandle,
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
val.cudnn(),
val.data(),
adj.cudnn(),
adj.data(),
&beta,
grad.cudnn(),
grad.data());
cudaDeviceSynchronize();
}
void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
float alpha = 1, beta = 0;
cudnnSoftmaxBackward(cudnnHandle,
CUDNN_SOFTMAX_LOG,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
val.cudnn(),
val.data(),
adj.cudnn(),
adj.data(),
&beta,
grad.cudnn(),
grad.data());
cudaDeviceSynchronize();
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if (j < rows) {
extern __shared__ float _share[];
float* _max = _share + blockDim.x;
float* sp = out + j * cols;
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if (id < cols) {
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if (threadIdx.x < (len >> 1)) {
if (_max[threadIdx.x + skip] > _max[threadIdx.x]) {
_max[threadIdx.x] = _max[threadIdx.x + skip];
}
}
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _max[0];
}
}
}
}
void SubtractMax(Tensor* Out) {
// Out is a m-by-k matrix, passed as input.
// The max element of each row of Out is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = softMaxP + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
sp[id] = __expf(sp[id]);
_sum[threadIdx.x] += sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] /= _sum[0];
}
}
}
}
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
@ -84,123 +252,22 @@ void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaStreamSynchronize(0);
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if (j < rows) {
extern __shared__ float _share[];
float* _max = _share + blockDim.x;
float* sp = out + j * cols;
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if (id < cols) {
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if (threadIdx.x < (len >> 1)) {
if (_max[threadIdx.x + skip] > _max[threadIdx.x]) {
_max[threadIdx.x] = _max[threadIdx.x + skip];
}
}
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _max[0];
}
}
}
}
void SubtractMax(Tensor* Out) {
// Out is a m-by-k matrix, passed as input.
// The max element of each row of Out is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
//template <class T>
//__global__ void gClipNorm(T t) {
// int rows = t.rows();
// int cols = t.cols();
//
// for(int bid = 0; bid < rows; bid += gridDim.x) {
// int i = bid + blockIdx.x;
// if(i < rows) {
// extern __shared__ float _share[];
// float* _sum = _share + blockDim.x;
// _sum[threadIdx.x] = 0.0;
// for(int tid = 0; tid < cols; tid += blockDim.x) {
// int j = tid + threadIdx.x;
// if(j < cols)
// _sum[threadIdx.x] += powf(t(i,j), 2.0f);
// }
// __syncthreads();
// int len = blockDim.x;
// while(len != 1) {
// __syncthreads();
// int skip = (len + 1) >> 1;
// if(threadIdx.x < (len >> 1))
// _sum[threadIdx.x] += _sum[threadIdx.x + skip];
// len = (len + 1) >> 1;
// }
// __syncthreads();
// float total = 0;
// if(j == 0) {
// for()
// }
// for(int tid = 0; tid < cols; tid += blockDim.x){
// int j = tid + threadIdx.x;
// if(j < cols)
// sp[j] /= _sum[0];
// }
// }
// }
//}
//
//void ClipNorm(Tensor out, float threshold);
// size_t m = out.shape()[0];
// size_t k = out.shape()[1];
//
// int blocks = std::min(MAX_BLOCKS, (int) m);
// int threads = std::min(MAX_THREADS, (int) k);
// int shared = sizeof(float) * threads * 2;
// gClipNorm<<<blocks, threads, shared>>>(out.gpu());
// cudaStreamSynchronize(0);
//}
///////////////////////////////////////////////////////
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
__global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = softMaxP + j * cols;
float* gradRow = grad + j * cols;
const float* adjRow = adj + j * cols;
const float* valRow = val + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
sp[id] = __expf(sp[id]);
_sum[threadIdx.x] += sp[id];
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp because we chached logsoftmax
}
}
__syncthreads();
@ -216,23 +283,26 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] /= _sum[0];
gradRow[id] += adjRow[id] - _sum[0];
}
}
}
}
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
// grad and val are both m-by-k matrices, passed as input.
// A weighted average of each row of grad (according to the weights
// specified in val) is computed and subtracted from Out.
// adj is multiplied for each element to get backward step in autodiff
int m = grad.shape()[0];
int k = grad.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gLogSoftmaxGrad<<<blocks, threads, shared>>>(grad.data(),
adj.data(), val.data(),
m, k);
cudaStreamSynchronize(0);
}
@ -320,30 +390,5 @@ Tensor SumRowwise(const Tensor A, Tensor result) {
return temp;
}
// @TODO: replace this by something else when broadcast elementwise operations
// are ready.
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) rowOut[i] *= scalingFactors[j];
}
}
}
}
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
Float* d_out = Out.data();
const Float* d_in = ScalingFactors.data();
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
gScaleRowwise<<<blocks, threads>>>(d_out, d_in,
Out.shape()[0], Out.shape()[1]);
cudaStreamSynchronize(0);
}
}

View File

@ -158,6 +158,13 @@ void SubtractMax(Tensor* Out);
void Softmax(Tensor* Out);
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void CudnnSoftmax(Tensor out, Tensor in);
void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void CudnnLogSoftmax(Tensor out, Tensor in);
void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void Argmax(Tensor* Out, const Tensor* In);

View File

@ -19,6 +19,9 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include <chrono>
#include <boost/timer/timer.hpp>
#include "marian.h"
#include "mnist.h"
#include "vocab.h"
@ -27,6 +30,15 @@
using namespace marian;
using namespace keywords;
void random_permutation(int n, std::vector<size_t> *indices) {
std::srand(std::time(0));
indices->clear();
for(size_t i = 0; i < n; ++i) {
indices->push_back(i);
}
std::random_shuffle(indices->begin(), indices->end());
}
ExpressionGraph build_graph(int source_vocabulary_size,
int target_vocabulary_size,
int embedding_size,
@ -101,14 +113,17 @@ ExpressionGraph build_graph(int source_vocabulary_size,
std::cerr << "Building output layer..." << std::endl;
// Softmax layer and cost function.
std::vector<Expr> Yp;
Yp.emplace_back(named(softmax(dot(h0_d, Why) + by), "pred"));
Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1);
//std::vector<Expr> Yp;
//Yp.emplace_back(named(softmax(dot(h0_d, Why) + by), "pred"));
//Expr word_cost = sum(Y[0] * log(Yp[0]), axis=1);
Expr word_cost = cross_entropy(dot(h0_d, Why) + by, Y[0]);
for (int t = 1; t <= num_outputs; ++t) {
Yp.emplace_back(named(softmax(dot(S[t-1], Why) + by), "pred"));
cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1);
//Yp.emplace_back(named(softmax(dot(S[t-1], Why) + by), "pred"));
//word_cost = word_cost + sum(Y[t] * log(Yp[t]), axis=1);
word_cost = word_cost + cross_entropy(dot(S[t-1], Why) + by, Y[t]);
}
auto cost = named(-mean(cross_entropy, axis=0), "cost");
//auto cost = named(-mean(word_cost, axis=0), "cost");
auto cost = named(mean(word_cost, axis=0), "cost");
std::cerr << "Done." << std::endl;
@ -125,7 +140,7 @@ int main(int argc, char** argv) {
// Right now we're only reading the first few sentence pairs, and defining
// that as the step size.
int batch_size = 64;
int batch_size = 100;
int num_source_tokens = -1;
int num_target_tokens = -1;
std::vector<std::vector<size_t> > source_sentences, target_sentences;
@ -142,7 +157,7 @@ int main(int argc, char** argv) {
if (num_target_tokens < 0 || target_ids.size() > num_target_tokens) {
num_target_tokens = target_ids.size();
}
if (source_sentences.size() == batch_size) break;
//if (source_sentences.size() == 1000) break;
}
std::cerr << "Done." << std::endl;
std::cerr << source_sentences.size()
@ -214,14 +229,65 @@ int main(int argc, char** argv) {
std::cout << g.graphviz() << std::endl;
std::cerr << "Training..." << std::endl;
int num_training_examples = source_sentences.size();
std::cerr << num_training_examples << " training examples." << std::endl;
boost::timer::cpu_timer total;
Adam opt;
int num_epochs = 20;
for(size_t epoch = 0; epoch < num_epochs; ++epoch) {
opt(g, batch_size); // Full batch for now.
std::cerr << "Epoch " << epoch << ": "
<< "Loss = " << g["cost"].val()[0]
<< std::endl;
for(int epoch = 1; epoch <= num_epochs; ++epoch) {
boost::timer::cpu_timer timer;
// TODO: shuffle the batches.
// shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
std::vector<size_t> indices;
int num_batches = num_training_examples / batch_size;
random_permutation(num_batches, &indices);
float cost = 0;
for(int j = 0; j < num_batches; j++) {
int b = indices[j]; // Batch index.
// Attaching the data to the computation graph...
// Convert the data to dense one-hot vectors.
// TODO: make the graph handle sparse indices with a proper lookup layer.
// TODO: use different sentence lengths for the batches.
for (int t = 0; t < num_source_tokens; ++t) {
Tensor Xt({batch_size, static_cast<int>(source_vocab.Size())});
std::vector<float> values(batch_size * source_vocab.Size(), 0.0);
int k = 0;
for (int i = 0; i < batch_size; ++i) {
values[k + source_sentences[i + b*batch_size][t]] = 1.0;
k += source_vocab.Size();
}
thrust::copy(values.begin(), values.end(), Xt.begin());
// Attach this slice to the graph.
std::stringstream ss;
ss << "X" << t;
g[ss.str()] = Xt;
}
for (int t = 0; t < num_target_tokens; ++t) {
Tensor Yt({batch_size, static_cast<int>(target_vocab.Size())});
std::vector<float> values(batch_size * target_vocab.Size(), 0.0);
int k = 0;
for (int i = 0; i < batch_size; ++i) {
values[k + target_sentences[i + b*batch_size][t]] = 1.0;
k += target_vocab.Size();
}
thrust::copy(values.begin(), values.end(), Yt.begin());
// Attach this slice to the graph.
std::stringstream ss;
ss << "Y" << t;
g[ss.str()] = Yt;
}
opt(g, batch_size);
cost += g["cost"].val()[0];
}
std::cerr << "Epoch: " << epoch << " - Cost: "
<< cost / num_training_examples * batch_size
<< " - " << timer.format(3, "%ws") << std::endl;
}
std::cerr << "Total: " << total.format(3, "%ws") << std::endl;
return 0;
}