return reference

This commit is contained in:
Hieu Hoang 2016-09-14 15:16:12 +02:00
parent f05d17e7ae
commit 6974ceb9d1
3 changed files with 5 additions and 5 deletions

View File

@ -10,7 +10,7 @@ Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
keywords::shape={1,1})) {}
Tensor Expr::val() {
Tensor &Expr::val() {
return pimpl_->val();
}

View File

@ -15,7 +15,7 @@ class Expr {
return *this;
}
Tensor val();
Tensor &val();
Tensor grad();
void forward(size_t batchSize);

View File

@ -17,7 +17,7 @@ struct Chainable {
virtual void allocate(size_t) = 0;
virtual const Shape& shape() = 0;
virtual DataType val() = 0;
virtual DataType &val() = 0;
virtual DataType grad() = 0;
virtual void setVal(Tensor t) {
UTIL_THROW2("Tensors can only be assigned to input nodes");
@ -82,7 +82,7 @@ class Node : public Chainable<Tensor>,
}
}
virtual Tensor val() {
virtual Tensor &val() {
UTIL_THROW_IF2(!val_, "Tensor has not been allocated");
return val_;
};
@ -104,4 +104,4 @@ class Node : public Chainable<Tensor>,
Tensor adj_;
};
}
}