mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
changes names
This commit is contained in:
parent
b40512cb7f
commit
37a73e20be
@ -101,6 +101,39 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
}
|
||||
};
|
||||
|
||||
struct ArgmaxOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ArgmaxOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a, -1), args...),
|
||||
axis_(-1) { }
|
||||
|
||||
Shape newShape(ChainPtr a, int axis) {
|
||||
Shape shape1 = a->shape();
|
||||
UTIL_THROW_IF2(shape1.size() > 2,
|
||||
"Tensors with more than 2 dimensions not supported yet");
|
||||
if(axis == 0) {
|
||||
shape1[0] = 1;
|
||||
}
|
||||
else if(axis == 1) {
|
||||
shape1[1] = 1;
|
||||
}
|
||||
else {
|
||||
shape1 = {1, 1};
|
||||
}
|
||||
return shape1;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
//val_ = Argmax(a_->val(), axis_);
|
||||
}
|
||||
|
||||
void backward() {}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
|
26
src/test.cu
26
src/test.cu
@ -21,10 +21,14 @@ int main(int argc, char** argv) {
|
||||
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
|
||||
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
|
||||
|
||||
auto scores = dot(x, w) + b;
|
||||
auto lr = softmax_fast(scores, axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
cerr << "lr=" << lr.Debug() << endl;
|
||||
auto z = dot(x, w) + b;
|
||||
auto pred = softmax(z);
|
||||
//auto decision = argmax(pred, axis=1);
|
||||
|
||||
auto cost = -mean(sum(y * log(pred), axis=1),
|
||||
axis=0);
|
||||
|
||||
cerr << "pred=" << pred.Debug() << endl;
|
||||
|
||||
#if 0
|
||||
int numofdata;
|
||||
@ -49,27 +53,27 @@ int main(int argc, char** argv) {
|
||||
x = tx;
|
||||
y = ty;
|
||||
|
||||
graph.forward(500);
|
||||
cost.forward(500);
|
||||
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : scores.val().shape()) {
|
||||
for (auto val : pred.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : lr.val().shape()) {
|
||||
for (auto val : pred.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
lr.val().Print();
|
||||
pred.val().Print();
|
||||
std::cerr << "Log-likelihood: ";
|
||||
for (auto val : graph.val().shape()) {
|
||||
for (auto val : cost.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
graph.val().Print();
|
||||
cost.val().Print();
|
||||
|
||||
graph.backward();
|
||||
cost.backward();
|
||||
|
||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user