changed to fast softmax

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-16 00:55:54 +02:00
commit 02629ca13b
14 changed files with 421 additions and 359 deletions

View File

@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(marian CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math; -Xcompiler '-fPIC')
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math; -Xcompiler '-fPIC')
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)

View File

@ -4,10 +4,10 @@ include_directories(.)
cuda_add_library(marian_lib
cnpy/cnpy.cpp
exception.cpp
expressions.cu
sgd.cu
tensor.cu
expression_graph.cu
tensor.cu
tensor_operators.cu
expression_operators.cu
)
target_link_libraries(marian_lib)

34
src/chainable.h Normal file
View File

@ -0,0 +1,34 @@
#pragma once
#include <vector>
#include <memory>
#include "exception.h"
namespace marian {
template <class DataType>
struct Chainable {
Chainable() { }
virtual ~Chainable() { }
virtual void forward() { }
virtual void backward() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }
virtual void allocate(size_t) = 0;
virtual const Shape& shape() = 0;
virtual DataType &val() = 0;
virtual DataType grad() = 0;
virtual void setVal(DataType t) {
UTIL_THROW2("Tensors can only be assigned to input nodes");
};
};
typedef std::vector<Chainable<Tensor>*> ChainableStack;
typedef std::shared_ptr<ChainableStack> ChainableStackPtr;
typedef std::shared_ptr<Chainable<Tensor>> ChainPtr;
}

41
src/expression_graph.cu Normal file
View File

@ -0,0 +1,41 @@
#include <sstream>
#include "expression_graph.h"
using namespace std;
namespace marian {
Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
: graph_(g), pimpl_(chainable) {
graph_->stack()->push_back(chainable);
}
Tensor Expr::val() {
return pimpl_->val();
}
Tensor Expr::grad() {
return pimpl_->grad();
}
ChainPtr Expr::node() {
return pimpl_;
}
ExpressionGraphPtr Expr::graph() {
return graph_;
}
Expr::operator ChainPtr() {
return pimpl_;
}
std::string Expr::Debug() const
{
stringstream strm;
const Shape &shape = pimpl_->shape();
strm << marian::Debug(shape);
return strm.str();
}
}

120
src/expression_graph.h Normal file
View File

@ -0,0 +1,120 @@
#pragma once
#include <map>
#include "definitions.h"
#include "chainable.h"
#include "node_operators.h"
#include "tensor.h"
namespace marian {
class ExpressionGraph;
typedef ExpressionGraph* ExpressionGraphPtr;
class Expr {
public:
Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable);
Expr operator=(Tensor t) {
pimpl_->setVal(t);
return *this;
}
Tensor val();
Tensor grad();
ExpressionGraphPtr graph();
ChainPtr node();
operator ChainPtr();
std::string Debug() const;
private:
ExpressionGraphPtr graph_;
ChainPtr pimpl_;
};
class ExpressionGraph {
public:
ExpressionGraph()
: stack_(new ChainableStack)
{}
void forward(size_t batchSize) {
for(auto&& v : *stack_) {
v->allocate(batchSize);
}
for(auto&& v : *stack_)
v->forward();
}
void backward() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
typedef typename ChainableStack::reverse_iterator It;
stack_->back()->init_dependent();
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
(*it)->backward();
}
template <typename ...Args>
inline Expr input(Args ...args) {
return Expr(this, new InputNode(args...));
}
template <typename ...Args>
inline Expr param(Args ...args) {
Expr e(this, new ParamNode(args...));
params_.emplace_back(e);
return e;
}
template <typename ...Args>
inline Expr constant(Args ...args) {
return Expr(this, new ConstantNode(args...));
}
template <typename ...Args>
inline Expr ones(Args ...args) {
return Expr(this, new ConstantNode(keywords::value=1, args...));
}
template <typename ...Args>
inline Expr zeroes(Args ...args) {
return Expr(this, new ConstantNode(keywords::value=0, args...));
}
/*********************************************************/
ChainableStackPtr stack() {
return stack_;
}
Expr& operator[](const std::string& name) {
auto it = named_.find(name);
UTIL_THROW_IF2(it == named_.end(), "No such named node in graph: " << name);
return it->second;
}
bool has_node(const std::string& name) const {
return named_.count(name) > 0;
}
void add_named_node(Expr e, const std::string& name) {
named_.emplace(name, e);
}
std::vector<Expr>& params() {
return params_;
}
private:
ChainableStackPtr stack_;
std::map<std::string, Expr> named_;
std::vector<Expr> params_;
};
}

