diff --git a/src/expressions.cu b/src/expressions.cu index a95b1bef..2d656ce1 100644 --- a/src/expressions.cu +++ b/src/expressions.cu @@ -10,7 +10,7 @@ Expr::Expr(Chainable* chainable) : pimpl_(chainable) {} Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v, keywords::shape={1,1})) {} -Tensor Expr::val() { +Tensor &Expr::val() { return pimpl_->val(); } diff --git a/src/expressions.h b/src/expressions.h index d7945f07..09d0edfa 100644 --- a/src/expressions.h +++ b/src/expressions.h @@ -15,7 +15,7 @@ class Expr { return *this; } - Tensor val(); + Tensor &val(); Tensor grad(); void forward(size_t batchSize); diff --git a/src/graph.h b/src/graph.h index 15b4721d..33de8a5e 100644 --- a/src/graph.h +++ b/src/graph.h @@ -17,7 +17,7 @@ struct Chainable { virtual void allocate(size_t) = 0; virtual const Shape& shape() = 0; - virtual DataType val() = 0; + virtual DataType &val() = 0; virtual DataType grad() = 0; virtual void setVal(Tensor t) { UTIL_THROW2("Tensors can only be assigned to input nodes"); @@ -82,7 +82,7 @@ class Node : public Chainable, } } - virtual Tensor val() { + virtual Tensor &val() { UTIL_THROW_IF2(!val_, "Tensor has not been allocated"); return val_; }; @@ -104,4 +104,4 @@ class Node : public Chainable, Tensor adj_; }; -} \ No newline at end of file +}