mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
tidied up initializers
This commit is contained in:
parent
2583e0fa42
commit
499faceb8e
@ -10,25 +10,42 @@
|
||||
namespace marian {
|
||||
|
||||
void zeros(Tensor t) {
|
||||
std::vector<float> vals(t.size(), 0.0f);
|
||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
||||
t.set(0.f);
|
||||
}
|
||||
|
||||
void ones(Tensor t) {
|
||||
std::vector<float> vals(t.size(), 1.0f);
|
||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
||||
t.set(1.0f);
|
||||
}
|
||||
|
||||
void randreal(Tensor t) {
|
||||
template <class Distribution>
|
||||
void distribution(Tensor t, float a=0.0, float b=0.1) {
|
||||
std::random_device device;
|
||||
std::default_random_engine engine(device());
|
||||
std::uniform_real_distribution<> dist(0, 0.1);
|
||||
Distribution dist(a, b);
|
||||
auto gen = std::bind(dist, engine);
|
||||
|
||||
std::vector<float> vals(t.size());
|
||||
std::generate(begin(vals), end(vals), gen);
|
||||
|
||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
||||
t << vals;
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.1) {
|
||||
return [mean, std](Tensor t) {
|
||||
distribution<std::normal_distribution<float>>(t, mean, std);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.1) {
|
||||
return [a, b](Tensor t) {
|
||||
distribution<std::uniform_real_distribution<float>>(t, a, b);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> from_vector(const std::vector<float>& v) {
|
||||
return [&v](Tensor t) {
|
||||
t << v;
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -9,7 +9,7 @@ using namespace keywords;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
cudaSetDevice(0);
|
||||
cudaSetDevice(1);
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
@ -20,7 +20,6 @@ int main(int argc, char** argv) {
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
|
||||
std::cerr << "Loading model params...";
|
||||
NpzConverter converter("../scripts/test_model/model.npz");
|
||||
|
||||
@ -36,9 +35,9 @@ int main(int argc, char** argv) {
|
||||
auto y = input(shape={whatevs, LABEL_SIZE});
|
||||
|
||||
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
init=[wData](Tensor t) { t.set(wData); });
|
||||
init=from_vector(wData));
|
||||
auto b = param(shape={1, LABEL_SIZE},
|
||||
init=[bData](Tensor t) { t.set(bData); });
|
||||
init=from_vector(bData));
|
||||
|
||||
auto probs = softmax(dot(x, w) + b, axis=1);
|
||||
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
|
||||
|
Loading…
Reference in New Issue
Block a user