This commit is contained in:
Tomasz Dwojak 2016-09-14 16:00:56 +01:00
commit 74626d347f
3 changed files with 36 additions and 3 deletions

View File

@ -19,11 +19,11 @@ struct Chainable {
virtual const Shape& shape() = 0;
virtual DataType &val() = 0;
virtual DataType grad() = 0;
virtual void setVal(Tensor t) {
virtual void setVal(DataType t) {
UTIL_THROW2("Tensors can only be assigned to input nodes");
};
typedef std::vector<Chainable<Tensor>*> ChainableStack;
typedef std::vector<Chainable<DataType>*> ChainableStack;
static ChainableStack stack;
};

View File

@ -101,6 +101,39 @@ struct TanhNodeOp : public UnaryNodeOp {
}
};
struct ArgmaxOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...),
axis_(-1) { }
Shape newShape(ChainPtr a, int axis) {
Shape shape1 = a->shape();
UTIL_THROW_IF2(shape1.size() > 2,
"Tensors with more than 2 dimensions not supported yet");
if(axis == 0) {
shape1[0] = 1;
}
else if(axis == 1) {
shape1[1] = 1;
}
else {
shape1 = {1, 1};
}
return shape1;
}
void forward() {
//val_ = Argmax(a_->val(), axis_);
}
void backward() {}
private:
int axis_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(ChainPtr a, Args ...args)

View File

@ -68,7 +68,7 @@ int main(int argc, char** argv) {
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
//std::cerr << "scores=" << scores.val().Debug() << endl;
std::cerr << "lr=" << lr.val().Debug() << endl;
//std::cerr << "lr=" << lr.val().Debug() << endl;
graph.backward();