mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
fixed backprop error, TODO: check shapes
This commit is contained in:
parent
cc4e24b3a6
commit
ffb21bb699
@ -24,22 +24,17 @@ ChainPtr Expr::node() {
|
||||
|
||||
void Expr::forward(size_t batchSize) {
|
||||
UTIL_THROW_IF2(pimpl_.get() != Chainable<Tensor>::stack.back(),
|
||||
"Trying to call forward on non-root of computation graph");
|
||||
std::cerr << "forward:" << std::endl;
|
||||
|
||||
"Trying to call forward on non-root of computation graph");
|
||||
for(auto&& v : Chainable<Tensor>::stack) {
|
||||
v->allocate(batchSize);
|
||||
}
|
||||
|
||||
for(auto&& v : Chainable<Tensor>::stack)
|
||||
v->forward();
|
||||
}
|
||||
|
||||
void Expr::backward() {
|
||||
UTIL_THROW_IF2(pimpl_.get() != Chainable<Tensor>::stack.back(),
|
||||
"Trying to call backward on non-root of computation graph");
|
||||
std::cerr << "backward:" << std::endl;
|
||||
|
||||
"Trying to call backward on non-root of computation graph");
|
||||
for(auto&& v : Chainable<Tensor>::stack)
|
||||
v->set_zero_adjoint();
|
||||
|
||||
@ -56,7 +51,6 @@ Expr::operator ChainPtr() {
|
||||
std::string Expr::Debug() const
|
||||
{
|
||||
stringstream strm;
|
||||
//const Chainable<Tensor> &ct = *pimpl_;
|
||||
const Shape &shape = pimpl_->shape();
|
||||
strm << marian::Debug(shape);
|
||||
return strm.str();
|
||||
|
@ -42,7 +42,8 @@ struct ParamNode : public Node {
|
||||
template <typename ...Args>
|
||||
ParamNode(Args ...args)
|
||||
: Node(args...),
|
||||
init_(Get<std::function<void(Tensor)>>(keywords::init, [](Tensor){ }))
|
||||
init_(Get<std::function<void(Tensor)>>(keywords::init, [](Tensor){ })),
|
||||
initialized_(false)
|
||||
{
|
||||
UTIL_THROW_IF2(!Has(keywords::shape) &&
|
||||
!Has(keywords::lazy_shape),
|
||||
@ -51,14 +52,18 @@ struct ParamNode : public Node {
|
||||
|
||||
void forward() {}
|
||||
void backward() {}
|
||||
|
||||
|
||||
virtual void allocate(size_t batchSize) {
|
||||
val_.allocate(shape_);
|
||||
init_(val_);
|
||||
if(!initialized_) {
|
||||
init_(val_);
|
||||
initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<void(Tensor)> init_;
|
||||
bool initialized_;
|
||||
};
|
||||
|
||||
struct UnaryNodeOp : public Node {
|
||||
@ -139,6 +144,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
@ -164,8 +170,8 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
|
||||
struct LogNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) {}
|
||||
LogNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=a->shape(), args...) {}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Log(_2), val_, a_->val());
|
||||
@ -180,14 +186,9 @@ struct LogNodeOp : public UnaryNodeOp {
|
||||
struct ExpNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ExpNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
: UnaryNodeOp(a, keywords::shape=a->shape(),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Exp(_2), val_, a_->val());
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ using namespace keywords;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
cudaSetDevice(0);
|
||||
cudaSetDevice(1);
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
@ -30,18 +30,22 @@ int main(int argc, char** argv) {
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
std::cerr << "Building model...";
|
||||
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
|
||||
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
|
||||
|
||||
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0",
|
||||
auto x = input(shape={whatevs, IMAGE_SIZE});
|
||||
auto y = input(shape={whatevs, LABEL_SIZE});
|
||||
|
||||
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
init=[wData](Tensor t) { t.set(wData); });
|
||||
auto b = param(shape={1, LABEL_SIZE}, name="b0",
|
||||
init=[bData](Tensor t) {t.set(bData); });
|
||||
auto b = param(shape={1, LABEL_SIZE},
|
||||
init=[bData](Tensor t) { t.set(bData); });
|
||||
|
||||
auto predict = softmax(dot(x, w) + b,
|
||||
axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(predict), axis=1),
|
||||
axis=0, name="cost");
|
||||
auto zd = dot(x, w);
|
||||
auto z = zd + b;
|
||||
auto predict = softmax(z, axis=1);
|
||||
auto logp = log(predict);
|
||||
auto cost = sum(y * logp, axis=1);
|
||||
auto graph = -mean(cost, axis=0);
|
||||
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
@ -51,14 +55,20 @@ int main(int argc, char** argv) {
|
||||
y = yt << testLabels;
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
graph.backward();
|
||||
for(size_t i = 0; i < 1000; ++i) {
|
||||
graph.backward();
|
||||
|
||||
auto update_rule = _1 -= 0.1 * _2;
|
||||
Element(update_rule, w.val(), w.grad());
|
||||
Element(update_rule, b.val(), b.grad());
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
}
|
||||
|
||||
auto results = predict.val();
|
||||
std::vector<float> resultsv(results.size());
|
||||
resultsv << results;
|
||||
|
||||
std::cerr << b.grad().Debug() << std::endl;
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
@ -69,7 +79,7 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
acc += (correct == predicted);
|
||||
}
|
||||
std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl;
|
||||
std::cerr << "Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user