clang-format -i

This commit is contained in:
Hieu Hoang 2021-03-05 22:54:05 -07:00
parent 55f4216552
commit ba19663784
14 changed files with 1053 additions and 848 deletions

View File

@ -1,8 +1,8 @@
#pragma once
#include "layers/embedding.h"
#include "layers/factory.h"
#include "layers/generic.h"
#include "layers/embedding.h"
#include "layers/output.h"
namespace marian {
@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory {
// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
protected:
std::string tiedTransposedName_;
Ptr<data::Shortlist> shortlist_;
@ -55,9 +56,7 @@ public:
return Accumulator<OutputFactory>(*this);
}
void setShortlist(Ptr<data::Shortlist> shortlist) {
shortlist_ = shortlist;
}
void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
auto output = New<Output>(graph, options_);
@ -89,8 +88,7 @@ protected:
std::vector<Ptr<IUnaryLayer>> layers_;
public:
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: graph_(graph), options_(options) {}
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
Expr apply(const std::vector<Expr>& av) override {
Expr output;
@ -106,46 +104,53 @@ public:
}
Logits applyAsLogits(const std::vector<Expr>& av) override {
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
// return type
auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
if (layers_.size() == 1) {
if (av.size() == 1)
ABORT_IF(
!lastLayer,
"MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
if(layers_.size() == 1) {
if(av.size() == 1)
return lastLayer->applyAsLogits(av[0]);
else
return lastLayer->applyAsLogits(av);
}
else {
} else {
Expr output;
if (av.size() == 1)
if(av.size() == 1)
output = layers_[0]->apply(av[0]);
else
output = layers_[0]->apply(av);
for (size_t i = 1; i < layers_.size() - 1; ++i)
for(size_t i = 1; i < layers_.size() - 1; ++i)
output = layers_[i]->apply(output);
return lastLayer->applyAsLogits(output);
}
}
Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); }
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); }
Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists");
ABORT_IF(
!p,
"setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
void clear() override final {
auto p = tryAsHasShortlist();
if (p)
if(p)
p->clear();
}
private:
Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); }
Ptr<IHasShortList> tryAsHasShortlist() const {
return std::dynamic_pointer_cast<IHasShortList>(layers_.back());
}
};
/**
@ -154,6 +159,7 @@ private:
*/
class MLPFactory : public Factory {
using Factory::Factory;
private:
std::vector<Ptr<LayerFactory>> layers_;
@ -177,23 +183,27 @@ public:
// which will go away if we get rid of the abstract factories, and instead just construct
// all layers immediately, which is my long-term goal for Marian.
private:
template<class WrappedFactory>
template <class WrappedFactory>
class AsLayerFactory : public LayerFactory {
WrappedFactory us;
WrappedFactory us;
public:
AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
return p;
}
AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
return p;
}
};
template<class WrappedFactory>
static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; }
template <class WrappedFactory>
static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) {
return wrapped;
}
public:
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
//layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
// layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
return Accumulator<MLPFactory>(*this);
}
};

View File

@ -3,173 +3,205 @@
namespace marian {
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
std::string name = opt<std::string>("prefix");
int dimVoc = opt<int>("dimVocab");
int dimEmb = opt<int>("dimEmb");
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
std::string name = opt<std::string>("prefix");
int dimVoc = opt<int>("dimVocab");
int dimEmb = opt<int>("dimEmb");
bool fixed = opt<bool>("fixed", false);
bool fixed = opt<bool>("fixed", false);
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
if (factoredVocab_) {
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
if(factoredVocab_) {
dimVoc = (int)factoredVocab_->factorVocabSize();
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
}
}
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
auto initFunc = inits::glorotUniform(
/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
if (options_->has("embFile")) {
if(options_->has("embFile")) {
std::string file = opt<std::string>("embFile");
if (!file.empty()) {
bool norm = opt<bool>("normalization", false);
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
if(!file.empty()) {
bool norm = opt<bool>("normalization", false);
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
}
}
}
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
}
// helper to embed a sequence of words (given as indices) via factored embeddings
Expr Embedding::multiRows(const Words& data, float dropProb) const {
auto graph = E_->graph();
auto factoredData = factoredVocab_->csr_rows(data);
// multi-hot factor vectors are represented as a sparse CSR matrix
// [row index = word position index] -> set of factor indices for word at this position
ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
// the CSR matrix is passed in pieces
auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights));
auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
// apply dropout
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
if(!inference_)
auto graph = E_->graph();
auto factoredData = factoredVocab_->csr_rows(data);
// multi-hot factor vectors are represented as a sparse CSR matrix
// [row index = word position index] -> set of factor indices for word at this position
ABORT_IF(factoredData.shape
!= Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}),
"shape mismatch??");
// the CSR matrix is passed in pieces
auto weights = graph->constant({(int)factoredData.weights.size()},
inits::fromVector(factoredData.weights));
auto indices = graph->constant(
{(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32);
auto offsets = graph->constant(
{(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32);
// apply dropout
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire
// vectors.
if(!inference_)
weights = dropout(weights, dropProb);
// perform the product
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
// perform the product
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
}
std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
auto graph = E_->graph();
int dimBatch = (int)subBatch->batchSize();
int dimEmb = E_->shape()[-1];
int dimWidth = (int)subBatch->batchWidth();
std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const
/*override final*/ {
auto graph = E_->graph();
int dimBatch = (int)subBatch->batchSize();
int dimEmb = E_->shape()[-1];
int dimWidth = (int)subBatch->batchWidth();
// factored embeddings:
// - regular:
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
// - factored:
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
// - each row of M contains the set of factors for one word => we want a CSR matrix
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
// - first compute x @ M on the CPU
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
// - shape (U, specifically) not actually needed here
// - foreach input x[i]
// - locate row M[i,*]
// - copy through its index values (std::vector<push_back>)
// - create a matching ones vector (we can keep growing)
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
// - CSR matrix product with E
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
// - weighting:
// - core factors' gradients are sums over all words that use the factors;
// - core factors' embeddings move very fast
// - words will need to make up for the move; rare words cannot
// - so, we multiply each factor with 1/refCount
// - core factors get weighed down a lot
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
// - but forward pass weighs them down, so that all factors are in a similar numeric range
// - if it is required to be in a different range, the embeddings can still learn that, but more slowly
// factored embeddings:
// - regular:
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
// - factored:
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
// - each row of M contains the set of factors for one word => we want a CSR matrix
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
// - first compute x @ M on the CPU
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
// - shape (U, specifically) not actually needed here
// - foreach input x[i]
// - locate row M[i,*]
// - copy through its index values (std::vector<push_back>)
// - create a matching ones vector (we can keep growing)
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
// - CSR matrix product with E
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
// - double-check if all dimensions are specified. Probably not for transpose (which would
// be like csc_dot()).
// - weighting:
// - core factors' gradients are sums over all words that use the factors;
// - core factors' embeddings move very fast
// - words will need to make up for the move; rare words cannot
// - so, we multiply each factor with 1/refCount
// - core factors get weighed down a lot
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as
// before
// - but forward pass weighs them down, so that all factors are in a similar numeric range
// - if it is required to be in a different range, the embeddings can still learn that, but
// more slowly
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
#if 1
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
inits::fromVector(subBatch->mask()));
#else // @TODO: this is dead code now, get rid of it
// experimental: hide inline-fix source tokens from cross attention
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask()));
#else // @TODO: this is dead code now, get rid of it
// experimental: hide inline-fix source tokens from cross attention
auto batchMask
= graph->constant({dimWidth, dimBatch, 1},
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
#endif
// give the graph inputs readable names for debugging and ONNX
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
// give the graph inputs readable names for debugging and ONNX
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
return std::make_tuple(batchEmbeddings, batchMask);
return std::make_tuple(batchEmbeddings, batchMask);
}
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
if (factoredVocab_) {
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
//selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
if(factoredVocab_) {
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
// selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), {
// selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
return selectedEmbs;
}
else
} else
return applyIndices(toWordIndexVector(words), shape);
}
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
auto embIdxExpr = E_->graph()->indices(embIdx);
embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
if(!inference_)
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
return selectedEmbs;
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const
/*override final*/ {
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
auto embIdxExpr = E_->graph()->indices(embIdx);
embIdxExpr->set_name("data_"
+ std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape()
// (test that separately)
if(!inference_)
selectedEmbs = dropout(
selectedEmbs, options_->get<float>("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1});
return selectedEmbs;
}
// standard encoder word embeddings
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
auto options = New<Options>(
"dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
"dimEmb", opt<int>("dim-emb"),
"dropout", dropoutEmbeddings_,
"inference", inference_,
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
"fixed", embeddingFix_,
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
if(options_->hasAndNotEmpty("embedding-vectors")) {
auto options = New<Options>(
"dimVocab",
opt<std::vector<int>>("dim-vocabs")[batchIndex_],
"dimEmb",
opt<int>("dim-emb"),
"dropout",
dropoutEmbeddings_,
"inference",
inference_,
"prefix",
(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
: prefix_ + "_Wemb",
"fixed",
embeddingFix_,
"vocab",
opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
if(options_->hasAndNotEmpty("embedding-vectors")) {
auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
options->set(
"embFile", embFiles[batchIndex_],
"normalization", opt<bool>("embedding-normalization"));
}
return New<Embedding>(graph_, options);
"embFile", embFiles[batchIndex_], "normalization", opt<bool>("embedding-normalization"));
}
return New<Embedding>(graph_, options);
}
// ULR word embeddings
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
return New<ULREmbedding>(graph_, New<Options>(
"dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
"dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
"dimUlrEmb", opt<int>("ulr-dim-emb"),
"dimEmb", opt<int>("dim-emb"),
"ulr-dropout", opt<float>("ulr-dropout"),
"dropout", dropoutEmbeddings_,
"inference", inference_,
"ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
"ulrQueryFile", opt<std::string>("ulr-query-vectors"),
"ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
return New<ULREmbedding>(
graph_,
New<Options>("dimSrcVoc",
opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
"dimTgtVoc",
opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
"dimUlrEmb",
opt<int>("ulr-dim-emb"),
"dimEmb",
opt<int>("dim-emb"),
"ulr-dropout",
opt<float>("ulr-dropout"),
"dropout",
dropoutEmbeddings_,
"inference",
inference_,
"ulrTrainTransform",
opt<bool>("ulr-trainable-transformation"),
"ulrQueryFile",
opt<std::string>("ulr-query-vectors"),
"ulrKeysFile",
opt<std::string>("ulr-keys-vectors")));
}
// get embedding layer for this encoder or decoder
// This is lazy mostly because the constructors of the consuming objects are not
// guaranteed presently to have access to their graph.
Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
if (embeddingLayers_.size() <= batchIndex_)
embeddingLayers_.resize(batchIndex_ + 1);
if (ulr)
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
if(embeddingLayers_.size() <= batchIndex_)
embeddingLayers_.resize(batchIndex_ + 1);
if(ulr)
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
else
embeddingLayers_[batchIndex_] = createEmbeddingLayer();
}
return embeddingLayers_[batchIndex_];
}
embeddingLayers_[batchIndex_] = createEmbeddingLayer();
}
return embeddingLayers_[batchIndex_];
}
} // namespace marian

