toy training with initialized data

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-15 00:55:24 +02:00
parent 8ea40b587b
commit 4bafa6c360
2 changed files with 12 additions and 13 deletions

View File

@ -142,13 +142,8 @@ struct ArgmaxOp : public UnaryNodeOp {
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a),
: UnaryNodeOp(a, keywords::shape=a->shape(),
args...) { }
Shape newShape(ChainPtr a) {
Shape shape = a->shape();
return shape;
}
void forward() {
// B = softmax(A).

View File

@ -8,7 +8,7 @@ using namespace keywords;
int main(int argc, char** argv) {
cudaSetDevice(1);
cudaSetDevice(0);
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
@ -55,14 +55,18 @@ int main(int argc, char** argv) {
y = yt << testLabels;
graph.forward(BATCH_SIZE);
for(size_t i = 0; i < 1000; ++i) {
graph.backward();
auto update_rule = _1 -= 0.1 * _2;
Element(update_rule, w.val(), w.grad());
Element(update_rule, b.val(), b.grad());
for (size_t j = 0; j < 10; ++j) {
for(size_t i = 0; i < 60; ++i) {
graph.backward();
graph.forward(BATCH_SIZE);
auto update_rule = _1 -= 0.1 * _2;
Element(update_rule, w.val(), w.grad());
Element(update_rule, b.val(), b.grad());
graph.forward(BATCH_SIZE);
}
std::cerr << "Epoch: " << j << std::endl;
}
auto results = predict.val();