mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge ../Marian.hieu
This commit is contained in:
commit
f0340804f5
@ -34,7 +34,7 @@ struct Chainable {
|
||||
virtual ~Chainable() { }
|
||||
virtual void forward() { }
|
||||
virtual void backward() { }
|
||||
virtual void backward_numeric(Float delta) { }
|
||||
virtual void backward_debug(Float delta) { }
|
||||
|
||||
virtual void check() { }
|
||||
virtual void init_dependent() { }
|
||||
|
@ -127,7 +127,7 @@ class ExpressionGraph {
|
||||
(*it)->backward();
|
||||
}
|
||||
|
||||
void backward_numeric(Float delta) {
|
||||
void backward_debug(Float delta) {
|
||||
for(auto&& v : *stack_)
|
||||
v->set_zero_adjoint();
|
||||
|
||||
@ -136,7 +136,7 @@ class ExpressionGraph {
|
||||
for(It it = stack_->rbegin(); it != stack_->rend(); ++it) {
|
||||
Chainable<Tensor> *chainable = *it;
|
||||
//chainable->backward();
|
||||
chainable->backward_numeric(delta);
|
||||
chainable->backward_debug(delta);
|
||||
}
|
||||
}
|
||||
|
||||
|
36
src/node.h
36
src/node.h
@ -120,11 +120,21 @@ class Node : public Chainable<Tensor>,
|
||||
const std::vector<float> &prevCalcGrad
|
||||
)
|
||||
{
|
||||
size_t totSize = GetTotalSize(input.shape());
|
||||
using namespace std;
|
||||
|
||||
std::vector<float> diffGrad(totSize);
|
||||
size_t inputSize = GetTotalSize(input.shape());
|
||||
size_t gradSize = GetTotalSize(grad.shape());
|
||||
size_t adjSize = GetTotalSize(adj_.shape());
|
||||
cerr << "sizes: "
|
||||
<< Debug(input.shape())<< "=" << inputSize << " "
|
||||
<< Debug(grad.shape()) << "=" << gradSize << " "
|
||||
<< Debug(adj_.shape()) << "=" << adjSize
|
||||
<< endl;
|
||||
|
||||
std::vector<float> diffGrad(gradSize);
|
||||
thrust::copy(grad.begin(), grad.end(), diffGrad.begin());
|
||||
output("diffGrad", diffGrad);
|
||||
cerr << "diffGrad=" << grad.Debug() << endl;
|
||||
//output("diffGrad", diffGrad);
|
||||
|
||||
// reset grad
|
||||
thrust::copy(prevCalcGrad.begin(), prevCalcGrad.end(), grad.begin());
|
||||
@ -138,7 +148,7 @@ class Node : public Chainable<Tensor>,
|
||||
//cerr << "input=" << input.Debug() << endl;
|
||||
//cerr << "val_=" << val_.Debug() << endl;
|
||||
|
||||
std::vector<float> newVal(totSize);
|
||||
std::vector<float> newVal(inputSize);
|
||||
thrust::copy(val_.begin(), val_.end(), newVal.begin());
|
||||
//output("newVal", newVal);
|
||||
|
||||
@ -149,31 +159,31 @@ class Node : public Chainable<Tensor>,
|
||||
//cerr << "input=" << input.Debug() << endl;
|
||||
//cerr << "val_=" << val_.Debug() << endl;
|
||||
|
||||
std::vector<float> origVal(totSize);
|
||||
std::vector<float> origVal(inputSize);
|
||||
thrust::copy(val_.begin(), val_.end(), origVal.begin());
|
||||
//output("origVal", origVal);
|
||||
|
||||
// calc gradient
|
||||
//cerr << "adj_=" << adj_.Debug() << endl;
|
||||
std::vector<float> adjVec(totSize);
|
||||
std::vector<float> adjVec(adjSize);
|
||||
thrust::copy(adj_.begin(), adj_.end(), adjVec.begin());
|
||||
|
||||
std::vector<float> numericalGrad(totSize);
|
||||
for (size_t i = 0; i < totSize; ++i) {
|
||||
std::vector<float> numericalGrad(gradSize);
|
||||
for (size_t i = 0; i < numericalGrad.size(); ++i) {
|
||||
numericalGrad[i] = prevCalcGrad[i] + (adjVec[i] * (newVal[i] - origVal[i]) / delta);
|
||||
}
|
||||
output("numericalGrad", numericalGrad);
|
||||
//cerr << "numeric a_->grad()=" << a_->grad().Debug() << endl;
|
||||
|
||||
// set grad results
|
||||
thrust::copy(numericalGrad.begin(), numericalGrad.end(), grad.begin());
|
||||
cerr << "numericalGrad=" << grad.Debug() << endl;
|
||||
//output("numericalGrad", numericalGrad);
|
||||
|
||||
// print out diff between diffGrad and numericalGrad
|
||||
std::vector<float> origGrad(totSize);
|
||||
std::vector<float> diff(totSize);
|
||||
std::vector<float> origGrad(gradSize);
|
||||
std::vector<float> diff(gradSize);
|
||||
|
||||
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
|
||||
for (size_t i = 0; i < totSize; ++i) {
|
||||
for (size_t i = 0; i < diff.size(); ++i) {
|
||||
diff[i] = (diffGrad[i] - numericalGrad[i]) ;
|
||||
}
|
||||
output("diff", diff);
|
||||
|
@ -12,7 +12,7 @@ struct BinaryNodeOp : public Node {
|
||||
: Node(args...), a_(a), b_(b) {}
|
||||
|
||||
|
||||
void backward_numeric(Float delta) {
|
||||
void backward_debug(Float delta) {
|
||||
using namespace std;
|
||||
|
||||
cerr << "BinaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl;
|
||||
|
@ -11,7 +11,7 @@ struct UnaryNodeOp : public Node {
|
||||
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
|
||||
args...), a_(a) {}
|
||||
|
||||
void backward_numeric(Float delta) {
|
||||
void backward_debug(Float delta) {
|
||||
using namespace std;
|
||||
|
||||
cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl;
|
||||
|
@ -55,7 +55,7 @@ int main(int argc, char** argv)
|
||||
// train
|
||||
g.forward(batch_size);
|
||||
//g.backward();
|
||||
g.backward_numeric(0.00001);
|
||||
g.backward_debug(0.00001);
|
||||
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user