tidied up initializers

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-15 12:23:31 +02:00
parent 2583e0fa42
commit 499faceb8e
2 changed files with 27 additions and 11 deletions

View File

@ -10,25 +10,42 @@
namespace marian {
void zeros(Tensor t) {
std::vector<float> vals(t.size(), 0.0f);
thrust::copy(vals.begin(), vals.end(), t.begin());
t.set(0.f);
}
void ones(Tensor t) {
std::vector<float> vals(t.size(), 1.0f);
thrust::copy(vals.begin(), vals.end(), t.begin());
t.set(1.0f);
}
void randreal(Tensor t) {
template <class Distribution>
void distribution(Tensor t, float a=0.0, float b=0.1) {
std::random_device device;
std::default_random_engine engine(device());
std::uniform_real_distribution<> dist(0, 0.1);
Distribution dist(a, b);
auto gen = std::bind(dist, engine);
std::vector<float> vals(t.size());
std::generate(begin(vals), end(vals), gen);
thrust::copy(vals.begin(), vals.end(), t.begin());
t << vals;
}
std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.1) {
return [mean, std](Tensor t) {
distribution<std::normal_distribution<float>>(t, mean, std);
};
}
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.1) {
return [a, b](Tensor t) {
distribution<std::uniform_real_distribution<float>>(t, a, b);
};
}
std::function<void(Tensor)> from_vector(const std::vector<float>& v) {
return [&v](Tensor t) {
t << v;
};
}
} // namespace marian

View File

@ -9,7 +9,7 @@ using namespace keywords;
int main(int argc, char** argv) {
cudaSetDevice(0);
cudaSetDevice(1);
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
@ -20,7 +20,6 @@ int main(int argc, char** argv) {
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
std::cerr << "Done." << std::endl;
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model/model.npz");
@ -36,9 +35,9 @@ int main(int argc, char** argv) {
auto y = input(shape={whatevs, LABEL_SIZE});
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
init=[wData](Tensor t) { t.set(wData); });
init=from_vector(wData));
auto b = param(shape={1, LABEL_SIZE},
init=[bData](Tensor t) { t.set(bData); });
init=from_vector(bData));
auto probs = softmax(dot(x, w) + b, axis=1);
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);