mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
clang-format -i
This commit is contained in:
parent
55f4216552
commit
ba19663784
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "layers/embedding.h"
|
||||
#include "layers/factory.h"
|
||||
#include "layers/generic.h"
|
||||
#include "layers/embedding.h"
|
||||
#include "layers/output.h"
|
||||
|
||||
namespace marian {
|
||||
@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory {
|
||||
// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
|
||||
class OutputFactory : public LogitLayerFactory {
|
||||
using LogitLayerFactory::LogitLayerFactory;
|
||||
|
||||
protected:
|
||||
std::string tiedTransposedName_;
|
||||
Ptr<data::Shortlist> shortlist_;
|
||||
@ -55,9 +56,7 @@ public:
|
||||
return Accumulator<OutputFactory>(*this);
|
||||
}
|
||||
|
||||
void setShortlist(Ptr<data::Shortlist> shortlist) {
|
||||
shortlist_ = shortlist;
|
||||
}
|
||||
void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
|
||||
|
||||
Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
|
||||
auto output = New<Output>(graph, options_);
|
||||
@ -89,8 +88,7 @@ protected:
|
||||
std::vector<Ptr<IUnaryLayer>> layers_;
|
||||
|
||||
public:
|
||||
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: graph_(graph), options_(options) {}
|
||||
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
|
||||
|
||||
Expr apply(const std::vector<Expr>& av) override {
|
||||
Expr output;
|
||||
@ -106,46 +104,53 @@ public:
|
||||
}
|
||||
|
||||
Logits applyAsLogits(const std::vector<Expr>& av) override {
|
||||
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type
|
||||
// same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
|
||||
// return type
|
||||
auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
|
||||
ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
|
||||
if (layers_.size() == 1) {
|
||||
if (av.size() == 1)
|
||||
ABORT_IF(
|
||||
!lastLayer,
|
||||
"MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
|
||||
if(layers_.size() == 1) {
|
||||
if(av.size() == 1)
|
||||
return lastLayer->applyAsLogits(av[0]);
|
||||
else
|
||||
return lastLayer->applyAsLogits(av);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
Expr output;
|
||||
if (av.size() == 1)
|
||||
if(av.size() == 1)
|
||||
output = layers_[0]->apply(av[0]);
|
||||
else
|
||||
output = layers_[0]->apply(av);
|
||||
for (size_t i = 1; i < layers_.size() - 1; ++i)
|
||||
for(size_t i = 1; i < layers_.size() - 1; ++i)
|
||||
output = layers_[i]->apply(output);
|
||||
return lastLayer->applyAsLogits(output);
|
||||
}
|
||||
}
|
||||
|
||||
Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); }
|
||||
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); }
|
||||
Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
|
||||
Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
|
||||
|
||||
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
|
||||
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
|
||||
|
||||
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
|
||||
auto p = tryAsHasShortlist();
|
||||
ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists");
|
||||
ABORT_IF(
|
||||
!p,
|
||||
"setShortlist() called on an MLP with an output layer that does not support short lists");
|
||||
p->setShortlist(shortlist);
|
||||
}
|
||||
|
||||
void clear() override final {
|
||||
auto p = tryAsHasShortlist();
|
||||
if (p)
|
||||
if(p)
|
||||
p->clear();
|
||||
}
|
||||
|
||||
private:
|
||||
Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); }
|
||||
Ptr<IHasShortList> tryAsHasShortlist() const {
|
||||
return std::dynamic_pointer_cast<IHasShortList>(layers_.back());
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@ -154,6 +159,7 @@ private:
|
||||
*/
|
||||
class MLPFactory : public Factory {
|
||||
using Factory::Factory;
|
||||
|
||||
private:
|
||||
std::vector<Ptr<LayerFactory>> layers_;
|
||||
|
||||
@ -177,23 +183,27 @@ public:
|
||||
// which will go away if we get rid of the abstract factories, and instead just construct
|
||||
// all layers immediately, which is my long-term goal for Marian.
|
||||
private:
|
||||
template<class WrappedFactory>
|
||||
template <class WrappedFactory>
|
||||
class AsLayerFactory : public LayerFactory {
|
||||
WrappedFactory us;
|
||||
WrappedFactory us;
|
||||
|
||||
public:
|
||||
AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
|
||||
Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
|
||||
auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
|
||||
ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
|
||||
return p;
|
||||
}
|
||||
AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
|
||||
Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
|
||||
auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
|
||||
ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
|
||||
return p;
|
||||
}
|
||||
};
|
||||
template<class WrappedFactory>
|
||||
static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; }
|
||||
template <class WrappedFactory>
|
||||
static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) {
|
||||
return wrapped;
|
||||
}
|
||||
|
||||
public:
|
||||
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
|
||||
push_back(AsLayerFactory<OutputFactory>(lf));
|
||||
//layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
|
||||
// layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
|
||||
return Accumulator<MLPFactory>(*this);
|
||||
}
|
||||
};
|
||||
|
@ -3,173 +3,205 @@
|
||||
|
||||
namespace marian {
|
||||
|
||||
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = opt<std::string>("prefix");
|
||||
int dimVoc = opt<int>("dimVocab");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = opt<std::string>("prefix");
|
||||
int dimVoc = opt<int>("dimVocab");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
|
||||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
|
||||
if (factoredVocab_) {
|
||||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
|
||||
if(factoredVocab_) {
|
||||
dimVoc = (int)factoredVocab_->factorVocabSize();
|
||||
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
|
||||
}
|
||||
}
|
||||
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
auto initFunc = inits::glorotUniform(
|
||||
/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
|
||||
|
||||
if (options_->has("embFile")) {
|
||||
if(options_->has("embFile")) {
|
||||
std::string file = opt<std::string>("embFile");
|
||||
if (!file.empty()) {
|
||||
bool norm = opt<bool>("normalization", false);
|
||||
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
|
||||
if(!file.empty()) {
|
||||
bool norm = opt<bool>("normalization", false);
|
||||
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
|
||||
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
|
||||
}
|
||||
|
||||
// helper to embed a sequence of words (given as indices) via factored embeddings
|
||||
Expr Embedding::multiRows(const Words& data, float dropProb) const {
|
||||
auto graph = E_->graph();
|
||||
auto factoredData = factoredVocab_->csr_rows(data);
|
||||
// multi-hot factor vectors are represented as a sparse CSR matrix
|
||||
// [row index = word position index] -> set of factor indices for word at this position
|
||||
ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
|
||||
// the CSR matrix is passed in pieces
|
||||
auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights));
|
||||
auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
|
||||
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
|
||||
// apply dropout
|
||||
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
|
||||
if(!inference_)
|
||||
auto graph = E_->graph();
|
||||
auto factoredData = factoredVocab_->csr_rows(data);
|
||||
// multi-hot factor vectors are represented as a sparse CSR matrix
|
||||
// [row index = word position index] -> set of factor indices for word at this position
|
||||
ABORT_IF(factoredData.shape
|
||||
!= Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}),
|
||||
"shape mismatch??");
|
||||
// the CSR matrix is passed in pieces
|
||||
auto weights = graph->constant({(int)factoredData.weights.size()},
|
||||
inits::fromVector(factoredData.weights));
|
||||
auto indices = graph->constant(
|
||||
{(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32);
|
||||
auto offsets = graph->constant(
|
||||
{(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32);
|
||||
// apply dropout
|
||||
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire
|
||||
// vectors.
|
||||
if(!inference_)
|
||||
weights = dropout(weights, dropProb);
|
||||
// perform the product
|
||||
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
|
||||
// perform the product
|
||||
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
|
||||
}
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
|
||||
auto graph = E_->graph();
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = E_->shape()[-1];
|
||||
int dimWidth = (int)subBatch->batchWidth();
|
||||
std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const
|
||||
/*override final*/ {
|
||||
auto graph = E_->graph();
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = E_->shape()[-1];
|
||||
int dimWidth = (int)subBatch->batchWidth();
|
||||
|
||||
// factored embeddings:
|
||||
// - regular:
|
||||
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
|
||||
// - factored:
|
||||
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
|
||||
// - each row of M contains the set of factors for one word => we want a CSR matrix
|
||||
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
|
||||
// - first compute x @ M on the CPU
|
||||
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
|
||||
// - shape (U, specifically) not actually needed here
|
||||
// - foreach input x[i]
|
||||
// - locate row M[i,*]
|
||||
// - copy through its index values (std::vector<push_back>)
|
||||
// - create a matching ones vector (we can keep growing)
|
||||
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
|
||||
// - CSR matrix product with E
|
||||
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
|
||||
// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
|
||||
// - weighting:
|
||||
// - core factors' gradients are sums over all words that use the factors;
|
||||
// - core factors' embeddings move very fast
|
||||
// - words will need to make up for the move; rare words cannot
|
||||
// - so, we multiply each factor with 1/refCount
|
||||
// - core factors get weighed down a lot
|
||||
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
|
||||
// - but forward pass weighs them down, so that all factors are in a similar numeric range
|
||||
// - if it is required to be in a different range, the embeddings can still learn that, but more slowly
|
||||
// factored embeddings:
|
||||
// - regular:
|
||||
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
|
||||
// - factored:
|
||||
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
|
||||
// - each row of M contains the set of factors for one word => we want a CSR matrix
|
||||
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
|
||||
// - first compute x @ M on the CPU
|
||||
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
|
||||
// - shape (U, specifically) not actually needed here
|
||||
// - foreach input x[i]
|
||||
// - locate row M[i,*]
|
||||
// - copy through its index values (std::vector<push_back>)
|
||||
// - create a matching ones vector (we can keep growing)
|
||||
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
|
||||
// - CSR matrix product with E
|
||||
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
|
||||
// - double-check if all dimensions are specified. Probably not for transpose (which would
|
||||
// be like csc_dot()).
|
||||
// - weighting:
|
||||
// - core factors' gradients are sums over all words that use the factors;
|
||||
// - core factors' embeddings move very fast
|
||||
// - words will need to make up for the move; rare words cannot
|
||||
// - so, we multiply each factor with 1/refCount
|
||||
// - core factors get weighed down a lot
|
||||
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as
|
||||
// before
|
||||
// - but forward pass weighs them down, so that all factors are in a similar numeric range
|
||||
// - if it is required to be in a different range, the embeddings can still learn that, but
|
||||
// more slowly
|
||||
|
||||
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
|
||||
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
|
||||
#if 1
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->mask()));
|
||||
#else // @TODO: this is dead code now, get rid of it
|
||||
// experimental: hide inline-fix source tokens from cross attention
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
|
||||
auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask()));
|
||||
#else // @TODO: this is dead code now, get rid of it
|
||||
// experimental: hide inline-fix source tokens from cross attention
|
||||
auto batchMask
|
||||
= graph->constant({dimWidth, dimBatch, 1},
|
||||
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
|
||||
#endif
|
||||
// give the graph inputs readable names for debugging and ONNX
|
||||
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
|
||||
// give the graph inputs readable names for debugging and ONNX
|
||||
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
|
||||
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
}
|
||||
|
||||
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
|
||||
if (factoredVocab_) {
|
||||
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
//selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
|
||||
if(factoredVocab_) {
|
||||
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
// selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), {
|
||||
// selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
|
||||
return selectedEmbs;
|
||||
}
|
||||
else
|
||||
} else
|
||||
return applyIndices(toWordIndexVector(words), shape);
|
||||
}
|
||||
|
||||
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
|
||||
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
|
||||
auto embIdxExpr = E_->graph()->indices(embIdx);
|
||||
embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
|
||||
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
|
||||
if(!inference_)
|
||||
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
|
||||
return selectedEmbs;
|
||||
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const
|
||||
/*override final*/ {
|
||||
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
|
||||
auto embIdxExpr = E_->graph()->indices(embIdx);
|
||||
embIdxExpr->set_name("data_"
|
||||
+ std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
|
||||
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape()
|
||||
// (test that separately)
|
||||
if(!inference_)
|
||||
selectedEmbs = dropout(
|
||||
selectedEmbs, options_->get<float>("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1});
|
||||
return selectedEmbs;
|
||||
}
|
||||
|
||||
// standard encoder word embeddings
|
||||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
|
||||
auto options = New<Options>(
|
||||
"dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
|
||||
"dimEmb", opt<int>("dim-emb"),
|
||||
"dropout", dropoutEmbeddings_,
|
||||
"inference", inference_,
|
||||
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
|
||||
"fixed", embeddingFix_,
|
||||
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
|
||||
if(options_->hasAndNotEmpty("embedding-vectors")) {
|
||||
auto options = New<Options>(
|
||||
"dimVocab",
|
||||
opt<std::vector<int>>("dim-vocabs")[batchIndex_],
|
||||
"dimEmb",
|
||||
opt<int>("dim-emb"),
|
||||
"dropout",
|
||||
dropoutEmbeddings_,
|
||||
"inference",
|
||||
inference_,
|
||||
"prefix",
|
||||
(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
|
||||
: prefix_ + "_Wemb",
|
||||
"fixed",
|
||||
embeddingFix_,
|
||||
"vocab",
|
||||
opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
|
||||
if(options_->hasAndNotEmpty("embedding-vectors")) {
|
||||
auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
|
||||
options->set(
|
||||
"embFile", embFiles[batchIndex_],
|
||||
"normalization", opt<bool>("embedding-normalization"));
|
||||
}
|
||||
return New<Embedding>(graph_, options);
|
||||
"embFile", embFiles[batchIndex_], "normalization", opt<bool>("embedding-normalization"));
|
||||
}
|
||||
return New<Embedding>(graph_, options);
|
||||
}
|
||||
|
||||
// ULR word embeddings
|
||||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
|
||||
return New<ULREmbedding>(graph_, New<Options>(
|
||||
"dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
|
||||
"dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
|
||||
"dimUlrEmb", opt<int>("ulr-dim-emb"),
|
||||
"dimEmb", opt<int>("dim-emb"),
|
||||
"ulr-dropout", opt<float>("ulr-dropout"),
|
||||
"dropout", dropoutEmbeddings_,
|
||||
"inference", inference_,
|
||||
"ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
|
||||
"ulrQueryFile", opt<std::string>("ulr-query-vectors"),
|
||||
"ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
|
||||
return New<ULREmbedding>(
|
||||
graph_,
|
||||
New<Options>("dimSrcVoc",
|
||||
opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
|
||||
"dimTgtVoc",
|
||||
opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
|
||||
"dimUlrEmb",
|
||||
opt<int>("ulr-dim-emb"),
|
||||
"dimEmb",
|
||||
opt<int>("dim-emb"),
|
||||
"ulr-dropout",
|
||||
opt<float>("ulr-dropout"),
|
||||
"dropout",
|
||||
dropoutEmbeddings_,
|
||||
"inference",
|
||||
inference_,
|
||||
"ulrTrainTransform",
|
||||
opt<bool>("ulr-trainable-transformation"),
|
||||
"ulrQueryFile",
|
||||
opt<std::string>("ulr-query-vectors"),
|
||||
"ulrKeysFile",
|
||||
opt<std::string>("ulr-keys-vectors")));
|
||||
}
|
||||
|
||||
// get embedding layer for this encoder or decoder
|
||||
// This is lazy mostly because the constructors of the consuming objects are not
|
||||
// guaranteed presently to have access to their graph.
|
||||
Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
|
||||
if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
|
||||
if (embeddingLayers_.size() <= batchIndex_)
|
||||
embeddingLayers_.resize(batchIndex_ + 1);
|
||||
if (ulr)
|
||||
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
|
||||
if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
|
||||
if(embeddingLayers_.size() <= batchIndex_)
|
||||
embeddingLayers_.resize(batchIndex_ + 1);
|
||||
if(ulr)
|
||||
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
|
||||
else
|
||||
embeddingLayers_[batchIndex_] = createEmbeddingLayer();
|
||||
}
|
||||
return embeddingLayers_[batchIndex_];
|
||||
}
|
||||
|
||||
embeddingLayers_[batchIndex_] = createEmbeddingLayer();
|
||||
}
|
||||
return embeddingLayers_[batchIndex_];
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "marian.h"
|
||||
#include "generic.h"
|
||||
#include "marian.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -19,7 +19,8 @@ class Embedding : public LayerBase, public IEmbeddingLayer {
|
||||
public:
|
||||
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
|
||||
std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
|
||||
Ptr<data::SubBatch> subBatch) const override final;
|
||||
|
||||
Expr apply(const Words& words, const Shape& shape) const override final;
|
||||
|
||||
@ -27,17 +28,18 @@ public:
|
||||
};
|
||||
|
||||
class ULREmbedding : public LayerBase, public IEmbeddingLayer {
|
||||
std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
|
||||
std::vector<Expr>
|
||||
ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
|
||||
bool inference_{false};
|
||||
|
||||
public:
|
||||
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = "url_embed"; //opt<std::string>("prefix");
|
||||
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||
std::string name = "url_embed"; // opt<std::string>("prefix");
|
||||
int dimKeys = opt<int>("dimTgtVoc");
|
||||
int dimQueries = opt<int>("dimSrcVoc");
|
||||
int dimEmb = opt<int>("dimEmb");
|
||||
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
|
||||
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
|
||||
bool fixed = opt<bool>("fixed", false);
|
||||
|
||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||
@ -46,58 +48,61 @@ public:
|
||||
std::string queryFile = opt<std::string>("ulrQueryFile");
|
||||
std::string keyFile = opt<std::string>("ulrKeysFile");
|
||||
bool trainTrans = opt<bool>("ulrTrainTransform", false);
|
||||
if (!queryFile.empty() && !keyFile.empty()) {
|
||||
if(!queryFile.empty() && !keyFile.empty()) {
|
||||
initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
|
||||
name = "ulr_query";
|
||||
fixed = true;
|
||||
auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
|
||||
auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(query_embed);
|
||||
// keys embeds
|
||||
initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
|
||||
name = "ulr_keys";
|
||||
fixed = true;
|
||||
auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
|
||||
auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(key_embed);
|
||||
// actual trainable embedding
|
||||
initFunc = inits::glorotUniform();
|
||||
name = "ulr_embed";
|
||||
fixed = false;
|
||||
auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
|
||||
auto ulr_embed
|
||||
= graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim
|
||||
ulrEmbeddings_.push_back(ulr_embed);
|
||||
// init trainable src embedding
|
||||
name = "ulr_src_embed";
|
||||
auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
|
||||
auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(ulr_src_embed);
|
||||
// ulr transformation matrix
|
||||
//initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
|
||||
if (trainTrans) {
|
||||
// initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall
|
||||
// we make this to the fixed case only
|
||||
if(trainTrans) {
|
||||
initFunc = inits::glorotUniform();
|
||||
fixed = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
initFunc = inits::eye(); // identity matrix
|
||||
} else {
|
||||
initFunc = inits::eye(); // identity matrix
|
||||
fixed = true;
|
||||
}
|
||||
name = "ulr_transform";
|
||||
auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
|
||||
auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(ulrTransform);
|
||||
|
||||
initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
|
||||
initFunc = inits::fromValue(
|
||||
1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no
|
||||
// universal embeddings - should be zero for top freq only
|
||||
fixed = true;
|
||||
name = "ulr_shared";
|
||||
auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
|
||||
auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed);
|
||||
ulrEmbeddings_.push_back(share_embed);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
|
||||
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
|
||||
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
|
||||
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
|
||||
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
|
||||
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
|
||||
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
|
||||
std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
|
||||
Ptr<data::SubBatch> subBatch) const override final {
|
||||
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
|
||||
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
|
||||
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
|
||||
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
|
||||
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
|
||||
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
|
||||
int dimBatch = (int)subBatch->batchSize();
|
||||
int dimEmb = uniEmbed->shape()[-1];
|
||||
int dimWords = (int)subBatch->batchWidth();
|
||||
@ -106,34 +111,42 @@ public:
|
||||
// dim A = uni_embed_size*uni_embed_size
|
||||
// dim Q: uni_embed_size * total_merged_vocab_size
|
||||
// dim D = univ_tok_vocab * total_merged_vocab_size
|
||||
// note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
|
||||
// here we need to handle the mini-batch
|
||||
// extract raws corresponding to Xs in this minibatch from Q
|
||||
// note all above can be precombuted and serialized if A is not trainiable and during decoding
|
||||
// (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this
|
||||
// minibatch from Q
|
||||
auto embIdx = toWordIndexVector(subBatch->data());
|
||||
auto queryEmbeddings = rows(queryEmbed, embIdx);
|
||||
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
|
||||
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
|
||||
auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
|
||||
auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
|
||||
qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
|
||||
auto z = dot(qt, keyEmbed, false, true); // query-key similarity
|
||||
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
|
||||
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
|
||||
auto qt = dot(queryEmbeddings,
|
||||
ulrTransform,
|
||||
false,
|
||||
false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
|
||||
auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]);
|
||||
qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in
|
||||
// magnitude with larger embeds sizes
|
||||
auto z = dot(qt, keyEmbed, false, true); // query-key similarity
|
||||
float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
|
||||
if(!inference_)
|
||||
z = dropout(z, dropProb);
|
||||
|
||||
float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
|
||||
float tau
|
||||
= this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
|
||||
// temperature in softmax is to control randomness of predictions
|
||||
// high temperature Softmax outputs are more close to each other
|
||||
// low temperatures the softmax become more similar to "hardmax"
|
||||
auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
|
||||
auto weights
|
||||
= softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
|
||||
auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
|
||||
auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
|
||||
auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
|
||||
auto chosenEmbeddings_mix
|
||||
= srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
|
||||
auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb});
|
||||
auto graph = ulrEmbeddings_.front()->graph();
|
||||
auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
|
||||
inits::fromVector(subBatch->mask()));
|
||||
auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask()));
|
||||
if(!inference_)
|
||||
batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
|
||||
batchEmbeddings = dropout(batchEmbeddings,
|
||||
options_->get<float>("dropout-embeddings", 0.0f),
|
||||
{batchEmbeddings->shape()[-3], 1, 1});
|
||||
return std::make_tuple(batchEmbeddings, batchMask);
|
||||
}
|
||||
|
||||
@ -142,9 +155,10 @@ public:
|
||||
}
|
||||
|
||||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
|
||||
embIdx; shape;
|
||||
ABORT("not implemented"); // @TODO: implement me
|
||||
embIdx;
|
||||
shape;
|
||||
ABORT("not implemented"); // @TODO: implement me
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace marian
|
||||
|
@ -1,13 +1,10 @@
|
||||
#include "marian.h"
|
||||
|
||||
#include "layers/generic.h"
|
||||
#include "layers/constructors.h"
|
||||
#include "layers/loss.h"
|
||||
#include "data/factored_vocab.h"
|
||||
#include "models/states.h" // for EncoderState
|
||||
#include "layers/constructors.h"
|
||||
#include "layers/generic.h"
|
||||
#include "layers/loss.h"
|
||||
#include "layers/lsh.h"
|
||||
#include "models/states.h" // for EncoderState
|
||||
|
||||
namespace marian {
|
||||
|
||||
|
||||
} // namespace marian
|
||||
namespace marian {} // namespace marian
|
||||
|
@ -5,12 +5,14 @@
|
||||
#include "data/shortlist.h"
|
||||
#include "layers/factory.h"
|
||||
|
||||
namespace marian { namespace mlp {
|
||||
/**
|
||||
* @brief Activation functions
|
||||
*/
|
||||
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
|
||||
}}
|
||||
namespace marian {
|
||||
namespace mlp {
|
||||
/**
|
||||
* @brief Activation functions
|
||||
*/
|
||||
enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
|
||||
} // namespace mlp
|
||||
} // namespace marian
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -23,8 +25,7 @@ protected:
|
||||
Ptr<Options> options_;
|
||||
|
||||
public:
|
||||
LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: graph_(graph), options_(options) {}
|
||||
LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
|
||||
|
||||
template <typename T>
|
||||
T opt(const std::string key) const {
|
||||
@ -42,7 +43,7 @@ struct IUnaryLayer {
|
||||
virtual ~IUnaryLayer() {}
|
||||
virtual Expr apply(Expr) = 0;
|
||||
virtual Expr apply(const std::vector<Expr>& es) {
|
||||
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
|
||||
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
|
||||
return apply(es.front());
|
||||
}
|
||||
};
|
||||
@ -54,7 +55,8 @@ struct IHasShortList {
|
||||
|
||||
// Embedding from corpus sub-batch to (emb, mask)
|
||||
struct IEmbeddingLayer {
|
||||
virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0;
|
||||
virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
|
||||
Ptr<data::SubBatch> subBatch) const = 0;
|
||||
|
||||
virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0;
|
||||
|
||||
@ -63,28 +65,29 @@ struct IEmbeddingLayer {
|
||||
virtual ~IEmbeddingLayer() {}
|
||||
};
|
||||
|
||||
// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index)
|
||||
// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream
|
||||
// index)
|
||||
class EncoderDecoderLayerBase : public LayerBase {
|
||||
protected:
|
||||
const std::string prefix_;
|
||||
const bool embeddingFix_;
|
||||
const float dropoutEmbeddings_; // this drops out full embedding vectors
|
||||
const float dropoutEmbeddings_; // this drops out full embedding vectors
|
||||
const bool inference_;
|
||||
const size_t batchIndex_;
|
||||
mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
|
||||
mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
|
||||
|
||||
EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
|
||||
Ptr<Options> options,
|
||||
const std::string& prefix,
|
||||
EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
|
||||
Ptr<Options> options,
|
||||
const std::string& prefix,
|
||||
size_t batchIndex,
|
||||
float dropoutEmbeddings,
|
||||
bool embeddingFix) :
|
||||
LayerBase(graph, options),
|
||||
prefix_(options->get<std::string>("prefix", prefix)),
|
||||
embeddingFix_(embeddingFix),
|
||||
dropoutEmbeddings_(dropoutEmbeddings),
|
||||
inference_(options->get<bool>("inference", false)),
|
||||
batchIndex_(options->get<size_t>("index", batchIndex)) {}
|
||||
bool embeddingFix)
|
||||
: LayerBase(graph, options),
|
||||
prefix_(options->get<std::string>("prefix", prefix)),
|
||||
embeddingFix_(embeddingFix),
|
||||
dropoutEmbeddings_(dropoutEmbeddings),
|
||||
inference_(options->get<bool>("inference", false)),
|
||||
batchIndex_(options->get<size_t>("index", batchIndex)) {}
|
||||
|
||||
virtual ~EncoderDecoderLayerBase() {}
|
||||
|
||||
@ -101,8 +104,7 @@ namespace mlp {
|
||||
|
||||
class Dense : public LayerBase, public IUnaryLayer {
|
||||
public:
|
||||
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options) {}
|
||||
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
|
||||
|
||||
Expr apply(const std::vector<Expr>& inputs) override {
|
||||
ABORT_IF(inputs.empty(), "No inputs");
|
||||
@ -124,21 +126,17 @@ public:
|
||||
if(inputs.size() > 1)
|
||||
num = std::to_string(i);
|
||||
|
||||
Expr W = g->param(
|
||||
name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
|
||||
Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
|
||||
Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros());
|
||||
|
||||
if(useLayerNorm) {
|
||||
if(useNematusNorm) {
|
||||
auto ln_s = g->param(
|
||||
name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
|
||||
auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
|
||||
auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros());
|
||||
|
||||
outputs.push_back(
|
||||
layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
|
||||
outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
|
||||
} else {
|
||||
auto gamma = g->param(
|
||||
name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
|
||||
auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
|
||||
|
||||
outputs.push_back(layerNorm(dot(in, W), gamma, b));
|
||||
}
|
||||
@ -165,39 +163,35 @@ public:
|
||||
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
|
||||
};
|
||||
|
||||
} // namespace mlp
|
||||
|
||||
} // namespace mlp
|
||||
|
||||
// --- a few layers with built-in parameters created on the fly, without proper object
|
||||
// @TODO: change to a proper layer object
|
||||
|
||||
// like affine() but with built-in parameters, activation, and dropout
|
||||
static inline
|
||||
Expr denseInline(Expr x,
|
||||
std::string prefix,
|
||||
std::string suffix,
|
||||
int outDim,
|
||||
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
|
||||
const std::function<Expr(Expr)>& actFn = nullptr,
|
||||
float dropProb = 0.0f)
|
||||
{
|
||||
static inline Expr denseInline(Expr x,
|
||||
std::string prefix,
|
||||
std::string suffix,
|
||||
int outDim,
|
||||
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
|
||||
const std::function<Expr(Expr)>& actFn = nullptr,
|
||||
float dropProb = 0.0f) {
|
||||
auto graph = x->graph();
|
||||
|
||||
auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform());
|
||||
auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros());
|
||||
auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
|
||||
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
|
||||
|
||||
x = affine(x, W, b);
|
||||
if (actFn)
|
||||
if(actFn)
|
||||
x = actFn(x);
|
||||
x = dropout(x, dropProb); // @TODO: check for infernce?
|
||||
x = dropout(x, dropProb); // @TODO: check for infernce?
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline
|
||||
Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
|
||||
static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
|
||||
int dimModel = x->shape()[-1];
|
||||
auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones());
|
||||
auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros());
|
||||
auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones());
|
||||
auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros());
|
||||
return marian::layerNorm(x, scale, bias, 1e-6f);
|
||||
}
|
||||
|
||||
|
@ -1,212 +1,250 @@
|
||||
#include "logits.h"
|
||||
#include "loss.h"
|
||||
#include "data/factored_vocab.h"
|
||||
#include "rnn/types.h" // for State::select()
|
||||
#include "loss.h"
|
||||
#include "rnn/types.h" // for State::select()
|
||||
|
||||
namespace marian {
|
||||
Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count)
|
||||
Logits::Logits(Expr logits)
|
||||
: Logits(New<RationalLoss>(logits, nullptr)) {
|
||||
} // single-output constructor from Expr only (RationalLoss has no count)
|
||||
|
||||
Ptr<ExpressionGraph> Logits::graph() const {
|
||||
ABORT_IF(logits_.empty(), "Empty logits object??");
|
||||
return logits_.front()->loss()->graph();
|
||||
Ptr<ExpressionGraph> Logits::graph() const {
|
||||
ABORT_IF(logits_.empty(), "Empty logits object??");
|
||||
return logits_.front()->loss()->graph();
|
||||
}
|
||||
|
||||
// This function assumes that the object holds one or more factor logits.
|
||||
// It applies the supplied loss function to each, and then returns the aggregate loss over all
|
||||
// factors.
|
||||
Expr Logits::applyLossFunction(
|
||||
const Words& labels,
|
||||
const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const {
|
||||
LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
|
||||
auto firstLogits = logits_.front()->loss();
|
||||
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
|
||||
"Labels not matching logits shape ({} != {}, {})??",
|
||||
labels.size() * firstLogits->shape()[-1],
|
||||
firstLogits->shape().elements(),
|
||||
firstLogits->shape());
|
||||
|
||||
// base case (no factors)
|
||||
if(!factoredVocab_) {
|
||||
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
|
||||
return lossFn(firstLogits, indices(toWordIndexVector(labels)));
|
||||
}
|
||||
|
||||
// This function assumes that the object holds one or more factor logits.
|
||||
// It applies the supplied loss function to each, and then returns the aggregate loss over all factors.
|
||||
Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
|
||||
LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
auto numGroups = factoredVocab_->getNumGroups();
|
||||
|
||||
auto firstLogits = logits_.front()->loss();
|
||||
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
|
||||
"Labels not matching logits shape ({} != {}, {})??",
|
||||
labels.size() * firstLogits->shape()[-1],
|
||||
firstLogits->shape().elements(),
|
||||
firstLogits->shape());
|
||||
// split labels into individual factor labels
|
||||
auto allMaskedFactoredLabels
|
||||
= factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
|
||||
|
||||
// base case (no factors)
|
||||
if (!factoredVocab_) {
|
||||
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
|
||||
return lossFn(firstLogits, indices(toWordIndexVector(labels)));
|
||||
// Expr indices = this->indices(toWordIndexVector(labels));
|
||||
// accumulate all CEs for all words that have the factor
|
||||
// Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
|
||||
Expr loss;
|
||||
for(size_t g = 0; g < numGroups; g++) {
|
||||
if(!logits_[g])
|
||||
continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
|
||||
const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
|
||||
auto factorIndices = indices(
|
||||
maskedFactoredLabels
|
||||
.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
|
||||
auto factorMask
|
||||
= constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with
|
||||
// 0 for labels that don't have this factor
|
||||
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
|
||||
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask
|
||||
// it out next.
|
||||
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
|
||||
if(loss)
|
||||
factorLoss = cast(factorLoss, loss->value_type());
|
||||
factorLoss
|
||||
= factorLoss
|
||||
* cast(
|
||||
reshape(factorMask, factorLoss->shape()),
|
||||
factorLoss->value_type()); // mask out factor for words that do not have that factor
|
||||
loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
|
||||
// This function assumes this object holds a single factor that represents a rational loss (with
|
||||
// count).
|
||||
// Ptr<RationalLoss> Logits::getRationalLoss() const {
|
||||
// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on
|
||||
// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational
|
||||
// loss without count"); return logits_.front();
|
||||
//}
|
||||
|
||||
// get logits for one factor group
|
||||
// For groupIndex == 0, the function also requires the shortlist if there is one.
|
||||
Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
Ptr<data::Shortlist> shortlist /*= nullptr*/,
|
||||
const std::vector<IndexType>& hypIndices /*= {}*/,
|
||||
size_t beamSize /*= 0*/) const {
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
|
||||
auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
|
||||
|
||||
// normalize for decoding:
|
||||
// - all secondary factors: subtract their max
|
||||
// - lemma: add all maxes of applicable factors
|
||||
if(groupIndex > 0) {
|
||||
sel = sel - max(sel, -1);
|
||||
} else {
|
||||
auto numGroups = getNumFactorGroups();
|
||||
for(size_t g = 1; g < numGroups; g++) {
|
||||
auto factorMaxima = max(logits_[g]->loss(),
|
||||
-1); // we cast since loss is likely ce-loss which has type float32
|
||||
auto factorMasks = constant(
|
||||
getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
|
||||
sel = sel
|
||||
+ cast(factorMaxima, sel->value_type())
|
||||
* cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor
|
||||
// get multiplied with 0
|
||||
}
|
||||
|
||||
auto numGroups = factoredVocab_->getNumGroups();
|
||||
|
||||
// split labels into individual factor labels
|
||||
auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
|
||||
|
||||
//Expr indices = this->indices(toWordIndexVector(labels));
|
||||
// accumulate all CEs for all words that have the factor
|
||||
// Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
|
||||
Expr loss;
|
||||
for (size_t g = 0; g < numGroups; g++) {
|
||||
if (!logits_[g])
|
||||
continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
|
||||
const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
|
||||
auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
|
||||
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
|
||||
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
|
||||
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
|
||||
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
|
||||
if(loss)
|
||||
factorLoss = cast(factorLoss, loss->value_type());
|
||||
factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor
|
||||
loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
|
||||
// This function assumes this object holds a single factor that represents a rational loss (with count).
|
||||
//Ptr<RationalLoss> Logits::getRationalLoss() const {
|
||||
// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs");
|
||||
// ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count");
|
||||
// return logits_.front();
|
||||
//}
|
||||
// if selIdx are given, then we must reshuffle accordingly
|
||||
if(!hypIndices.empty()) // use the same function that shuffles decoder state
|
||||
sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
|
||||
|
||||
// get logits for one factor group
|
||||
// For groupIndex == 0, the function also requires the shortlist if there is one.
|
||||
Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const {
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
|
||||
auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
|
||||
return sel;
|
||||
}
|
||||
|
||||
// normalize for decoding:
|
||||
// - all secondary factors: subtract their max
|
||||
// - lemma: add all maxes of applicable factors
|
||||
if (groupIndex > 0) {
|
||||
sel = sel - max(sel, -1);
|
||||
}
|
||||
else {
|
||||
auto numGroups = getNumFactorGroups();
|
||||
for (size_t g = 1; g < numGroups; g++) {
|
||||
auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32
|
||||
auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
|
||||
sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0
|
||||
}
|
||||
}
|
||||
// used for breakDown() only
|
||||
// Index is flattened
|
||||
Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
return logits_[groupIndex]->loss()->val();
|
||||
}
|
||||
|
||||
// if selIdx are given, then we must reshuffle accordingly
|
||||
if (!hypIndices.empty()) // use the same function that shuffles decoder state
|
||||
sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
|
||||
|
||||
return sel;
|
||||
// This function assumes that the object holds one or more factor logits, which are summed up
|
||||
// into output-vocab logits according to the factored model (with correct normalization of factors).
|
||||
// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
|
||||
// @TODO: remove altogether
|
||||
Expr Logits::getLogits() const {
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
if(!factoredVocab_) {
|
||||
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
|
||||
return getFactoredLogits(0);
|
||||
}
|
||||
|
||||
// used for breakDown() only
|
||||
// Index is flattened
|
||||
Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
return logits_[groupIndex]->loss()->val();
|
||||
}
|
||||
|
||||
// This function assumes that the object holds one or more factor logits, which are summed up
|
||||
// into output-vocab logits according to the factored model (with correct normalization of factors).
|
||||
// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
|
||||
// @TODO: remove altogether
|
||||
Expr Logits::getLogits() const {
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
if (!factoredVocab_) {
|
||||
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
|
||||
return getFactoredLogits(0);
|
||||
}
|
||||
|
||||
#ifdef FACTOR_FULL_EXPANSION
|
||||
// compute normalized factor log probs
|
||||
std::vector<Expr> logProbs(logits_.size());
|
||||
for (size_t g = 0; g < logits_.size(); g++)
|
||||
logProbs[g] = logsoftmax(logits_[g]->loss());
|
||||
auto y = concatenate(logProbs, /*axis=*/ -1);
|
||||
// compute normalized factor log probs
|
||||
std::vector<Expr> logProbs(logits_.size());
|
||||
for(size_t g = 0; g < logits_.size(); g++)
|
||||
logProbs[g] = logsoftmax(logits_[g]->loss());
|
||||
auto y = concatenate(logProbs, /*axis=*/-1);
|
||||
|
||||
// sum up the unit logits across factors for each target word
|
||||
auto graph = y->graph();
|
||||
auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
|
||||
y = dot_csr(
|
||||
y, // [B x U]
|
||||
factorMatrix.shape,
|
||||
graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
|
||||
graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32),
|
||||
graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32),
|
||||
/*transB=*/ true); // -> [B x V]
|
||||
// sum up the unit logits across factors for each target word
|
||||
auto graph = y->graph();
|
||||
auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
|
||||
y = dot_csr(
|
||||
y, // [B x U]
|
||||
factorMatrix.shape,
|
||||
graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
|
||||
graph->constant({(int)factorMatrix.indices.size()},
|
||||
inits::fromVector(factorMatrix.indices),
|
||||
Type::uint32),
|
||||
graph->constant({(int)factorMatrix.offsets.size()},
|
||||
inits::fromVector(factorMatrix.offsets),
|
||||
Type::uint32),
|
||||
/*transB=*/true); // -> [B x V]
|
||||
|
||||
// mask out gaps
|
||||
auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
|
||||
y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask));
|
||||
// mask out gaps
|
||||
auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
|
||||
y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask));
|
||||
|
||||
return y;
|
||||
return y;
|
||||
#else
|
||||
ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
|
||||
ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
|
||||
bool isValid = FactoredVocab::isFactorValid(factorIndex);
|
||||
indices.push_back(isValid ? (WordIndex)factorIndex : 0);
|
||||
masks.push_back((float)isValid);
|
||||
}
|
||||
void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
|
||||
bool isValid = FactoredVocab::isFactorValid(factorIndex);
|
||||
indices.push_back(isValid ? (WordIndex)factorIndex : 0);
|
||||
masks.push_back((float)isValid);
|
||||
}
|
||||
|
||||
std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
|
||||
if (!factoredVocab_) {
|
||||
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
|
||||
return {MaskedFactorIndices(words)};
|
||||
}
|
||||
auto numGroups = factoredVocab_->getNumGroups();
|
||||
std::vector<MaskedFactorIndices> res(numGroups);
|
||||
for (size_t g = 0; g < numGroups; g++) {
|
||||
auto& resg = res[g];
|
||||
resg.reserve(words.size());
|
||||
for (const auto& word : words)
|
||||
resg.push_back(factoredVocab_->getFactor(word, g));
|
||||
}
|
||||
return res;
|
||||
std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words)
|
||||
const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
|
||||
if(!factoredVocab_) {
|
||||
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
|
||||
return {MaskedFactorIndices(words)};
|
||||
}
|
||||
|
||||
//// use first factor of each word to determine whether it has a specific factor
|
||||
//std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0
|
||||
// std::vector<float> res;
|
||||
// res.reserve(words.size());
|
||||
// for (const auto& word : words) {
|
||||
// auto lemma = factoredVocab_->getFactor(word, 0);
|
||||
// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
|
||||
// }
|
||||
// return res;
|
||||
//}
|
||||
|
||||
// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
|
||||
// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
|
||||
std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
|
||||
size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size();
|
||||
std::vector<float> res;
|
||||
res.reserve(n);
|
||||
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
|
||||
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
|
||||
}
|
||||
return res;
|
||||
auto numGroups = factoredVocab_->getNumGroups();
|
||||
std::vector<MaskedFactorIndices> res(numGroups);
|
||||
for(size_t g = 0; g < numGroups; g++) {
|
||||
auto& resg = res[g];
|
||||
resg.reserve(words.size());
|
||||
for(const auto& word : words)
|
||||
resg.push_back(factoredVocab_->getFactor(word, g));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
for (const auto& l : logits_)
|
||||
newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
|
||||
return Logits(std::move(newLogits), factoredVocab_);
|
||||
}
|
||||
//// use first factor of each word to determine whether it has a specific factor
|
||||
// std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0
|
||||
// for words that do have this factor; else 0
|
||||
// std::vector<float> res;
|
||||
// res.reserve(words.size());
|
||||
// for (const auto& word : words) {
|
||||
// auto lemma = factoredVocab_->getFactor(word, 0);
|
||||
// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
|
||||
// }
|
||||
// return res;
|
||||
//}
|
||||
|
||||
Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const {
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
bool first = true;
|
||||
for (const auto& l : logits_) {
|
||||
newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others
|
||||
first = false;
|
||||
}
|
||||
return Logits(std::move(newLogits), factoredVocab_);
|
||||
// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
|
||||
// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
|
||||
std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices)
|
||||
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
|
||||
size_t n
|
||||
= indices.empty()
|
||||
? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
|
||||
: indices.size();
|
||||
std::vector<float> res;
|
||||
res.reserve(n);
|
||||
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
|
||||
// into FactoredVocab
|
||||
for(size_t i = 0; i < n; i++) {
|
||||
auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
|
||||
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// @TODO: code dup with above; we can merge it into applyToRationalLoss()
|
||||
Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
for (const auto& l : logits_)
|
||||
newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
|
||||
return Logits(std::move(newLogits), factoredVocab_);
|
||||
Logits Logits::applyUnaryFunction(
|
||||
const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
for(const auto& l : logits_)
|
||||
newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
|
||||
return Logits(std::move(newLogits), factoredVocab_);
|
||||
}
|
||||
|
||||
Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
|
||||
const std::function<Expr(Expr)>& fother) const {
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
bool first = true;
|
||||
for(const auto& l : logits_) {
|
||||
newLogits.emplace_back(New<RationalLoss>((first ? f1 : fother)(l->loss()),
|
||||
l->count())); // f1 for first, fother for all others
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
return Logits(std::move(newLogits), factoredVocab_);
|
||||
}
|
||||
|
||||
// @TODO: code dup with above; we can merge it into applyToRationalLoss()
|
||||
Logits Logits::withCounts(
|
||||
const Expr& count) const { // create new Logits with 'count' implanted into all logits_
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
for(const auto& l : logits_)
|
||||
newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
|
||||
return Logits(std::move(newLogits), factoredVocab_);
|
||||
}
|
||||
} // namespace marian
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "marian.h"
|
||||
#include "data/shortlist.h"
|
||||
#include "generic.h"
|
||||
#include "marian.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -16,46 +16,77 @@ class FactoredVocab;
|
||||
class RationalLoss;
|
||||
class Logits {
|
||||
public:
|
||||
Logits() {}
|
||||
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
|
||||
logits_.push_back(logits);
|
||||
}
|
||||
explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
|
||||
Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
|
||||
Logits() {}
|
||||
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
|
||||
logits_.push_back(logits);
|
||||
}
|
||||
explicit Logits(
|
||||
Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
|
||||
Logits(std::vector<Ptr<RationalLoss>>&& logits,
|
||||
Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
|
||||
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
|
||||
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
|
||||
Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
|
||||
//Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
|
||||
Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
|
||||
Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
|
||||
Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values
|
||||
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
|
||||
Expr getFactoredLogits(
|
||||
size_t groupIndex,
|
||||
Ptr<data::Shortlist> shortlist = nullptr,
|
||||
const std::vector<IndexType>& hypIndices = {},
|
||||
size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
|
||||
// Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
|
||||
Expr applyLossFunction(
|
||||
const Words& labels,
|
||||
const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const;
|
||||
Logits applyUnaryFunction(
|
||||
const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
|
||||
Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
|
||||
const std::function<Expr(Expr)>& fother)
|
||||
const; // clone this but apply f1 to first and fother to to all other values
|
||||
|
||||
struct MaskedFactorIndices {
|
||||
std::vector<WordIndex> indices; // factor index, or 0 if masked
|
||||
std::vector<float> masks;
|
||||
void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
|
||||
void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
|
||||
MaskedFactorIndices() {}
|
||||
MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
|
||||
};
|
||||
std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
|
||||
Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
|
||||
size_t getNumFactorGroups() const { return logits_.size(); }
|
||||
bool empty() const { return logits_.empty(); }
|
||||
Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
|
||||
struct MaskedFactorIndices {
|
||||
std::vector<WordIndex> indices; // factor index, or 0 if masked
|
||||
std::vector<float> masks;
|
||||
void reserve(size_t n) {
|
||||
indices.reserve(n);
|
||||
masks.reserve(n);
|
||||
}
|
||||
void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0
|
||||
// for invalid entries
|
||||
MaskedFactorIndices() {}
|
||||
MaskedFactorIndices(const Words& words) {
|
||||
indices = toWordIndexVector(words);
|
||||
} // we can leave masks uninitialized for this special use case
|
||||
};
|
||||
std::vector<MaskedFactorIndices> factorizeWords(
|
||||
const Words& words) const; // breaks encoded Word into individual factor indices
|
||||
Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
|
||||
size_t getNumFactorGroups() const { return logits_.size(); }
|
||||
bool empty() const { return logits_.empty(); }
|
||||
Logits withCounts(
|
||||
const Expr& count) const; // create new Logits with 'count' implanted into all logits_
|
||||
private:
|
||||
// helper functions
|
||||
Ptr<ExpressionGraph> graph() const;
|
||||
Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
|
||||
Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
|
||||
template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
|
||||
Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
|
||||
std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
|
||||
// helper functions
|
||||
Ptr<ExpressionGraph> graph() const;
|
||||
Expr constant(const Shape& shape, const std::vector<float>& data) const {
|
||||
return graph()->constant(shape, inits::fromVector(data));
|
||||
}
|
||||
Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const {
|
||||
return graph()->constant(shape, inits::fromVector(data));
|
||||
}
|
||||
template <typename T>
|
||||
Expr constant(const std::vector<T>& data) const {
|
||||
return constant(Shape{(int)data.size()}, data);
|
||||
} // same as constant() but assuming vector
|
||||
Expr indices(const std::vector<uint32_t>& data) const {
|
||||
return graph()->indices(data);
|
||||
} // actually the same as constant(data) for this data type
|
||||
std::vector<float> getFactorMasks(size_t factorGroup,
|
||||
const std::vector<WordIndex>& indices) const;
|
||||
|
||||
private:
|
||||
// members
|
||||
// @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
|
||||
std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
|
||||
Ptr<FactoredVocab> factoredVocab_;
|
||||
// members
|
||||
// @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just
|
||||
// by the Expr
|
||||
std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
|
||||
Ptr<FactoredVocab> factoredVocab_;
|
||||
};
|
||||
|
||||
// Unary function that returns a Logits object
|
||||
@ -65,12 +96,11 @@ private:
|
||||
struct IUnaryLogitLayer : public IUnaryLayer {
|
||||
virtual Logits applyAsLogits(Expr) = 0;
|
||||
virtual Logits applyAsLogits(const std::vector<Expr>& es) {
|
||||
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
|
||||
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
|
||||
return applyAsLogits(es.front());
|
||||
}
|
||||
virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
|
||||
virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -13,26 +13,30 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
|
||||
bool wordScores = options->get<bool>("word-scores", false);
|
||||
return New<RescorerLoss>(wordScores);
|
||||
} else if(unlikelihood) {
|
||||
ABORT_IF(!options->hasAndNotEmpty("data-weighting")
|
||||
&& options->get<std::string>("data-weighting-type") != "word",
|
||||
"Unlikelihood loss training requires error annotation in form of per-target-label scores");
|
||||
return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
|
||||
} else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
|
||||
ABORT_IF(
|
||||
!options->hasAndNotEmpty("data-weighting")
|
||||
&& options->get<std::string>("data-weighting-type") != "word",
|
||||
"Unlikelihood loss training requires error annotation in form of per-target-label scores");
|
||||
return New<SequenceUnlikelihoodLoss>(
|
||||
smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on
|
||||
// values given for data-weighting
|
||||
} else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones.
|
||||
// E.g. what about ce-sum?
|
||||
return New<CrossEntropyLoss>(smoothing, factorWeight);
|
||||
}
|
||||
}
|
||||
|
||||
// see loss.h for detailed explanations of each class
|
||||
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
|
||||
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
|
||||
if(multiLossType == "sum") // sum of sums
|
||||
return New<SumMultiRationalLoss>();
|
||||
else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
|
||||
return New<ScaledMultiRationalLoss>();
|
||||
else if(multiLossType == "mean") // sum of means
|
||||
return New<MeanMultiRationalLoss>();
|
||||
else
|
||||
ABORT("Unknown multi-loss-type {}", multiLossType);
|
||||
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
|
||||
if(multiLossType == "sum") // sum of sums
|
||||
return New<SumMultiRationalLoss>();
|
||||
else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
|
||||
return New<ScaledMultiRationalLoss>();
|
||||
else if(multiLossType == "mean") // sum of means
|
||||
return New<MeanMultiRationalLoss>();
|
||||
else
|
||||
ABORT("Unknown multi-loss-type {}", multiLossType);
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "graph/expression_operators.h"
|
||||
#include "layers/logits.h" // for Logits (Frank's factor hack)
|
||||
#include "data/types.h"
|
||||
#include "graph/expression_operators.h"
|
||||
#include "layers/logits.h" // for Logits (Frank's factor hack)
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -22,21 +22,18 @@ namespace marian {
|
||||
*/
|
||||
class RationalLoss {
|
||||
protected:
|
||||
Expr loss_; // numerator
|
||||
Expr count_; // denominator
|
||||
Expr loss_; // numerator
|
||||
Expr count_; // denominator
|
||||
|
||||
RationalLoss() = default; // protected
|
||||
RationalLoss() = default; // protected
|
||||
|
||||
public:
|
||||
RationalLoss(Expr loss, Expr count)
|
||||
: loss_(loss), count_(count) {}
|
||||
RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {}
|
||||
|
||||
RationalLoss(Expr loss, float count)
|
||||
: loss_(loss),
|
||||
count_(constant_like(loss, inits::fromValue(count))) {}
|
||||
: loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {}
|
||||
|
||||
RationalLoss(const RationalLoss& other)
|
||||
: loss_(other.loss_), count_(other.count_) {}
|
||||
RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {}
|
||||
|
||||
virtual ~RationalLoss() = default;
|
||||
|
||||
@ -50,7 +47,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T loss() const { // this will fail if loss is not a single value
|
||||
T loss() const { // this will fail if loss is not a single value
|
||||
ABORT_IF(!loss_, "Loss has not been defined");
|
||||
return loss_->val()->scalar<T>();
|
||||
}
|
||||
@ -65,7 +62,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T count() const { // this will fail if loss is not a single value
|
||||
T count() const { // this will fail if loss is not a single value
|
||||
ABORT_IF(!count_, "Labels have not been defined");
|
||||
return count_->val()->scalar<T>();
|
||||
}
|
||||
@ -85,21 +82,21 @@ public:
|
||||
* RationalLoss object.
|
||||
*/
|
||||
struct StaticLoss {
|
||||
float loss; // numerator
|
||||
float count; // denominator
|
||||
float loss; // numerator
|
||||
float count; // denominator
|
||||
|
||||
StaticLoss() : loss(0.f), count(0.f) {}
|
||||
|
||||
StaticLoss(const RationalLoss& dynamic)
|
||||
: loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
|
||||
: loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
|
||||
|
||||
StaticLoss operator +(const StaticLoss& other) const {
|
||||
StaticLoss operator+(const StaticLoss& other) const {
|
||||
StaticLoss res(*this);
|
||||
res += other;
|
||||
return res;
|
||||
}
|
||||
|
||||
StaticLoss& operator +=(const StaticLoss& other) {
|
||||
StaticLoss& operator+=(const StaticLoss& other) {
|
||||
loss = loss + other.loss;
|
||||
count = count + other.count;
|
||||
return *this;
|
||||
@ -139,32 +136,21 @@ protected:
|
||||
public:
|
||||
MultiRationalLoss() : RationalLoss() {}
|
||||
|
||||
MultiRationalLoss(const RationalLoss& rl) : RationalLoss() {
|
||||
push_back(rl);
|
||||
}
|
||||
MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); }
|
||||
|
||||
virtual void push_back(const RationalLoss& current) {
|
||||
loss_ = accumulateLoss(current);
|
||||
count_ = accumulateCount(current);
|
||||
loss_ = accumulateLoss(current);
|
||||
count_ = accumulateCount(current);
|
||||
partialLosses_.push_back(current);
|
||||
}
|
||||
|
||||
const RationalLoss& operator[](size_t i) {
|
||||
return partialLosses_[i];
|
||||
}
|
||||
const RationalLoss& operator[](size_t i) { return partialLosses_[i]; }
|
||||
|
||||
auto begin() -> decltype(partialLosses_.begin()) const {
|
||||
return partialLosses_.begin();
|
||||
}
|
||||
auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); }
|
||||
|
||||
auto end() -> decltype(partialLosses_.end()) const {
|
||||
return partialLosses_.end();
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return partialLosses_.size();
|
||||
}
|
||||
auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); }
|
||||
|
||||
size_t size() const { return partialLosses_.size(); }
|
||||
};
|
||||
|
||||
/**
|
||||
@ -212,17 +198,19 @@ private:
|
||||
virtual Expr accumulateLoss(const RationalLoss& current) override {
|
||||
if(loss_) {
|
||||
const auto& first = partialLosses_.front();
|
||||
return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss
|
||||
return loss_
|
||||
+ current.loss() * first.count()
|
||||
/ current.count(); // scale up/down to match scale of first loss
|
||||
} else {
|
||||
return current.loss(); // first reference loss, keeps to scale with this one
|
||||
return current.loss(); // first reference loss, keeps to scale with this one
|
||||
}
|
||||
}
|
||||
|
||||
virtual Expr accumulateCount(const RationalLoss& current) override {
|
||||
if(count_) {
|
||||
return count_; // Keep first label count // or: count_ + first.count() / current.count();
|
||||
return count_; // Keep first label count // or: count_ + first.count() / current.count();
|
||||
} else {
|
||||
return current.count(); // This is the first loss
|
||||
return current.count(); // This is the first loss
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,9 +241,10 @@ private:
|
||||
|
||||
virtual Expr accumulateCount(const RationalLoss& current) override {
|
||||
if(count_)
|
||||
return count_; // keep the existing '1'
|
||||
return count_; // keep the existing '1'
|
||||
else
|
||||
return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
|
||||
return current.count()->graph()->ones(
|
||||
{1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
|
||||
}
|
||||
|
||||
public:
|
||||
@ -279,18 +268,21 @@ class LabelwiseLoss {
|
||||
protected:
|
||||
std::vector<int> axes_;
|
||||
|
||||
virtual Expr compute(Logits logits, const Words& labels,
|
||||
Expr mask = nullptr, Expr labelWeights = nullptr) = 0;
|
||||
virtual Expr compute(Logits logits,
|
||||
const Words& labels,
|
||||
Expr mask = nullptr,
|
||||
Expr labelWeights = nullptr)
|
||||
= 0;
|
||||
|
||||
// label counts are available, reduce together with loss to obtain counts
|
||||
RationalLoss reduce(Expr loss, Expr labels) {
|
||||
ABORT_IF(!loss, "Loss has not been computed");
|
||||
ABORT_IF(!labels, "Labels have not been computed");
|
||||
|
||||
Expr lossSum = cast(loss, Type::float32); // accumulate in float32
|
||||
Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
|
||||
Expr lossSum = cast(loss, Type::float32); // accumulate in float32
|
||||
Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
|
||||
for(int i = 0; i < axes_.size(); ++i) {
|
||||
lossSum = sum(lossSum, axes_[i]);
|
||||
lossSum = sum(lossSum, axes_[i]);
|
||||
labelsSum = sum(labelsSum, axes_[i]);
|
||||
}
|
||||
|
||||
@ -301,7 +293,7 @@ protected:
|
||||
RationalLoss reduce(Expr loss) {
|
||||
ABORT_IF(!loss, "Loss has not been computed");
|
||||
|
||||
Expr lossSum = cast(loss, Type::float32);
|
||||
Expr lossSum = cast(loss, Type::float32);
|
||||
for(int i = 0; i < axes_.size(); ++i)
|
||||
lossSum = sum(lossSum, axes_[i]);
|
||||
|
||||
@ -311,17 +303,18 @@ protected:
|
||||
}
|
||||
|
||||
public:
|
||||
LabelwiseLoss(const std::vector<int>& axes)
|
||||
: axes_(axes) { }
|
||||
LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {}
|
||||
|
||||
virtual RationalLoss apply(Logits logits, const Words& labels,
|
||||
Expr mask = nullptr, Expr labelWeights = nullptr) {
|
||||
virtual RationalLoss apply(Logits logits,
|
||||
const Words& labels,
|
||||
Expr mask = nullptr,
|
||||
Expr labelWeights = nullptr) {
|
||||
Expr loss = compute(logits, labels, mask, labelWeights);
|
||||
|
||||
if(mask)
|
||||
return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
|
||||
return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
|
||||
else
|
||||
return reduce(loss); // we have no mask, assume all items are labels
|
||||
return reduce(loss); // we have no mask, assume all items are labels
|
||||
}
|
||||
};
|
||||
|
||||
@ -331,28 +324,34 @@ public:
|
||||
class CrossEntropyLoss : public LabelwiseLoss {
|
||||
public:
|
||||
CrossEntropyLoss(float labelSmoothing, float factorWeight)
|
||||
: CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
|
||||
: CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {
|
||||
} // cross-entropy already reduces over axis -1
|
||||
|
||||
CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
|
||||
: LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
|
||||
labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {}
|
||||
: LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
|
||||
labelSmoothing_(labelSmoothing),
|
||||
factorWeight_(factorWeight) {}
|
||||
|
||||
virtual ~CrossEntropyLoss() {}
|
||||
protected:
|
||||
float labelSmoothing_; // interpolation factor for label smoothing, see below
|
||||
float factorWeight_; // give extra weight to factors
|
||||
|
||||
virtual Expr compute(Logits logits, const Words& labels,
|
||||
Expr mask = nullptr, Expr labelWeights = nullptr) override {
|
||||
// logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up
|
||||
protected:
|
||||
float labelSmoothing_; // interpolation factor for label smoothing, see below
|
||||
float factorWeight_; // give extra weight to factors
|
||||
|
||||
virtual Expr compute(Logits logits,
|
||||
const Words& labels,
|
||||
Expr mask = nullptr,
|
||||
Expr labelWeights = nullptr) override {
|
||||
// logits may be factored; in that case, the getLoss() function computes one loss for each, and
|
||||
// sums them up
|
||||
int inFactor = false;
|
||||
auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
|
||||
logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
|
||||
logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
|
||||
// for bert training or classification the time dimension is lost.
|
||||
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
|
||||
|
||||
|
||||
Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32);
|
||||
if (inFactor && factorWeight_ != 1.0f) {
|
||||
if(inFactor && factorWeight_ != 1.0f) {
|
||||
LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
|
||||
ce = ce * factorWeight_;
|
||||
}
|
||||
@ -365,8 +364,10 @@ protected:
|
||||
|
||||
if(labelWeights) {
|
||||
// We currently do not know how to use target factors and word-level label weights together
|
||||
bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
|
||||
ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors");
|
||||
bool wordlevel = labelWeights->shape()[-3]
|
||||
> 1; // Time-dimension is not trivially 1, hence we have word-level weights.
|
||||
ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1,
|
||||
"CE loss with word-level label weights is not implemented for factors");
|
||||
ce = ce * cast(labelWeights, Type::float32);
|
||||
}
|
||||
|
||||
@ -374,13 +375,12 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an
|
||||
* implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319.
|
||||
* We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not
|
||||
* zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going
|
||||
* to flip over to use SUL for that sentence to penalize the selected word.
|
||||
* We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are
|
||||
* not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it
|
||||
* is going to flip over to use SUL for that sentence to penalize the selected word.
|
||||
*
|
||||
* SUL is implemented as:
|
||||
* -log(gather(1 - softmax(logits), -1, indices))
|
||||
@ -390,35 +390,45 @@ protected:
|
||||
class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
|
||||
public:
|
||||
SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
|
||||
: CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
|
||||
: CrossEntropyLoss(labelSmoothing, factorWeight) {
|
||||
} // cross-entropy already reduces over axis -1
|
||||
|
||||
SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
|
||||
: CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
|
||||
: CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
|
||||
|
||||
protected:
|
||||
virtual Expr compute(Logits logits, const Words& labels,
|
||||
Expr mask = nullptr, Expr labelWeights = nullptr) override {
|
||||
auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
|
||||
virtual Expr compute(Logits logits,
|
||||
const Words& labels,
|
||||
Expr mask = nullptr,
|
||||
Expr labelWeights = nullptr) override {
|
||||
auto ce = CrossEntropyLoss::compute(
|
||||
logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
|
||||
if(!labelWeights)
|
||||
return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
|
||||
return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
|
||||
|
||||
// We currently do not know how to use target factors and word-level label weights together
|
||||
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
|
||||
|
||||
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
|
||||
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete)
|
||||
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by
|
||||
// default 1, which would make this obsolete.
|
||||
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask
|
||||
// again to eliminate padding (might be obsolete)
|
||||
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
|
||||
|
||||
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
|
||||
return cast(unlikelihood(logits, indices), Type::float32);
|
||||
});
|
||||
|
||||
// compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training
|
||||
// schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL.
|
||||
auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
|
||||
ceUl = errorMask * ceUl; // don't use for correct label or padding
|
||||
// compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only
|
||||
// on_ the errors with UL. This is the "mixed" training schedule from
|
||||
// https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily
|
||||
// switch between CE and UL.
|
||||
auto onlyCe = eq(sum(errorMask, /*axis=*/-3),
|
||||
0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
|
||||
ceUl = errorMask * ceUl; // don't use for correct label or padding
|
||||
|
||||
auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry
|
||||
auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never
|
||||
// simultanously used as cost per batch entry
|
||||
|
||||
return cost;
|
||||
}
|
||||
@ -463,7 +473,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* @brief Factory for label-wise loss functions
|
||||
*/
|
||||
|
@ -1,120 +1,131 @@
|
||||
#include "output.h"
|
||||
#include "data/factored_vocab.h"
|
||||
#include "common/timer.h"
|
||||
#include "layers/lsh.h"
|
||||
#include "data/factored_vocab.h"
|
||||
#include "layers/loss.h"
|
||||
#include "layers/lsh.h"
|
||||
|
||||
namespace marian {
|
||||
namespace mlp {
|
||||
|
||||
/*private*/ void Output::lazyConstruct(int inputDim) {
|
||||
// We must construct lazily since we won't know tying nor input dim in constructor.
|
||||
if (Wt_)
|
||||
// We must construct lazily since we won't know tying nor input dim in constructor.
|
||||
if(Wt_)
|
||||
return;
|
||||
|
||||
// this option is only set in the decoder
|
||||
if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
|
||||
auto k = opt<std::vector<int>>("output-approx-knn")[0];
|
||||
// this option is only set in the decoder
|
||||
if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
|
||||
auto k = opt<std::vector<int>>("output-approx-knn")[0];
|
||||
auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
|
||||
lsh_ = New<LSH>(k, nbits);
|
||||
}
|
||||
}
|
||||
|
||||
auto name = options_->get<std::string>("prefix");
|
||||
auto numOutputClasses = options_->get<int>("dim");
|
||||
auto name = options_->get<std::string>("prefix");
|
||||
auto numOutputClasses = options_->get<int>("dim");
|
||||
|
||||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
|
||||
if (factoredVocab_) {
|
||||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
|
||||
if(factoredVocab_) {
|
||||
numOutputClasses = (int)factoredVocab_->factorVocabSize();
|
||||
LOG_ONCE(info, "[embedding] Factored outputs enabled");
|
||||
}
|
||||
}
|
||||
|
||||
if(tiedParam_) {
|
||||
if(tiedParam_) {
|
||||
Wt_ = tiedParam_;
|
||||
} else {
|
||||
if (graph_->get(name + "_W")) { // support of legacy models that did not transpose
|
||||
Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
|
||||
isLegacyUntransposedW = true;
|
||||
}
|
||||
else // this is the regular case:
|
||||
Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
|
||||
}
|
||||
} else {
|
||||
if(graph_->get(name + "_W")) { // support of legacy models that did not transpose
|
||||
Wt_ = graph_->param(
|
||||
name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
|
||||
isLegacyUntransposedW = true;
|
||||
} else // this is the regular case:
|
||||
Wt_ = graph_->param(
|
||||
name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
|
||||
}
|
||||
|
||||
if(hasBias_)
|
||||
if(hasBias_)
|
||||
b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
|
||||
|
||||
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
|
||||
ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
|
||||
if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
|
||||
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
|
||||
ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
|
||||
if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
|
||||
#define HARDMAX_HACK
|
||||
#ifdef HARDMAX_HACK
|
||||
lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
|
||||
lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
|
||||
#endif
|
||||
auto range = factoredVocab_->getGroupRange(0);
|
||||
auto lemmaVocabDim = (int)(range.second - range.first);
|
||||
auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
|
||||
lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
|
||||
}
|
||||
auto initFunc = inits::glorotUniform(
|
||||
/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
|
||||
lemmaEt_ = graph_->param(name + "_lemmaEt",
|
||||
{lemmaDimEmb, lemmaVocabDim},
|
||||
initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
|
||||
}
|
||||
}
|
||||
|
||||
Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
lazyConstruct(input->shape()[-1]);
|
||||
lazyConstruct(input->shape()[-1]);
|
||||
|
||||
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
if(b)
|
||||
return affine(x, W, b, transA, transB);
|
||||
return affine(x, W, b, transA, transB);
|
||||
else
|
||||
return dot(x, W, transA, transB);
|
||||
};
|
||||
return dot(x, W, transA, transB);
|
||||
};
|
||||
|
||||
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
if(lsh_) {
|
||||
ABORT_IF( transA, "Transposed query not supported for LSH");
|
||||
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
|
||||
return lsh_->apply(x, W, b); // knows how to deal with undefined bias
|
||||
ABORT_IF(transA, "Transposed query not supported for LSH");
|
||||
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
|
||||
return lsh_->apply(x, W, b); // knows how to deal with undefined bias
|
||||
} else {
|
||||
return affineOrDot(x, W, b, transA, transB);
|
||||
return affineOrDot(x, W, b, transA, transB);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
|
||||
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
|
||||
if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one
|
||||
// batch, then clear()ed
|
||||
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
|
||||
if(hasBias_)
|
||||
cachedShortb_ = index_select(b_ , -1, shortlist_->indices());
|
||||
}
|
||||
cachedShortb_ = index_select(b_, -1, shortlist_->indices());
|
||||
}
|
||||
|
||||
if (factoredVocab_) {
|
||||
if(factoredVocab_) {
|
||||
auto graph = input->graph();
|
||||
|
||||
// project each factor separately
|
||||
auto numGroups = factoredVocab_->getNumGroups();
|
||||
std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors)
|
||||
Expr input1 = input; // [B... x D]
|
||||
Expr Plemma = nullptr; // used for lemmaDimEmb=-1
|
||||
Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
|
||||
for (size_t g = 0; g < numGroups; g++) {
|
||||
auto range = factoredVocab_->getGroupRange(g);
|
||||
if (g > 0 && range.first == range.second) // empty entry
|
||||
std::vector<Ptr<RationalLoss>> allLogits(numGroups,
|
||||
nullptr); // (note: null entries for absent factors)
|
||||
Expr input1 = input; // [B... x D]
|
||||
Expr Plemma = nullptr; // used for lemmaDimEmb=-1
|
||||
Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
|
||||
for(size_t g = 0; g < numGroups; g++) {
|
||||
auto range = factoredVocab_->getGroupRange(g);
|
||||
if(g > 0 && range.first == range.second) // empty entry
|
||||
continue;
|
||||
ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g);
|
||||
// slice this group's section out of W_
|
||||
Expr factorWt, factorB;
|
||||
if (g == 0 && shortlist_) {
|
||||
ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second,
|
||||
"Factor groups must be consecutive (group {} vs predecessor)",
|
||||
g);
|
||||
// slice this group's section out of W_
|
||||
Expr factorWt, factorB;
|
||||
if(g == 0 && shortlist_) {
|
||||
factorWt = cachedShortWt_;
|
||||
factorB = cachedShortb_;
|
||||
}
|
||||
else {
|
||||
factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
|
||||
factorB = cachedShortb_;
|
||||
} else {
|
||||
factorWt = slice(
|
||||
Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
|
||||
if(hasBias_)
|
||||
factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
|
||||
}
|
||||
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
|
||||
if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
|
||||
factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
|
||||
}
|
||||
/*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
|
||||
if((lemmaDimEmb == -2 || lemmaDimEmb == -3)
|
||||
&& g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
|
||||
LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
|
||||
// this mimics one transformer layer
|
||||
// - attention over two inputs:
|
||||
// - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas.
|
||||
// - e = current lemma. We use the original embedding vector; specifically, expectation
|
||||
// over all lemmas.
|
||||
// - input = hidden state FF(h_enc+h_dec)
|
||||
// - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention)
|
||||
// - dot-prod attention to allow both sides to influence (unlike our recurrent
|
||||
// self-attention)
|
||||
// - multi-head to allow for multiple conditions to be modeled
|
||||
// - add & norm, for gradient flow and scaling
|
||||
// - FF layer --this is expensive; it is per-factor
|
||||
@ -122,112 +133,161 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
int inputDim = input->shape()[-1];
|
||||
int heads = 8;
|
||||
auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
|
||||
auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform());
|
||||
auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform());
|
||||
auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform());
|
||||
auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform());
|
||||
auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform());
|
||||
auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform());
|
||||
auto toMultiHead = [&](Expr x, int heads) {
|
||||
const auto& shape = x->shape();
|
||||
int inputDim = shape[-1];
|
||||
int otherDim = shape.elements() / inputDim;
|
||||
ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads);
|
||||
return reshape(x, { otherDim, heads, 1, inputDim / heads });
|
||||
const auto& shape = x->shape();
|
||||
int inputDim = shape[-1];
|
||||
int otherDim = shape.elements() / inputDim;
|
||||
ABORT_IF(inputDim / heads * heads != inputDim,
|
||||
"inputDim ({}) must be multiple of number of heads ({})",
|
||||
inputDim,
|
||||
heads);
|
||||
return reshape(x, {otherDim, heads, 1, inputDim / heads});
|
||||
};
|
||||
input1 = inputLemma;
|
||||
auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
|
||||
auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax.
|
||||
auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values
|
||||
auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other
|
||||
auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
|
||||
auto sm = sigmoid(zm); // [B... x H x 1]
|
||||
auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
|
||||
auto r = reshape(rm, input->shape()); // [B... x D]
|
||||
auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
|
||||
auto kdm = toMultiHead(dot(input1 - input, Wk),
|
||||
heads); // [B... x H x D/H] the two data vectors projected as keys.
|
||||
// Use diff and sigmoid, instead of softmax.
|
||||
auto vem = toMultiHead(
|
||||
dot(input1, Wv),
|
||||
heads); // [B... x H x D/H] one of the two data vectors projected as values
|
||||
auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other
|
||||
auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
|
||||
auto sm = sigmoid(zm); // [B... x H x 1]
|
||||
auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
|
||||
auto r = reshape(rm, input->shape()); // [B... x D]
|
||||
// add & norm
|
||||
input1 = r + input1;
|
||||
input1 = layerNorm(input1, name + "_att");
|
||||
// FF layer
|
||||
auto ffnDropProb = 0.1f; // @TODO: get as a parameter
|
||||
auto ffnDim = inputDim * 2; // @TODO: get as a parameter
|
||||
auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb);
|
||||
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
|
||||
auto ffnDropProb = 0.1f; // @TODO: get as a parameter
|
||||
auto ffnDim = inputDim * 2; // @TODO: get as a parameter
|
||||
auto f = denseInline(input1,
|
||||
name + "_ffn",
|
||||
/*suffix=*/"1",
|
||||
ffnDim,
|
||||
inits::glorotUniform(),
|
||||
(ActivationFunction*)relu,
|
||||
ffnDropProb);
|
||||
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
|
||||
// add & norm
|
||||
input1 = f + input1;
|
||||
input1 = layerNorm(input1, name + "_ffn");
|
||||
}
|
||||
// @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix
|
||||
Expr factorLogits;
|
||||
if(g == 0)
|
||||
factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
else
|
||||
factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
|
||||
// optionally add lemma-dependent bias
|
||||
if (Plemma) { // [B... x U0]
|
||||
}
|
||||
// @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
|
||||
// matrix
|
||||
Expr factorLogits;
|
||||
if(g == 0)
|
||||
factorLogits = affineOrLSH(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
else
|
||||
factorLogits = affineOrDot(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
|
||||
// optionally add lemma-dependent bias
|
||||
if(Plemma) { // [B... x U0]
|
||||
int lemmaVocabDim = Plemma->shape()[-1];
|
||||
int factorVocabDim = factorLogits->shape()[-1];
|
||||
auto name = options_->get<std::string>("prefix");
|
||||
Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
|
||||
auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
|
||||
Expr lemmaBt
|
||||
= graph_->param(name + "_lemmaBt_" + std::to_string(g),
|
||||
{factorVocabDim, lemmaVocabDim},
|
||||
inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
|
||||
auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
|
||||
factorLogits = factorLogits + b;
|
||||
}
|
||||
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
|
||||
// optionally add a soft embedding of lemma back to create some lemma dependency
|
||||
// @TODO: if this works, move it into lazyConstruct
|
||||
if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
|
||||
}
|
||||
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
|
||||
// optionally add a soft embedding of lemma back to create some lemma dependency
|
||||
// @TODO: if this works, move it into lazyConstruct
|
||||
if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
|
||||
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
|
||||
// get expected lemma embedding vector
|
||||
auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
|
||||
auto factorLogSoftmax = logsoftmax(
|
||||
factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
|
||||
auto factorSoftmax = exp(factorLogSoftmax);
|
||||
inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
|
||||
}
|
||||
else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
|
||||
inputLemma = dot(factorSoftmax,
|
||||
factorWt,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
|
||||
} else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
|
||||
LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
|
||||
// get max-lemma embedding vector
|
||||
auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set
|
||||
auto maxVal = max(factorLogits,
|
||||
-1); // [B... x U] note: with shortlist, this is not the full lemma set
|
||||
auto factorHardmax = eq(factorLogits, maxVal);
|
||||
inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
|
||||
}
|
||||
else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
|
||||
inputLemma = dot(factorHardmax,
|
||||
factorWt,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
|
||||
} else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
|
||||
ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
|
||||
LOG_ONCE(info, "[embedding] using lemma-dependent bias");
|
||||
auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
|
||||
auto z = /*stopGradient*/(factorLogSoftmax);
|
||||
Plemma = exp(z); // [B... x U]
|
||||
}
|
||||
else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
|
||||
auto factorLogSoftmax
|
||||
= logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
|
||||
auto z = /*stopGradient*/ (factorLogSoftmax);
|
||||
Plemma = exp(z); // [B... x U]
|
||||
} else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
|
||||
LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
|
||||
// compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE
|
||||
// compute softmax. We compute logsoftmax() separately because this way, computation will be
|
||||
// reused later via CSE
|
||||
auto factorLogSoftmax = logsoftmax(factorLogits);
|
||||
auto factorSoftmax = exp(factorLogSoftmax);
|
||||
#ifdef HARDMAX_HACK
|
||||
bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation)
|
||||
if (hardmax) {
|
||||
lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
|
||||
LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
|
||||
auto maxVal = max(factorSoftmax, -1);
|
||||
factorSoftmax = eq(factorSoftmax, maxVal);
|
||||
bool hardmax = (lemmaDimEmb & 1)
|
||||
!= 0; // odd value triggers hardmax for now (for quick experimentation)
|
||||
if(hardmax) {
|
||||
lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
|
||||
LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
|
||||
auto maxVal = max(factorSoftmax, -1);
|
||||
factorSoftmax = eq(factorSoftmax, maxVal);
|
||||
}
|
||||
#endif
|
||||
// re-embedding lookup, soft-indexed by softmax
|
||||
if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
|
||||
auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
|
||||
if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
|
||||
auto e = dot(factorSoftmax,
|
||||
cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_,
|
||||
false,
|
||||
true); // [B... x L]
|
||||
// project it back to regular hidden dim
|
||||
int inputDim = input1->shape()[-1];
|
||||
auto name = options_->get<std::string>("prefix");
|
||||
// note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1
|
||||
Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension
|
||||
auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
|
||||
// note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also
|
||||
// length 1
|
||||
Expr lemmaWt
|
||||
= inputDim == lemmaDimEmb
|
||||
? nullptr
|
||||
: graph_->param(name + "_lemmaWt",
|
||||
{inputDim, lemmaDimEmb},
|
||||
inits::glorotUniform()); // [D x L] D=hidden-vector dimension
|
||||
auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
|
||||
// augment the original hidden vector with this additional information
|
||||
input1 = input1 + f;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Logits(std::move(allLogits), factoredVocab_);
|
||||
} else if (shortlist_) {
|
||||
return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true));
|
||||
} else {
|
||||
return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
|
||||
}
|
||||
} else if(shortlist_) {
|
||||
return Logits(affineOrLSH(input,
|
||||
cachedShortWt_,
|
||||
cachedShortb_,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true));
|
||||
} else {
|
||||
return Logits(
|
||||
affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace mlp
|
||||
} // namespace marian
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "marian.h"
|
||||
#include "generic.h"
|
||||
#include "logits.h"
|
||||
#include "data/shortlist.h"
|
||||
#include "generic.h"
|
||||
#include "layers/factory.h"
|
||||
#include "logits.h"
|
||||
#include "marian.h"
|
||||
|
||||
namespace marian {
|
||||
class LSH;
|
||||
@ -14,42 +14,45 @@ namespace mlp {
|
||||
class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
|
||||
private:
|
||||
// parameters held by this layer
|
||||
Expr Wt_; // weight matrix is stored transposed for efficiency
|
||||
Expr Wt_; // weight matrix is stored transposed for efficiency
|
||||
Expr b_;
|
||||
Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
|
||||
bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
|
||||
Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
|
||||
bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
|
||||
bool hasBias_{true};
|
||||
|
||||
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
|
||||
Expr cachedShortb_; // these match the current value of shortlist_
|
||||
Expr cachedShortLemmaEt_;
|
||||
Ptr<FactoredVocab> factoredVocab_;
|
||||
|
||||
|
||||
// optional parameters set/updated after construction
|
||||
Expr tiedParam_;
|
||||
Ptr<data::Shortlist> shortlist_;
|
||||
Ptr<LSH> lsh_;
|
||||
|
||||
void lazyConstruct(int inputDim);
|
||||
|
||||
public:
|
||||
Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: LayerBase(graph, options),
|
||||
hasBias_{!options->get<bool>("output-omit-bias", false)} {
|
||||
: LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} {
|
||||
clear();
|
||||
}
|
||||
|
||||
void tieTransposed(Expr tied) {
|
||||
if (Wt_)
|
||||
ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created");
|
||||
if(Wt_)
|
||||
ABORT_IF(tiedParam_.get() != tied.get(),
|
||||
"Tied output projection cannot be changed once weights have been created");
|
||||
else
|
||||
tiedParam_ = tied;
|
||||
}
|
||||
|
||||
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
|
||||
if (shortlist_)
|
||||
ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()");
|
||||
if(shortlist_)
|
||||
ABORT_IF(shortlist.get() != shortlist_.get(),
|
||||
"Output shortlist cannot be changed except after clear()");
|
||||
else {
|
||||
ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??");
|
||||
ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_,
|
||||
"No shortlist but cached parameters??");
|
||||
shortlist_ = shortlist;
|
||||
}
|
||||
// cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
|
||||
@ -60,7 +63,7 @@ public:
|
||||
void clear() override final {
|
||||
shortlist_ = nullptr;
|
||||
cachedShortWt_ = nullptr;
|
||||
cachedShortb_ = nullptr;
|
||||
cachedShortb_ = nullptr;
|
||||
cachedShortLemmaEt_ = nullptr;
|
||||
}
|
||||
|
||||
@ -69,6 +72,4 @@ public:
|
||||
|
||||
} // namespace mlp
|
||||
|
||||
}
|
||||
|
||||
|
||||
} // namespace marian
|
||||
|
@ -4,13 +4,11 @@ namespace marian {
|
||||
namespace models {
|
||||
|
||||
Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
|
||||
// @TODO: This is becoming more and more opaque ^^. Can we simplify this?
|
||||
return state;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
|
||||
// @TODO: This is becoming more and more opaque ^^. Can we simplify this?
|
||||
return state;
|
||||
}
|
||||
|
||||
} // namespace models
|
||||
} // namespace marian
|
||||
|
@ -4,8 +4,8 @@
|
||||
#include "layers/guided_alignment.h"
|
||||
#include "layers/loss.h"
|
||||
#include "layers/weight.h"
|
||||
#include "models/encoder_decoder.h"
|
||||
#include "models/encoder_classifier.h"
|
||||
#include "models/encoder_decoder.h"
|
||||
#include "models/encoder_pooler.h"
|
||||
|
||||
namespace marian {
|
||||
@ -22,10 +22,12 @@ namespace models {
|
||||
|
||||
class ICost {
|
||||
public:
|
||||
virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) = 0;
|
||||
virtual Ptr<MultiRationalLoss> apply(
|
||||
Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true)
|
||||
= 0;
|
||||
virtual ~ICost() {}
|
||||
};
|
||||
|
||||
@ -45,10 +47,9 @@ public:
|
||||
: options_(options), inference_(options->get<bool>("inference", false)) {
|
||||
loss_ = newLoss(options_, inference_);
|
||||
|
||||
toBeWeighted_
|
||||
= (options_->hasAndNotEmpty("data-weighting") && !inference_)
|
||||
|| (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting")
|
||||
&& !inference_);
|
||||
toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_)
|
||||
|| (options_->has("dynamic-weighting")
|
||||
&& options_->get<bool>("dynamic-weighting") && !inference_);
|
||||
if(toBeWeighted_)
|
||||
weighter_ = WeightingFactory(options_);
|
||||
}
|
||||
@ -56,9 +57,9 @@ public:
|
||||
virtual ~EncoderDecoderCECost() {}
|
||||
|
||||
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
auto encdec = std::static_pointer_cast<EncoderDecoder>(model);
|
||||
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
|
||||
|
||||
@ -72,17 +73,17 @@ public:
|
||||
Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
|
||||
|
||||
// @TODO: adapt to multi-objective training with multiple decoders
|
||||
auto partialLoss = loss_->apply(state->getLogProbs(),
|
||||
state->getTargetWords(),
|
||||
state->getTargetMask(),
|
||||
weights);
|
||||
auto partialLoss = loss_->apply(
|
||||
state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights);
|
||||
multiLoss->push_back(partialLoss);
|
||||
|
||||
if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) {
|
||||
auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
|
||||
auto attentionVectors
|
||||
= encdec->getDecoders()[0]
|
||||
->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
|
||||
ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments");
|
||||
|
||||
auto attention = concatenate(attentionVectors, /*axis =*/ -1);
|
||||
auto attention = concatenate(attentionVectors, /*axis =*/-1);
|
||||
|
||||
auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention);
|
||||
multiLoss->push_back(alignmentLoss);
|
||||
@ -109,10 +110,9 @@ public:
|
||||
}
|
||||
|
||||
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
auto enccls = std::static_pointer_cast<EncoderClassifier>(model);
|
||||
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
|
||||
|
||||
@ -141,21 +141,20 @@ protected:
|
||||
|
||||
public:
|
||||
EncoderPoolerRankCost(Ptr<Options> options)
|
||||
: options_(options),
|
||||
inference_(options->get<bool>("inference", false)) {
|
||||
auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
|
||||
ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set");
|
||||
: options_(options), inference_(options->get<bool>("inference", false)) {
|
||||
auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
|
||||
ABORT_IF(trainEmbedderRank.empty(),
|
||||
"EncoderPoolerRankCost expects train-embedder-rank to be set");
|
||||
|
||||
margin_ = std::stof(trainEmbedderRank[0]);
|
||||
if(trainEmbedderRank.size() > 1)
|
||||
normalizer_ = std::stof(trainEmbedderRank[1]);
|
||||
margin_ = std::stof(trainEmbedderRank[0]);
|
||||
if(trainEmbedderRank.size() > 1)
|
||||
normalizer_ = std::stof(trainEmbedderRank[1]);
|
||||
}
|
||||
|
||||
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
|
||||
auto encpool = std::static_pointer_cast<EncoderPooler>(model);
|
||||
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
|
||||
std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph);
|
||||
@ -167,28 +166,41 @@ public:
|
||||
ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss");
|
||||
|
||||
// multi-objective training
|
||||
auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
|
||||
auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
|
||||
auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
|
||||
auto exponent
|
||||
= dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
|
||||
auto dp = exp(exponent);
|
||||
|
||||
Expr dn1, dn2;
|
||||
if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator.
|
||||
dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
|
||||
dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
|
||||
if(normalizer_
|
||||
!= 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the
|
||||
// magnitude of the sum of negative examples in the denominator.
|
||||
dn1 = normalizer_
|
||||
* mean(exp(dotProducts[1] - maxDot),
|
||||
-1); // dot product of anchor and first negative example
|
||||
dn2 = normalizer_
|
||||
* mean(exp(dotProducts[2] - maxDot),
|
||||
-1); // dot product of positive examples and first negative example
|
||||
} else {
|
||||
dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
|
||||
dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
|
||||
dn1 = sum(exp(dotProducts[1] - maxDot),
|
||||
-1); // dot product of anchor and first negative example
|
||||
dn2 = sum(exp(dotProducts[2] - maxDot),
|
||||
-1); // dot product of positive examples and first negative example
|
||||
}
|
||||
|
||||
// We rewrite the loss so it looks more like a log-softmax, presumably more stable?
|
||||
// Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
|
||||
auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
|
||||
auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples
|
||||
auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
|
||||
|
||||
// Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp
|
||||
// + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
|
||||
auto marginLoss1
|
||||
= log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
|
||||
auto marginLoss2
|
||||
= log(dp + dn2)
|
||||
- exponent; // symmetric version of the above with positive example vs negative examples
|
||||
auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
|
||||
|
||||
RationalLoss loss(marginLoss, (float)dimBatch);
|
||||
multiLoss->push_back(loss);
|
||||
|
||||
|
||||
return multiLoss;
|
||||
}
|
||||
};
|
||||
@ -199,8 +211,7 @@ protected:
|
||||
Ptr<ICost> cost_;
|
||||
|
||||
public:
|
||||
Trainer(Ptr<IModel> model, Ptr<ICost> cost)
|
||||
: model_(model), cost_(cost) {}
|
||||
Trainer(Ptr<IModel> model, Ptr<ICost> cost) : model_(model), cost_(cost) {}
|
||||
|
||||
virtual ~Trainer() {}
|
||||
|
||||
@ -219,8 +230,8 @@ public:
|
||||
}
|
||||
|
||||
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
return cost_->apply(model_, graph, batch, clearGraph);
|
||||
};
|
||||
|
||||
@ -230,24 +241,25 @@ public:
|
||||
class ILogProb {
|
||||
public:
|
||||
virtual Logits apply(Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) = 0;
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true)
|
||||
= 0;
|
||||
};
|
||||
|
||||
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
|
||||
// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
|
||||
// This class is for the former use. The latter is done using Trainer.
|
||||
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for
|
||||
// the ground truth?
|
||||
// Beam search uses it for the former meaning, while 'marian score' and validation in the
|
||||
// latter. This class is for the former use. The latter is done using Trainer.
|
||||
class Scorer : public IModel {
|
||||
protected:
|
||||
Ptr<IModel> model_;
|
||||
Ptr<ILogProb> logProb_;
|
||||
|
||||
public:
|
||||
Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
|
||||
: model_(model), logProb_(cost) {}
|
||||
Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) : model_(model), logProb_(cost) {}
|
||||
|
||||
virtual ~Scorer(){}
|
||||
virtual ~Scorer() {}
|
||||
|
||||
Ptr<IModel> getModel() { return model_; }
|
||||
|
||||
@ -264,8 +276,8 @@ public:
|
||||
}
|
||||
|
||||
virtual Logits build(Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
return logProb_->apply(model_, graph, batch, clearGraph);
|
||||
};
|
||||
|
||||
@ -293,10 +305,10 @@ public:
|
||||
virtual ~GumbelSoftmaxStep() {}
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
|
||||
[](Expr logits){ // lemma gets gumbelled
|
||||
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
[](Expr logits) { // lemma gets gumbelled
|
||||
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
return state;
|
||||
}
|
||||
};
|
||||
@ -311,8 +323,7 @@ protected:
|
||||
Ptr<ILogProbStep> cost_;
|
||||
|
||||
public:
|
||||
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost)
|
||||
: encdec_(encdec), cost_(cost) {}
|
||||
Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
@ -346,12 +357,13 @@ public:
|
||||
return encdec_->startState(graph, batch);
|
||||
}
|
||||
|
||||
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
|
||||
Ptr<DecoderState> state,
|
||||
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
|
||||
const Words& words, // [beamIndex * activeBatchSize + batchIndex]
|
||||
const std::vector<IndexType>& batchIndices, // [batchIndex]
|
||||
int beamSize) override {
|
||||
virtual Ptr<DecoderState> step(
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<DecoderState> state,
|
||||
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
|
||||
const Words& words, // [beamIndex * activeBatchSize + batchIndex]
|
||||
const std::vector<IndexType>& batchIndices, // [batchIndex]
|
||||
int beamSize) override {
|
||||
auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
|
||||
return cost_->apply(nextState);
|
||||
}
|
||||
@ -369,9 +381,7 @@ public:
|
||||
encdec_->setShortlistGenerator(shortlistGenerator);
|
||||
};
|
||||
|
||||
virtual Ptr<data::Shortlist> getShortlist() override {
|
||||
return encdec_->getShortlist();
|
||||
};
|
||||
virtual Ptr<data::Shortlist> getShortlist() override { return encdec_->getShortlist(); };
|
||||
|
||||
virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); }
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "layers/logits.h" // @HACK: for factored embeddings only so far
|
||||
#include "marian.h"
|
||||
#include "layers/logits.h" // @HACK: for factored embeddings only so far
|
||||
#include "rnn/types.h"
|
||||
|
||||
namespace marian {
|
||||
@ -9,7 +9,7 @@ namespace marian {
|
||||
class EncoderState {
|
||||
private:
|
||||
Expr context_;
|
||||
Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
|
||||
Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
|
||||
Ptr<data::CorpusBatch> batch_;
|
||||
|
||||
public:
|
||||
@ -19,31 +19,34 @@ public:
|
||||
EncoderState() {}
|
||||
virtual ~EncoderState() {}
|
||||
|
||||
virtual Expr getContext() const { return context_; }
|
||||
virtual Expr getAttended() const { return context_; }
|
||||
virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed
|
||||
virtual Expr getContext() const { return context_; }
|
||||
virtual Expr getAttended() const { return context_; }
|
||||
virtual Expr getMask() const {
|
||||
return mask_;
|
||||
} // source batch mask; may have additional positions suppressed
|
||||
|
||||
virtual const Words& getSourceWords() {
|
||||
return batch_->front()->data();
|
||||
}
|
||||
virtual const Words& getSourceWords() { return batch_->front()->data(); }
|
||||
|
||||
// Sub-select active batch entries from encoder context and context mask
|
||||
Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
|
||||
// Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
|
||||
return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
|
||||
Ptr<EncoderState> select(
|
||||
const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
|
||||
// Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer
|
||||
// gets transposed to the same dimension layout
|
||||
return New<EncoderState>(
|
||||
index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
|
||||
}
|
||||
};
|
||||
|
||||
class DecoderState {
|
||||
protected:
|
||||
rnn::States states_; // states of individual decoder layers
|
||||
rnn::States states_; // states of individual decoder layers
|
||||
Logits logProbs_;
|
||||
std::vector<Ptr<EncoderState>> encStates_;
|
||||
Ptr<data::CorpusBatch> batch_;
|
||||
|
||||
Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
|
||||
Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
|
||||
Expr targetMask_;
|
||||
Words targetWords_; // target labels
|
||||
Words targetWords_; // target labels
|
||||
|
||||
// Keep track of current target token position during translation
|
||||
size_t position_{0};
|
||||
@ -57,26 +60,30 @@ public:
|
||||
virtual ~DecoderState() {}
|
||||
|
||||
// @TODO: Do we need all these to be virtual?
|
||||
virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const {
|
||||
return encStates_;
|
||||
}
|
||||
virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { return encStates_; }
|
||||
|
||||
virtual Logits getLogProbs() const { return logProbs_; }
|
||||
virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; }
|
||||
|
||||
// @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop
|
||||
virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
|
||||
const std::vector<IndexType>& batchIndices, // [batchIndex]
|
||||
int beamSize) const {
|
||||
|
||||
// @TODO: should this be a constructor? Then derived classes can call this without the New<> in
|
||||
// the loop
|
||||
virtual Ptr<DecoderState> select(
|
||||
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
|
||||
const std::vector<IndexType>& batchIndices, // [batchIndex]
|
||||
int beamSize) const {
|
||||
std::vector<Ptr<EncoderState>> newEncStates;
|
||||
for(auto& es : encStates_)
|
||||
// If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
|
||||
newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
|
||||
// If the size of the batch dimension of the encoder state context changed, subselect the
|
||||
// correct batch entries
|
||||
newEncStates.push_back(
|
||||
es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
|
||||
|
||||
// hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices
|
||||
auto selectedState = New<DecoderState>(
|
||||
states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_);
|
||||
auto selectedState
|
||||
= New<DecoderState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false),
|
||||
logProbs_,
|
||||
newEncStates,
|
||||
batch_);
|
||||
|
||||
// Set positon of new state based on the target token position of current state
|
||||
selectedState->setPosition(getPosition());
|
||||
@ -86,7 +93,9 @@ public:
|
||||
virtual const rnn::States& getStates() const { return states_; }
|
||||
|
||||
virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
|
||||
virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; }
|
||||
virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) {
|
||||
targetHistoryEmbeddings_ = targetHistoryEmbeddings;
|
||||
}
|
||||
|
||||
virtual const Words& getTargetWords() const { return targetWords_; };
|
||||
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
|
||||
@ -94,9 +103,7 @@ public:
|
||||
virtual Expr getTargetMask() const { return targetMask_; };
|
||||
virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
|
||||
|
||||
virtual const Words& getSourceWords() const {
|
||||
return getEncoderStates()[0]->getSourceWords();
|
||||
}
|
||||
virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); }
|
||||
|
||||
Ptr<data::CorpusBatch> getBatch() const { return batch_; }
|
||||
|
||||
@ -111,7 +118,8 @@ public:
|
||||
|
||||
/**
|
||||
* Classifier output based on DecoderState
|
||||
* @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output.
|
||||
* @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have
|
||||
* stateful output.
|
||||
*/
|
||||
class ClassifierState {
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user