mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
More commenting
This commit is contained in:
parent
10ee031e61
commit
2c24cb6827
@ -60,15 +60,43 @@ class Expr {
|
||||
ChainPtr pimpl_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Represents a computation graph of expressions, over which algorithmic differentiation may be performed.
|
||||
*/
|
||||
class ExpressionGraph {
|
||||
public:
|
||||
|
||||
/** @brief Constructs a new expression graph */
|
||||
ExpressionGraph() : stack_(new ChainableStack) {}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Performs backpropogation on this expression graph.
|
||||
*
|
||||
* Backpropogation is implemented by performing first the forward pass
|
||||
* and then the backward pass of algorithmic differentiation (AD) on the nodes of the graph.
|
||||
*
|
||||
* @param batchSize XXX Marcin, could you provide a description of this param?
|
||||
*/
|
||||
void backprop(int batchSize) {
|
||||
forward(batchSize);
|
||||
backward();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Perform the forward pass of algorithmic differentiation (AD) on this graph.
|
||||
*
|
||||
* This pass traverses the nodes of this graph in the order they were created;
|
||||
* as each node is traversed, its <code>allocate()</code> method is called.
|
||||
*
|
||||
* Once allocation is complete for all nodes, this pass again traverses the nodes, in creation order;
|
||||
* as each node is traversed, its <code>forward()</code> method is called.
|
||||
*
|
||||
* After this method has successfully completed,
|
||||
* it is guaranteed that all node allocation has been completed,
|
||||
* and that all forward pass computations have been performed.
|
||||
*
|
||||
* @param batchSize XXX Marcin, could you provide a description of this param?
|
||||
*/
|
||||
void forward(int batchSize) {
|
||||
for(auto&& v : *stack_) {
|
||||
v->allocate(batchSize);
|
||||
@ -76,7 +104,19 @@ class ExpressionGraph {
|
||||
for(auto&& v : *stack_)
|
||||
v->forward();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Perform the backward pass of algorithmic differentiation (AD) on this graph.
|
||||
*
|
||||
* This pass traverses the nodes of this graph in reverse of the order they were created;
|
||||
* as each node is traversed, its <code>set_zero_adjoint()</code> method is called.
|
||||
*
|
||||
* Once this has been performed for all nodes, this pass again traverses the nodes, again in reverse creation order;
|
||||
* as each node is traversed, its <code>backward()</code> method is called.
|
||||
*
|
||||
* After this method has successfully completed,
|
||||
* and that all backward pass computations have been performed.
|
||||
*/
|
||||
void backward() {
|
||||
for(auto&& v : *stack_)
|
||||
v->set_zero_adjoint();
|
||||
@ -86,7 +126,14 @@ class ExpressionGraph {
|
||||
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
|
||||
(*it)->backward();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Returns a string representing this expression graph in <code>graphviz</code> notation.
|
||||
*
|
||||
* This string can be used by <code>graphviz</code> tools to visualize the expression graph.
|
||||
*
|
||||
* @return a string representing this expression graph in <code>graphviz</code> notation
|
||||
*/
|
||||
std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "digraph ExpressionGraph {" << std::endl;
|
||||
@ -248,6 +295,8 @@ class ExpressionGraph {
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/** @brief Pointer to the list of nodes */
|
||||
ChainableStackPtr stack_;
|
||||
|
||||
/** @brief Maps from name to expression node. */
|
||||
|
Loading…
Reference in New Issue
Block a user