mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-19 02:37:14 +03:00
more tests
This commit is contained in:
parent
83f84ae4ee
commit
e46210b58b
@ -93,6 +93,33 @@ void construct(ExpressionGraphPtr g,
|
||||
return RNN<GRUFast>(encParams);
|
||||
};
|
||||
|
||||
auto buildDecoderGRU = [=](){
|
||||
std::string prefix = "decoder";
|
||||
auto U = g->param(prefix + "_U", {dimEncState, 2 * dimEncState},
|
||||
init=glorot_uniform);
|
||||
|
||||
auto W = g->param(prefix + "_W", {dimEncState, 2 * dimEncState},
|
||||
init=glorot_uniform);
|
||||
|
||||
auto b = g->param(prefix + "_b", {1, 2 * dimEncState}, init=zeros);
|
||||
|
||||
auto Ux = g->param(prefix + "_Ux", {dimEncState, dimEncState},
|
||||
init=glorot_uniform);
|
||||
|
||||
auto Wx = g->param(prefix + "_Wx", {dimEncState, dimEncState},
|
||||
init=glorot_uniform);
|
||||
|
||||
auto bx = g->param(prefix + "_bx", {1, dimEncState}, init=zeros);
|
||||
|
||||
ParametersGRUFast encParams;
|
||||
encParams.U = concatenate({U, Ux}, 1);
|
||||
encParams.W = concatenate({W, Wx}, 1);
|
||||
encParams.b = concatenate({b, bx}, 1);
|
||||
|
||||
return RNN<GRUFast>(encParams);
|
||||
};
|
||||
|
||||
|
||||
auto encStartState = name(g->zeros(shape={dimBatch, dimEncState}), "start");
|
||||
|
||||
auto encForward = buildEncoderGRU("encoder");
|
||||
@ -103,17 +130,27 @@ void construct(ExpressionGraphPtr g,
|
||||
auto statesBackward = encBackward.apply(inputs.rbegin(), inputs.rend(),
|
||||
encStartState);
|
||||
|
||||
std::vector<Expr> joinedStates;
|
||||
auto itFw = statesForward.begin();
|
||||
auto itBw = statesBackward.rbegin();
|
||||
while(itFw != statesForward.end()) {
|
||||
// add proper axes
|
||||
joinedStates.push_back(concatenate({*itFw++, *itBw++}, 1));
|
||||
}
|
||||
//std::vector<Expr> joinedStates;
|
||||
//auto itFw = statesForward.begin();
|
||||
//auto itBw = statesBackward.rbegin();
|
||||
//while(itFw != statesForward.end()) {
|
||||
// // add proper axes
|
||||
// joinedStates.push_back(concatenate({*itFw++, *itBw++}, 1));
|
||||
//}
|
||||
//
|
||||
//// add proper axes and make this a 3D tensor
|
||||
//auto encContext = name(concatenate(joinedStates, 2), "context");
|
||||
//
|
||||
//auto decStartState = mean(encContext, axis=2);
|
||||
|
||||
auto newStart = statesForward.back() + statesBackward.back();
|
||||
|
||||
auto dec1 = buildEncoderGRU("decoder1");
|
||||
auto states1 = dec1.apply(inputs.begin(), inputs.end(),
|
||||
encStartState);
|
||||
newStart);
|
||||
auto dec2 = buildDecoderGRU();
|
||||
auto states2 = dec2.apply(states1.begin(), states1.end(),
|
||||
newStart);
|
||||
|
||||
auto Wi = g->param("Wi", {dimEncState, 85000},
|
||||
init=glorot_uniform);
|
||||
@ -121,7 +158,7 @@ void construct(ExpressionGraphPtr g,
|
||||
init=zeros);
|
||||
|
||||
Expr total;
|
||||
for(auto h : states1) {
|
||||
for(auto h : states2) {
|
||||
auto cost = mean(sum(softmax(dot(h, Wi) + bi), axis=1), axis=0);
|
||||
if(total)
|
||||
total = total + cost;
|
||||
@ -129,10 +166,7 @@ void construct(ExpressionGraphPtr g,
|
||||
total = cost;
|
||||
}
|
||||
|
||||
// add proper axes and make this a 3D tensor
|
||||
//auto encContext = name(concatenate(joinedStates, 2), "context");
|
||||
|
||||
//auto decStartState = mean(encContext, axis=2);
|
||||
}
|
||||
|
||||
SentBatch generateBatch(size_t batchSize) {
|
||||
@ -156,7 +190,7 @@ int main(int argc, char** argv) {
|
||||
auto g = New<ExpressionGraph>();
|
||||
load(g, "/home/marcinj/Badania/amunmt/test2/model.npz");
|
||||
|
||||
size_t batchSize = 40;
|
||||
size_t batchSize = 20;
|
||||
|
||||
boost::timer::cpu_timer timer;
|
||||
for(int i = 1; i <= 1000; ++i) {
|
||||
@ -167,8 +201,7 @@ int main(int argc, char** argv) {
|
||||
construct(g, batch);
|
||||
|
||||
g->forward();
|
||||
//exit(0);
|
||||
//g->backward();
|
||||
g->backward();
|
||||
if(i % 100 == 0)
|
||||
std::cout << i << std::endl;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user