mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge
This commit is contained in:
commit
74626d347f
@ -19,11 +19,11 @@ struct Chainable {
|
||||
virtual const Shape& shape() = 0;
|
||||
virtual DataType &val() = 0;
|
||||
virtual DataType grad() = 0;
|
||||
virtual void setVal(Tensor t) {
|
||||
virtual void setVal(DataType t) {
|
||||
UTIL_THROW2("Tensors can only be assigned to input nodes");
|
||||
};
|
||||
|
||||
typedef std::vector<Chainable<Tensor>*> ChainableStack;
|
||||
typedef std::vector<Chainable<DataType>*> ChainableStack;
|
||||
static ChainableStack stack;
|
||||
};
|
||||
|
||||
|
@ -101,6 +101,39 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
}
|
||||
};
|
||||
|
||||
struct ArgmaxOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ArgmaxOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...),
|
||||
axis_(-1) { }
|
||||
|
||||
Shape newShape(ChainPtr a, int axis) {
|
||||
Shape shape1 = a->shape();
|
||||
UTIL_THROW_IF2(shape1.size() > 2,
|
||||
"Tensors with more than 2 dimensions not supported yet");
|
||||
if(axis == 0) {
|
||||
shape1[0] = 1;
|
||||
}
|
||||
else if(axis == 1) {
|
||||
shape1[1] = 1;
|
||||
}
|
||||
else {
|
||||
shape1 = {1, 1};
|
||||
}
|
||||
return shape1;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
//val_ = Argmax(a_->val(), axis_);
|
||||
}
|
||||
|
||||
void backward() {}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
|
@ -68,7 +68,7 @@ int main(int argc, char** argv) {
|
||||
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
|
||||
|
||||
//std::cerr << "scores=" << scores.val().Debug() << endl;
|
||||
std::cerr << "lr=" << lr.val().Debug() << endl;
|
||||
//std::cerr << "lr=" << lr.val().Debug() << endl;
|
||||
|
||||
graph.backward();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user