fixed conflict

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-14 16:51:42 +02:00
commit aa22b47fcc
9 changed files with 173 additions and 96 deletions

View File

@ -10,7 +10,7 @@ Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
keywords::shape={1,1})) {}
Tensor Expr::val() {
Tensor &Expr::val() {
return pimpl_->val();
}

View File

@ -15,7 +15,7 @@ class Expr {
return *this;
}
Tensor val();
Tensor &val();
Tensor grad();
void forward(size_t batchSize);

View File

@ -17,7 +17,7 @@ struct Chainable {
virtual void allocate(size_t) = 0;
virtual const Shape& shape() = 0;
virtual DataType val() = 0;
virtual DataType &val() = 0;
virtual DataType grad() = 0;
virtual void setVal(DataType t) {
UTIL_THROW2("Tensors can only be assigned to input nodes");
@ -82,7 +82,7 @@ class Node : public Chainable<Tensor>,
}
}
virtual Tensor val() {
virtual Tensor &val() {
UTIL_THROW_IF2(!val_, "Tensor has not been allocated");
return val_;
};

View File

@ -80,8 +80,8 @@ struct SigmoidNodeOp : public UnaryNodeOp {
}
void backward() {
Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)),
a_->grad(), adj_, a_->val());
Element(_1 += _2 * _3 * (1 - _3),
a_->grad(), adj_, val_);
}
};
@ -96,8 +96,8 @@ struct TanhNodeOp : public UnaryNodeOp {
}
void backward() {
Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)),
a_->grad(), adj_, a_->val());
Element(_1 += _2 * (1 - _3 * _3),
a_->grad(), adj_, val_);
}
};
@ -139,7 +139,6 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
SoftmaxNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a),
args...) { }
Shape newShape(ChainPtr a) {
Shape shape = a->shape();
return shape;
@ -152,9 +151,14 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
}
void backward() {
// TODO
Element(_1 += _2 * Exp(_3),
a_->grad(), adj_, a_->val());
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
// where avg = p'*dy and p is the softmax output (probabilities).
Tensor result = adj_;
SubtractMean(&result, val_);
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
// to sum gradients from different graph parts.
Prod(a_->grad(), adj_, result, false, false, 1.0);
}
};

View File

@ -59,8 +59,7 @@ inline std::vector<T> Tokenize( const std::string &input
void Tensor::Load(const std::string &path)
{
size_t totSize = std::accumulate(pimpl_->shape().begin(), pimpl_->shape().end(),
1, std::multiplies<int>());
size_t totSize = GetTotalSize(pimpl_->shape());
cerr << "totSize=" << totSize << endl;
std::vector<float> hostData(totSize);
@ -81,12 +80,12 @@ void Tensor::Load(const std::string &path)
}
strm.close();
Load(hostData);
Load(hostData.begin(), hostData.begin());
}
void Tensor::Load(const std::vector<float> &values)
void Tensor::Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end)
{
pimpl_->set(values);
pimpl_->set(begin, end);
}
}

View File

