mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
fixed conflict
This commit is contained in:
commit
aa22b47fcc
@ -10,7 +10,7 @@ Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
|
||||
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
|
||||
keywords::shape={1,1})) {}
|
||||
|
||||
Tensor Expr::val() {
|
||||
Tensor &Expr::val() {
|
||||
return pimpl_->val();
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ class Expr {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Tensor val();
|
||||
Tensor &val();
|
||||
Tensor grad();
|
||||
|
||||
void forward(size_t batchSize);
|
||||
|
@ -17,7 +17,7 @@ struct Chainable {
|
||||
virtual void allocate(size_t) = 0;
|
||||
|
||||
virtual const Shape& shape() = 0;
|
||||
virtual DataType val() = 0;
|
||||
virtual DataType &val() = 0;
|
||||
virtual DataType grad() = 0;
|
||||
virtual void setVal(DataType t) {
|
||||
UTIL_THROW2("Tensors can only be assigned to input nodes");
|
||||
@ -82,7 +82,7 @@ class Node : public Chainable<Tensor>,
|
||||
}
|
||||
}
|
||||
|
||||
virtual Tensor val() {
|
||||
virtual Tensor &val() {
|
||||
UTIL_THROW_IF2(!val_, "Tensor has not been allocated");
|
||||
return val_;
|
||||
};
|
||||
|
@ -80,8 +80,8 @@ struct SigmoidNodeOp : public UnaryNodeOp {
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * Sigma(_3) * (1 - Sigma(_3)),
|
||||
a_->grad(), adj_, a_->val());
|
||||
Element(_1 += _2 * _3 * (1 - _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
};
|
||||
|
||||
@ -96,8 +96,8 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (1 - Tanh(_3) * Tanh(_3)),
|
||||
a_->grad(), adj_, a_->val());
|
||||
Element(_1 += _2 * (1 - _3 * _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
};
|
||||
|
||||
@ -139,7 +139,6 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
@ -152,9 +151,14 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// TODO
|
||||
Element(_1 += _2 * Exp(_3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
// For each row, the Jacobian times vector is given by:
|
||||
// J * dy = p .* (dy - avg*1)
|
||||
// where avg = p'*dy and p is the softmax output (probabilities).
|
||||
Tensor result = adj_;
|
||||
SubtractMean(&result, val_);
|
||||
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
|
||||
// to sum gradients from different graph parts.
|
||||
Prod(a_->grad(), adj_, result, false, false, 1.0);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -59,8 +59,7 @@ inline std::vector<T> Tokenize( const std::string &input
|
||||
|
||||
void Tensor::Load(const std::string &path)
|
||||
{
|
||||
size_t totSize = std::accumulate(pimpl_->shape().begin(), pimpl_->shape().end(),
|
||||
1, std::multiplies<int>());
|
||||
size_t totSize = GetTotalSize(pimpl_->shape());
|
||||
cerr << "totSize=" << totSize << endl;
|
||||
std::vector<float> hostData(totSize);
|
||||
|
||||
@ -81,12 +80,12 @@ void Tensor::Load(const std::string &path)
|
||||
}
|
||||
strm.close();
|
||||
|
||||
Load(hostData);
|
||||
Load(hostData.begin(), hostData.begin());
|
||||
}
|
||||
|
||||
void Tensor::Load(const std::vector<float> &values)
|
||||
void Tensor::Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end)
|
||||
{
|
||||
pimpl_->set(values);
|
||||
pimpl_->set(begin, end);
|
||||
}
|
||||
|
||||
}
|
||||
|
39
src/tensor.h
39
src/tensor.h
@ -48,6 +48,13 @@ inline std::string Debug(const Shape &shape)
|
||||
return strm.str();
|
||||
}
|
||||
|
||||
inline size_t GetTotalSize(const Shape &shape)
|
||||
{
|
||||
size_t ret = std::accumulate(shape.begin(), shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
return ret;
|
||||
}
|
||||
|
||||
template<class Float>
|
||||
class TensorImpl {
|
||||
private:
|
||||
@ -81,8 +88,7 @@ class TensorImpl {
|
||||
|
||||
std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl;
|
||||
|
||||
int size = std::accumulate(shape_.begin(), shape_.end(),
|
||||
1, std::multiplies<int>());
|
||||
int size = GetTotalSize(shape_);
|
||||
data_.resize(size, value);
|
||||
cudnnCreateTensorDescriptor(&desc_);
|
||||
switch (shape_.size()) {
|
||||
@ -152,19 +158,32 @@ class TensorImpl {
|
||||
thrust::fill(data_.begin(), data_.end(), value);
|
||||
}
|
||||
|
||||
void set(const std::vector<Float> &values) {
|
||||
size_t totSize = std::accumulate(shape().begin(), shape().end(),
|
||||
1, std::multiplies<int>());
|
||||
std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl;
|
||||
assert(totSize == values.size());
|
||||
thrust::copy(values.begin(), values.end(), data_.begin());
|
||||
void set(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end) {
|
||||
size_t totSize = GetTotalSize(shape());
|
||||
//std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl;
|
||||
//assert(totSize == values.size());
|
||||
thrust::copy(begin, end, data_.begin());
|
||||
}
|
||||
|
||||
std::string Debug() const
|
||||
{
|
||||
std::stringstream strm;
|
||||
assert(shape_.size());
|
||||
strm << "shape=" << marian::Debug(shape_);
|
||||
strm << "shape=" << marian::Debug(shape_) << std::endl;
|
||||
|
||||
// values
|
||||
size_t totSize = GetTotalSize(shape());
|
||||
std::vector<Float> values(totSize);
|
||||
thrust::copy(data_.begin(), data_.end(), values.begin());
|
||||
|
||||
size_t ind = 0;
|
||||
for (size_t i = 0; i < shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < shape()[1]; ++j) {
|
||||
strm << values[ind] << " ";
|
||||
++ind;
|
||||
}
|
||||
strm << std::endl;
|
||||
}
|
||||
return strm.str();
|
||||
}
|
||||
};
|
||||
@ -256,7 +275,7 @@ class Tensor {
|
||||
}
|
||||
|
||||
void Load(const std::string &path);
|
||||
void Load(const std::vector<float> &values);
|
||||
void Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
|
||||
|
||||
};
|
||||
|
||||
|
@ -2,7 +2,57 @@
|
||||
|
||||
namespace marian {
|
||||
|
||||
// TODO: implement this.
|
||||
__global__ void gSubtractMean(float* out, float* weights,
|
||||
size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
extern __shared__ float _share[];
|
||||
float* _sum = _share + blockDim.x;
|
||||
float* sp = out + j * cols;
|
||||
float* w = weights + j * cols;
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
_sum[threadIdx.x] += w[id] * sp[id];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
int skip = (len + 1) >> 1;
|
||||
if(threadIdx.x < (len >> 1))
|
||||
_sum[threadIdx.x] += _sum[threadIdx.x + skip];
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
__syncthreads();
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x){
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols)
|
||||
sp[id] -= _sum[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SubtractMean(Tensor* Out, Tensor &Weights) {
|
||||
// Out and Weights are both m-by-k matrices, passed as input.
|
||||
// A weighted average of each row of Out (according to the weights
|
||||
// specified in Weights) is computed and subtracted from Out.
|
||||
// Out is both input and output.
|
||||
size_t m = Out->shape()[0];
|
||||
size_t k = Out->shape()[1];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int) m);
|
||||
int threads = std::min(MAX_THREADS, (int) k);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
gSubtractMean<<<blocks, threads, shared>>>(Out->data(), Weights.data(),
|
||||
m, k);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
@ -37,7 +87,6 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement this.
|
||||
void Softmax(Tensor* Out) {
|
||||
size_t m = Out->shape()[0];
|
||||
size_t k = Out->shape()[1];
|
||||
|
@ -142,6 +142,11 @@ void Element(Functor functor,
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
__global__ void gSubtractMean(float* out, float* weights,
|
||||
size_t rows, size_t cols);
|
||||
|
||||
void SubtractMean(Tensor* Out, Tensor &Weights);
|
||||
|
||||
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
|
||||
|
||||
void Softmax(Tensor* Out);
|
||||
|
77
src/test.cu
77
src/test.cu
@ -12,6 +12,7 @@ int main(int argc, char** argv) {
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
const size_t BATCH_SIZE = 500;
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
|
||||
@ -21,62 +22,62 @@ int main(int argc, char** argv) {
|
||||
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
|
||||
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
|
||||
|
||||
auto z = dot(x, w) + b;
|
||||
auto pred = softmax(z);
|
||||
//auto decision = argmax(pred, axis=1);
|
||||
Expr z = dot(x, w) + b;
|
||||
Expr lr = softmax(z, axis=1, name="pred");
|
||||
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
//cerr << "x=" << Debug(lr.val().shape()) << endl;
|
||||
|
||||
auto cost = -mean(sum(y * log(pred), axis=1),
|
||||
axis=0);
|
||||
|
||||
cerr << "pred=" << pred.Debug() << endl;
|
||||
|
||||
#if 0
|
||||
int numofdata;
|
||||
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
//vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
//vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
|
||||
cerr << "numofdata=" << numofdata << endl;
|
||||
|
||||
size_t startInd = 0;
|
||||
size_t startIndData = 0;
|
||||
while (startInd < numofdata) {
|
||||
size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd;
|
||||
cerr << "startInd=" << startInd
|
||||
<< " startIndData=" << startIndData
|
||||
<< " batchSize=" << batchSize << endl;
|
||||
|
||||
Tensor tx({numofdata, IMAGE_SIZE}, 1);
|
||||
Tensor ty({numofdata, LABEL_SIZE}, 1);
|
||||
|
||||
tx.Load(images);
|
||||
ty.Load(labels);
|
||||
tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE);
|
||||
ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize);
|
||||
|
||||
cerr << "tx=" << tx.Debug() << endl;
|
||||
cerr << "ty=" << ty.Debug() << endl;
|
||||
#else
|
||||
Tensor tx({500, 784}, 1);
|
||||
Tensor ty({500, 10}, 1);
|
||||
#endif
|
||||
//cerr << "tx=" << Debug(tx.shape()) << endl;
|
||||
//cerr << "ty=" << Debug(ty.shape()) << endl;
|
||||
|
||||
x = tx;
|
||||
y = ty;
|
||||
|
||||
cost.forward(500);
|
||||
cerr << "x=" << Debug(x.val().shape()) << endl;
|
||||
cerr << "y=" << Debug(y.val().shape()) << endl;
|
||||
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : pred.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : pred.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
pred.val().Print();
|
||||
std::cerr << "Log-likelihood: ";
|
||||
for (auto val : cost.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
cost.val().Print();
|
||||
|
||||
cost.backward();
|
||||
graph.forward(batchSize);
|
||||
|
||||
cerr << "w=" << Debug(w.val().shape()) << endl;
|
||||
cerr << "b=" << Debug(b.val().shape()) << endl;
|
||||
std::cerr << "z: " << Debug(z.val().shape()) << endl;
|
||||
std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
|
||||
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
|
||||
|
||||
//std::cerr << "scores=" << scores.val().Debug() << endl;
|
||||
//std::cerr << "lr=" << lr.val().Debug() << endl;
|
||||
|
||||
graph.backward();
|
||||
|
||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||
|
||||
startInd += batchSize;
|
||||
startIndData += batchSize * IMAGE_SIZE;
|
||||
}
|
||||
|
||||
|
||||
// XOR
|
||||
/*
|
||||
|
Loading…
Reference in New Issue
Block a user