124
src/expression_operators.cu Normal file
View File

@ -0,0 +1,124 @@
#include "expression_operators.h"
#include "node_operators.h"
namespace marian {
Expr named(Expr a, const std::string& name) {
a.graph()->add_named_node(a, name);
return a;
}
Expr logit(Expr a) {
return Expr(a.graph(), new LogitNodeOp(a));
}
Expr tanh(Expr a) {
return Expr(a.graph(), new TanhNodeOp(a));
}
Expr log(Expr a) {
return Expr(a.graph(), new LogNodeOp(a));
};
Expr exp(Expr a) {
return Expr(a.graph(), new ExpNodeOp(a));
};
Expr operator-(Expr a) {
return Expr(a.graph(), new NegNodeOp(a));
};
Expr softmax_fast(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a));
}
/*********************************************************/
static Shape newShape(ChainPtr a, ChainPtr b) {
size_t dimsA = a->shape().size();
size_t dimsB = b->shape().size();
UTIL_THROW_IF2(dimsA != dimsB,
"Tensors have different numbers of dimensions");
Shape shape(dimsA);
for(size_t i = 0; i < dimsA; ++i) {
int dimA = a->shape()[i];
int dimB = b->shape()[i];
bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
<< "operation cannot be broadcasted: " << dimA << " != " << dimB);
shape[i] = std::max(dimA, dimB);
if(dimA == whatevs || dimB == whatevs)
shape[i] = whatevs;
}
return shape;
}
Expr broadcast(Shape bShape, Expr a) {
const Shape& aShape = a.node()->shape();
if(aShape == bShape) {
return a;
}
else {
size_t dimsA = aShape.size();
size_t dimsB = bShape.size();
UTIL_THROW_IF2(dimsA != dimsB,
"Tensor and shape have different number of dimensions");
for(size_t i = 0; i < dimsA; ++i) {
int dimA = aShape[i];
int dimB = bShape[i];
bool broadcastable = (dimA == dimB || dimA == 1);
UTIL_THROW_IF2(!broadcastable,
"Cannot broadcast tensor dimension "
<< dimA << " to " << dimB);
if(dimA == 1 && dimB != 1) {
if(i == 0) {
Expr one = a.graph()->ones(keywords::shape={bShape[0], 1});
a = dot(one, a);
}
else if(i == 1) {
Expr one = a.graph()->ones(keywords::shape={1, bShape[1]});
a = dot(a, one);
}
else {
UTIL_THROW2("Not implemented");
}
}
}
return a;
}
}
Expr operator+(Expr a, Expr b) {
Shape shape = newShape(a, b);
Expr cast_a = broadcast(shape, a);
Expr cast_b = broadcast(shape, b);
return Expr(a.graph(), new PlusNodeOp(cast_a, cast_b));
}
Expr operator-(Expr a, Expr b) {
Shape shape = newShape(a, b);
Expr cast_a = broadcast(shape, a);
Expr cast_b = broadcast(shape, b);
return Expr(a.graph(), new MinusNodeOp(cast_a, cast_b));
}
Expr operator*(Expr a, Expr b) {
Shape shape = newShape(a, b);
Expr cast_a = broadcast(shape, a);
Expr cast_b = broadcast(shape, b);
return Expr(a.graph(), new MultNodeOp(cast_a, cast_b));
}
Expr operator/(Expr a, Expr b) {
Shape shape = newShape(a, b);
Expr cast_a = broadcast(shape, a);
Expr cast_b = broadcast(shape, b);
return Expr(a.graph(), new DivNodeOp(cast_a, cast_b));
}
Expr dot(Expr a, Expr b) {
return Expr(a.graph(), new DotNodeOp(a, b));
}
}

View File