View File

@ -1,6 +1,6 @@
#pragma once
#include "marian.h"
#include "generic.h"
#include "marian.h"
namespace marian {
@ -19,7 +19,8 @@ class Embedding : public LayerBase, public IEmbeddingLayer {
public:
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const override final;
Expr apply(const Words& words, const Shape& shape) const override final;
@ -27,17 +28,18 @@ public:
};
class ULREmbedding : public LayerBase, public IEmbeddingLayer {
std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
std::vector<Expr>
ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
bool inference_{false};
public:
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
std::string name = "url_embed"; //opt<std::string>("prefix");
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
std::string name = "url_embed"; // opt<std::string>("prefix");
int dimKeys = opt<int>("dimTgtVoc");
int dimQueries = opt<int>("dimSrcVoc");
int dimEmb = opt<int>("dimEmb");
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
bool fixed = opt<bool>("fixed", false);
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
@ -46,58 +48,61 @@ public:
std::string queryFile = opt<std::string>("ulrQueryFile");
std::string keyFile = opt<std::string>("ulrKeysFile");
bool trainTrans = opt<bool>("ulrTrainTransform", false);
if (!queryFile.empty() && !keyFile.empty()) {
if(!queryFile.empty() && !keyFile.empty()) {
initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
name = "ulr_query";
fixed = true;
auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(query_embed);
// keys embeds
initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
name = "ulr_keys";
fixed = true;
auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(key_embed);
// actual trainable embedding
initFunc = inits::glorotUniform();
name = "ulr_embed";
fixed = false;
auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
auto ulr_embed
= graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim
ulrEmbeddings_.push_back(ulr_embed);
// init trainable src embedding
name = "ulr_src_embed";
auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(ulr_src_embed);
// ulr transformation matrix
//initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
if (trainTrans) {
// initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall
// we make this to the fixed case only
if(trainTrans) {
initFunc = inits::glorotUniform();
fixed = false;
}
else
{
initFunc = inits::eye(); // identity matrix
} else {
initFunc = inits::eye(); // identity matrix
fixed = true;
}
name = "ulr_transform";
auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed);
ulrEmbeddings_.push_back(ulrTransform);
initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
initFunc = inits::fromValue(
1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no
// universal embeddings - should be zero for top freq only
fixed = true;
name = "ulr_shared";
auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed);
ulrEmbeddings_.push_back(share_embed);
}
}
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const override final {
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
int dimBatch = (int)subBatch->batchSize();
int dimEmb = uniEmbed->shape()[-1];
int dimWords = (int)subBatch->batchWidth();
@ -106,34 +111,42 @@ public:
// dim A = uni_embed_size*uni_embed_size
// dim Q: uni_embed_size * total_merged_vocab_size
// dim D = univ_tok_vocab * total_merged_vocab_size
// note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
// here we need to handle the mini-batch
// extract raws corresponding to Xs in this minibatch from Q
// note all above can be precombuted and serialized if A is not trainiable and during decoding
// (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this
// minibatch from Q
auto embIdx = toWordIndexVector(subBatch->data());
auto queryEmbeddings = rows(queryEmbed, embIdx);
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
auto z = dot(qt, keyEmbed, false, true); // query-key similarity
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
auto qt = dot(queryEmbeddings,
ulrTransform,
false,
false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]);
qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in
// magnitude with larger embeds sizes
auto z = dot(qt, keyEmbed, false, true); // query-key similarity
float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
if(!inference_)
z = dropout(z, dropProb);
float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
float tau
= this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
// temperature in softmax is to control randomness of predictions
// high temperature Softmax outputs are more close to each other
// low temperatures the softmax become more similar to "hardmax"
auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
auto weights
= softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
auto chosenEmbeddings_mix
= srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb});
auto graph = ulrEmbeddings_.front()->graph();
auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
inits::fromVector(subBatch->mask()));
auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask()));
if(!inference_)
batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
batchEmbeddings = dropout(batchEmbeddings,
options_->get<float>("dropout-embeddings", 0.0f),
{batchEmbeddings->shape()[-3], 1, 1});
return std::make_tuple(batchEmbeddings, batchMask);
}
@ -142,9 +155,10 @@ public:
}
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
embIdx; shape;
ABORT("not implemented"); // @TODO: implement me
embIdx;
shape;
ABORT("not implemented"); // @TODO: implement me
}
};
}
} // namespace marian

View File

@ -1,13 +1,10 @@
#include "marian.h"
#include "layers/generic.h"
#include "layers/constructors.h"
#include "layers/loss.h"
#include "data/factored_vocab.h"
#include "models/states.h" // for EncoderState
#include "layers/constructors.h"
#include "layers/generic.h"
#include "layers/loss.h"
#include "layers/lsh.h"
#include "models/states.h" // for EncoderState
namespace marian {
} // namespace marian
namespace marian {} // namespace marian

View File

