more tests

This commit is contained in:
Marcin Junczys-Dowmunt 2016-11-18 01:07:06 +01:00
parent 83f84ae4ee
commit e46210b58b

View File

@ -93,6 +93,33 @@ void construct(ExpressionGraphPtr g,
return RNN<GRUFast>(encParams);
};
auto buildDecoderGRU = [=](){
std::string prefix = "decoder";
auto U = g->param(prefix + "_U", {dimEncState, 2 * dimEncState},
init=glorot_uniform);
auto W = g->param(prefix + "_W", {dimEncState, 2 * dimEncState},
init=glorot_uniform);
auto b = g->param(prefix + "_b", {1, 2 * dimEncState}, init=zeros);
auto Ux = g->param(prefix + "_Ux", {dimEncState, dimEncState},
init=glorot_uniform);
auto Wx = g->param(prefix + "_Wx", {dimEncState, dimEncState},
init=glorot_uniform);
auto bx = g->param(prefix + "_bx", {1, dimEncState}, init=zeros);
ParametersGRUFast encParams;
encParams.U = concatenate({U, Ux}, 1);
encParams.W = concatenate({W, Wx}, 1);
encParams.b = concatenate({b, bx}, 1);
return RNN<GRUFast>(encParams);
};
auto encStartState = name(g->zeros(shape={dimBatch, dimEncState}), "start");
auto encForward = buildEncoderGRU("encoder");
@ -103,17 +130,27 @@ void construct(ExpressionGraphPtr g,
auto statesBackward = encBackward.apply(inputs.rbegin(), inputs.rend(),
encStartState);
std::vector<Expr> joinedStates;
auto itFw = statesForward.begin();
auto itBw = statesBackward.rbegin();
while(itFw != statesForward.end()) {
// add proper axes
joinedStates.push_back(concatenate({*itFw++, *itBw++}, 1));
}
//std::vector<Expr> joinedStates;
//auto itFw = statesForward.begin();
//auto itBw = statesBackward.rbegin();
//while(itFw != statesForward.end()) {
// // add proper axes
// joinedStates.push_back(concatenate({*itFw++, *itBw++}, 1));
//}
//
//// add proper axes and make this a 3D tensor
//auto encContext = name(concatenate(joinedStates, 2), "context");
//
//auto decStartState = mean(encContext, axis=2);
auto newStart = statesForward.back() + statesBackward.back();
auto dec1 = buildEncoderGRU("decoder1");
auto states1 = dec1.apply(inputs.begin(), inputs.end(),
encStartState);
newStart);
auto dec2 = buildDecoderGRU();
auto states2 = dec2.apply(states1.begin(), states1.end(),
newStart);
auto Wi = g->param("Wi", {dimEncState, 85000},
init=glorot_uniform);
@ -121,7 +158,7 @@ void construct(ExpressionGraphPtr g,
init=zeros);
Expr total;
for(auto h : states1) {
for(auto h : states2) {
auto cost = mean(sum(softmax(dot(h, Wi) + bi), axis=1), axis=0);
if(total)
total = total + cost;
@ -129,10 +166,7 @@ void construct(ExpressionGraphPtr g,
total = cost;
}
// add proper axes and make this a 3D tensor
//auto encContext = name(concatenate(joinedStates, 2), "context");
//auto decStartState = mean(encContext, axis=2);
}
SentBatch generateBatch(size_t batchSize) {
@ -156,7 +190,7 @@ int main(int argc, char** argv) {
auto g = New<ExpressionGraph>();
load(g, "/home/marcinj/Badania/amunmt/test2/model.npz");
size_t batchSize = 40;
size_t batchSize = 20;
boost::timer::cpu_timer timer;
for(int i = 1; i <= 1000; ++i) {
@ -167,8 +201,7 @@ int main(int argc, char** argv) {
construct(g, batch);
g->forward();
//exit(0);
//g->backward();
g->backward();
if(i % 100 == 0)
std::cout << i << std::endl;
}