mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
small changes to test.cu
This commit is contained in:
parent
803a562d4b
commit
3d177ccc28
@ -19,11 +19,11 @@ struct Chainable {
|
|||||||
virtual const Shape& shape() = 0;
|
virtual const Shape& shape() = 0;
|
||||||
virtual DataType val() = 0;
|
virtual DataType val() = 0;
|
||||||
virtual DataType grad() = 0;
|
virtual DataType grad() = 0;
|
||||||
virtual void setVal(Tensor t) {
|
virtual void setVal(DataType t) {
|
||||||
UTIL_THROW2("Tensors can only be assigned to input nodes");
|
UTIL_THROW2("Tensors can only be assigned to input nodes");
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef std::vector<Chainable<Tensor>*> ChainableStack;
|
typedef std::vector<Chainable<DataType>*> ChainableStack;
|
||||||
static ChainableStack stack;
|
static ChainableStack stack;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ class TensorImpl {
|
|||||||
value_type operator[](size_t i) const {
|
value_type operator[](size_t i) const {
|
||||||
return data_[i];
|
return data_[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto begin() -> decltype( data_.begin() ) {
|
auto begin() -> decltype( data_.begin() ) {
|
||||||
return data_.begin();
|
return data_.begin();
|
||||||
}
|
}
|
||||||
|
53
src/test.cu
53
src/test.cu
@ -8,20 +8,21 @@ int main(int argc, char** argv) {
|
|||||||
using namespace marian;
|
using namespace marian;
|
||||||
using namespace keywords;
|
using namespace keywords;
|
||||||
|
|
||||||
Expr x = input(shape={whatevs, 784}, name="X");
|
Expr x = input(name="X");
|
||||||
Expr y = input(shape={whatevs, 10}, name="Y");
|
Expr y = input(name="Y");
|
||||||
|
|
||||||
Expr w = param(shape={784, 10}, name="W0");
|
Expr w = param(shape={784, 10}, name="W0");
|
||||||
Expr b = param(shape={1, 10}, name="b0");
|
Expr b = param(shape={1, 10}, name="b0");
|
||||||
|
|
||||||
auto scores = dot(x, w) + b;
|
Expr pred = softmax(dot(x, w) + b, axis=1);
|
||||||
auto lr = softmax(scores, axis=1, name="pred");
|
cerr << "lr=" << pred.Debug() << endl;
|
||||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
|
||||||
cerr << "lr=" << lr.Debug() << endl;
|
|
||||||
|
|
||||||
|
Expr graph = -mean(sum(y * log(pred), axis=1),
|
||||||
|
axis=0, name="cost");
|
||||||
|
|
||||||
Tensor tx({500, 784}, 1);
|
Tensor tx({500, 784}, 1);
|
||||||
Tensor ty({500, 10}, 1);
|
Tensor ty({500, 10}, 1);
|
||||||
|
|
||||||
cerr << "tx=" << tx.Debug() << endl;
|
cerr << "tx=" << tx.Debug() << endl;
|
||||||
cerr << "ty=" << ty.Debug() << endl;
|
cerr << "ty=" << ty.Debug() << endl;
|
||||||
|
|
||||||
@ -29,47 +30,7 @@ int main(int argc, char** argv) {
|
|||||||
y = ty;
|
y = ty;
|
||||||
|
|
||||||
graph.forward(500);
|
graph.forward(500);
|
||||||
|
|
||||||
std::cerr << "Result: ";
|
|
||||||
for (auto val : scores.val().shape()) {
|
|
||||||
std::cerr << val << " ";
|
|
||||||
}
|
|
||||||
std::cerr << std::endl;
|
|
||||||
std::cerr << "Result: ";
|
|
||||||
for (auto val : lr.val().shape()) {
|
|
||||||
std::cerr << val << " ";
|
|
||||||
}
|
|
||||||
std::cerr << std::endl;
|
|
||||||
std::cerr << "Log-likelihood: ";
|
|
||||||
for (auto val : graph.val().shape()) {
|
|
||||||
std::cerr << val << " ";
|
|
||||||
}
|
|
||||||
std::cerr << std::endl;
|
|
||||||
|
|
||||||
graph.backward();
|
graph.backward();
|
||||||
|
|
||||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
hook0(graph);
|
|
||||||
graph.autodiff();
|
|
||||||
std::cerr << graph["cost"].val()[0] << std::endl;
|
|
||||||
//hook1(graph);
|
|
||||||
for(auto p : graph.params()) {
|
|
||||||
auto update = _1 = _1 - alpha * _2;
|
|
||||||
Element(update, p.val(), p.grad());
|
|
||||||
}
|
|
||||||
hook2(graph);
|
|
||||||
|
|
||||||
auto opt = adadelta(cost_function=cost,
|
|
||||||
eta=0.9, gamma=0.1,
|
|
||||||
set_batch=set,
|
|
||||||
before_update=before,
|
|
||||||
after_update=after,
|
|
||||||
set_valid=valid,
|
|
||||||
validation_freq=100,
|
|
||||||
verbose=1, epochs=3, early_stopping=10);
|
|
||||||
opt.run();
|
|
||||||
#endif
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user