mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
batches
This commit is contained in:
parent
a573eecf5c
commit
f0f0dbe9ee
@ -80,12 +80,12 @@ void Tensor::Load(const std::string &path)
|
||||
}
|
||||
strm.close();
|
||||
|
||||
Load(hostData);
|
||||
Load(hostData.begin(), hostData.begin());
|
||||
}
|
||||
|
||||
void Tensor::Load(const std::vector<float> &values)
|
||||
void Tensor::Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end)
|
||||
{
|
||||
pimpl_->set(values);
|
||||
pimpl_->set(begin, end);
|
||||
}
|
||||
|
||||
}
|
||||
|
10
src/tensor.h
10
src/tensor.h
@ -158,11 +158,11 @@ class TensorImpl {
|
||||
thrust::fill(data_.begin(), data_.end(), value);
|
||||
}
|
||||
|
||||
void set(const std::vector<Float> &values) {
|
||||
void set(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end) {
|
||||
size_t totSize = GetTotalSize(shape());
|
||||
std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl;
|
||||
assert(totSize == values.size());
|
||||
thrust::copy(values.begin(), values.end(), data_.begin());
|
||||
//std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl;
|
||||
//assert(totSize == values.size());
|
||||
thrust::copy(begin, end, data_.begin());
|
||||
}
|
||||
|
||||
std::string Debug() const
|
||||
@ -275,7 +275,7 @@ class Tensor {
|
||||
}
|
||||
|
||||
void Load(const std::string &path);
|
||||
void Load(const std::vector<float> &values);
|
||||
void Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
|
||||
|
||||
};
|
||||
|
||||
|
21
src/test.cu
21
src/test.cu
@ -12,6 +12,7 @@ int main(int argc, char** argv) {
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
const size_t BATCH_SIZE = 500;
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
|
||||
@ -31,16 +32,22 @@ int main(int argc, char** argv) {
|
||||
//vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
|
||||
|
||||
cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
|
||||
cerr << "numofdata=" << numofdata << endl;
|
||||
|
||||
size_t startInd = 0;
|
||||
size_t startIndData = 0;
|
||||
while (startInd < numofdata) {
|
||||
size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd;
|
||||
cerr << "startInd=" << startInd
|
||||
<< " startIndData=" << startIndData
|
||||
<< " batchSize=" << batchSize << endl;
|
||||
|
||||
Tensor tx({numofdata, IMAGE_SIZE}, 1);
|
||||
Tensor ty({numofdata, LABEL_SIZE}, 1);
|
||||
|
||||
tx.Load(images);
|
||||
ty.Load(labels);
|
||||
tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE);
|
||||
ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize);
|
||||
|
||||
//cerr << "tx=" << Debug(tx.shape()) << endl;
|
||||
//cerr << "ty=" << Debug(ty.shape()) << endl;
|
||||
@ -52,7 +59,7 @@ int main(int argc, char** argv) {
|
||||
cerr << "y=" << Debug(y.val().shape()) << endl;
|
||||
|
||||
|
||||
graph.forward(500);
|
||||
graph.forward(batchSize);
|
||||
|
||||
cerr << "w=" << Debug(w.val().shape()) << endl;
|
||||
cerr << "b=" << Debug(b.val().shape()) << endl;
|
||||
@ -67,6 +74,10 @@ int main(int argc, char** argv) {
|
||||
|
||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||
|
||||
startInd += batchSize;
|
||||
startIndData += batchSize * IMAGE_SIZE;
|
||||
}
|
||||
|
||||
|
||||
// XOR
|
||||
/*
|
||||
|
Loading…
Reference in New Issue
Block a user