mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-27 10:33:14 +03:00
Add marian style tie-embs
This commit is contained in:
parent
9aee648fa1
commit
3cce98c356
@ -6,10 +6,6 @@ namespace amunmt {
|
|||||||
namespace CPU {
|
namespace CPU {
|
||||||
namespace dl4mt {
|
namespace dl4mt {
|
||||||
|
|
||||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
|
|
||||||
: E_(model[key])
|
|
||||||
{}
|
|
||||||
|
|
||||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
|
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
|
||||||
: E_(model.getFirstOfMany(keys))
|
: E_(model.getFirstOfMany(keys))
|
||||||
{}
|
{}
|
||||||
@ -67,8 +63,10 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
|||||||
B2_(model("ff_logit_prev_b", true)),
|
B2_(model("ff_logit_prev_b", true)),
|
||||||
W3_(model["ff_logit_ctx_W"]),
|
W3_(model["ff_logit_ctx_W"]),
|
||||||
B3_(model("ff_logit_ctx_b", true)),
|
B3_(model("ff_logit_ctx_b", true)),
|
||||||
W4_(model.getFirstOfMany({std::pair<std::string, bool>(std::string("ff_logit_W"), false),
|
W4_(model.getFirstOfMany({std::pair<std::string, bool>(
|
||||||
std::make_pair(std::string("Wemb_dec"), true)})),
|
std::string("ff_logit_W"), false),
|
||||||
|
std::make_pair(std::string("Wemb_dec"), true),
|
||||||
|
std::make_pair(std::string("Wemb"), true)})),
|
||||||
B4_(model("ff_logit_b", true)),
|
B4_(model("ff_logit_b", true)),
|
||||||
Gamma_0_(model["ff_logit_l1_gamma0"]),
|
Gamma_0_(model["ff_logit_l1_gamma0"]),
|
||||||
Gamma_1_(model["ff_logit_l1_gamma1"]),
|
Gamma_1_(model["ff_logit_l1_gamma1"]),
|
||||||
@ -78,12 +76,15 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
|||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
Weights::Weights(const NpzConverter& model, size_t)
|
Weights::Weights(const NpzConverter& model, size_t)
|
||||||
: encEmbeddings_(model, "Wemb"),
|
: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
|
||||||
|
std::make_pair(std::string("Wemb"), false),
|
||||||
|
std::make_pair(std::string("Wemb_dec"), false)})),
|
||||||
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx",
|
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx",
|
||||||
"encoder_Ux", "encoder_gamma1", "encoder_gamma2"}),
|
"encoder_Ux", "encoder_gamma1", "encoder_gamma2"}),
|
||||||
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx",
|
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx",
|
||||||
"encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}),
|
"encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}),
|
||||||
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({std::make_pair(std::string("Wemb_dec"), false),
|
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
|
||||||
|
std::make_pair(std::string("Wemb_dec"), false),
|
||||||
std::make_pair(std::string("Wemb"), false)})),
|
std::make_pair(std::string("Wemb"), false)})),
|
||||||
decInit_(model),
|
decInit_(model),
|
||||||
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux",
|
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux",
|
||||||
|
@ -16,7 +16,6 @@ struct Weights {
|
|||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct Embeddings {
|
struct Embeddings {
|
||||||
Embeddings(const NpzConverter& model, const std::string &key);
|
|
||||||
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
|
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
|
||||||
|
|
||||||
const mblas::Matrix E_;
|
const mblas::Matrix E_;
|
||||||
|
@ -60,10 +60,6 @@ std::string Weights::Transition::name(const std::string& prefix, std::string nam
|
|||||||
return prefix + name + infix + "_drt_" + std::to_string(index) + suffix;
|
return prefix + name + infix + "_drt_" + std::to_string(index) + suffix;
|
||||||
}
|
}
|
||||||
|
|
||||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
|
|
||||||
: E_(model[key])
|
|
||||||
{}
|
|
||||||
|
|
||||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
|
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
|
||||||
: E_(model.getFirstOfMany(keys))
|
: E_(model.getFirstOfMany(keys))
|
||||||
{}
|
{}
|
||||||
@ -143,7 +139,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
|||||||
W3_(model["ff_logit_ctx_W"]),
|
W3_(model["ff_logit_ctx_W"]),
|
||||||
B3_(model("ff_logit_ctx_b", true)),
|
B3_(model("ff_logit_ctx_b", true)),
|
||||||
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
|
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
|
||||||
std::make_pair(std::string("Wemb_dec"), true)})),
|
std::make_pair(std::string("Wemb_dec"), true),
|
||||||
|
std::make_pair(std::string("Wemb"), true)})),
|
||||||
B4_(model("ff_logit_b", true)),
|
B4_(model("ff_logit_b", true)),
|
||||||
lns_1_(model["ff_logit_lstm_ln_s"]),
|
lns_1_(model["ff_logit_lstm_ln_s"]),
|
||||||
lns_2_(model["ff_logit_prev_ln_s"]),
|
lns_2_(model["ff_logit_prev_ln_s"]),
|
||||||
@ -156,7 +153,9 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
|||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
Weights::Weights(const NpzConverter& model, size_t)
|
Weights::Weights(const NpzConverter& model, size_t)
|
||||||
: encEmbeddings_(model, "Wemb"),
|
: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
|
||||||
|
{std::make_pair(std::string("Wemb"), false),
|
||||||
|
std::make_pair(std::string("Wemb_dec"), false)})),
|
||||||
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
|
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
|
||||||
{std::make_pair(std::string("Wemb_dec"), false),
|
{std::make_pair(std::string("Wemb_dec"), false),
|
||||||
std::make_pair(std::string("Wemb"), false)})),
|
std::make_pair(std::string("Wemb"), false)})),
|
||||||
|
@ -49,7 +49,6 @@ struct Weights {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Embeddings {
|
struct Embeddings {
|
||||||
Embeddings(const NpzConverter& model, const std::string &key);
|
|
||||||
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
|
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
|
||||||
|
|
||||||
const mblas::Matrix E_;
|
const mblas::Matrix E_;
|
||||||
|
@ -5,7 +5,8 @@ namespace GPU {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model)
|
Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model)
|
||||||
: E_(model.get("Wemb", true))
|
: E_(model.getFirstOfMany({std::make_pair("Wemb", false),
|
||||||
|
std::make_pair("Wemb_dec", false)}, true))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -141,7 +142,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
|||||||
W3_(model.get("ff_logit_ctx_W", true)),
|
W3_(model.get("ff_logit_ctx_W", true)),
|
||||||
B3_(model.get("ff_logit_ctx_b", true, true)),
|
B3_(model.get("ff_logit_ctx_b", true, true)),
|
||||||
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
|
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
|
||||||
std::make_pair(std::string("Wemb_dec"), true)}, true)),
|
std::make_pair(std::string("Wemb_dec"), true),
|
||||||
|
std::make_pair(std::string("Wemb"), true)}, true)),
|
||||||
B4_(model.get("ff_logit_b", true, true)),
|
B4_(model.get("ff_logit_b", true, true)),
|
||||||
Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
|
Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
|
||||||
Gamma_1_(model.get("ff_logit_l1_gamma1", false)),
|
Gamma_1_(model.get("ff_logit_l1_gamma1", false)),
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit f0f06468e1633af503b81eaa8b1fbca881b95257
|
Subproject commit 7f3802e97f08e91382c6609cd772c4e80488b4d8
|
Loading…
Reference in New Issue
Block a user