diff --git a/src/tensor.cu b/src/tensor.cu index c5619b99..398b696a 100644 --- a/src/tensor.cu +++ b/src/tensor.cu @@ -80,12 +80,12 @@ void Tensor::Load(const std::string &path) } strm.close(); - Load(hostData); + Load(hostData.begin(), hostData.begin()); } -void Tensor::Load(const std::vector &values) +void Tensor::Load(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { - pimpl_->set(values); + pimpl_->set(begin, end); } } diff --git a/src/tensor.h b/src/tensor.h index 83965508..d6acea11 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -158,11 +158,11 @@ class TensorImpl { thrust::fill(data_.begin(), data_.end(), value); } - void set(const std::vector &values) { + void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { size_t totSize = GetTotalSize(shape()); - std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl; - assert(totSize == values.size()); - thrust::copy(values.begin(), values.end(), data_.begin()); + //std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl; + //assert(totSize == values.size()); + thrust::copy(begin, end, data_.begin()); } std::string Debug() const @@ -275,7 +275,7 @@ class Tensor { } void Load(const std::string &path); - void Load(const std::vector &values); + void Load(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); }; diff --git a/src/test.cu b/src/test.cu index 777b4b39..a78e182f 100644 --- a/src/test.cu +++ b/src/test.cu @@ -12,6 +12,7 @@ int main(int argc, char** argv) { using namespace marian; using namespace keywords; + const size_t BATCH_SIZE = 500; const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; @@ -31,41 +32,51 @@ int main(int argc, char** argv) { //vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); vector images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE); vector labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE); - - cerr << "images=" << images.size() << " labels=" << labels.size() << endl; cerr << "numofdata=" << numofdata << endl; - Tensor tx({numofdata, IMAGE_SIZE}, 1); - Tensor ty({numofdata, LABEL_SIZE}, 1); + size_t startInd = 0; + size_t startIndData = 0; + while (startInd < numofdata) { + size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd; + cerr << "startInd=" << startInd + << " startIndData=" << startIndData + << " batchSize=" << batchSize << endl; - tx.Load(images); - ty.Load(labels); + Tensor tx({numofdata, IMAGE_SIZE}, 1); + Tensor ty({numofdata, LABEL_SIZE}, 1); - //cerr << "tx=" << Debug(tx.shape()) << endl; - //cerr << "ty=" << Debug(ty.shape()) << endl; + tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE); + ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize); - x = tx; - y = ty; + //cerr << "tx=" << Debug(tx.shape()) << endl; + //cerr << "ty=" << Debug(ty.shape()) << endl; - cerr << "x=" << Debug(x.val().shape()) << endl; - cerr << "y=" << Debug(y.val().shape()) << endl; + x = tx; + y = ty; + + cerr << "x=" << Debug(x.val().shape()) << endl; + cerr << "y=" << Debug(y.val().shape()) << endl; - graph.forward(500); + graph.forward(batchSize); - cerr << "w=" << Debug(w.val().shape()) << endl; - cerr << "b=" << Debug(b.val().shape()) << endl; - std::cerr << "z: " << Debug(z.val().shape()) << endl; - std::cerr << "lr: " << Debug(lr.val().shape()) << endl; - std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ; + cerr << "w=" << Debug(w.val().shape()) << endl; + cerr << "b=" << Debug(b.val().shape()) << endl; + std::cerr << "z: " << Debug(z.val().shape()) << endl; + std::cerr << "lr: " << Debug(lr.val().shape()) << endl; + std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ; - //std::cerr << "scores=" << scores.val().Debug() << endl; - std::cerr << "lr=" << lr.val().Debug() << endl; + //std::cerr << "scores=" << scores.val().Debug() << endl; + std::cerr << "lr=" << lr.val().Debug() << endl; - graph.backward(); - - //std::cerr << graph["pred"].val()[0] << std::endl; + graph.backward(); + + //std::cerr << graph["pred"].val()[0] << std::endl; + + startInd += batchSize; + startIndData += batchSize * IMAGE_SIZE; + } // XOR