Add marian style tie-embs

This commit is contained in:
Tomasz Dwojak 2017-11-28 16:25:45 +00:00
parent 9aee648fa1
commit 3cce98c356
6 changed files with 19 additions and 19 deletions

View File

@ -6,10 +6,6 @@ namespace amunmt {
namespace CPU {
namespace dl4mt {
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
: E_(model[key])
{}
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
: E_(model.getFirstOfMany(keys))
{}
@ -67,8 +63,10 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
B2_(model("ff_logit_prev_b", true)),
W3_(model["ff_logit_ctx_W"]),
B3_(model("ff_logit_ctx_b", true)),
W4_(model.getFirstOfMany({std::pair<std::string, bool>(std::string("ff_logit_W"), false),
std::make_pair(std::string("Wemb_dec"), true)})),
W4_(model.getFirstOfMany({std::pair<std::string, bool>(
std::string("ff_logit_W"), false),
std::make_pair(std::string("Wemb_dec"), true),
std::make_pair(std::string("Wemb"), true)})),
B4_(model("ff_logit_b", true)),
Gamma_0_(model["ff_logit_l1_gamma0"]),
Gamma_1_(model["ff_logit_l1_gamma1"]),
@ -78,12 +76,15 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
//////////////////////////////////////////////////////////////////////////////
Weights::Weights(const NpzConverter& model, size_t)
: encEmbeddings_(model, "Wemb"),
: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
std::make_pair(std::string("Wemb"), false),
std::make_pair(std::string("Wemb_dec"), false)})),
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx",
"encoder_Ux", "encoder_gamma1", "encoder_gamma2"}),
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx",
"encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}),
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({std::make_pair(std::string("Wemb_dec"), false),
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
std::make_pair(std::string("Wemb_dec"), false),
std::make_pair(std::string("Wemb"), false)})),
decInit_(model),
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux",

View File

@ -16,7 +16,6 @@ struct Weights {
//////////////////////////////////////////////////////////////////////////////
struct Embeddings {
Embeddings(const NpzConverter& model, const std::string &key);
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
const mblas::Matrix E_;

View File

@ -60,10 +60,6 @@ std::string Weights::Transition::name(const std::string& prefix, std::string nam
return prefix + name + infix + "_drt_" + std::to_string(index) + suffix;
}
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
: E_(model[key])
{}
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
: E_(model.getFirstOfMany(keys))
{}
@ -143,7 +139,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
W3_(model["ff_logit_ctx_W"]),
B3_(model("ff_logit_ctx_b", true)),
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
std::make_pair(std::string("Wemb_dec"), true)})),
std::make_pair(std::string("Wemb_dec"), true),
std::make_pair(std::string("Wemb"), true)})),
B4_(model("ff_logit_b", true)),
lns_1_(model["ff_logit_lstm_ln_s"]),
lns_2_(model["ff_logit_prev_ln_s"]),
@ -156,7 +153,9 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
//////////////////////////////////////////////////////////////////////////////
Weights::Weights(const NpzConverter& model, size_t)
: encEmbeddings_(model, "Wemb"),
: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
{std::make_pair(std::string("Wemb"), false),
std::make_pair(std::string("Wemb_dec"), false)})),
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
{std::make_pair(std::string("Wemb_dec"), false),
std::make_pair(std::string("Wemb"), false)})),

View File

@ -49,7 +49,6 @@ struct Weights {
};
struct Embeddings {
Embeddings(const NpzConverter& model, const std::string &key);
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
const mblas::Matrix E_;

View File

@ -5,7 +5,8 @@ namespace GPU {
////////////////////////////////////////////////////////////////////////////////////////////////////
Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model)
: E_(model.get("Wemb", true))
: E_(model.getFirstOfMany({std::make_pair("Wemb", false),
std::make_pair("Wemb_dec", false)}, true))
{}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -141,7 +142,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
W3_(model.get("ff_logit_ctx_W", true)),
B3_(model.get("ff_logit_ctx_b", true, true)),
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
std::make_pair(std::string("Wemb_dec"), true)}, true)),
std::make_pair(std::string("Wemb_dec"), true),
std::make_pair(std::string("Wemb"), true)}, true)),
B4_(model.get("ff_logit_b", true, true)),
Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
Gamma_1_(model.get("ff_logit_l1_gamma1", false)),

@ -1 +1 @@
Subproject commit f0f06468e1633af503b81eaa8b1fbca881b95257
Subproject commit 7f3802e97f08e91382c6609cd772c4e80488b4d8