small changes to test.cu

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-14 14:43:27 +02:00
parent 803a562d4b
commit 3d177ccc28
3 changed files with 10 additions and 49 deletions

View File

@ -19,11 +19,11 @@ struct Chainable {
virtual const Shape& shape() = 0;
virtual DataType val() = 0;
virtual DataType grad() = 0;
virtual void setVal(Tensor t) {
virtual void setVal(DataType t) {
UTIL_THROW2("Tensors can only be assigned to input nodes");
};
typedef std::vector<Chainable<Tensor>*> ChainableStack;
typedef std::vector<Chainable<DataType>*> ChainableStack;
static ChainableStack stack;
};

View File

@ -8,20 +8,21 @@ int main(int argc, char** argv) {
using namespace marian;
using namespace keywords;
Expr x = input(shape={whatevs, 784}, name="X");
Expr y = input(shape={whatevs, 10}, name="Y");
Expr x = input(name="X");
Expr y = input(name="Y");
Expr w = param(shape={784, 10}, name="W0");
Expr b = param(shape={1, 10}, name="b0");
auto scores = dot(x, w) + b;
auto lr = softmax(scores, axis=1, name="pred");
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
cerr << "lr=" << lr.Debug() << endl;
Expr pred = softmax(dot(x, w) + b, axis=1);
cerr << "lr=" << pred.Debug() << endl;
Expr graph = -mean(sum(y * log(pred), axis=1),
axis=0, name="cost");
Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1);
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
@ -29,47 +30,7 @@ int main(int argc, char** argv) {
y = ty;
graph.forward(500);
std::cerr << "Result: ";
for (auto val : scores.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
std::cerr << "Result: ";
for (auto val : lr.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
std::cerr << "Log-likelihood: ";
for (auto val : graph.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;
#if 0
hook0(graph);
graph.autodiff();
std::cerr << graph["cost"].val()[0] << std::endl;
//hook1(graph);
for(auto p : graph.params()) {
auto update = _1 = _1 - alpha * _2;
Element(update, p.val(), p.grad());
}
hook2(graph);
auto opt = adadelta(cost_function=cost,
eta=0.9, gamma=0.1,
set_batch=set,
before_update=before,
after_update=after,
set_valid=valid,
validation_freq=100,
verbose=1, epochs=3, early_stopping=10);
opt.run();
#endif
return 0;
}