mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
operators
This commit is contained in:
parent
7ef5061d8c
commit
aab15d66e6
@ -15,5 +15,16 @@ void Tensor::set(const std::vector<float>::const_iterator &begin, const std::vec
|
||||
pimpl_->set(begin, end);
|
||||
}
|
||||
|
||||
Tensor& operator<<(Tensor& t, const std::vector<float> &vec) {
|
||||
t.set(vec);
|
||||
return t;
|
||||
}
|
||||
|
||||
std::vector<float>& operator<<(std::vector<float> &vec, const Tensor& t) {
|
||||
t.get(vec);
|
||||
return vec;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@ -218,13 +218,17 @@ class Tensor {
|
||||
void set(const std::vector<float>& data);
|
||||
void set(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
|
||||
|
||||
void get(std::vector<float>::iterator out) {
|
||||
void get(std::vector<float>::iterator out) const {
|
||||
pimpl_->get(out);
|
||||
}
|
||||
|
||||
void get(std::vector<float> &vout) {
|
||||
void get(std::vector<float> &vout) const {
|
||||
pimpl_->get(vout.begin());
|
||||
}
|
||||
};
|
||||
|
||||
Tensor& operator<<(Tensor& t, const std::vector<float> &vec);
|
||||
|
||||
std::vector<float>& operator<<(std::vector<float> &vec, const Tensor& t);
|
||||
|
||||
}
|
||||
|
@ -31,11 +31,11 @@ int main(int argc, char** argv) {
|
||||
converter.Load("bias", bData, bShape);
|
||||
|
||||
auto initW = [wData](Tensor t) {
|
||||
t.set(wData.begin(), wData.end());
|
||||
t.set(wData);
|
||||
};
|
||||
|
||||
auto initB = [bData](Tensor t) {
|
||||
t.set(bData.begin(), bData.end());
|
||||
t.set(bData);
|
||||
};
|
||||
|
||||
std::cerr << "\tDone." << std::endl;
|
||||
@ -56,20 +56,17 @@ int main(int argc, char** argv) {
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
Tensor xt({numofdata, IMAGE_SIZE});
|
||||
xt.set(testImages);
|
||||
|
||||
Tensor yt({numofdata, LABEL_SIZE});
|
||||
yt.set(testLabels);
|
||||
|
||||
x = xt;
|
||||
y = yt;
|
||||
x = xt << testImages;
|
||||
y = yt << testLabels;
|
||||
|
||||
graph.forward(numofdata);
|
||||
auto results = predict.val();
|
||||
graph.backward();
|
||||
|
||||
std::vector<float> resultsv(results.size());
|
||||
results.get(resultsv);
|
||||
resultsv << results;
|
||||
|
||||
std::cerr << b.grad().Debug() << std::endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user