mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Created setGrad() function.
This commit is contained in:
parent
55fa57c762
commit
12295b9a72
@ -49,10 +49,13 @@ struct Chainable {
|
||||
|
||||
virtual const Shape& shape() = 0;
|
||||
virtual const DataType &val() = 0;
|
||||
virtual DataType grad() = 0;
|
||||
virtual const DataType &grad() = 0;
|
||||
virtual void setVal(DataType t) {
|
||||
UTIL_THROW2("Tensors can only be assigned to input and parameter nodes");
|
||||
};
|
||||
virtual void setGrad(DataType t) {
|
||||
UTIL_THROW2("Gradients can only be assigned to parameter nodes");
|
||||
};
|
||||
};
|
||||
|
||||
// XXX Marcin, is ChainableStack the most appropriate name?
|
||||
|
@ -29,7 +29,7 @@ Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
|
||||
graph_->stack()->push_back(chainable);
|
||||
}
|
||||
|
||||
Tensor Expr::val() {
|
||||
const Tensor &Expr::val() {
|
||||
return pimpl_->val();
|
||||
}
|
||||
|
||||
@ -37,8 +37,12 @@ void Expr::setVal(const Tensor &val) {
|
||||
pimpl_->setVal(val);
|
||||
}
|
||||
|
||||
Tensor Expr::grad() {
|
||||
return pimpl_->grad();
|
||||
const Tensor &Expr::grad() {
|
||||
return pimpl_->grad();
|
||||
}
|
||||
|
||||
void Expr::setGrad(const Tensor &grad) {
|
||||
pimpl_->setGrad(grad);
|
||||
}
|
||||
|
||||
ChainPtr Expr::node() {
|
||||
|
@ -45,10 +45,11 @@ class Expr {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Tensor val();
|
||||
Tensor grad();
|
||||
const Tensor &val();
|
||||
const Tensor &grad();
|
||||
|
||||
void setVal(const Tensor &val);
|
||||
void setGrad(const Tensor &grad);
|
||||
|
||||
ExpressionGraphPtr graph();
|
||||
|
||||
|
@ -80,7 +80,7 @@ class Node : public Chainable<Tensor>,
|
||||
return val_;
|
||||
};
|
||||
|
||||
virtual Tensor grad() {
|
||||
virtual const Tensor &grad() {
|
||||
UTIL_THROW_IF2(!adj_, "Tensor has not been allocated");
|
||||
return adj_;
|
||||
};
|
||||
|
@ -92,6 +92,12 @@ struct ParamNode : public Node {
|
||||
//@todo, shape checking
|
||||
};
|
||||
|
||||
virtual void setGrad(Tensor t) {
|
||||
adj_ = t;
|
||||
shape_ = t.shape();
|
||||
//@todo, shape checking
|
||||
};
|
||||
|
||||
void forward() {}
|
||||
void backward() {}
|
||||
|
||||
|
@ -224,9 +224,11 @@ int main(int argc, char** argv) {
|
||||
//ExpressionGraph g = graphs[b];
|
||||
ExpressionGraph g = graphs[b0];
|
||||
// Share the parameters.
|
||||
std::cerr << j << std::endl;
|
||||
if (false && b != b0) {
|
||||
for (int i = 0; i < g.params().size(); ++i) {
|
||||
g.params()[i].setVal(graphs[b0].params()[i].val());
|
||||
g.params()[i].setGrad(graphs[b0].params()[i].grad());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user