More commenting

This commit is contained in:
Lane Schwartz 2016-09-18 15:55:04 -05:00
parent 10ee031e61
commit 2c24cb6827

View File

@ -60,15 +60,43 @@ class Expr {
ChainPtr pimpl_;
};
/**
* @brief Represents a computation graph of expressions, over which algorithmic differentiation may be performed.
*/
class ExpressionGraph {
public:
/** @brief Constructs a new expression graph */
ExpressionGraph() : stack_(new ChainableStack) {}
/**
* @brief Performs backpropogation on this expression graph.
*
* Backpropogation is implemented by performing first the forward pass
* and then the backward pass of algorithmic differentiation (AD) on the nodes of the graph.
*
* @param batchSize XXX Marcin, could you provide a description of this param?
*/
void backprop(int batchSize) {
forward(batchSize);
backward();
}
/**
* @brief Perform the forward pass of algorithmic differentiation (AD) on this graph.
*
* This pass traverses the nodes of this graph in the order they were created;
* as each node is traversed, its <code>allocate()</code> method is called.
*
* Once allocation is complete for all nodes, this pass again traverses the nodes, in creation order;
* as each node is traversed, its <code>forward()</code> method is called.
*
* After this method has successfully completed,
* it is guaranteed that all node allocation has been completed,
* and that all forward pass computations have been performed.
*
* @param batchSize XXX Marcin, could you provide a description of this param?
*/
void forward(int batchSize) {
for(auto&& v : *stack_) {
v->allocate(batchSize);
@ -76,7 +104,19 @@ class ExpressionGraph {
for(auto&& v : *stack_)
v->forward();
}
/**
* @brief Perform the backward pass of algorithmic differentiation (AD) on this graph.
*
* This pass traverses the nodes of this graph in reverse of the order they were created;
* as each node is traversed, its <code>set_zero_adjoint()</code> method is called.
*
* Once this has been performed for all nodes, this pass again traverses the nodes, again in reverse creation order;
* as each node is traversed, its <code>backward()</code> method is called.
*
* After this method has successfully completed,
* and that all backward pass computations have been performed.
*/
void backward() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
@ -86,7 +126,14 @@ class ExpressionGraph {
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
(*it)->backward();
}
/**
* @brief Returns a string representing this expression graph in <code>graphviz</code> notation.
*
* This string can be used by <code>graphviz</code> tools to visualize the expression graph.
*
* @return a string representing this expression graph in <code>graphviz</code> notation
*/
std::string graphviz() {
std::stringstream ss;
ss << "digraph ExpressionGraph {" << std::endl;
@ -248,6 +295,8 @@ class ExpressionGraph {
}
private:
/** @brief Pointer to the list of nodes */
ChainableStackPtr stack_;
/** @brief Maps from name to expression node. */