mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
toy training with initialized data
This commit is contained in:
parent
8ea40b587b
commit
4bafa6c360
@ -142,14 +142,9 @@ struct ArgmaxOp : public UnaryNodeOp {
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
: UnaryNodeOp(a, keywords::shape=a->shape(),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
val_ = a_->val();
|
||||
|
@ -8,7 +8,7 @@ using namespace keywords;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
cudaSetDevice(1);
|
||||
cudaSetDevice(0);
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
@ -55,14 +55,18 @@ int main(int argc, char** argv) {
|
||||
y = yt << testLabels;
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
for(size_t i = 0; i < 1000; ++i) {
|
||||
graph.backward();
|
||||
|
||||
auto update_rule = _1 -= 0.1 * _2;
|
||||
Element(update_rule, w.val(), w.grad());
|
||||
Element(update_rule, b.val(), b.grad());
|
||||
for (size_t j = 0; j < 10; ++j) {
|
||||
for(size_t i = 0; i < 60; ++i) {
|
||||
graph.backward();
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
auto update_rule = _1 -= 0.1 * _2;
|
||||
Element(update_rule, w.val(), w.grad());
|
||||
Element(update_rule, b.val(), b.grad());
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
}
|
||||
std::cerr << "Epoch: " << j << std::endl;
|
||||
}
|
||||
|
||||
auto results = predict.val();
|
||||
|
Loading…
Reference in New Issue
Block a user