Created setGrad() function.

This commit is contained in:
Andre Martins 2016-09-22 17:48:20 +01:00
parent 55fa57c762
commit 12295b9a72
6 changed files with 23 additions and 7 deletions

View File

@ -49,10 +49,13 @@ struct Chainable {
virtual const Shape& shape() = 0;
virtual const DataType &val() = 0;
virtual DataType grad() = 0;
virtual const DataType &grad() = 0;
virtual void setVal(DataType t) {
UTIL_THROW2("Tensors can only be assigned to input and parameter nodes");
};
virtual void setGrad(DataType t) {
UTIL_THROW2("Gradients can only be assigned to parameter nodes");
};
};
// XXX Marcin, is ChainableStack the most appropriate name?

View File

@ -29,7 +29,7 @@ Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
graph_->stack()->push_back(chainable);
}
Tensor Expr::val() {
const Tensor &Expr::val() {
return pimpl_->val();
}
@ -37,8 +37,12 @@ void Expr::setVal(const Tensor &val) {
pimpl_->setVal(val);
}
Tensor Expr::grad() {
return pimpl_->grad();
const Tensor &Expr::grad() {
return pimpl_->grad();
}
void Expr::setGrad(const Tensor &grad) {
pimpl_->setGrad(grad);
}
ChainPtr Expr::node() {

View File

@ -45,10 +45,11 @@ class Expr {
return *this;
}
Tensor val();
Tensor grad();
const Tensor &val();
const Tensor &grad();
void setVal(const Tensor &val);
void setGrad(const Tensor &grad);
ExpressionGraphPtr graph();

View File

@ -80,7 +80,7 @@ class Node : public Chainable<Tensor>,
return val_;
};
virtual Tensor grad() {
virtual const Tensor &grad() {
UTIL_THROW_IF2(!adj_, "Tensor has not been allocated");
return adj_;
};

View File

@ -92,6 +92,12 @@ struct ParamNode : public Node {
//@todo, shape checking
};
virtual void setGrad(Tensor t) {
adj_ = t;
shape_ = t.shape();
//@todo, shape checking
};
void forward() {}
void backward() {}

View File

@ -224,9 +224,11 @@ int main(int argc, char** argv) {
//ExpressionGraph g = graphs[b];
ExpressionGraph g = graphs[b0];
// Share the parameters.
std::cerr << j << std::endl;
if (false && b != b0) {
for (int i = 0; i < g.params().size(); ++i) {
g.params()[i].setVal(graphs[b0].params()[i].val());
g.params()[i].setGrad(graphs[b0].params()[i].grad());
}
}