operators

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-14 23:31:11 +02:00
parent 7ef5061d8c
commit aab15d66e6
3 changed files with 22 additions and 10 deletions

View File

@ -15,5 +15,16 @@ void Tensor::set(const std::vector<float>::const_iterator &begin, const std::vec
pimpl_->set(begin, end);
}
Tensor& operator<<(Tensor& t, const std::vector<float> &vec) {
t.set(vec);
return t;
}
std::vector<float>& operator<<(std::vector<float> &vec, const Tensor& t) {
t.get(vec);
return vec;
}
}

View File

@ -218,13 +218,17 @@ class Tensor {
void set(const std::vector<float>& data);
void set(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
void get(std::vector<float>::iterator out) {
void get(std::vector<float>::iterator out) const {
pimpl_->get(out);
}
void get(std::vector<float> &vout) {
void get(std::vector<float> &vout) const {
pimpl_->get(vout.begin());
}
};
Tensor& operator<<(Tensor& t, const std::vector<float> &vec);
std::vector<float>& operator<<(std::vector<float> &vec, const Tensor& t);
}

View File

@ -31,11 +31,11 @@ int main(int argc, char** argv) {
converter.Load("bias", bData, bShape);
auto initW = [wData](Tensor t) {
t.set(wData.begin(), wData.end());
t.set(wData);
};
auto initB = [bData](Tensor t) {
t.set(bData.begin(), bData.end());
t.set(bData);
};
std::cerr << "\tDone." << std::endl;
@ -56,20 +56,17 @@ int main(int argc, char** argv) {
std::cerr << "Done." << std::endl;
Tensor xt({numofdata, IMAGE_SIZE});
xt.set(testImages);
Tensor yt({numofdata, LABEL_SIZE});
yt.set(testLabels);
x = xt;
y = yt;
x = xt << testImages;
y = yt << testLabels;
graph.forward(numofdata);
auto results = predict.val();
graph.backward();
std::vector<float> resultsv(results.size());
results.get(resultsv);
resultsv << results;
std::cerr << b.grad().Debug() << std::endl;