This commit is contained in:
Roman Grundkiewicz 2017-06-04 15:07:29 +02:00
commit 3cc34f172c
3 changed files with 55 additions and 74 deletions

View File

@ -10,6 +10,7 @@
#include "graph/expression_graph.h"
#include "graph/expression_operators.h"
#include "layers/generic.h"
#include "layers/attention.h"
namespace marian {
@ -91,9 +92,6 @@ class RNN : public Layer {
else
return concatenate(states, keywords::axis=2);
}
else if(direction_ == dir::bidirect) {
UTIL_THROW2("Use BiRNN for bidirectional RNNs");
}
else { // assuming dir::forward
auto states = apply(input, state, mask, false);
if(outputLast_)
@ -179,63 +177,19 @@ class MLRNN : public Layer {
}
};
template <class Cell>
class BiRNN : public Layer {
public:
int layers_;
int dimState_;
Ptr<RNN<Cell>> rnn1_;
Ptr<RNN<Cell>> rnn2_;
template <typename ...Args>
BiRNN(Ptr<ExpressionGraph> graph,
const std::string& name,
int layers,
int dimInput,
int dimState,
Args ...args)
: Layer(name),
dimState_{dimState},
rnn1_(New<MLRNN<Cell>>(graph, name, layers, dimInput, dimState,
keywords::direction=dir::forward,
args...)),
rnn2_(New<MLRNN<Cell>>(graph, name + "_r", layers, dimInput, dimState,
keywords::direction=dir::backward,
args...)) {}
template <typename ...Args>
std::vector<Expr> operator()(Expr input, Args ...args) {
Expr mask = Get(keywords::mask, nullptr, args...);
auto statesfw = (*rnn1_)(input);
auto statesbw = (*rnn2_)(input, keywords::mask=mask);
std::vector<Expr> outStates;
for(int i = 0; i < layers_; ++i)
outStates.push_back(concatenate({statesfw[i], statesbw[i]},
keywords::axis=1));
return outStates;
}
template <typename ...Args>
std::vector<Expr> operator()(Expr input, std::vector<Expr> states, Args ...args) {
Expr mask = Get(keywords::mask, nullptr, args...);
auto statesfw = (*rnn1_)(input, states);
auto statesbw = (*rnn2_)(input, states, keywords::mask=mask);
std::vector<Expr> outStates;
for(int i = 0; i < layers_; ++i)
outStates.push_back(concatenate({statesfw[i], statesbw[i]},
keywords::axis=1));
return outStates;
}
};
/***************************************************************/
class Tanh {
private:
Expr U_, W_, b_;
Expr gamma1_;
Expr gamma2_;
bool layerNorm_;
float dropout_;
Expr dropMaskX_;
Expr dropMaskS_;
public:
template <typename ...Args>
@ -250,6 +204,21 @@ class Tanh {
keywords::init=inits::glorot_uniform);
b_ = graph->param(prefix + "_b", {1, dimState},
keywords::init=inits::zeros);
layerNorm_ = Get(keywords::normalize, false, args...);
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
if(dropout_> 0.0f) {
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
}
if(layerNorm_) {
gamma1_ = graph->param(prefix + "_gamma1", {1, 3 * dimState},
keywords::init=inits::from_value(1.f));
gamma2_ = graph->param(prefix + "_gamma2", {1, 3 * dimState},
keywords::init=inits::from_value(1.f));
}
}
std::vector<Expr> apply(std::vector<Expr> inputs, std::vector<Expr> states, Expr mask = nullptr) {
@ -262,18 +231,37 @@ class Tanh {
input = concatenate(inputs, keywords::axis=1);
else
input = inputs.front();
return { dot(input, W_) };
if(dropMaskX_)
input = dropout(input, keywords::mask=dropMaskX_);
auto xW = dot(input, W_);
if(layerNorm_)
xW = layer_norm(xW, gamma1_);
return { xW };
}
std::vector<Expr> applyState(std::vector<Expr> xWs, std::vector<Expr> states, Expr mask = nullptr) {
std::vector<Expr> applyState(std::vector<Expr> xWs, std::vector<Expr> states,
Expr mask = nullptr) {
Expr state;
if(states.size() > 1)
state = concatenate(states, keywords::axis=1);
else
state = states.front();
auto stateDropped = state;
if(dropMaskS_)
stateDropped = dropout(state, keywords::mask=dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
sU = layer_norm(sU, gamma2_);
auto xW = xWs.front();
auto output = tanh(xW, dot(state, U_), b_);
auto output = tanh(xW, sU, b_);
if(mask)
return {output * mask};
else
@ -465,4 +453,6 @@ class AttentionCell {
}
};
typedef AttentionCell<GRU, GlobalAttention, GRU> CGRU;
}

View File

@ -74,7 +74,7 @@ class EncoderAmun : public EncoderBase {
int dimSrcEmb = options_->get<int>("dim-emb");
int dimEncState = options_->get<int>("dim-rnn");
bool layerNorm = options_->get<bool>("layer-normalization");
UTIL_THROW_IF2(options_->get<int>("layers-enc") > 1,
"--type amun does not currently support multiple encoder layers, use --type s2s");
UTIL_THROW_IF2(options_->get<bool>("skip"),
@ -87,7 +87,7 @@ class EncoderAmun : public EncoderBase {
Expr x, xMask;
std::tie(x, xMask) = prepareSource(xEmb, batch, batchIdx);
if(dropoutSrc) {
int srcWords = x->shape()[2];
auto srcWordDrop = graph->dropout(dropoutSrc, {1, 1, srcWords});
@ -153,7 +153,7 @@ class DecoderAmun : public DecoderBase {
"--type amun does not currently support multiple decoder layers, use --type s2s");
UTIL_THROW_IF2(options_->get<bool>("tied-embeddings"),
"--type amun does not currently support tied embeddings, use --type s2s");
float dropoutRnn = inference_ ? 0 : options_->get<float>("dropout-rnn");
float dropoutTrg = inference_ ? 0 : options_->get<float>("dropout-trg");
@ -199,7 +199,7 @@ class DecoderAmun : public DecoderBase {
return New<DecoderStateAmun>(stateOut, logitsOut,
state->getEncoderState());
}
const std::vector<Expr> getAlignments() {
return attention_->getAlignments();
}
@ -328,7 +328,7 @@ class Amun : public EncoderDecoder<EncoderAmun, DecoderAmun> {
bool saveTranslatorConfig) {
save(graph, name);
if(saveTranslatorConfig) {
YAML::Node amun;
auto vocabs = options_->get<std::vector<std::string>>("vocabs");
@ -413,7 +413,7 @@ class Amun : public EncoderDecoder<EncoderAmun, DecoderAmun> {
float ctt = 0;
shape[0] = 1;
cnpy::npz_save(name, "decoder_c_tt", &ctt, shape, 1, mode);
options_->saveModelParameters(name);
}
};

View File

@ -5,8 +5,6 @@
namespace marian {
typedef AttentionCell<GRU, GlobalAttention, GRU> CGRU;
class EncoderStateS2S : public EncoderState {
private:
Expr context_;
@ -264,13 +262,6 @@ class DecoderS2S : public DecoderBase {
};
class S2S : public EncoderDecoder<EncoderS2S, DecoderS2S> {
public:
template <class ...Args>
S2S(Ptr<Config> options, Args ...args)
: EncoderDecoder(options, args...) {}
};
typedef EncoderDecoder<EncoderS2S, DecoderS2S> S2S;
}