mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
toy training with initialized data
This commit is contained in:
parent
8ea40b587b
commit
4bafa6c360
@ -142,13 +142,8 @@ struct ArgmaxOp : public UnaryNodeOp {
|
|||||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||||
template <typename ...Args>
|
template <typename ...Args>
|
||||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
: UnaryNodeOp(a, keywords::shape=a->shape(),
|
||||||
args...) { }
|
args...) { }
|
||||||
|
|
||||||
Shape newShape(ChainPtr a) {
|
|
||||||
Shape shape = a->shape();
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
void forward() {
|
void forward() {
|
||||||
// B = softmax(A).
|
// B = softmax(A).
|
||||||
|
@ -8,7 +8,7 @@ using namespace keywords;
|
|||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
|
||||||
cudaSetDevice(1);
|
cudaSetDevice(0);
|
||||||
|
|
||||||
const size_t IMAGE_SIZE = 784;
|
const size_t IMAGE_SIZE = 784;
|
||||||
const size_t LABEL_SIZE = 10;
|
const size_t LABEL_SIZE = 10;
|
||||||
@ -55,14 +55,18 @@ int main(int argc, char** argv) {
|
|||||||
y = yt << testLabels;
|
y = yt << testLabels;
|
||||||
|
|
||||||
graph.forward(BATCH_SIZE);
|
graph.forward(BATCH_SIZE);
|
||||||
for(size_t i = 0; i < 1000; ++i) {
|
|
||||||
graph.backward();
|
|
||||||
|
|
||||||
auto update_rule = _1 -= 0.1 * _2;
|
for (size_t j = 0; j < 10; ++j) {
|
||||||
Element(update_rule, w.val(), w.grad());
|
for(size_t i = 0; i < 60; ++i) {
|
||||||
Element(update_rule, b.val(), b.grad());
|
graph.backward();
|
||||||
|
|
||||||
graph.forward(BATCH_SIZE);
|
auto update_rule = _1 -= 0.1 * _2;
|
||||||
|
Element(update_rule, w.val(), w.grad());
|
||||||
|
Element(update_rule, b.val(), b.grad());
|
||||||
|
|
||||||
|
graph.forward(BATCH_SIZE);
|
||||||
|
}
|
||||||
|
std::cerr << "Epoch: " << j << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto results = predict.val();
|
auto results = predict.val();
|
||||||
|
Loading…
Reference in New Issue
Block a user