@ -5,12 +5,14 @@
#include "data/shortlist.h"
#include "layers/factory.h"
namespace marian { namespace mlp {
/**
* @brief Activation functions
*/
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
}}
namespace marian {
namespace mlp {
/**
* @brief Activation functions
*/
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
} // namespace mlp
} // namespace marian
namespace marian {
@ -23,8 +25,7 @@ protected:
Ptr<Options> options_;
public:
LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: graph_(graph), options_(options) {}
LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
template <typename T>
T opt(const std::string key) const {
@ -42,7 +43,7 @@ struct IUnaryLayer {
virtual ~IUnaryLayer() {}
virtual Expr apply(Expr) = 0;
virtual Expr apply(const std::vector<Expr>& es) {
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return apply(es.front());
}
};
@ -54,7 +55,8 @@ struct IHasShortList {
// Embedding from corpus sub-batch to (emb, mask)
struct IEmbeddingLayer {
virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0;
virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
Ptr<data::SubBatch> subBatch) const = 0;
virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0;
@ -63,28 +65,29 @@ struct IEmbeddingLayer {
virtual ~IEmbeddingLayer() {}
};
// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index)
// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream
// index)
class EncoderDecoderLayerBase : public LayerBase {
protected:
const std::string prefix_;
const bool embeddingFix_;
const float dropoutEmbeddings_; // this drops out full embedding vectors
const float dropoutEmbeddings_; // this drops out full embedding vectors
const bool inference_;
const size_t batchIndex_;
mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
Ptr<Options> options,
const std::string& prefix,
EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
Ptr<Options> options,
const std::string& prefix,
size_t batchIndex,
float dropoutEmbeddings,
bool embeddingFix) :
LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", prefix)),
embeddingFix_(embeddingFix),
dropoutEmbeddings_(dropoutEmbeddings),
inference_(options->get<bool>("inference", false)),
batchIndex_(options->get<size_t>("index", batchIndex)) {}
bool embeddingFix)
: LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", prefix)),
embeddingFix_(embeddingFix),
dropoutEmbeddings_(dropoutEmbeddings),
inference_(options->get<bool>("inference", false)),
batchIndex_(options->get<size_t>("index", batchIndex)) {}
virtual ~EncoderDecoderLayerBase() {}
@ -101,8 +104,7 @@ namespace mlp {
class Dense : public LayerBase, public IUnaryLayer {
public:
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options) {}
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
@ -124,21 +126,17 @@ public:
if(inputs.size() > 1)
num = std::to_string(i);
Expr W = g->param(
name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros());
if(useLayerNorm) {
if(useNematusNorm) {
auto ln_s = g->param(
name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros());
outputs.push_back(
layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
} else {
auto gamma = g->param(
name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
outputs.push_back(layerNorm(dot(in, W), gamma, b));
}
@ -165,39 +163,35 @@ public:
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};
} // namespace mlp
} // namespace mlp
// --- a few layers with built-in parameters created on the fly, without proper object
// @TODO: change to a proper layer object
// like affine() but with built-in parameters, activation, and dropout
static inline
Expr denseInline(Expr x,
std::string prefix,
std::string suffix,
int outDim,
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
const std::function<Expr(Expr)>& actFn = nullptr,
float dropProb = 0.0f)
{
static inline Expr denseInline(Expr x,
std::string prefix,
std::string suffix,
int outDim,
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
const std::function<Expr(Expr)>& actFn = nullptr,
float dropProb = 0.0f) {
auto graph = x->graph();
auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform());
auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros());
auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
x = affine(x, W, b);
if (actFn)
if(actFn)
x = actFn(x);
x = dropout(x, dropProb); // @TODO: check for infernce?
x = dropout(x, dropProb); // @TODO: check for infernce?
return x;
}
static inline
Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
int dimModel = x->shape()[-1];
auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones());
auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros());
auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones());
auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros());
return marian::layerNorm(x, scale, bias, 1e-6f);
}

View File

