mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
tidied up initializers
This commit is contained in:
parent
2583e0fa42
commit
499faceb8e
@ -10,25 +10,42 @@
|
|||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
void zeros(Tensor t) {
|
void zeros(Tensor t) {
|
||||||
std::vector<float> vals(t.size(), 0.0f);
|
t.set(0.f);
|
||||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ones(Tensor t) {
|
void ones(Tensor t) {
|
||||||
std::vector<float> vals(t.size(), 1.0f);
|
t.set(1.0f);
|
||||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void randreal(Tensor t) {
|
template <class Distribution>
|
||||||
|
void distribution(Tensor t, float a=0.0, float b=0.1) {
|
||||||
std::random_device device;
|
std::random_device device;
|
||||||
std::default_random_engine engine(device());
|
std::default_random_engine engine(device());
|
||||||
std::uniform_real_distribution<> dist(0, 0.1);
|
Distribution dist(a, b);
|
||||||
auto gen = std::bind(dist, engine);
|
auto gen = std::bind(dist, engine);
|
||||||
|
|
||||||
std::vector<float> vals(t.size());
|
std::vector<float> vals(t.size());
|
||||||
std::generate(begin(vals), end(vals), gen);
|
std::generate(begin(vals), end(vals), gen);
|
||||||
|
|
||||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
t << vals;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.1) {
|
||||||
|
return [mean, std](Tensor t) {
|
||||||
|
distribution<std::normal_distribution<float>>(t, mean, std);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.1) {
|
||||||
|
return [a, b](Tensor t) {
|
||||||
|
distribution<std::uniform_real_distribution<float>>(t, a, b);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void(Tensor)> from_vector(const std::vector<float>& v) {
|
||||||
|
return [&v](Tensor t) {
|
||||||
|
t << v;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
@ -9,7 +9,7 @@ using namespace keywords;
|
|||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
|
||||||
cudaSetDevice(0);
|
cudaSetDevice(1);
|
||||||
|
|
||||||
const size_t IMAGE_SIZE = 784;
|
const size_t IMAGE_SIZE = 784;
|
||||||
const size_t LABEL_SIZE = 10;
|
const size_t LABEL_SIZE = 10;
|
||||||
@ -20,7 +20,6 @@ int main(int argc, char** argv) {
|
|||||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
|
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
|
||||||
std::cerr << "Done." << std::endl;
|
std::cerr << "Done." << std::endl;
|
||||||
|
|
||||||
|
|
||||||
std::cerr << "Loading model params...";
|
std::cerr << "Loading model params...";
|
||||||
NpzConverter converter("../scripts/test_model/model.npz");
|
NpzConverter converter("../scripts/test_model/model.npz");
|
||||||
|
|
||||||
@ -36,9 +35,9 @@ int main(int argc, char** argv) {
|
|||||||
auto y = input(shape={whatevs, LABEL_SIZE});
|
auto y = input(shape={whatevs, LABEL_SIZE});
|
||||||
|
|
||||||
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
|
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||||
init=[wData](Tensor t) { t.set(wData); });
|
init=from_vector(wData));
|
||||||
auto b = param(shape={1, LABEL_SIZE},
|
auto b = param(shape={1, LABEL_SIZE},
|
||||||
init=[bData](Tensor t) { t.set(bData); });
|
init=from_vector(bData));
|
||||||
|
|
||||||
auto probs = softmax(dot(x, w) + b, axis=1);
|
auto probs = softmax(dot(x, w) + b, axis=1);
|
||||||
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
|
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
|
||||||
|
Loading…
Reference in New Issue
Block a user