@ -1,115 +1,36 @@
#pragma once
#include "graph.h"
#include "graph_operators.h"
#include "expressions.h"
#include "expression_graph.h"
namespace marian {
template <typename ...Args>
inline Expr input(Args ...args) {
return Expr(new InputNode(args...));
}
Expr named(Expr a, const std::string& name);
template <typename ...Args>
inline Expr param(Args ...args) {
return Expr(new ParamNode(args...));
}
template <typename ...Args>
inline Expr constant(Args ...args) {
return Expr(new ConstantNode(args...));
}
Expr logit(Expr a);
template <typename ...Args>
inline Expr ones(Args ...args) {
return Expr(new ConstantNode(keywords::value=1, args...));
}
Expr tanh(Expr a);
template <typename ...Args>
inline Expr zeroes(Args ...args) {
return Expr(new ConstantNode(keywords::value=0, args...));
}
Expr log(Expr a);
Expr exp(Expr a);
Expr operator-(Expr a);
/*********************************************************/
inline Expr logit(Expr a) {
return Expr(new LogitNodeOp(a));
}
Expr operator+(Expr a, Expr b);
inline Expr tanh(Expr a) {
return Expr(new TanhNodeOp(a));
}
Expr operator-(Expr a, Expr b);
inline Expr log(Expr a) {
return Expr(new LogNodeOp(a));
};
Expr operator*(Expr a, Expr b);
inline Expr exp(Expr a) {
return Expr(new ExpNodeOp(a));
};
Expr operator/(Expr a, Expr b);
inline Expr operator-(Expr a) {
return Expr(new NegNodeOp(a));
};
/*********************************************************/
inline Expr operator+(Expr a, Expr b) {
return Expr(new PlusNodeOp(a, b));
}
inline Expr operator-(Expr a, Expr b) {
return Expr(new MinusNodeOp(a, b));
}
inline Expr operator*(Expr a, Expr b) {
return Expr(new MultNodeOp(a, b));
}
inline Expr operator/(Expr a, Expr b) {
return Expr(new DivNodeOp(a, b));
}
inline Expr dot(Expr a, Expr b) {
return Expr(new DotNodeOp(a, b));
}
Expr dot(Expr a, Expr b);
/******************************************************/
Expr broadcast(Shape bShape, Expr a) {
const Shape& aShape = a.node()->shape();
if(aShape == bShape) {
return a;
}
else {
size_t dimsA = aShape.size();
size_t dimsB = bShape.size();
UTIL_THROW_IF2(dimsA != dimsB,
"Tensor and shape have different number of dimensions");
for(size_t i = 0; i < dimsA; ++i) {
int dimA = aShape[i];
int dimB = bShape[i];
bool broadcastable = (dimA == dimB || dimA == 1);
UTIL_THROW_IF2(!broadcastable,
"Cannot broadcast tensor dimension "
<< dimA << " to " << dimB);
if(dimA == 1 && dimB != 1) {
if(i == 0) {
Expr one = ones(keywords::shape={bShape[0], 1});
a = dot(one, a);
}
else if(i == 1) {
Expr one = ones(keywords::shape={1, bShape[1]});
a = dot(a, one);
}
else {
UTIL_THROW2("Not implemented");
}
}
}
return a;
}
}
Expr broadcast(Shape bShape, Expr a);
/*********************************************************/
@ -126,7 +47,7 @@ inline Expr sum(Expr a, Args ...args) {
int rows = n->val().shape()[0];
return {1, rows};
};
Expr one = ones(shape={1, n->shape()[0]},
Expr one = a.graph()->ones(shape={1, n->shape()[0]},
lazy_shape=lshape);
return dot(one, a);
}
@ -136,8 +57,8 @@ inline Expr sum(Expr a, Args ...args) {
//std::cerr << "Shape will be " << cols << " by 1." << std::endl;
return {cols, 1};
};
Expr one = ones(shape={n->shape()[1], 1},
lazy_shape=lshape);
Expr one = a.graph()->ones(shape={n->shape()[1], 1},
lazy_shape=lshape);
return dot(a, one);
}
else if(ax == 2) {
@ -151,17 +72,12 @@ inline Expr sum(Expr a, Args ...args) {
// inefficient
template <typename ...Args>
inline Expr softmax(Expr a, Args ...args) {
Expr softmax(Expr a, Args ...args) {
Expr e = exp(a);
return e / sum(e, args...);
}
template <typename ...Args>
inline Expr softmax_fast(Expr a, Args ...args) {
Expr e = Expr(new SoftmaxNodeOp(a, args...));
return e;
}
Expr softmax_fast(Expr a);
// inefficient
template <typename ...Args>
@ -173,12 +89,12 @@ inline Expr mean(Expr a, Args ...args) {
ChainPtr n = a.node();
switch (ax) {
case 0:
return sum(a, axis=0) / constant(shape={1, 1},
return sum(a, axis=0) / a.graph()->constant(shape={1, 1},
lazy_value=[n]() -> Float {
return n->val().shape()[0];
});
case 1:
return sum(a, axis=1) / constant(shape={1, 1},
return sum(a, axis=1) / a.graph()->constant(shape={1, 1},
lazy_value=[n]() -> Float {
return n->val().shape()[1];
});
@ -187,7 +103,7 @@ inline Expr mean(Expr a, Args ...args) {
case 3:
UTIL_THROW2("Not implemented");
default:
return sum(a) / constant(shape={1, 1},
return sum(a) / a.graph()->constant(shape={1, 1},
lazy_value=[n]() -> Float {
return n->val().size();
});

View File

@ -1,59 +0,0 @@
#include <sstream>
#include "expressions.h"
#include "graph_operators.h"
using namespace std;
namespace marian {
Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
keywords::shape={1,1})) {}
Tensor Expr::val() {
return pimpl_->val();
}
Tensor Expr::grad() {
return pimpl_->grad();
}
ChainPtr Expr::node() {
return pimpl_;
}
void Expr::forward(size_t batchSize) {
UTIL_THROW_IF2(pimpl_.get() != Chainable<Tensor>::stack.back(),
"Trying to call forward on non-root of computation graph");
for(auto&& v : Chainable<Tensor>::stack) {
v->allocate(batchSize);
}
for(auto&& v : Chainable<Tensor>::stack)
v->forward();
}
void Expr::backward() {
UTIL_THROW_IF2(pimpl_.get() != Chainable<Tensor>::stack.back(),
"Trying to call backward on non-root of computation graph");
for(auto&& v : Chainable<Tensor>::stack)
v->set_zero_adjoint();
typedef typename Chainable<Tensor>::ChainableStack::reverse_iterator It;
pimpl_->init_dependent();
for(It it = Chainable<Tensor>::stack.rbegin(); it != Chainable<Tensor>::stack.rend(); ++it)
(*it)->backward();
}
Expr::operator ChainPtr() {
return pimpl_;
}
std::string Expr::Debug() const
{
stringstream strm;
const Shape &shape = pimpl_->shape();
strm << marian::Debug(shape);
return strm.str();
}
}

View File

@ -1,33 +0,0 @@
#pragma once
#include "definitions.h"
#include "graph.h"
namespace marian {
class Expr {
public:
Expr(Chainable<Tensor>* chainable);
Expr(Float v);
Expr operator=(Tensor t) {
pimpl_->setVal(t);
return *this;
}
Tensor val();
Tensor grad();
void forward(size_t batchSize);
void backward();
ChainPtr node();
operator ChainPtr();
std::string Debug() const;
private:
ChainPtr pimpl_;
};
}

View File

@ -1,9 +1,7 @@
#pragma once
#include "definitions.h"
#include "graph.h"
#include "graph_operators.h"
#include "expressions.h"
#include "expression_operators.h"
#include "expression_graph.h"
#include "param_initializers.h"
#include "expression_operators.h"

View File

@ -2,36 +2,10 @@
#include "keywords.h"
#include "tensor.h"
#include "chainable.h"
namespace marian {
template <class DataType>
struct Chainable {
Chainable() { }
virtual ~Chainable() { }
virtual void forward() { }
virtual void backward() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }
virtual void allocate(size_t) = 0;
virtual const Shape& shape() = 0;
virtual DataType &val() = 0;
virtual DataType grad() = 0;
virtual void setVal(DataType t) {
UTIL_THROW2("Tensors can only be assigned to input nodes");
};
typedef std::vector<Chainable<DataType>*> ChainableStack;
static ChainableStack stack;
};
template <class DataType>
typename Chainable<DataType>::ChainableStack Chainable<DataType>::stack;
typedef std::shared_ptr<Chainable<Tensor>> ChainPtr;
class Node : public Chainable<Tensor>,
public keywords::Keywords {
public:
@ -40,9 +14,7 @@ class Node : public Chainable<Tensor>,
: Keywords(args...),
shape_(Get<Shape>(keywords::shape, {1, 1})),
name_(Get<std::string>(keywords::name, "none"))
{
stack.push_back(this);
}
{ }
virtual ~Node() {};

View File

@ -1,7 +1,6 @@
#pragma once
#include "expressions.h"
#include "graph.h"
#include "node.h"
#include "tensor_operators.h"
namespace marian {
@ -108,49 +107,14 @@ struct TanhNodeOp : public UnaryNodeOp {
}
};
struct ArgmaxOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...),
axis_(-1) { }
Shape newShape(ChainPtr a, int axis) {
Shape shape1 = a->shape();
UTIL_THROW_IF2(shape1.size() > 2,
"Tensors with more than 2 dimensions not supported yet");
if(axis == 0) {
shape1[0] = 1;
}
else if(axis == 1) {
shape1[1] = 1;
}
else {
shape1 = {1, 1};
}
return shape1;
}
void forward() {
//val_ = Argmax(a_->val(), axis_);
UTIL_THROW2("Not implemented");
}
void backward() {
UTIL_THROW2("Not implemented");
}
private:
int axis_;
};
// @TODO, make this numerically safe(r):
// softmax(X) = softmax_safe(X - max(X, axis=1))
// Probably best to do this directly in Softmax
// function.
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, args...) { }
SoftmaxNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
// B = softmax(A).
@ -171,8 +135,8 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
struct LogNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, args...) {}
LogNodeOp(Args ...args)
: UnaryNodeOp(args...) {}
void forward() {
Element(_1 = Log(_2), val_, a_->val());
@ -186,8 +150,8 @@ struct LogNodeOp : public UnaryNodeOp {
struct ExpNodeOp : public UnaryNodeOp {
template <typename ...Args>
ExpNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, args...) { }
ExpNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = Exp(_2), val_, a_->val());
@ -230,7 +194,7 @@ struct DotNodeOp : public BinaryNodeOp {
template <typename ...Args>
DotNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b,
keywords::shape=newShape(a,b),
keywords::shape=newShape(a, b),
args...) { }
Shape newShape(ChainPtr a, ChainPtr b) {
@ -258,41 +222,11 @@ struct DotNodeOp : public BinaryNodeOp {
}
};
Expr broadcast(Shape shape, Expr a);
struct BroadcastingNodeOp : public BinaryNodeOp {
struct PlusNodeOp : public BinaryNodeOp {
template <typename ...Args>
BroadcastingNodeOp(Expr a, Expr b, Args ...args)
: BinaryNodeOp(broadcast(newShape(a ,b), a),
broadcast(newShape(a ,b), b),
keywords::shape=newShape(a, b),
args...) {}
static Shape newShape(ChainPtr a, ChainPtr b) {
size_t dimsA = a->shape().size();
size_t dimsB = b->shape().size();
UTIL_THROW_IF2(dimsA != dimsB,
"Tensors have different numbers of dimensions");
Shape shape(dimsA);
for(size_t i = 0; i < dimsA; ++i) {
int dimA = a->shape()[i];
int dimB = b->shape()[i];
bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
<< "operation cannot be broadcasted: " << dimA << " != " << dimB);
shape[i] = std::max(dimA, dimB);
if(dimA == whatevs || dimB == whatevs)
shape[i] = whatevs;
}
return shape;
}
};
struct PlusNodeOp : public BroadcastingNodeOp {
template <typename ...Args>
PlusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
PlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 + _3,
val_, a_->val(), b_->val());
@ -306,10 +240,11 @@ struct PlusNodeOp : public BroadcastingNodeOp {
}
};
struct MinusNodeOp : public BroadcastingNodeOp {
struct MinusNodeOp : public BinaryNodeOp {
template <typename ...Args>
MinusNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
MinusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 - _3,
val_, a_->val(), b_->val());
@ -323,10 +258,11 @@ struct MinusNodeOp : public BroadcastingNodeOp {
}
};
struct MultNodeOp : public BroadcastingNodeOp {
struct MultNodeOp : public BinaryNodeOp {
template <typename ...Args>
MultNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
MultNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 * _3,
val_, a_->val(), b_->val());
@ -340,9 +276,10 @@ struct MultNodeOp : public BroadcastingNodeOp {
}
};
struct DivNodeOp : public BroadcastingNodeOp {
struct DivNodeOp : public BinaryNodeOp {
template <typename ...Args>
DivNodeOp(Args ...args) : BroadcastingNodeOp(args...) { }
DivNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 / _3,

View File

@ -18,7 +18,7 @@ void ones(Tensor t) {
}
template <class Distribution>
void distribution(Tensor t, float a=0.0, float b=0.1) {
void distribution(Tensor t, float a, float b) {
std::random_device device;
std::default_random_engine engine(device());
Distribution dist(a, b);
@ -43,7 +43,7 @@ std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.1) {
}
std::function<void(Tensor)> from_vector(const std::vector<float>& v) {
return [&v](Tensor t) {
return [v](Tensor t) {
t << v;
};
}

View File

@ -2,24 +2,15 @@
#include "marian.h"
#include "mnist.h"
#include "npz_converter.h"
#include "param_initializers.h"
using namespace marian;
using namespace keywords;
int main(int argc, char** argv) {
cudaSetDevice(1);
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int BATCH_SIZE = 10000;
std::cerr << "Loading test set...";
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
std::cerr << "Done." << std::endl;
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int BATCH_SIZE = 10000;
ExpressionGraph build_graph() {
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model_single/model.npz");
@ -31,29 +22,50 @@ int main(int argc, char** argv) {
std::cerr << "Building model...";
auto x = input(shape={whatevs, IMAGE_SIZE});
auto y = input(shape={whatevs, LABEL_SIZE});
ExpressionGraph g;
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
init=from_vector(wData));
auto b = param(shape={1, LABEL_SIZE},
init=from_vector(bData));
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
init=from_vector(wData)), "w");
auto b = named(g.param(shape={1, LABEL_SIZE},
init=from_vector(bData)), "b");
auto probs = softmax_fast(dot(x, w) + b, axis=1);
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
auto probs = named(
softmax_fast(dot(x, w) + b), //, axis=1),
"probs"
);
auto cost = named(
-mean(sum(y * log(probs), axis=1), axis=0),
"cost"
);
std::cerr << "Done." << std::endl;
return g;
}
int main(int argc, char** argv) {
cudaSetDevice(1);
std::cerr << "Loading test set...";
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph();
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});
x = xt << testImages;
y = yt << testLabels;
g["x"] = (xt << testImages);
g["y"] = (yt << testLabels);
cost.forward(BATCH_SIZE);
g.forward(BATCH_SIZE);
std::vector<float> results;
results << probs.val();
results << g["probs"].val();
size_t acc = 0;
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
@ -65,22 +77,22 @@ int main(int argc, char** argv) {
}
acc += (correct == proposed);
}
std::cerr << "Cost: " << cost.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
float eta = 0.1;
for (size_t j = 0; j < 10; ++j) {
for(size_t i = 0; i < 60; ++i) {
cost.backward();
g.backward();
auto update_rule = _1 -= eta * _2;
Element(update_rule, w.val(), w.grad());
Element(update_rule, b.val(), b.grad());
for(auto param : g.params())
Element(update_rule, param.val(), param.grad());
cost.forward(BATCH_SIZE);
g.forward(BATCH_SIZE);
}
std::cerr << "Epoch: " << j << std::endl;
std::vector<float> results;
results << probs.val();
results << g["probs"].val();
size_t acc = 0;
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
@ -92,7 +104,7 @@ int main(int argc, char** argv) {
}
acc += (correct == proposed);
}
std::cerr << "Cost: " << cost.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
}
return 0;
}