@ -1,212 +1,250 @@
#include "logits.h"
#include "loss.h"
#include "data/factored_vocab.h"
#include "rnn/types.h" // for State::select()
#include "loss.h"
#include "rnn/types.h" // for State::select()
namespace marian {
Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count)
Logits::Logits(Expr logits)
: Logits(New<RationalLoss>(logits, nullptr)) {
} // single-output constructor from Expr only (RationalLoss has no count)
Ptr<ExpressionGraph> Logits::graph() const {
ABORT_IF(logits_.empty(), "Empty logits object??");
return logits_.front()->loss()->graph();
Ptr<ExpressionGraph> Logits::graph() const {
ABORT_IF(logits_.empty(), "Empty logits object??");
return logits_.front()->loss()->graph();
}
// This function assumes that the object holds one or more factor logits.
// It applies the supplied loss function to each, and then returns the aggregate loss over all
// factors.
Expr Logits::applyLossFunction(
const Words& labels,
const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const {
LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
auto firstLogits = logits_.front()->loss();
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
"Labels not matching logits shape ({} != {}, {})??",
labels.size() * firstLogits->shape()[-1],
firstLogits->shape().elements(),
firstLogits->shape());
// base case (no factors)
if(!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return lossFn(firstLogits, indices(toWordIndexVector(labels)));
}
// This function assumes that the object holds one or more factor logits.
// It applies the supplied loss function to each, and then returns the aggregate loss over all factors.
Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
auto numGroups = factoredVocab_->getNumGroups();
auto firstLogits = logits_.front()->loss();
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
"Labels not matching logits shape ({} != {}, {})??",
labels.size() * firstLogits->shape()[-1],
firstLogits->shape().elements(),
firstLogits->shape());
// split labels into individual factor labels
auto allMaskedFactoredLabels
= factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
// base case (no factors)
if (!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return lossFn(firstLogits, indices(toWordIndexVector(labels)));
// Expr indices = this->indices(toWordIndexVector(labels));
// accumulate all CEs for all words that have the factor
// Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
Expr loss;
for(size_t g = 0; g < numGroups; g++) {
if(!logits_[g])
continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
auto factorIndices = indices(
maskedFactoredLabels
.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
auto factorMask
= constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with
// 0 for labels that don't have this factor
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask
// it out next.
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
if(loss)
factorLoss = cast(factorLoss, loss->value_type());
factorLoss
= factorLoss
* cast(
reshape(factorMask, factorLoss->shape()),
factorLoss->value_type()); // mask out factor for words that do not have that factor
loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
}
return loss;
}
// This function assumes this object holds a single factor that represents a rational loss (with
// count).
// Ptr<RationalLoss> Logits::getRationalLoss() const {
// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on
// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational
// loss without count"); return logits_.front();
//}
// get logits for one factor group
// For groupIndex == 0, the function also requires the shortlist if there is one.
Expr Logits::getFactoredLogits(size_t groupIndex,
Ptr<data::Shortlist> shortlist /*= nullptr*/,
const std::vector<IndexType>& hypIndices /*= {}*/,
size_t beamSize /*= 0*/) const {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
// normalize for decoding:
// - all secondary factors: subtract their max
// - lemma: add all maxes of applicable factors
if(groupIndex > 0) {
sel = sel - max(sel, -1);
} else {
auto numGroups = getNumFactorGroups();
for(size_t g = 1; g < numGroups; g++) {
auto factorMaxima = max(logits_[g]->loss(),
-1); // we cast since loss is likely ce-loss which has type float32
auto factorMasks = constant(
getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
sel = sel
+ cast(factorMaxima, sel->value_type())
* cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor
// get multiplied with 0
}
auto numGroups = factoredVocab_->getNumGroups();
// split labels into individual factor labels
auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
//Expr indices = this->indices(toWordIndexVector(labels));
// accumulate all CEs for all words that have the factor
// Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
Expr loss;
for (size_t g = 0; g < numGroups; g++) {
if (!logits_[g])
continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
if(loss)
factorLoss = cast(factorLoss, loss->value_type());
factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor
loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
}
return loss;
}
// This function assumes this object holds a single factor that represents a rational loss (with count).
//Ptr<RationalLoss> Logits::getRationalLoss() const {
// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs");
// ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count");
// return logits_.front();
//}
// if selIdx are given, then we must reshuffle accordingly
if(!hypIndices.empty()) // use the same function that shuffles decoder state
sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
// get logits for one factor group
// For groupIndex == 0, the function also requires the shortlist if there is one.
Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
return sel;
}
// normalize for decoding:
// - all secondary factors: subtract their max
// - lemma: add all maxes of applicable factors
if (groupIndex > 0) {
sel = sel - max(sel, -1);
}
else {
auto numGroups = getNumFactorGroups();
for (size_t g = 1; g < numGroups; g++) {
auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32
auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0
}
}
// used for breakDown() only
// Index is flattened
Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
return logits_[groupIndex]->loss()->val();
}
// if selIdx are given, then we must reshuffle accordingly
if (!hypIndices.empty()) // use the same function that shuffles decoder state
sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
return sel;
// This function assumes that the object holds one or more factor logits, which are summed up
// into output-vocab logits according to the factored model (with correct normalization of factors).
// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
// @TODO: remove altogether
Expr Logits::getLogits() const {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
if(!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return getFactoredLogits(0);
}
// used for breakDown() only
// Index is flattened
Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
return logits_[groupIndex]->loss()->val();
}
// This function assumes that the object holds one or more factor logits, which are summed up
// into output-vocab logits according to the factored model (with correct normalization of factors).
// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
// @TODO: remove altogether
Expr Logits::getLogits() const {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
if (!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return getFactoredLogits(0);
}
#ifdef FACTOR_FULL_EXPANSION
// compute normalized factor log probs
std::vector<Expr> logProbs(logits_.size());
for (size_t g = 0; g < logits_.size(); g++)
logProbs[g] = logsoftmax(logits_[g]->loss());
auto y = concatenate(logProbs, /*axis=*/ -1);
// compute normalized factor log probs
std::vector<Expr> logProbs(logits_.size());
for(size_t g = 0; g < logits_.size(); g++)
logProbs[g] = logsoftmax(logits_[g]->loss());
auto y = concatenate(logProbs, /*axis=*/-1);
// sum up the unit logits across factors for each target word
auto graph = y->graph();
auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
y = dot_csr(
y, // [B x U]
factorMatrix.shape,
graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32),
graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32),
/*transB=*/ true); // -> [B x V]
// sum up the unit logits across factors for each target word
auto graph = y->graph();
auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
y = dot_csr(
y, // [B x U]
factorMatrix.shape,
graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
graph->constant({(int)factorMatrix.indices.size()},
inits::fromVector(factorMatrix.indices),
Type::uint32),
graph->constant({(int)factorMatrix.offsets.size()},
inits::fromVector(factorMatrix.offsets),
Type::uint32),
/*transB=*/true); // -> [B x V]
// mask out gaps
auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask));
// mask out gaps
auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask));
return y;
return y;
#else
ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
#endif
}
}
void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
bool isValid = FactoredVocab::isFactorValid(factorIndex);
indices.push_back(isValid ? (WordIndex)factorIndex : 0);
masks.push_back((float)isValid);
}
void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
bool isValid = FactoredVocab::isFactorValid(factorIndex);
indices.push_back(isValid ? (WordIndex)factorIndex : 0);
masks.push_back((float)isValid);
}
std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
if (!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return {MaskedFactorIndices(words)};
}
auto numGroups = factoredVocab_->getNumGroups();
std::vector<MaskedFactorIndices> res(numGroups);
for (size_t g = 0; g < numGroups; g++) {
auto& resg = res[g];
resg.reserve(words.size());
for (const auto& word : words)
resg.push_back(factoredVocab_->getFactor(word, g));
}
return res;
std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words)
const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
if(!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return {MaskedFactorIndices(words)};
}
//// use first factor of each word to determine whether it has a specific factor
//std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0
// std::vector<float> res;
// res.reserve(words.size());
// for (const auto& word : words) {
// auto lemma = factoredVocab_->getFactor(word, 0);
// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
// }
// return res;
//}
// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size();
std::vector<float> res;
res.reserve(n);
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab
for (size_t i = 0; i < n; i++) {
auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
}
return res;
auto numGroups = factoredVocab_->getNumGroups();
std::vector<MaskedFactorIndices> res(numGroups);
for(size_t g = 0; g < numGroups; g++) {
auto& resg = res[g];
resg.reserve(words.size());
for(const auto& word : words)
resg.push_back(factoredVocab_->getFactor(word, g));
}
return res;
}
Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
std::vector<Ptr<RationalLoss>> newLogits;
for (const auto& l : logits_)
newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
return Logits(std::move(newLogits), factoredVocab_);
}
//// use first factor of each word to determine whether it has a specific factor
// std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0
// for words that do have this factor; else 0
// std::vector<float> res;
// res.reserve(words.size());
// for (const auto& word : words) {
// auto lemma = factoredVocab_->getFactor(word, 0);
// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
// }
// return res;
//}
Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const {
std::vector<Ptr<RationalLoss>> newLogits;
bool first = true;
for (const auto& l : logits_) {
newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others
first = false;
}
return Logits(std::move(newLogits), factoredVocab_);
// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
size_t n
= indices.empty()
? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
: indices.size();
std::vector<float> res;
res.reserve(n);
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
// into FactoredVocab
for(size_t i = 0; i < n; i++) {
auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
}
return res;
}
// @TODO: code dup with above; we can merge it into applyToRationalLoss()
Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_
std::vector<Ptr<RationalLoss>> newLogits;
for (const auto& l : logits_)
newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
return Logits(std::move(newLogits), factoredVocab_);
Logits Logits::applyUnaryFunction(
const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
std::vector<Ptr<RationalLoss>> newLogits;
for(const auto& l : logits_)
newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
return Logits(std::move(newLogits), factoredVocab_);
}
Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
const std::function<Expr(Expr)>& fother) const {
std::vector<Ptr<RationalLoss>> newLogits;
bool first = true;
for(const auto& l : logits_) {
newLogits.emplace_back(New<RationalLoss>((first ? f1 : fother)(l->loss()),
l->count())); // f1 for first, fother for all others
first = false;
}
}
return Logits(std::move(newLogits), factoredVocab_);
}
// @TODO: code dup with above; we can merge it into applyToRationalLoss()
Logits Logits::withCounts(
const Expr& count) const { // create new Logits with 'count' implanted into all logits_
std::vector<Ptr<RationalLoss>> newLogits;
for(const auto& l : logits_)
newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
return Logits(std::move(newLogits), factoredVocab_);
}
} // namespace marian

View File

@ -1,8 +1,8 @@
#pragma once
#include "marian.h"
#include "data/shortlist.h"
#include "generic.h"
#include "marian.h"
namespace marian {
@ -16,46 +16,77 @@ class FactoredVocab;
class RationalLoss;
class Logits {
public:
Logits() {}
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
logits_.push_back(logits);
}
explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
Logits() {}
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
logits_.push_back(logits);
}
explicit Logits(
Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
Logits(std::vector<Ptr<RationalLoss>>&& logits,
Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
//Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
Expr getFactoredLogits(
size_t groupIndex,
Ptr<data::Shortlist> shortlist = nullptr,
const std::vector<IndexType>& hypIndices = {},
size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
// Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
Expr applyLossFunction(
const Words& labels,
const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const;
Logits applyUnaryFunction(
const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
const std::function<Expr(Expr)>& fother)
const; // clone this but apply f1 to first and fother to to all other values
struct MaskedFactorIndices {
std::vector<WordIndex> indices; // factor index, or 0 if masked
std::vector<float> masks;
void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
MaskedFactorIndices() {}
MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
};
std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
size_t getNumFactorGroups() const { return logits_.size(); }
bool empty() const { return logits_.empty(); }
Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
struct MaskedFactorIndices {
std::vector<WordIndex> indices; // factor index, or 0 if masked
std::vector<float> masks;
void reserve(size_t n) {
indices.reserve(n);
masks.reserve(n);
}
void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0
// for invalid entries
MaskedFactorIndices() {}
MaskedFactorIndices(const Words& words) {
indices = toWordIndexVector(words);
} // we can leave masks uninitialized for this special use case
};
std::vector<MaskedFactorIndices> factorizeWords(
const Words& words) const; // breaks encoded Word into individual factor indices
Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
size_t getNumFactorGroups() const { return logits_.size(); }
bool empty() const { return logits_.empty(); }
Logits withCounts(
const Expr& count) const; // create new Logits with 'count' implanted into all logits_
private:
// helper functions
Ptr<ExpressionGraph> graph() const;
Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
// helper functions
Ptr<ExpressionGraph> graph() const;
Expr constant(const Shape& shape, const std::vector<float>& data) const {
return graph()->constant(shape, inits::fromVector(data));
}
Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const {
return graph()->constant(shape, inits::fromVector(data));
}
template <typename T>
Expr constant(const std::vector<T>& data) const {
return constant(Shape{(int)data.size()}, data);
} // same as constant() but assuming vector
Expr indices(const std::vector<uint32_t>& data) const {
return graph()->indices(data);
} // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup,
const std::vector<WordIndex>& indices) const;
private:
// members
// @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
Ptr<FactoredVocab> factoredVocab_;
// members
// @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just
// by the Expr
std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
Ptr<FactoredVocab> factoredVocab_;
};
// Unary function that returns a Logits object
@ -65,12 +96,11 @@ private:
struct IUnaryLogitLayer : public IUnaryLayer {
virtual Logits applyAsLogits(Expr) = 0;
virtual Logits applyAsLogits(const std::vector<Expr>& es) {
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return applyAsLogits(es.front());
}
virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
};
}
} // namespace marian

