mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
UnaryNodeOp inherits shape from child by default
This commit is contained in:
parent
4bafa6c360
commit
3ccdc15263
@ -71,7 +71,9 @@ struct UnaryNodeOp : public Node {
|
|||||||
|
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
UnaryNodeOp(ChainPtr a, Args ...args)
|
UnaryNodeOp(ChainPtr a, Args ...args)
|
||||||
: Node(args...), a_(a) {}
|
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
|
||||||
|
args...),
|
||||||
|
a_(a) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SigmoidNodeOp : public UnaryNodeOp {
|
struct SigmoidNodeOp : public UnaryNodeOp {
|
||||||
@ -142,8 +144,7 @@ struct ArgmaxOp : public UnaryNodeOp {
|
|||||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||||
: UnaryNodeOp(a, keywords::shape=a->shape(),
|
: UnaryNodeOp(a, args...) { }
|
||||||
args...) { }
|
|
||||||
|
|
||||||
void forward() {
|
void forward() {
|
||||||
// B = softmax(A).
|
// B = softmax(A).
|
||||||
@ -166,7 +167,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
|||||||
struct LogNodeOp : public UnaryNodeOp {
|
struct LogNodeOp : public UnaryNodeOp {
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
LogNodeOp(ChainPtr a, Args ...args)
|
LogNodeOp(ChainPtr a, Args ...args)
|
||||||
: UnaryNodeOp(a, keywords::shape=a->shape(), args...) {}
|
: UnaryNodeOp(a, args...) {}
|
||||||
|
|
||||||
void forward() {
|
void forward() {
|
||||||
Element(_1 = Log(_2), val_, a_->val());
|
Element(_1 = Log(_2), val_, a_->val());
|
||||||
@ -181,8 +182,7 @@ struct LogNodeOp : public UnaryNodeOp {
|
|||||||
struct ExpNodeOp : public UnaryNodeOp {
|
struct ExpNodeOp : public UnaryNodeOp {
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
ExpNodeOp(ChainPtr a, Args ...args)
|
ExpNodeOp(ChainPtr a, Args ...args)
|
||||||
: UnaryNodeOp(a, keywords::shape=a->shape(),
|
: UnaryNodeOp(a, args...) { }
|
||||||
args...) { }
|
|
||||||
|
|
||||||
void forward() {
|
void forward() {
|
||||||
Element(_1 = Exp(_2), val_, a_->val());
|
Element(_1 = Exp(_2), val_, a_->val());
|
||||||
|
@ -22,7 +22,7 @@ void ones(Tensor t) {
|
|||||||
void randreal(Tensor t) {
|
void randreal(Tensor t) {
|
||||||
std::random_device device;
|
std::random_device device;
|
||||||
std::default_random_engine engine(device());
|
std::default_random_engine engine(device());
|
||||||
std::uniform_real_distribution<> dist(0, 1);
|
std::uniform_real_distribution<> dist(0, 0.1);
|
||||||
auto gen = std::bind(dist, engine);
|
auto gen = std::bind(dist, engine);
|
||||||
|
|
||||||
std::vector<float> vals(t.size());
|
std::vector<float> vals(t.size());
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include "marian.h"
|
#include "marian.h"
|
||||||
#include "mnist.h"
|
#include "mnist.h"
|
||||||
#include "npz_converter.h"
|
#include "npz_converter.h"
|
||||||
|
#include "param_initializers.h"
|
||||||
|
|
||||||
using namespace marian;
|
using namespace marian;
|
||||||
using namespace keywords;
|
using namespace keywords;
|
||||||
@ -39,12 +40,8 @@ int main(int argc, char** argv) {
|
|||||||
auto b = param(shape={1, LABEL_SIZE},
|
auto b = param(shape={1, LABEL_SIZE},
|
||||||
init=[bData](Tensor t) { t.set(bData); });
|
init=[bData](Tensor t) { t.set(bData); });
|
||||||
|
|
||||||
auto zd = dot(x, w);
|
auto predict = softmax(dot(x, w) + b, axis=1);
|
||||||
auto z = zd + b;
|
auto graph = -mean(sum(y * log(predict), axis=1), axis=0);
|
||||||
auto predict = softmax(z, axis=1);
|
|
||||||
auto logp = log(predict);
|
|
||||||
auto cost = sum(y * logp, axis=1);
|
|
||||||
auto graph = -mean(cost, axis=0);
|
|
||||||
|
|
||||||
std::cerr << "Done." << std::endl;
|
std::cerr << "Done." << std::endl;
|
||||||
|
|
||||||
@ -56,34 +53,33 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
graph.forward(BATCH_SIZE);
|
graph.forward(BATCH_SIZE);
|
||||||
|
|
||||||
for (size_t j = 0; j < 10; ++j) {
|
float eta = 0.1;
|
||||||
|
for (size_t j = 0; j < 100; ++j) {
|
||||||
for(size_t i = 0; i < 60; ++i) {
|
for(size_t i = 0; i < 60; ++i) {
|
||||||
graph.backward();
|
graph.backward();
|
||||||
|
|
||||||
auto update_rule = _1 -= 0.1 * _2;
|
auto update_rule = _1 -= eta * _2;
|
||||||
Element(update_rule, w.val(), w.grad());
|
Element(update_rule, w.val(), w.grad());
|
||||||
Element(update_rule, b.val(), b.grad());
|
Element(update_rule, b.val(), b.grad());
|
||||||
|
|
||||||
graph.forward(BATCH_SIZE);
|
graph.forward(BATCH_SIZE);
|
||||||
}
|
}
|
||||||
std::cerr << "Epoch: " << j << std::endl;
|
std::cerr << "Epoch: " << j << std::endl;
|
||||||
}
|
auto results = predict.val();
|
||||||
|
std::vector<float> resultsv(results.size());
|
||||||
auto results = predict.val();
|
resultsv << results;
|
||||||
std::vector<float> resultsv(results.size());
|
|
||||||
resultsv << results;
|
size_t acc = 0;
|
||||||
|
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||||
size_t acc = 0;
|
size_t correct = 0;
|
||||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
size_t predicted = 0;
|
||||||
size_t correct = 0;
|
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||||
size_t predicted = 0;
|
if (testLabels[i+j]) correct = j;
|
||||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
if (resultsv[i + j] > resultsv[i + predicted]) predicted = j;
|
||||||
if (testLabels[i+j]) correct = j;
|
}
|
||||||
if (resultsv[i + j] > resultsv[i + predicted]) predicted = j;
|
acc += (correct == predicted);
|
||||||
}
|
}
|
||||||
acc += (correct == predicted);
|
std::cerr << "Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||||
}
|
}
|
||||||
std::cerr << "Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user