return reference

This commit is contained in:
Hieu Hoang 2016-09-14 15:16:12 +02:00
parent f05d17e7ae
commit 6974ceb9d1
3 changed files with 5 additions and 5 deletions

View File

@ -10,7 +10,7 @@ Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v, Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
keywords::shape={1,1})) {} keywords::shape={1,1})) {}
Tensor Expr::val() { Tensor &Expr::val() {
return pimpl_->val(); return pimpl_->val();
} }

View File

@ -15,7 +15,7 @@ class Expr {
return *this; return *this;
} }
Tensor val(); Tensor &val();
Tensor grad(); Tensor grad();
void forward(size_t batchSize); void forward(size_t batchSize);

View File

@ -17,7 +17,7 @@ struct Chainable {
virtual void allocate(size_t) = 0; virtual void allocate(size_t) = 0;
virtual const Shape& shape() = 0; virtual const Shape& shape() = 0;
virtual DataType val() = 0; virtual DataType &val() = 0;
virtual DataType grad() = 0; virtual DataType grad() = 0;
virtual void setVal(Tensor t) { virtual void setVal(Tensor t) {
UTIL_THROW2("Tensors can only be assigned to input nodes"); UTIL_THROW2("Tensors can only be assigned to input nodes");
@ -82,7 +82,7 @@ class Node : public Chainable<Tensor>,
} }
} }
virtual Tensor val() { virtual Tensor &val() {
UTIL_THROW_IF2(!val_, "Tensor has not been allocated"); UTIL_THROW_IF2(!val_, "Tensor has not been allocated");
return val_; return val_;
}; };
@ -104,4 +104,4 @@ class Node : public Chainable<Tensor>,
Tensor adj_; Tensor adj_;
}; };
} }