diff --git a/src/param_initializers.h b/src/param_initializers.h index 5a04a25c..04c6b48e 100644 --- a/src/param_initializers.h +++ b/src/param_initializers.h @@ -10,25 +10,42 @@ namespace marian { void zeros(Tensor t) { - std::vector vals(t.size(), 0.0f); - thrust::copy(vals.begin(), vals.end(), t.begin()); + t.set(0.f); } void ones(Tensor t) { - std::vector vals(t.size(), 1.0f); - thrust::copy(vals.begin(), vals.end(), t.begin()); + t.set(1.0f); } -void randreal(Tensor t) { +template +void distribution(Tensor t, float a=0.0, float b=0.1) { std::random_device device; std::default_random_engine engine(device()); - std::uniform_real_distribution<> dist(0, 0.1); + Distribution dist(a, b); auto gen = std::bind(dist, engine); std::vector vals(t.size()); std::generate(begin(vals), end(vals), gen); - thrust::copy(vals.begin(), vals.end(), t.begin()); + t << vals; } +std::function normal(float mean = 0.0, float std = 0.1) { + return [mean, std](Tensor t) { + distribution>(t, mean, std); + }; +} + +std::function uniform(float a = 0.0, float b = 0.1) { + return [a, b](Tensor t) { + distribution>(t, a, b); + }; +} + +std::function from_vector(const std::vector& v) { + return [&v](Tensor t) { + t << v; + }; +} + } // namespace marian diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index e9b5735d..7d812e36 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -9,7 +9,7 @@ using namespace keywords; int main(int argc, char** argv) { - cudaSetDevice(0); + cudaSetDevice(1); const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; @@ -20,7 +20,6 @@ int main(int argc, char** argv) { std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model/model.npz"); @@ -36,9 +35,9 @@ int main(int argc, char** argv) { auto y = input(shape={whatevs, LABEL_SIZE}); auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, - init=[wData](Tensor t) { t.set(wData); }); + init=from_vector(wData)); auto b = param(shape={1, LABEL_SIZE}, - init=[bData](Tensor t) { t.set(bData); }); + init=from_vector(bData)); auto probs = softmax(dot(x, w) + b, axis=1); auto cost = -mean(sum(y * log(probs), axis=1), axis=0);