This commit is contained in:
Hieu Hoang 2016-09-21 12:22:53 +01:00
parent 0fe591f64b
commit b5c4895683

View File

@ -15,10 +15,14 @@ void Node::calc_numeric_grad(
size_t inputSize = GetTotalSize(input.shape());
size_t gradSize = GetTotalSize(grad.shape());
size_t adjSize = GetTotalSize(adj_.shape());
size_t valSize = GetTotalSize(val_.shape());
assert(adjSize == valSize);
cerr << "sizes: "
<< Debug(input.shape())<< "=" << inputSize << " "
<< Debug(grad.shape()) << "=" << gradSize << " "
<< Debug(adj_.shape()) << "=" << adjSize
<< Debug(adj_.shape()) << "=" << adjSize << " "
<< Debug(val_.shape()) << "=" << valSize
<< endl;
std::vector<float> diffGrad(gradSize);
@ -30,6 +34,7 @@ void Node::calc_numeric_grad(
thrust::copy(prevCalcGrad.begin(), prevCalcGrad.end(), grad.begin());
//cerr << "reset a_->grad()=" << a_->grad().Debug() << endl;
// START CALC of numerical gradient
// new values
input.incr(delta);
@ -49,7 +54,7 @@ void Node::calc_numeric_grad(
//cerr << "input=" << input.Debug() << endl;
//cerr << "val_=" << val_.Debug() << endl;
std::vector<float> origVal(inputSize);
std::vector<float> origVal(valSize);
thrust::copy(val_.begin(), val_.end(), origVal.begin());
//output("origVal", origVal);