View File

@ -13,26 +13,30 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
bool wordScores = options->get<bool>("word-scores", false);
return New<RescorerLoss>(wordScores);
} else if(unlikelihood) {
ABORT_IF(!options->hasAndNotEmpty("data-weighting")
&& options->get<std::string>("data-weighting-type") != "word",
"Unlikelihood loss training requires error annotation in form of per-target-label scores");
return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
} else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
ABORT_IF(
!options->hasAndNotEmpty("data-weighting")
&& options->get<std::string>("data-weighting-type") != "word",
"Unlikelihood loss training requires error annotation in form of per-target-label scores");
return New<SequenceUnlikelihoodLoss>(
smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on
// values given for data-weighting
} else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones.
// E.g. what about ce-sum?
return New<CrossEntropyLoss>(smoothing, factorWeight);
}
}
// see loss.h for detailed explanations of each class
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
if(multiLossType == "sum") // sum of sums
return New<SumMultiRationalLoss>();
else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
return New<ScaledMultiRationalLoss>();
else if(multiLossType == "mean") // sum of means
return New<MeanMultiRationalLoss>();
else
ABORT("Unknown multi-loss-type {}", multiLossType);
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
if(multiLossType == "sum") // sum of sums
return New<SumMultiRationalLoss>();
else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
return New<ScaledMultiRationalLoss>();
else if(multiLossType == "mean") // sum of means
return New<MeanMultiRationalLoss>();
else
ABORT("Unknown multi-loss-type {}", multiLossType);
}
} // namespace marian

View File

@ -1,8 +1,8 @@
#pragma once
#include "graph/expression_operators.h"
#include "layers/logits.h" // for Logits (Frank's factor hack)
#include "data/types.h"
#include "graph/expression_operators.h"
#include "layers/logits.h" // for Logits (Frank's factor hack)
namespace marian {
@ -22,21 +22,18 @@ namespace marian {
*/
class RationalLoss {
protected:
Expr loss_; // numerator
Expr count_; // denominator
Expr loss_; // numerator
Expr count_; // denominator
RationalLoss() = default; // protected
RationalLoss() = default; // protected
public:
RationalLoss(Expr loss, Expr count)
: loss_(loss), count_(count) {}
RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {}
RationalLoss(Expr loss, float count)
: loss_(loss),
count_(constant_like(loss, inits::fromValue(count))) {}
: loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {}
RationalLoss(const RationalLoss& other)
: loss_(other.loss_), count_(other.count_) {}
RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {}
virtual ~RationalLoss() = default;
@ -50,7 +47,7 @@ public:
}
template <typename T>
T loss() const { // this will fail if loss is not a single value
T loss() const { // this will fail if loss is not a single value
ABORT_IF(!loss_, "Loss has not been defined");
return loss_->val()->scalar<T>();
}
@ -65,7 +62,7 @@ public:
}
template <typename T>
T count() const { // this will fail if loss is not a single value
T count() const { // this will fail if loss is not a single value
ABORT_IF(!count_, "Labels have not been defined");
return count_->val()->scalar<T>();
}
@ -85,21 +82,21 @@ public:
* RationalLoss object.
*/
struct StaticLoss {
float loss; // numerator
float count; // denominator
float loss; // numerator
float count; // denominator
StaticLoss() : loss(0.f), count(0.f) {}
StaticLoss(const RationalLoss& dynamic)
: loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
: loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
StaticLoss operator +(const StaticLoss& other) const {
StaticLoss operator+(const StaticLoss& other) const {
StaticLoss res(*this);
res += other;
return res;
}
StaticLoss& operator +=(const StaticLoss& other) {
StaticLoss& operator+=(const StaticLoss& other) {
loss = loss + other.loss;
count = count + other.count;
return *this;
@ -139,32 +136,21 @@ protected:
public:
MultiRationalLoss() : RationalLoss() {}
MultiRationalLoss(const RationalLoss& rl) : RationalLoss() {
push_back(rl);
}
MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); }
virtual void push_back(const RationalLoss& current) {
loss_ = accumulateLoss(current);
count_ = accumulateCount(current);
loss_ = accumulateLoss(current);
count_ = accumulateCount(current);
partialLosses_.push_back(current);
}
const RationalLoss& operator[](size_t i) {
return partialLosses_[i];
}
const RationalLoss& operator[](size_t i) { return partialLosses_[i]; }
auto begin() -> decltype(partialLosses_.begin()) const {
return partialLosses_.begin();
}
auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); }
auto end() -> decltype(partialLosses_.end()) const {
return partialLosses_.end();
}
size_t size() const {
return partialLosses_.size();
}
auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); }
size_t size() const { return partialLosses_.size(); }
};
/**
@ -212,17 +198,19 @@ private:
virtual Expr accumulateLoss(const RationalLoss& current) override {
if(loss_) {
const auto& first = partialLosses_.front();
return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss
return loss_
+ current.loss() * first.count()
/ current.count(); // scale up/down to match scale of first loss
} else {
return current.loss(); // first reference loss, keeps to scale with this one
return current.loss(); // first reference loss, keeps to scale with this one
}
}
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_) {
return count_; // Keep first label count // or: count_ + first.count() / current.count();
return count_; // Keep first label count // or: count_ + first.count() / current.count();
} else {
return current.count(); // This is the first loss
return current.count(); // This is the first loss
}
}
@ -253,9 +241,10 @@ private:
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_)
return count_; // keep the existing '1'
return count_; // keep the existing '1'
else
return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
return current.count()->graph()->ones(
{1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
}
public:
@ -279,18 +268,21 @@ class LabelwiseLoss {
protected:
std::vector<int> axes_;
virtual Expr compute(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) = 0;
virtual Expr compute(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr)
= 0;
// label counts are available, reduce together with loss to obtain counts
RationalLoss reduce(Expr loss, Expr labels) {
ABORT_IF(!loss, "Loss has not been computed");
ABORT_IF(!labels, "Labels have not been computed");
Expr lossSum = cast(loss, Type::float32); // accumulate in float32
Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
Expr lossSum = cast(loss, Type::float32); // accumulate in float32
Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
for(int i = 0; i < axes_.size(); ++i) {
lossSum = sum(lossSum, axes_[i]);
lossSum = sum(lossSum, axes_[i]);
labelsSum = sum(labelsSum, axes_[i]);
}
@ -301,7 +293,7 @@ protected:
RationalLoss reduce(Expr loss) {
ABORT_IF(!loss, "Loss has not been computed");
Expr lossSum = cast(loss, Type::float32);
Expr lossSum = cast(loss, Type::float32);
for(int i = 0; i < axes_.size(); ++i)
lossSum = sum(lossSum, axes_[i]);
@ -311,17 +303,18 @@ protected:
}
public:
LabelwiseLoss(const std::vector<int>& axes)
: axes_(axes) { }
LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {}
virtual RationalLoss apply(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) {
virtual RationalLoss apply(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) {
Expr loss = compute(logits, labels, mask, labelWeights);
if(mask)
return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
else
return reduce(loss); // we have no mask, assume all items are labels
return reduce(loss); // we have no mask, assume all items are labels
}
};
@ -331,28 +324,34 @@ public:
class CrossEntropyLoss : public LabelwiseLoss {
public:
CrossEntropyLoss(float labelSmoothing, float factorWeight)
: CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
: CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {
} // cross-entropy already reduces over axis -1
CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
: LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {}
: LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
labelSmoothing_(labelSmoothing),
factorWeight_(factorWeight) {}
virtual ~CrossEntropyLoss() {}
protected:
float labelSmoothing_; // interpolation factor for label smoothing, see below
float factorWeight_; // give extra weight to factors
virtual Expr compute(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
// logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up
protected:
float labelSmoothing_; // interpolation factor for label smoothing, see below
float factorWeight_; // give extra weight to factors
virtual Expr compute(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) override {
// logits may be factored; in that case, the getLoss() function computes one loss for each, and
// sums them up
int inFactor = false;
auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
// for bert training or classification the time dimension is lost.
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32);
if (inFactor && factorWeight_ != 1.0f) {
if(inFactor && factorWeight_ != 1.0f) {
LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
ce = ce * factorWeight_;
}
@ -365,8 +364,10 @@ protected:
if(labelWeights) {
// We currently do not know how to use target factors and word-level label weights together
bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors");
bool wordlevel = labelWeights->shape()[-3]
> 1; // Time-dimension is not trivially 1, hence we have word-level weights.
ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1,
"CE loss with word-level label weights is not implemented for factors");
ce = ce * cast(labelWeights, Type::float32);
}
@ -374,13 +375,12 @@ protected:
}
};
/**
* @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an
* implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319.
* We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not
* zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going
* to flip over to use SUL for that sentence to penalize the selected word.
* We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are
* not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it
* is going to flip over to use SUL for that sentence to penalize the selected word.
*
* SUL is implemented as:
* -log(gather(1 - softmax(logits), -1, indices))
@ -390,35 +390,45 @@ protected:
class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
: CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
: CrossEntropyLoss(labelSmoothing, factorWeight) {
} // cross-entropy already reduces over axis -1
SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
: CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
: CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
protected:
virtual Expr compute(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
virtual Expr compute(Logits logits,
const Words& labels,
Expr mask = nullptr,
Expr labelWeights = nullptr) override {
auto ce = CrossEntropyLoss::compute(
logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
if(!labelWeights)
return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
// We currently do not know how to use target factors and word-level label weights together
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete)
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by
// default 1, which would make this obsolete.
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask
// again to eliminate padding (might be obsolete)
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
return cast(unlikelihood(logits, indices), Type::float32);
});
// compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training
// schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL.
auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
ceUl = errorMask * ceUl; // don't use for correct label or padding
// compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only
// on_ the errors with UL. This is the "mixed" training schedule from
// https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily
// switch between CE and UL.
auto onlyCe = eq(sum(errorMask, /*axis=*/-3),
0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
ceUl = errorMask * ceUl; // don't use for correct label or padding
auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry
auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never
// simultanously used as cost per batch entry
return cost;
}
@ -463,7 +473,6 @@ public:
}
};
/**
* @brief Factory for label-wise loss functions
*/