@ -48,6 +48,13 @@ inline std::string Debug(const Shape &shape)
return strm.str();
}
inline size_t GetTotalSize(const Shape &shape)
{
size_t ret = std::accumulate(shape.begin(), shape.end(),
1, std::multiplies<int>());
return ret;
}
template<class Float>
class TensorImpl {
private:
@ -81,8 +88,7 @@ class TensorImpl {
std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl;
int size = std::accumulate(shape_.begin(), shape_.end(),
1, std::multiplies<int>());
int size = GetTotalSize(shape_);
data_.resize(size, value);
cudnnCreateTensorDescriptor(&desc_);
switch (shape_.size()) {
@ -152,19 +158,32 @@ class TensorImpl {
thrust::fill(data_.begin(), data_.end(), value);
}
void set(const std::vector<Float> &values) {
size_t totSize = std::accumulate(shape().begin(), shape().end(),
1, std::multiplies<int>());
std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl;
assert(totSize == values.size());
thrust::copy(values.begin(), values.end(), data_.begin());
void set(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end) {
size_t totSize = GetTotalSize(shape());
//std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl;
//assert(totSize == values.size());
thrust::copy(begin, end, data_.begin());
}
std::string Debug() const
{
std::stringstream strm;
assert(shape_.size());
strm << "shape=" << marian::Debug(shape_);
strm << "shape=" << marian::Debug(shape_) << std::endl;
// values
size_t totSize = GetTotalSize(shape());
std::vector<Float> values(totSize);
thrust::copy(data_.begin(), data_.end(), values.begin());
size_t ind = 0;
for (size_t i = 0; i < shape()[0]; ++i) {
for (size_t j = 0; j < shape()[1]; ++j) {
strm << values[ind] << " ";
++ind;
}
strm << std::endl;
}
return strm.str();
}
};
@ -256,7 +275,7 @@ class Tensor {
}
void Load(const std::string &path);
void Load(const std::vector<float> &values);
void Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
};

View File

@ -2,7 +2,57 @@
namespace marian {
// TODO: implement this.
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
extern __shared__ float _share[];
float* _sum = _share + blockDim.x;
float* sp = out + j * cols;
float* w = weights + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += w[id] * sp[id];
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _sum[0];
}
}
}
}
void SubtractMean(Tensor* Out, Tensor &Weights) {
// Out and Weights are both m-by-k matrices, passed as input.
// A weighted average of each row of Out (according to the weights
// specified in Weights) is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMean<<<blocks, threads, shared>>>(Out->data(), Weights.data(),
m, k);
cudaStreamSynchronize(0);
}
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
@ -37,7 +87,6 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
}
}
// TODO: implement this.
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];

View File

@ -142,6 +142,11 @@ void Element(Functor functor,
cudaStreamSynchronize(0);
}
__global__ void gSubtractMean(float* out, float* weights,
size_t rows, size_t cols);
void SubtractMean(Tensor* Out, Tensor &Weights);
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
void Softmax(Tensor* Out);

View File

@ -12,6 +12,7 @@ int main(int argc, char** argv) {
using namespace marian;
using namespace keywords;
const size_t BATCH_SIZE = 500;
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
@ -21,62 +22,62 @@ int main(int argc, char** argv) {
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
auto z = dot(x, w) + b;
auto pred = softmax(z);
//auto decision = argmax(pred, axis=1);
Expr z = dot(x, w) + b;
Expr lr = softmax(z, axis=1, name="pred");
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
//cerr << "x=" << Debug(lr.val().shape()) << endl;
auto cost = -mean(sum(y * log(pred), axis=1),
axis=0);
cerr << "pred=" << pred.Debug() << endl;
#if 0
int numofdata;
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
//vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
//vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE);
cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
cerr << "numofdata=" << numofdata << endl;
size_t startInd = 0;
size_t startIndData = 0;
while (startInd < numofdata) {
size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd;
cerr << "startInd=" << startInd
<< " startIndData=" << startIndData
<< " batchSize=" << batchSize << endl;
Tensor tx({numofdata, IMAGE_SIZE}, 1);
Tensor ty({numofdata, LABEL_SIZE}, 1);
tx.Load(images);
ty.Load(labels);
tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE);
ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize);
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
#else
Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1);
#endif
//cerr << "tx=" << Debug(tx.shape()) << endl;
//cerr << "ty=" << Debug(ty.shape()) << endl;
x = tx;
y = ty;
cost.forward(500);
cerr << "x=" << Debug(x.val().shape()) << endl;
cerr << "y=" << Debug(y.val().shape()) << endl;
std::cerr << "Result: ";
for (auto val : pred.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
std::cerr << "Result: ";
for (auto val : pred.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
pred.val().Print();
std::cerr << "Log-likelihood: ";
for (auto val : cost.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
cost.val().Print();
cost.backward();
graph.forward(batchSize);
cerr << "w=" << Debug(w.val().shape()) << endl;
cerr << "b=" << Debug(b.val().shape()) << endl;
std::cerr << "z: " << Debug(z.val().shape()) << endl;
std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
//std::cerr << "scores=" << scores.val().Debug() << endl;
//std::cerr << "lr=" << lr.val().Debug() << endl;
graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;
startInd += batchSize;
startIndData += batchSize * IMAGE_SIZE;
}
// XOR
/*