changes names

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-14 16:48:01 +02:00
parent b40512cb7f
commit 37a73e20be
2 changed files with 48 additions and 11 deletions

View File

@ -101,6 +101,39 @@ struct TanhNodeOp : public UnaryNodeOp {
}
};
struct ArgmaxOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...),
axis_(-1) { }
Shape newShape(ChainPtr a, int axis) {
Shape shape1 = a->shape();
UTIL_THROW_IF2(shape1.size() > 2,
"Tensors with more than 2 dimensions not supported yet");
if(axis == 0) {
shape1[0] = 1;
}
else if(axis == 1) {
shape1[1] = 1;
}
else {
shape1 = {1, 1};
}
return shape1;
}
void forward() {
//val_ = Argmax(a_->val(), axis_);
}
void backward() {}
private:
int axis_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(ChainPtr a, Args ...args)

View File

@ -21,10 +21,14 @@ int main(int argc, char** argv) {
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
auto scores = dot(x, w) + b;
auto lr = softmax_fast(scores, axis=1, name="pred");
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
cerr << "lr=" << lr.Debug() << endl;
auto z = dot(x, w) + b;
auto pred = softmax(z);
//auto decision = argmax(pred, axis=1);
auto cost = -mean(sum(y * log(pred), axis=1),
axis=0);
cerr << "pred=" << pred.Debug() << endl;
#if 0
int numofdata;
@ -49,27 +53,27 @@ int main(int argc, char** argv) {
x = tx;
y = ty;
graph.forward(500);
cost.forward(500);
std::cerr << "Result: ";
for (auto val : scores.val().shape()) {
for (auto val : pred.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
std::cerr << "Result: ";
for (auto val : lr.val().shape()) {
for (auto val : pred.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
lr.val().Print();
pred.val().Print();
std::cerr << "Log-likelihood: ";
for (auto val : graph.val().shape()) {
for (auto val : cost.val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
graph.val().Print();
cost.val().Print();
graph.backward();
cost.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;