View File

@ -1,120 +1,131 @@
#include "output.h"
#include "data/factored_vocab.h"
#include "common/timer.h"
#include "layers/lsh.h"
#include "data/factored_vocab.h"
#include "layers/loss.h"
#include "layers/lsh.h"
namespace marian {
namespace mlp {
/*private*/ void Output::lazyConstruct(int inputDim) {
// We must construct lazily since we won't know tying nor input dim in constructor.
if (Wt_)
// We must construct lazily since we won't know tying nor input dim in constructor.
if(Wt_)
return;
// this option is only set in the decoder
if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
auto k = opt<std::vector<int>>("output-approx-knn")[0];
// this option is only set in the decoder
if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
auto k = opt<std::vector<int>>("output-approx-knn")[0];
auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
lsh_ = New<LSH>(k, nbits);
}
}
auto name = options_->get<std::string>("prefix");
auto numOutputClasses = options_->get<int>("dim");
auto name = options_->get<std::string>("prefix");
auto numOutputClasses = options_->get<int>("dim");
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
if (factoredVocab_) {
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
if(factoredVocab_) {
numOutputClasses = (int)factoredVocab_->factorVocabSize();
LOG_ONCE(info, "[embedding] Factored outputs enabled");
}
}
if(tiedParam_) {
if(tiedParam_) {
Wt_ = tiedParam_;
} else {
if (graph_->get(name + "_W")) { // support of legacy models that did not transpose
Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
isLegacyUntransposedW = true;
}
else // this is the regular case:
Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
}
} else {
if(graph_->get(name + "_W")) { // support of legacy models that did not transpose
Wt_ = graph_->param(
name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
isLegacyUntransposedW = true;
} else // this is the regular case:
Wt_ = graph_->param(
name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
}
if(hasBias_)
if(hasBias_)
b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
#define HARDMAX_HACK
#ifdef HARDMAX_HACK
lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
#endif
auto range = factoredVocab_->getGroupRange(0);
auto lemmaVocabDim = (int)(range.second - range.first);
auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
}
auto initFunc = inits::glorotUniform(
/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
lemmaEt_ = graph_->param(name + "_lemmaEt",
{lemmaDimEmb, lemmaVocabDim},
initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
}
}
Logits Output::applyAsLogits(Expr input) /*override final*/ {
lazyConstruct(input->shape()[-1]);
lazyConstruct(input->shape()[-1]);
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
if(b)
return affine(x, W, b, transA, transB);
return affine(x, W, b, transA, transB);
else
return dot(x, W, transA, transB);
};
return dot(x, W, transA, transB);
};
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
if(lsh_) {
ABORT_IF( transA, "Transposed query not supported for LSH");
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
return lsh_->apply(x, W, b); // knows how to deal with undefined bias
ABORT_IF(transA, "Transposed query not supported for LSH");
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
return lsh_->apply(x, W, b); // knows how to deal with undefined bias
} else {
return affineOrDot(x, W, b, transA, transB);
return affineOrDot(x, W, b, transA, transB);
}
};
};
if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one
// batch, then clear()ed
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
if(hasBias_)
cachedShortb_ = index_select(b_ , -1, shortlist_->indices());
}
cachedShortb_ = index_select(b_, -1, shortlist_->indices());
}
if (factoredVocab_) {
if(factoredVocab_) {
auto graph = input->graph();
// project each factor separately
auto numGroups = factoredVocab_->getNumGroups();
std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors)
Expr input1 = input; // [B... x D]
Expr Plemma = nullptr; // used for lemmaDimEmb=-1
Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
for (size_t g = 0; g < numGroups; g++) {
auto range = factoredVocab_->getGroupRange(g);
if (g > 0 && range.first == range.second) // empty entry
std::vector<Ptr<RationalLoss>> allLogits(numGroups,
nullptr); // (note: null entries for absent factors)
Expr input1 = input; // [B... x D]
Expr Plemma = nullptr; // used for lemmaDimEmb=-1
Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
for(size_t g = 0; g < numGroups; g++) {
auto range = factoredVocab_->getGroupRange(g);
if(g > 0 && range.first == range.second) // empty entry
continue;
ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g);
// slice this group's section out of W_
Expr factorWt, factorB;
if (g == 0 && shortlist_) {
ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second,
"Factor groups must be consecutive (group {} vs predecessor)",
g);
// slice this group's section out of W_
Expr factorWt, factorB;
if(g == 0 && shortlist_) {
factorWt = cachedShortWt_;
factorB = cachedShortb_;
}
else {
factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
factorB = cachedShortb_;
} else {
factorWt = slice(
Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
if(hasBias_)
factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
}
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
}
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
if((lemmaDimEmb == -2 || lemmaDimEmb == -3)
&& g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
// this mimics one transformer layer
// - attention over two inputs:
// - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas.
// - e = current lemma. We use the original embedding vector; specifically, expectation
// over all lemmas.
// - input = hidden state FF(h_enc+h_dec)
// - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention)
// - dot-prod attention to allow both sides to influence (unlike our recurrent
// self-attention)
// - multi-head to allow for multiple conditions to be modeled
// - add & norm, for gradient flow and scaling
// - FF layer --this is expensive; it is per-factor
@ -122,112 +133,161 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
int inputDim = input->shape()[-1];
int heads = 8;
auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform());
auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform());
auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform());
auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform());
auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform());
auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform());
auto toMultiHead = [&](Expr x, int heads) {
const auto& shape = x->shape();
int inputDim = shape[-1];
int otherDim = shape.elements() / inputDim;
ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads);
return reshape(x, { otherDim, heads, 1, inputDim / heads });
const auto& shape = x->shape();
int inputDim = shape[-1];
int otherDim = shape.elements() / inputDim;
ABORT_IF(inputDim / heads * heads != inputDim,
"inputDim ({}) must be multiple of number of heads ({})",
inputDim,
heads);
return reshape(x, {otherDim, heads, 1, inputDim / heads});
};
input1 = inputLemma;
auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax.
auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values
auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other
auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
auto sm = sigmoid(zm); // [B... x H x 1]
auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
auto r = reshape(rm, input->shape()); // [B... x D]
auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
auto kdm = toMultiHead(dot(input1 - input, Wk),
heads); // [B... x H x D/H] the two data vectors projected as keys.
// Use diff and sigmoid, instead of softmax.
auto vem = toMultiHead(
dot(input1, Wv),
heads); // [B... x H x D/H] one of the two data vectors projected as values
auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other
auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
auto sm = sigmoid(zm); // [B... x H x 1]
auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
auto r = reshape(rm, input->shape()); // [B... x D]
// add & norm
input1 = r + input1;
input1 = layerNorm(input1, name + "_att");
// FF layer
auto ffnDropProb = 0.1f; // @TODO: get as a parameter
auto ffnDim = inputDim * 2; // @TODO: get as a parameter
auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb);
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
auto ffnDropProb = 0.1f; // @TODO: get as a parameter
auto ffnDim = inputDim * 2; // @TODO: get as a parameter
auto f = denseInline(input1,
name + "_ffn",
/*suffix=*/"1",
ffnDim,
inits::glorotUniform(),
(ActivationFunction*)relu,
ffnDropProb);
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
// add & norm
input1 = f + input1;
input1 = layerNorm(input1, name + "_ffn");
}
// @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix
Expr factorLogits;
if(g == 0)
factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
else
factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
// optionally add lemma-dependent bias
if (Plemma) { // [B... x U0]
}
// @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
// matrix
Expr factorLogits;
if(g == 0)
factorLogits = affineOrLSH(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
else
factorLogits = affineOrDot(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
// optionally add lemma-dependent bias
if(Plemma) { // [B... x U0]
int lemmaVocabDim = Plemma->shape()[-1];
int factorVocabDim = factorLogits->shape()[-1];
auto name = options_->get<std::string>("prefix");
Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
Expr lemmaBt
= graph_->param(name + "_lemmaBt_" + std::to_string(g),
{factorVocabDim, lemmaVocabDim},
inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
factorLogits = factorLogits + b;
}
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
// optionally add a soft embedding of lemma back to create some lemma dependency
// @TODO: if this works, move it into lazyConstruct
if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
}
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
// optionally add a soft embedding of lemma back to create some lemma dependency
// @TODO: if this works, move it into lazyConstruct
if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
// get expected lemma embedding vector
auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
auto factorLogSoftmax = logsoftmax(
factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
auto factorSoftmax = exp(factorLogSoftmax);
inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
}
else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
inputLemma = dot(factorSoftmax,
factorWt,
false,
/*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
} else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
// get max-lemma embedding vector
auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set
auto maxVal = max(factorLogits,
-1); // [B... x U] note: with shortlist, this is not the full lemma set
auto factorHardmax = eq(factorLogits, maxVal);
inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
}
else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
inputLemma = dot(factorHardmax,
factorWt,
false,
/*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
} else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
LOG_ONCE(info, "[embedding] using lemma-dependent bias");
auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
auto z = /*stopGradient*/(factorLogSoftmax);
Plemma = exp(z); // [B... x U]
}
else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
auto factorLogSoftmax
= logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
auto z = /*stopGradient*/ (factorLogSoftmax);
Plemma = exp(z); // [B... x U]
} else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
// compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE
// compute softmax. We compute logsoftmax() separately because this way, computation will be
// reused later via CSE
auto factorLogSoftmax = logsoftmax(factorLogits);
auto factorSoftmax = exp(factorLogSoftmax);
#ifdef HARDMAX_HACK
bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation)
if (hardmax) {
lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
auto maxVal = max(factorSoftmax, -1);
factorSoftmax = eq(factorSoftmax, maxVal);
bool hardmax = (lemmaDimEmb & 1)
!= 0; // odd value triggers hardmax for now (for quick experimentation)
if(hardmax) {
lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
auto maxVal = max(factorSoftmax, -1);
factorSoftmax = eq(factorSoftmax, maxVal);
}
#endif
// re-embedding lookup, soft-indexed by softmax
if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
auto e = dot(factorSoftmax,
cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_,
false,
true); // [B... x L]
// project it back to regular hidden dim
int inputDim = input1->shape()[-1];
auto name = options_->get<std::string>("prefix");
// note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1
Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension
auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
// note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also
// length 1
Expr lemmaWt
= inputDim == lemmaDimEmb
? nullptr
: graph_->param(name + "_lemmaWt",
{inputDim, lemmaDimEmb},
inits::glorotUniform()); // [D x L] D=hidden-vector dimension
auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
// augment the original hidden vector with this additional information
input1 = input1 + f;
}
}
}
return Logits(std::move(allLogits), factoredVocab_);
} else if (shortlist_) {
return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true));
} else {
return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
}
} else if(shortlist_) {
return Logits(affineOrLSH(input,
cachedShortWt_,
cachedShortb_,
false,
/*transB=*/isLegacyUntransposedW ? false : true));
} else {
return Logits(
affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
}
}
}
}
} // namespace mlp
} // namespace marian

