mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
fixed begin/end in Tensor class
This commit is contained in:
parent
73e1d5f96a
commit
cc4e24b3a6
@ -178,12 +178,12 @@ class Tensor {
|
|||||||
return pimpl_->begin();
|
return pimpl_->begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto end() -> decltype( pimpl_->begin() ) {
|
auto end() -> decltype( pimpl_->end() ) {
|
||||||
return pimpl_->begin();
|
return pimpl_->end();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto end() const -> decltype( pimpl_->begin() ) {
|
auto end() const -> decltype( pimpl_->end() ) {
|
||||||
return pimpl_->begin();
|
return pimpl_->end();
|
||||||
}
|
}
|
||||||
|
|
||||||
const Shape& shape() const {
|
const Shape& shape() const {
|
||||||
|
@ -27,9 +27,9 @@ int main(int argc, char** argv) {
|
|||||||
Shape wShape, bShape;
|
Shape wShape, bShape;
|
||||||
converter.Load("weights", wData, wShape);
|
converter.Load("weights", wData, wShape);
|
||||||
converter.Load("bias", bData, bShape);
|
converter.Load("bias", bData, bShape);
|
||||||
|
|
||||||
std::cerr << "Done." << std::endl;
|
std::cerr << "Done." << std::endl;
|
||||||
|
|
||||||
|
std::cerr << "Building model...";
|
||||||
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
|
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
|
||||||
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
|
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
|
||||||
|
|
||||||
@ -38,7 +38,6 @@ int main(int argc, char** argv) {
|
|||||||
auto b = param(shape={1, LABEL_SIZE}, name="b0",
|
auto b = param(shape={1, LABEL_SIZE}, name="b0",
|
||||||
init=[bData](Tensor t) {t.set(bData); });
|
init=[bData](Tensor t) {t.set(bData); });
|
||||||
|
|
||||||
std::cerr << "Building model...";
|
|
||||||
auto predict = softmax(dot(x, w) + b,
|
auto predict = softmax(dot(x, w) + b,
|
||||||
axis=1, name="pred");
|
axis=1, name="pred");
|
||||||
auto graph = -mean(sum(y * log(predict), axis=1),
|
auto graph = -mean(sum(y * log(predict), axis=1),
|
||||||
@ -69,11 +68,6 @@ int main(int argc, char** argv) {
|
|||||||
if (resultsv[i + j] > resultsv[i + predicted]) predicted = j;
|
if (resultsv[i + j] > resultsv[i + predicted]) predicted = j;
|
||||||
}
|
}
|
||||||
acc += (correct == predicted);
|
acc += (correct == predicted);
|
||||||
//std::cerr << correct << " | " << predicted << " ( ";
|
|
||||||
//for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
|
||||||
// std::cerr << resultsv[i+j] << " ";
|
|
||||||
//}
|
|
||||||
//std::cerr << ")" << std::endl;
|
|
||||||
}
|
}
|
||||||
std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl;
|
std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user