mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
move embedding to its own file
This commit is contained in:
parent
96ed0baf5a
commit
0d8372c590
@ -72,6 +72,7 @@ set(MARIAN_SOURCES
|
||||
layers/loss.cpp
|
||||
layers/weight.cpp
|
||||
layers/lsh.cpp
|
||||
layers/embedding.cpp
|
||||
|
||||
rnn/cells.cpp
|
||||
rnn/attention.cpp
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "layers/factory.h"
|
||||
#include "layers/generic.h"
|
||||
#include "layers/embedding.h"
|
||||
|
||||
namespace marian {
|
||||
namespace mlp {
|
||||
|
175
src/layers/embedding.cpp
Normal file
175
src/layers/embedding.cpp
Normal file
@ -0,0 +1,175 @@
|
||||
#include "embedding.h"
|
||||
#include "data/factored_vocab.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = opt<std::string>("prefix");
|
||||
int dimVoc = opt<int>("dimVocab");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
|
||||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
|
||||
if (factoredVocab_) {
|
||||
dimVoc = (int)factoredVocab_->factorVocabSize();
|
||||
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
|
||||
}
|
||||
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
|
||||
|
||||
if (options_->has("embFile")) {
|
||||
std::string file = opt<std::string>("embFile");
|
||||
if (!file.empty()) {
|
||||
bool norm = opt<bool>("normalization", false);
|
||||
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
|
||||
}
|
||||
}
|
||||
|
||||
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
|
||||
}
|
||||
|
||||
// helper to embed a sequence of words (given as indices) via factored embeddings
|
||||
Expr Embedding::multiRows(const Words& data, float dropProb) const {
|
||||
auto graph = E_->graph();
|
||||
auto factoredData = factoredVocab_->csr_rows(data);
|
||||
// multi-hot factor vectors are represented as a sparse CSR matrix
|
||||
// [row index = word position index] -> set of factor indices for word at this position
|
||||
ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
|
||||
// the CSR matrix is passed in pieces
|
||||
auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights));
|
||||
auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
|
||||
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
|
||||
// apply dropout
|
||||
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
|
||||
if(!inference_)
|
||||
weights = dropout(weights, dropProb);
|
||||
// perform the product
|
||||
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
|
||||
}
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
|
||||
auto graph = E_->graph();
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = E_->shape()[-1];
|
||||
int dimWidth = (int)subBatch->batchWidth();
|
||||
|
||||
// factored embeddings:
|
||||
// - regular:
|
||||
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
|
||||
// - factored:
|
||||
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
|
||||
// - each row of M contains the set of factors for one word => we want a CSR matrix
|
||||
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
|
||||
// - first compute x @ M on the CPU
|
||||
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
|
||||
// - shape (U, specifically) not actually needed here
|
||||
// - foreach input x[i]
|
||||
// - locate row M[i,*]
|
||||
// - copy through its index values (std::vector<push_back>)
|
||||
// - create a matching ones vector (we can keep growing)
|
||||
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
|
||||
// - CSR matrix product with E
|
||||
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
|
||||
// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
|
||||
// - weighting:
|
||||
// - core factors' gradients are sums over all words that use the factors;
|
||||
// - core factors' embeddings move very fast
|
||||
// - words will need to make up for the move; rare words cannot
|
||||
// - so, we multiply each factor with 1/refCount
|
||||
// - core factors get weighed down a lot
|
||||
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
|
||||
// - but forward pass weighs them down, so that all factors are in a similar numeric range
|
||||
// - if it is required to be in a different range, the embeddings can still learn that, but more slowly
|
||||
|
||||
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
|
||||
#if 1
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->mask()));
|
||||
#else // @TODO: this is dead code now, get rid of it
|
||||
// experimental: hide inline-fix source tokens from cross attention
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
|
||||
#endif
|
||||
// give the graph inputs readable names for debugging and ONNX
|
||||
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
|
||||
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
}
|
||||
|
||||
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
|
||||
if (factoredVocab_) {
|
||||
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
//selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
|
||||
return selectedEmbs;
|
||||
}
|
||||
else
|
||||
return applyIndices(toWordIndexVector(words), shape);
|
||||
}
|
||||
|
||||
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
|
||||
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
|
||||
auto embIdxExpr = E_->graph()->indices(embIdx);
|
||||
embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
|
||||
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
|
||||
if(!inference_)
|
||||
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
|
||||
return selectedEmbs;
|
||||
}
|
||||
|
||||
// standard encoder word embeddings
|
||||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
|
||||
auto options = New<Options>(
|
||||
"dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
|
||||
"dimEmb", opt<int>("dim-emb"),
|
||||
"dropout", dropoutEmbeddings_,
|
||||
"inference", inference_,
|
||||
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
|
||||
"fixed", embeddingFix_,
|
||||
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
|
||||
if(options_->hasAndNotEmpty("embedding-vectors")) {
|
||||
auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
|
||||
options->set(
|
||||
"embFile", embFiles[batchIndex_],
|
||||
"normalization", opt<bool>("embedding-normalization"));
|
||||
}
|
||||
return New<Embedding>(graph_, options);
|
||||
}
|
||||
|
||||
// ULR word embeddings
|
||||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
|
||||
return New<ULREmbedding>(graph_, New<Options>(
|
||||
"dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
|
||||
"dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
|
||||
"dimUlrEmb", opt<int>("ulr-dim-emb"),
|
||||
"dimEmb", opt<int>("dim-emb"),
|
||||
"ulr-dropout", opt<float>("ulr-dropout"),
|
||||
"dropout", dropoutEmbeddings_,
|
||||
"inference", inference_,
|
||||
"ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
|
||||
"ulrQueryFile", opt<std::string>("ulr-query-vectors"),
|
||||
"ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
|
||||
}
|
||||
|
||||
// get embedding layer for this encoder or decoder
|
||||
// This is lazy mostly because the constructors of the consuming objects are not
|
||||
// guaranteed presently to have access to their graph.
|
||||
Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
|
||||
if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
|
||||
if (embeddingLayers_.size() <= batchIndex_)
|
||||
embeddingLayers_.resize(batchIndex_ + 1);
|
||||
if (ulr)
|
||||
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
|
||||
else
|
||||
embeddingLayers_[batchIndex_] = createEmbeddingLayer();
|
||||
}
|
||||
return embeddingLayers_[batchIndex_];
|
||||
}
|
||||
|
||||
}
|
||||
|
148
src/layers/embedding.h
Normal file
148
src/layers/embedding.h
Normal file
@ -0,0 +1,148 @@
|
||||
#pragma once
|
||||
#include "marian.h"
|
||||
#include "generic.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
// A regular embedding layer.
|
||||
// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
|
||||
// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
|
||||
// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
|
||||
class Embedding : public LayerBase, public IEmbeddingLayer {
|
||||
Expr E_;
|
||||
Ptr<FactoredVocab> factoredVocab_;
|
||||
Expr multiRows(const Words& data, float dropProb) const;
|
||||
bool inference_{false};
|
||||
|
||||
public:
|
||||
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
|
||||
|
||||
Expr apply(const Words& words, const Shape& shape) const override final;
|
||||
|
||||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
|
||||
};
|
||||
|
||||
class ULREmbedding : public LayerBase, public IEmbeddingLayer {
|
||||
std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
|
||||
bool inference_{false};
|
||||
|
||||
public:
|
||||
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = "url_embed"; //opt<std::string>("prefix");
|
||||
int dimKeys = opt<int>("dimTgtVoc");
|
||||
int dimQueries = opt<int>("dimSrcVoc");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true);
|
||||
|
||||
std::string queryFile = opt<std::string>("ulrQueryFile");
|
||||
std::string keyFile = opt<std::string>("ulrKeysFile");
|
||||
bool trainTrans = opt<bool>("ulrTrainTransform", false);
|
||||
if (!queryFile.empty() && !keyFile.empty()) {
|
||||
initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
|
||||
name = "ulr_query";
|
||||
fixed = true;
|
||||
auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(query_embed);
|
||||
// keys embeds
|
||||
initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
|
||||
name = "ulr_keys";
|
||||
fixed = true;
|
||||
auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(key_embed);
|
||||
// actual trainable embedding
|
||||
initFunc = inits::glorotUniform();
|
||||
name = "ulr_embed";
|
||||
fixed = false;
|
||||
auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
|
||||
ulrEmbeddings_.push_back(ulr_embed);
|
||||
// init trainable src embedding
|
||||
name = "ulr_src_embed";
|
||||
auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(ulr_src_embed);
|
||||
// ulr transformation matrix
|
||||
//initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
|
||||
if (trainTrans) {
|
||||
initFunc = inits::glorotUniform();
|
||||
fixed = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
initFunc = inits::eye(); // identity matrix
|
||||
fixed = true;
|
||||
}
|
||||
name = "ulr_transform";
|
||||
auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(ulrTransform);
|
||||
|
||||
initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
|
||||
fixed = true;
|
||||
name = "ulr_shared";
|
||||
auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(share_embed);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
|
||||
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
|
||||
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
|
||||
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
|
||||
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
|
||||
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
|
||||
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = uniEmbed->shape()[-1];
|
||||
int dimWords = (int)subBatch->batchWidth();
|
||||
// D = K.A.QT
|
||||
// dimm(K) = univ_tok_vocab*uni_embed_size
|
||||
// dim A = uni_embed_size*uni_embed_size
|
||||
// dim Q: uni_embed_size * total_merged_vocab_size
|
||||
// dim D = univ_tok_vocab * total_merged_vocab_size
|
||||
// note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
|
||||
// here we need to handle the mini-batch
|
||||
// extract raws corresponding to Xs in this minibatch from Q
|
||||
auto embIdx = toWordIndexVector(subBatch->data());
|
||||
auto queryEmbeddings = rows(queryEmbed, embIdx);
|
||||
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
|
||||
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
|
||||
auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
|
||||
auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
|
||||
qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
|
||||
auto z = dot(qt, keyEmbed, false, true); // query-key similarity
|
||||
float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
|
||||
if(!inference_)
|
||||
z = dropout(z, dropProb);
|
||||
|
||||
float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
|
||||
// temperature in softmax is to control randomness of predictions
|
||||
// high temperature Softmax outputs are more close to each other
|
||||
// low temperatures the softmax become more similar to "hardmax"
|
||||
auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
|
||||
auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
|
||||
auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
|
||||
auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
|
||||
auto graph = ulrEmbeddings_.front()->graph();
|
||||
auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
|
||||
inits::fromVector(subBatch->mask()));
|
||||
if(!inference_)
|
||||
batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
}
|
||||
|
||||
Expr apply(const Words& words, const Shape& shape) const override final {
|
||||
return applyIndices(toWordIndexVector(words), shape);
|
||||
}
|
||||
|
||||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
|
||||
embIdx; shape;
|
||||
ABORT("not implemented"); // @TODO: implement me
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -438,172 +438,4 @@ namespace marian {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = opt<std::string>("prefix");
|
||||
int dimVoc = opt<int>("dimVocab");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
|
||||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
|
||||
if (factoredVocab_) {
|
||||
dimVoc = (int)factoredVocab_->factorVocabSize();
|
||||
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
|
||||
}
|
||||
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
|
||||
|
||||
if (options_->has("embFile")) {
|
||||
std::string file = opt<std::string>("embFile");
|
||||
if (!file.empty()) {
|
||||
bool norm = opt<bool>("normalization", false);
|
||||
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
|
||||
}
|
||||
}
|
||||
|
||||
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
|
||||
}
|
||||
|
||||
// helper to embed a sequence of words (given as indices) via factored embeddings
|
||||
Expr Embedding::multiRows(const Words& data, float dropProb) const {
|
||||
auto graph = E_->graph();
|
||||
auto factoredData = factoredVocab_->csr_rows(data);
|
||||
// multi-hot factor vectors are represented as a sparse CSR matrix
|
||||
// [row index = word position index] -> set of factor indices for word at this position
|
||||
ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
|
||||
// the CSR matrix is passed in pieces
|
||||
auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights));
|
||||
auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
|
||||
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
|
||||
// apply dropout
|
||||
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
|
||||
if(!inference_)
|
||||
weights = dropout(weights, dropProb);
|
||||
// perform the product
|
||||
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
|
||||
}
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
|
||||
auto graph = E_->graph();
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = E_->shape()[-1];
|
||||
int dimWidth = (int)subBatch->batchWidth();
|
||||
|
||||
// factored embeddings:
|
||||
// - regular:
|
||||
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
|
||||
// - factored:
|
||||
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
|
||||
// - each row of M contains the set of factors for one word => we want a CSR matrix
|
||||
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
|
||||
// - first compute x @ M on the CPU
|
||||
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
|
||||
// - shape (U, specifically) not actually needed here
|
||||
// - foreach input x[i]
|
||||
// - locate row M[i,*]
|
||||
// - copy through its index values (std::vector<push_back>)
|
||||
// - create a matching ones vector (we can keep growing)
|
||||
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
|
||||
// - CSR matrix product with E
|
||||
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
|
||||
// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
|
||||
// - weighting:
|
||||
// - core factors' gradients are sums over all words that use the factors;
|
||||
// - core factors' embeddings move very fast
|
||||
// - words will need to make up for the move; rare words cannot
|
||||
// - so, we multiply each factor with 1/refCount
|
||||
// - core factors get weighed down a lot
|
||||
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
|
||||
// - but forward pass weighs them down, so that all factors are in a similar numeric range
|
||||
// - if it is required to be in a different range, the embeddings can still learn that, but more slowly
|
||||
|
||||
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
|
||||
#if 1
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->mask()));
|
||||
#else // @TODO: this is dead code now, get rid of it
|
||||
// experimental: hide inline-fix source tokens from cross attention
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
|
||||
#endif
|
||||
// give the graph inputs readable names for debugging and ONNX
|
||||
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
|
||||
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
}
|
||||
|
||||
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
|
||||
if (factoredVocab_) {
|
||||
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
//selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
|
||||
return selectedEmbs;
|
||||
}
|
||||
else
|
||||
return applyIndices(toWordIndexVector(words), shape);
|
||||
}
|
||||
|
||||
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
|
||||
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
|
||||
auto embIdxExpr = E_->graph()->indices(embIdx);
|
||||
embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
|
||||
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
|
||||
if(!inference_)
|
||||
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
|
||||
return selectedEmbs;
|
||||
}
|
||||
|
||||
// standard encoder word embeddings
|
||||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
|
||||
auto options = New<Options>(
|
||||
"dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
|
||||
"dimEmb", opt<int>("dim-emb"),
|
||||
"dropout", dropoutEmbeddings_,
|
||||
"inference", inference_,
|
||||
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
|
||||
"fixed", embeddingFix_,
|
||||
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
|
||||
if(options_->hasAndNotEmpty("embedding-vectors")) {
|
||||
auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
|
||||
options->set(
|
||||
"embFile", embFiles[batchIndex_],
|
||||
"normalization", opt<bool>("embedding-normalization"));
|
||||
}
|
||||
return New<Embedding>(graph_, options);
|
||||
}
|
||||
|
||||
// ULR word embeddings
|
||||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
|
||||
return New<ULREmbedding>(graph_, New<Options>(
|
||||
"dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
|
||||
"dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
|
||||
"dimUlrEmb", opt<int>("ulr-dim-emb"),
|
||||
"dimEmb", opt<int>("dim-emb"),
|
||||
"ulr-dropout", opt<float>("ulr-dropout"),
|
||||
"dropout", dropoutEmbeddings_,
|
||||
"inference", inference_,
|
||||
"ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
|
||||
"ulrQueryFile", opt<std::string>("ulr-query-vectors"),
|
||||
"ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
|
||||
}
|
||||
|
||||
// get embedding layer for this encoder or decoder
|
||||
// This is lazy mostly because the constructors of the consuming objects are not
|
||||
// guaranteed presently to have access to their graph.
|
||||
Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
|
||||
if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
|
||||
if (embeddingLayers_.size() <= batchIndex_)
|
||||
embeddingLayers_.resize(batchIndex_ + 1);
|
||||
if (ulr)
|
||||
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
|
||||
else
|
||||
embeddingLayers_[batchIndex_] = createEmbeddingLayer();
|
||||
}
|
||||
return embeddingLayers_[batchIndex_];
|
||||
}
|
||||
} // namespace marian
|
||||
|
@ -295,146 +295,6 @@ public:
|
||||
|
||||
} // namespace mlp
|
||||
|
||||
// A regular embedding layer.
|
||||
// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
|
||||
// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
|
||||
// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
|
||||
class Embedding : public LayerBase, public IEmbeddingLayer {
|
||||
Expr E_;
|
||||
Ptr<FactoredVocab> factoredVocab_;
|
||||
Expr multiRows(const Words& data, float dropProb) const;
|
||||
bool inference_{false};
|
||||
|
||||
public:
|
||||
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
|
||||
|
||||
Expr apply(const Words& words, const Shape& shape) const override final;
|
||||
|
||||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
|
||||
};
|
||||
|
||||
class ULREmbedding : public LayerBase, public IEmbeddingLayer {
|
||||
std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
|
||||
bool inference_{false};
|
||||
|
||||
public:
|
||||
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = "url_embed"; //opt<std::string>("prefix");
|
||||
int dimKeys = opt<int>("dimTgtVoc");
|
||||
int dimQueries = opt<int>("dimSrcVoc");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true);
|
||||
|
||||
std::string queryFile = opt<std::string>("ulrQueryFile");
|
||||
std::string keyFile = opt<std::string>("ulrKeysFile");
|
||||
bool trainTrans = opt<bool>("ulrTrainTransform", false);
|
||||
if (!queryFile.empty() && !keyFile.empty()) {
|
||||
initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
|
||||
name = "ulr_query";
|
||||
fixed = true;
|
||||
auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(query_embed);
|
||||
// keys embeds
|
||||
initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
|
||||
name = "ulr_keys";
|
||||
fixed = true;
|
||||
auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(key_embed);
|
||||
// actual trainable embedding
|
||||
initFunc = inits::glorotUniform();
|
||||
name = "ulr_embed";
|
||||
fixed = false;
|
||||
auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
|
||||
ulrEmbeddings_.push_back(ulr_embed);
|
||||
// init trainable src embedding
|
||||
name = "ulr_src_embed";
|
||||
auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(ulr_src_embed);
|
||||
// ulr transformation matrix
|
||||
//initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
|
||||
if (trainTrans) {
|
||||
initFunc = inits::glorotUniform();
|
||||
fixed = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
initFunc = inits::eye(); // identity matrix
|
||||
fixed = true;
|
||||
}
|
||||
name = "ulr_transform";
|
||||
auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(ulrTransform);
|
||||
|
||||
initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
|
||||
fixed = true;
|
||||
name = "ulr_shared";
|
||||
auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(share_embed);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
|
||||
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
|
||||
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
|
||||
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
|
||||
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
|
||||
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
|
||||
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = uniEmbed->shape()[-1];
|
||||
int dimWords = (int)subBatch->batchWidth();
|
||||
// D = K.A.QT
|
||||
// dimm(K) = univ_tok_vocab*uni_embed_size
|
||||
// dim A = uni_embed_size*uni_embed_size
|
||||
// dim Q: uni_embed_size * total_merged_vocab_size
|
||||
// dim D = univ_tok_vocab * total_merged_vocab_size
|
||||
// note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
|
||||
// here we need to handle the mini-batch
|
||||
// extract raws corresponding to Xs in this minibatch from Q
|
||||
auto embIdx = toWordIndexVector(subBatch->data());
|
||||
auto queryEmbeddings = rows(queryEmbed, embIdx);
|
||||
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
|
||||
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
|
||||
auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
|
||||
auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
|
||||
qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
|
||||
auto z = dot(qt, keyEmbed, false, true); // query-key similarity
|
||||
float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
|
||||
if(!inference_)
|
||||
z = dropout(z, dropProb);
|
||||
|
||||
float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
|
||||
// temperature in softmax is to control randomness of predictions
|
||||
// high temperature Softmax outputs are more close to each other
|
||||
// low temperatures the softmax become more similar to "hardmax"
|
||||
auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
|
||||
auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
|
||||
auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
|
||||
auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
|
||||
auto graph = ulrEmbeddings_.front()->graph();
|
||||
auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
|
||||
inits::fromVector(subBatch->mask()));
|
||||
if(!inference_)
|
||||
batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
}
|
||||
|
||||
Expr apply(const Words& words, const Shape& shape) const override final {
|
||||
return applyIndices(toWordIndexVector(words), shape);
|
||||
}
|
||||
|
||||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
|
||||
embIdx; shape;
|
||||
ABORT("not implemented"); // @TODO: implement me
|
||||
}
|
||||
};
|
||||
|
||||
// --- a few layers with built-in parameters created on the fly, without proper object
|
||||
// @TODO: change to a proper layer object
|
||||
|
Loading…
Reference in New Issue
Block a user