View File

@ -1,10 +1,10 @@
#pragma once
#include "marian.h"
#include "generic.h"
#include "logits.h"
#include "data/shortlist.h"
#include "generic.h"
#include "layers/factory.h"
#include "logits.h"
#include "marian.h"
namespace marian {
class LSH;
@ -14,42 +14,45 @@ namespace mlp {
class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
private:
// parameters held by this layer
Expr Wt_; // weight matrix is stored transposed for efficiency
Expr Wt_; // weight matrix is stored transposed for efficiency
Expr b_;
Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
bool hasBias_{true};
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
Ptr<FactoredVocab> factoredVocab_;
// optional parameters set/updated after construction
Expr tiedParam_;
Ptr<data::Shortlist> shortlist_;
Ptr<LSH> lsh_;
void lazyConstruct(int inputDim);
public:
Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options),
hasBias_{!options->get<bool>("output-omit-bias", false)} {
: LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} {
clear();
}
void tieTransposed(Expr tied) {
if (Wt_)
ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created");
if(Wt_)
ABORT_IF(tiedParam_.get() != tied.get(),
"Tied output projection cannot be changed once weights have been created");
else
tiedParam_ = tied;
}
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
if (shortlist_)
ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()");
if(shortlist_)
ABORT_IF(shortlist.get() != shortlist_.get(),
"Output shortlist cannot be changed except after clear()");
else {
ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??");
ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_,
"No shortlist but cached parameters??");
shortlist_ = shortlist;
}
// cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
@ -60,7 +63,7 @@ public:
void clear() override final {
shortlist_ = nullptr;
cachedShortWt_ = nullptr;
cachedShortb_ = nullptr;
cachedShortb_ = nullptr;
cachedShortLemmaEt_ = nullptr;
}
@ -69,6 +72,4 @@ public:
} // namespace mlp
}
} // namespace marian

View File

@ -4,13 +4,11 @@ namespace marian {
namespace models {
Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
// @TODO: This is becoming more and more opaque ^^. Can we simplify this?
return state;
}
}
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
// @TODO: This is becoming more and more opaque ^^. Can we simplify this?
return state;
}
} // namespace models
} // namespace marian

View File

