fixed begin/end in Tensor class

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-14 23:41:44 +02:00
parent 73e1d5f96a
commit cc4e24b3a6
2 changed files with 5 additions and 11 deletions

View File

@ -178,12 +178,12 @@ class Tensor {
return pimpl_->begin(); return pimpl_->begin();
} }
auto end() -> decltype( pimpl_->begin() ) { auto end() -> decltype( pimpl_->end() ) {
return pimpl_->begin(); return pimpl_->end();
} }
auto end() const -> decltype( pimpl_->begin() ) { auto end() const -> decltype( pimpl_->end() ) {
return pimpl_->begin(); return pimpl_->end();
} }
const Shape& shape() const { const Shape& shape() const {

View File

@ -27,9 +27,9 @@ int main(int argc, char** argv) {
Shape wShape, bShape; Shape wShape, bShape;
converter.Load("weights", wData, wShape); converter.Load("weights", wData, wShape);
converter.Load("bias", bData, bShape); converter.Load("bias", bData, bShape);
std::cerr << "Done." << std::endl; std::cerr << "Done." << std::endl;
std::cerr << "Building model...";
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X"); auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y"); auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
@ -38,7 +38,6 @@ int main(int argc, char** argv) {
auto b = param(shape={1, LABEL_SIZE}, name="b0", auto b = param(shape={1, LABEL_SIZE}, name="b0",
init=[bData](Tensor t) {t.set(bData); }); init=[bData](Tensor t) {t.set(bData); });
std::cerr << "Building model...";
auto predict = softmax(dot(x, w) + b, auto predict = softmax(dot(x, w) + b,
axis=1, name="pred"); axis=1, name="pred");
auto graph = -mean(sum(y * log(predict), axis=1), auto graph = -mean(sum(y * log(predict), axis=1),
@ -69,11 +68,6 @@ int main(int argc, char** argv) {
if (resultsv[i + j] > resultsv[i + predicted]) predicted = j; if (resultsv[i + j] > resultsv[i + predicted]) predicted = j;
} }
acc += (correct == predicted); acc += (correct == predicted);
//std::cerr << correct << " | " << predicted << " ( ";
//for (size_t j = 0; j < LABEL_SIZE; ++j) {
// std::cerr << resultsv[i+j] << " ";
//}
//std::cerr << ")" << std::endl;
} }
std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl; std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl;