mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Merge ../Marian
This commit is contained in:
commit
98298c52ab
@ -5,4 +5,5 @@
|
||||
#include "graph_operators.h"
|
||||
#include "expressions.h"
|
||||
#include "expression_operators.h"
|
||||
#include "param_initializers.h"
|
||||
|
||||
|
34
src/param_initializers.h
Normal file
34
src/param_initializers.h
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <functional>
|
||||
|
||||
#include "tensor.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
void zeros(Tensor t) {
|
||||
std::vector<float> vals(t.size(), 0.0f);
|
||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
||||
}
|
||||
|
||||
void ones(Tensor t) {
|
||||
std::vector<float> vals(t.size(), 1.0f);
|
||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
||||
}
|
||||
|
||||
void randreal(Tensor t) {
|
||||
std::random_device device;
|
||||
std::default_random_engine engine(device());
|
||||
std::uniform_real_distribution<> dist(0, 1);
|
||||
auto gen = std::bind(dist, engine);
|
||||
|
||||
std::vector<float> vals(t.size());
|
||||
std::generate(begin(vals), end(vals), gen);
|
||||
|
||||
thrust::copy(vals.begin(), vals.end(), t.begin());
|
||||
}
|
||||
|
||||
} // namespace marian
|
@ -20,6 +20,7 @@ int main(int argc, char** argv) {
|
||||
Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y");
|
||||
|
||||
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
|
||||
// Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=randreal);
|
||||
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
|
||||
|
||||
Expr z = dot(x, w) + b;
|
||||
|
Loading…
Reference in New Issue
Block a user