@ -4,8 +4,8 @@
#include "layers/guided_alignment.h"
#include "layers/loss.h"
#include "layers/weight.h"
#include "models/encoder_decoder.h"
#include "models/encoder_classifier.h"
#include "models/encoder_decoder.h"
#include "models/encoder_pooler.h"
namespace marian {
@ -22,10 +22,12 @@ namespace models {
class ICost {
public:
virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
virtual Ptr<MultiRationalLoss> apply(
Ptr<IModel> model,
Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
Ptr<data::Batch> batch,
bool clearGraph = true)
= 0;
virtual ~ICost() {}
};
@ -45,10 +47,9 @@ public:
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
toBeWeighted_
= (options_->hasAndNotEmpty("data-weighting") && !inference_)
|| (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting")
&& !inference_);
toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_)
|| (options_->has("dynamic-weighting")
&& options_->get<bool>("dynamic-weighting") && !inference_);
if(toBeWeighted_)
weighter_ = WeightingFactory(options_);
}
@ -56,9 +57,9 @@ public:
virtual ~EncoderDecoderCECost() {}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto encdec = std::static_pointer_cast<EncoderDecoder>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
@ -72,17 +73,17 @@ public:
Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
// @TODO: adapt to multi-objective training with multiple decoders
auto partialLoss = loss_->apply(state->getLogProbs(),
state->getTargetWords(),
state->getTargetMask(),
weights);
auto partialLoss = loss_->apply(
state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights);
multiLoss->push_back(partialLoss);
if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) {
auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
auto attentionVectors
= encdec->getDecoders()[0]
->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments");
auto attention = concatenate(attentionVectors, /*axis =*/ -1);
auto attention = concatenate(attentionVectors, /*axis =*/-1);
auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention);
multiLoss->push_back(alignmentLoss);
@ -109,10 +110,9 @@ public:
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto enccls = std::static_pointer_cast<EncoderClassifier>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
@ -141,21 +141,20 @@ protected:
public:
EncoderPoolerRankCost(Ptr<Options> options)
: options_(options),
inference_(options->get<bool>("inference", false)) {
auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set");
: options_(options), inference_(options->get<bool>("inference", false)) {
auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
ABORT_IF(trainEmbedderRank.empty(),
"EncoderPoolerRankCost expects train-embedder-rank to be set");
margin_ = std::stof(trainEmbedderRank[0]);
if(trainEmbedderRank.size() > 1)
normalizer_ = std::stof(trainEmbedderRank[1]);
margin_ = std::stof(trainEmbedderRank[0]);
if(trainEmbedderRank.size() > 1)
normalizer_ = std::stof(trainEmbedderRank[1]);
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto encpool = std::static_pointer_cast<EncoderPooler>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph);
@ -167,28 +166,41 @@ public:
ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss");
// multi-objective training
auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
auto exponent
= dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
auto dp = exp(exponent);
Expr dn1, dn2;
if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator.
dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
if(normalizer_
!= 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the
// magnitude of the sum of negative examples in the denominator.
dn1 = normalizer_
* mean(exp(dotProducts[1] - maxDot),
-1); // dot product of anchor and first negative example
dn2 = normalizer_
* mean(exp(dotProducts[2] - maxDot),
-1); // dot product of positive examples and first negative example
} else {
dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
dn1 = sum(exp(dotProducts[1] - maxDot),
-1); // dot product of anchor and first negative example
dn2 = sum(exp(dotProducts[2] - maxDot),
-1); // dot product of positive examples and first negative example
}
// We rewrite the loss so it looks more like a log-softmax, presumably more stable?
// Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples
auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
// Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp
// + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
auto marginLoss1
= log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
auto marginLoss2
= log(dp + dn2)
- exponent; // symmetric version of the above with positive example vs negative examples
auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
RationalLoss loss(marginLoss, (float)dimBatch);
multiLoss->push_back(loss);
return multiLoss;
}
};
@ -199,8 +211,7 @@ protected:
Ptr<ICost> cost_;
public:
Trainer(Ptr<IModel> model, Ptr<ICost> cost)
: model_(model), cost_(cost) {}
Trainer(Ptr<IModel> model, Ptr<ICost> cost) : model_(model), cost_(cost) {}
virtual ~Trainer() {}
@ -219,8 +230,8 @@ public:
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
Ptr<data::Batch> batch,
bool clearGraph = true) override {
return cost_->apply(model_, graph, batch, clearGraph);
};
@ -230,24 +241,25 @@ public:
class ILogProb {
public:
virtual Logits apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true)
= 0;
};
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
// This class is for the former use. The latter is done using Trainer.
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for
// the ground truth?
// Beam search uses it for the former meaning, while 'marian score' and validation in the
// latter. This class is for the former use. The latter is done using Trainer.
class Scorer : public IModel {
protected:
Ptr<IModel> model_;
Ptr<ILogProb> logProb_;
public:
Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
: model_(model), logProb_(cost) {}
Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) : model_(model), logProb_(cost) {}
virtual ~Scorer(){}
virtual ~Scorer() {}
Ptr<IModel> getModel() { return model_; }
@ -264,8 +276,8 @@ public:
}
virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
Ptr<data::Batch> batch,
bool clearGraph = true) override {
return logProb_->apply(model_, graph, batch, clearGraph);
};
@ -293,10 +305,10 @@ public:
virtual ~GumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
[](Expr logits){ // lemma gets gumbelled
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
},
logsoftmax)); // factors don't
[](Expr logits) { // lemma gets gumbelled
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
},
logsoftmax)); // factors don't
return state;
}
};
@ -311,8 +323,7 @@ protected:
Ptr<ILogProbStep> cost_;
public:
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost)
: encdec_(encdec), cost_(cost) {}
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@ -346,12 +357,13 @@ public:
return encdec_->startState(graph, batch);
}
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
const Words& words, // [beamIndex * activeBatchSize + batchIndex]
const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) override {
virtual Ptr<DecoderState> step(
Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
const Words& words, // [beamIndex * activeBatchSize + batchIndex]
const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) override {
auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
return cost_->apply(nextState);
}
@ -369,9 +381,7 @@ public:
encdec_->setShortlistGenerator(shortlistGenerator);
};
virtual Ptr<data::Shortlist> getShortlist() override {
return encdec_->getShortlist();
};
virtual Ptr<data::Shortlist> getShortlist() override { return encdec_->getShortlist(); };
virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); }
};

View File

@ -1,7 +1,7 @@
#pragma once
#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "marian.h"
#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"
namespace marian {
@ -9,7 +9,7 @@ namespace marian {
class EncoderState {
private:
Expr context_;
Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
Ptr<data::CorpusBatch> batch_;
public:
@ -19,31 +19,34 @@ public:
EncoderState() {}
virtual ~EncoderState() {}
virtual Expr getContext() const { return context_; }
virtual Expr getAttended() const { return context_; }
virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed
virtual Expr getContext() const { return context_; }
virtual Expr getAttended() const { return context_; }
virtual Expr getMask() const {
return mask_;
} // source batch mask; may have additional positions suppressed
virtual const Words& getSourceWords() {
return batch_->front()->data();
}
virtual const Words& getSourceWords() { return batch_->front()->data(); }
// Sub-select active batch entries from encoder context and context mask
Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
// Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
Ptr<EncoderState> select(
const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
// Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer
// gets transposed to the same dimension layout
return New<EncoderState>(
index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
}
};
class DecoderState {
protected:
rnn::States states_; // states of individual decoder layers
rnn::States states_; // states of individual decoder layers
Logits logProbs_;
std::vector<Ptr<EncoderState>> encStates_;
Ptr<data::CorpusBatch> batch_;
Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
Expr targetMask_;
Words targetWords_; // target labels
Words targetWords_; // target labels
// Keep track of current target token position during translation
size_t position_{0};
@ -57,26 +60,30 @@ public:
virtual ~DecoderState() {}
// @TODO: Do we need all these to be virtual?
virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const {
return encStates_;
}
virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { return encStates_; }
virtual Logits getLogProbs() const { return logProbs_; }
virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; }
// @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop
virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) const {
// @TODO: should this be a constructor? Then derived classes can call this without the New<> in
// the loop
virtual Ptr<DecoderState> select(
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) const {
std::vector<Ptr<EncoderState>> newEncStates;
for(auto& es : encStates_)
// If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
// If the size of the batch dimension of the encoder state context changed, subselect the
// correct batch entries
newEncStates.push_back(
es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
// hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices
auto selectedState = New<DecoderState>(
states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_);
auto selectedState
= New<DecoderState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false),
logProbs_,
newEncStates,
batch_);
// Set positon of new state based on the target token position of current state
selectedState->setPosition(getPosition());
@ -86,7 +93,9 @@ public:
virtual const rnn::States& getStates() const { return states_; }
virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; }
virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) {
targetHistoryEmbeddings_ = targetHistoryEmbeddings;
}
virtual const Words& getTargetWords() const { return targetWords_; };
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
@ -94,9 +103,7 @@ public:
virtual Expr getTargetMask() const { return targetMask_; };
virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
virtual const Words& getSourceWords() const {
return getEncoderStates()[0]->getSourceWords();
}
virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); }
Ptr<data::CorpusBatch> getBatch() const { return batch_; }
@ -111,7 +118,8 @@ public:
/**
* Classifier output based on DecoderState
* @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output.
* @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have
* stateful output.
*/
class ClassifierState {
private: