mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Improve checks on transformer cache (#881)
* Fix caching in transformer attention * Move hash specialization * Swap comments to doxygen * Include string header
This commit is contained in:
parent
b64e258bda
commit
894a07ad5b
@ -30,6 +30,7 @@ set(MARIAN_SOURCES
|
|||||||
common/filesystem.cpp
|
common/filesystem.cpp
|
||||||
common/file_stream.cpp
|
common/file_stream.cpp
|
||||||
common/file_utils.cpp
|
common/file_utils.cpp
|
||||||
|
common/hash.cpp
|
||||||
common/signal_handling.cpp
|
common/signal_handling.cpp
|
||||||
common/types.cpp
|
common/types.cpp
|
||||||
|
|
||||||
|
12
src/common/hash.cpp
Normal file
12
src/common/hash.cpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "hash.h"
|
||||||
|
#include "common/shape.h"
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
size_t hash<pair<string, marian::Shape>>::operator()(pair<string, marian::Shape> const& k) const {
|
||||||
|
size_t seed = hash<string>{}(k.first);
|
||||||
|
marian::util::hash_combine(seed, k.second.hash());
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
} // namespace std
|
@ -7,16 +7,18 @@ namespace util {
|
|||||||
|
|
||||||
template <class T> using hash = std::hash<T>;
|
template <class T> using hash = std::hash<T>;
|
||||||
|
|
||||||
// This combinator is based on boost::hash_combine, but uses
|
/**
|
||||||
// std::hash as the hash implementation. Used as a drop-in
|
* Combine hash values.
|
||||||
// replacement for boost::hash_combine.
|
* This combinator is based on boost::hash_combine, but uses std::hash as the hash implementation.
|
||||||
|
* Used as a drop-in replacement for boost::hash_combine.
|
||||||
|
*/
|
||||||
template <class T, class HashType = std::size_t>
|
template <class T, class HashType = std::size_t>
|
||||||
inline void hash_combine(HashType& seed, T const& v) {
|
inline void hash_combine(HashType& seed, T const& v) {
|
||||||
hash<T> hasher;
|
hash<T> hasher;
|
||||||
seed ^= static_cast<HashType>(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2);
|
seed ^= static_cast<HashType>(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hash a whole chunk of memory, mostly used for diagnostics
|
/** Hash a whole chunk of memory. */
|
||||||
template <class T, class HashType = std::size_t>
|
template <class T, class HashType = std::size_t>
|
||||||
inline HashType hashMem(const T* beg, size_t len) {
|
inline HashType hashMem(const T* beg, size_t len) {
|
||||||
HashType seed = 0;
|
HashType seed = 0;
|
||||||
@ -25,5 +27,17 @@ inline HashType hashMem(const T* beg, size_t len) {
|
|||||||
return seed;
|
return seed;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} // namespace util
|
||||||
|
|
||||||
|
struct Shape; // Forward declaration
|
||||||
|
} // namespace marian
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
/**
|
||||||
|
* std::hash specialization for the string-shape pair used as a cache key in transformer.h.
|
||||||
|
*/
|
||||||
|
template <>
|
||||||
|
struct hash<pair<string, marian::Shape>> {
|
||||||
|
size_t operator()(pair<string, marian::Shape> const& k) const;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
#include "marian.h"
|
#include "marian.h"
|
||||||
|
|
||||||
|
#include "common/hash.h"
|
||||||
#include "layers/constructors.h"
|
#include "layers/constructors.h"
|
||||||
#include "models/decoder.h"
|
#include "models/decoder.h"
|
||||||
#include "models/encoder.h"
|
#include "models/encoder.h"
|
||||||
@ -28,7 +29,7 @@ class Transformer : public EncoderOrDecoderBase {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
|
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
|
||||||
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
|
std::unordered_map<std::pair<std::string, Shape>, Expr> cache_; // caching transformation of the encoder that should not be created again
|
||||||
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings
|
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings
|
||||||
|
|
||||||
bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
|
bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
|
||||||
@ -40,16 +41,16 @@ protected:
|
|||||||
std::vector<Expr> alignments_; // [max tgt len or 1][beam depth, max src length, batch size, 1]
|
std::vector<Expr> alignments_; // [max tgt len or 1][beam depth, max src length, batch size, 1]
|
||||||
|
|
||||||
// @TODO: make this go away
|
// @TODO: make this go away
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); }
|
T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T opt(const std::string& key) const { return opt<T>(key.c_str()); }
|
T opt(const std::string& key) const { return opt<T>(key.c_str()); }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T opt(const char* const key, const T& def) const { Ptr<Options> options = options_; return options->get<T>(key, def); }
|
T opt(const char* const key, const T& def) const { Ptr<Options> options = options_; return options->get<T>(key, def); }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T opt(const std::string& key, const T& def) const { opt<T>(key.c_str(), def); }
|
T opt(const std::string& key, const T& def) const { opt<T>(key.c_str(), def); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -256,7 +257,7 @@ public:
|
|||||||
|
|
||||||
// take softmax along src sequence axis (-1)
|
// take softmax along src sequence axis (-1)
|
||||||
auto weights = softmax(z); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length]
|
auto weights = softmax(z); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length]
|
||||||
|
|
||||||
if(saveAttentionWeights)
|
if(saveAttentionWeights)
|
||||||
collectOneHead(weights, dimBeam);
|
collectOneHead(weights, dimBeam);
|
||||||
|
|
||||||
@ -289,26 +290,26 @@ public:
|
|||||||
// Caching transformation of the encoder that should not be created again.
|
// Caching transformation of the encoder that should not be created again.
|
||||||
// @TODO: set this automatically by memoizing encoder context and
|
// @TODO: set this automatically by memoizing encoder context and
|
||||||
// memoization propagation (short-term)
|
// memoization propagation (short-term)
|
||||||
if (cache // if caching
|
std::pair<std::unordered_map<std::pair<std::string, Shape>, Expr>::iterator, bool> cache_result;
|
||||||
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
|
if (cache
|
||||||
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
|
&& !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_keys", keys->shape()}, kh))).second)
|
||||||
kh = cache_[prefix + "_keys"]; // then return cached tensor
|
) {
|
||||||
}
|
kh = cache_result.first->second;
|
||||||
else {
|
} else {
|
||||||
int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
|
int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
|
||||||
auto Wk = graph_->param(prefix + "_Wk", {dimKeys, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
|
auto Wk = graph_->param(prefix + "_Wk", {dimKeys, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
|
||||||
auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros());
|
auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros());
|
||||||
|
|
||||||
kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
|
kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
|
||||||
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
|
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
|
||||||
cache_[prefix + "_keys"] = kh;
|
if (cache) cache_result.first->second = kh;
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr vh;
|
Expr vh;
|
||||||
if (cache
|
if (cache
|
||||||
&& cache_.count(prefix + "_values") > 0
|
&& !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_values", values->shape()}, vh))).second)
|
||||||
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
|
) {
|
||||||
vh = cache_[prefix + "_values"];
|
vh = cache_result.first->second;
|
||||||
} else {
|
} else {
|
||||||
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
|
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
|
||||||
auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
|
auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
|
||||||
@ -316,7 +317,7 @@ public:
|
|||||||
|
|
||||||
vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
|
vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
|
||||||
vh = SplitHeads(vh, dimHeads);
|
vh = SplitHeads(vh, dimHeads);
|
||||||
cache_[prefix + "_values"] = vh;
|
if (cache) cache_result.first->second = vh;
|
||||||
}
|
}
|
||||||
|
|
||||||
int dimBeam = q->shape()[-4];
|
int dimBeam = q->shape()[-4];
|
||||||
@ -377,7 +378,7 @@ public:
|
|||||||
|
|
||||||
// multi-head self-attention over previous input
|
// multi-head self-attention over previous input
|
||||||
output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights);
|
output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights);
|
||||||
|
|
||||||
auto opsPost = opt<std::string>("transformer-postprocess");
|
auto opsPost = opt<std::string>("transformer-postprocess");
|
||||||
output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb);
|
output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb);
|
||||||
|
|
||||||
@ -558,7 +559,7 @@ public:
|
|||||||
auto embeddingLayer = getEmbeddingLayer(opt<bool>("ulr", false));
|
auto embeddingLayer = getEmbeddingLayer(opt<bool>("ulr", false));
|
||||||
std::tie(batchEmbeddings, batchMask) = embeddingLayer->apply((*batch)[batchIndex_]);
|
std::tie(batchEmbeddings, batchMask) = embeddingLayer->apply((*batch)[batchIndex_]);
|
||||||
batchEmbeddings = addSpecialEmbeddings(batchEmbeddings, /*start=*/0, batch);
|
batchEmbeddings = addSpecialEmbeddings(batchEmbeddings, /*start=*/0, batch);
|
||||||
|
|
||||||
// reorganize batch and timestep
|
// reorganize batch and timestep
|
||||||
batchEmbeddings = atleast_nd(batchEmbeddings, 4); // [beam depth=1, max length, batch size, vector dim]
|
batchEmbeddings = atleast_nd(batchEmbeddings, 4); // [beam depth=1, max length, batch size, vector dim]
|
||||||
batchMask = atleast_nd(batchMask, 4); // [beam depth=1, max length, batch size, vector dim=1]
|
batchMask = atleast_nd(batchMask, 4); // [beam depth=1, max length, batch size, vector dim=1]
|
||||||
@ -593,7 +594,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// this allows to run a final layernorm operation after going through the transformer layer stack.
|
// this allows to run a final layernorm operation after going through the transformer layer stack.
|
||||||
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
|
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
|
||||||
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
|
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
|
||||||
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
|
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
|
||||||
layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb);
|
layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb);
|
||||||
@ -622,14 +623,14 @@ public:
|
|||||||
int beamSize) const override {
|
int beamSize) const override {
|
||||||
|
|
||||||
// @TODO: code duplication with DecoderState only because of isBatchMajor=true, should rather be a contructor argument of DecoderState?
|
// @TODO: code duplication with DecoderState only because of isBatchMajor=true, should rather be a contructor argument of DecoderState?
|
||||||
|
|
||||||
std::vector<Ptr<EncoderState>> newEncStates;
|
std::vector<Ptr<EncoderState>> newEncStates;
|
||||||
for(auto& es : encStates_)
|
for(auto& es : encStates_)
|
||||||
// If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
|
// If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
|
||||||
newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
|
newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
|
||||||
|
|
||||||
// Create hypothesis-selected state based on current state and hyp indices
|
// Create hypothesis-selected state based on current state and hyp indices
|
||||||
auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_);
|
auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_);
|
||||||
|
|
||||||
// Set the same target token position as the current state
|
// Set the same target token position as the current state
|
||||||
// @TODO: This is the same as in base function.
|
// @TODO: This is the same as in base function.
|
||||||
@ -763,8 +764,8 @@ public:
|
|||||||
|
|
||||||
// This would happen if something goes wrong during batch pruning.
|
// This would happen if something goes wrong during batch pruning.
|
||||||
ABORT_IF(encoderContext->shape()[-3] != dimBatch,
|
ABORT_IF(encoderContext->shape()[-3] != dimBatch,
|
||||||
"Context and query batch dimension do not match {} != {}",
|
"Context and query batch dimension do not match {} != {}",
|
||||||
encoderContext->shape()[-3],
|
encoderContext->shape()[-3],
|
||||||
dimBatch);
|
dimBatch);
|
||||||
|
|
||||||
// LayerAttention expects mask in a different layout
|
// LayerAttention expects mask in a different layout
|
||||||
@ -871,7 +872,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This allows to run a final layernorm operation after going through the transformer layer stack.
|
// This allows to run a final layernorm operation after going through the transformer layer stack.
|
||||||
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
|
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
|
||||||
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
|
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
|
||||||
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
|
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
|
||||||
query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb);
|
query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb);
|
||||||
@ -884,7 +885,7 @@ public:
|
|||||||
if(shortlist_)
|
if(shortlist_)
|
||||||
output_->setShortlist(shortlist_);
|
output_->setShortlist(shortlist_);
|
||||||
auto logits = output_->applyAsLogits(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab or shortlist dim]
|
auto logits = output_->applyAsLogits(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab or shortlist dim]
|
||||||
|
|
||||||
// return unormalized(!) probabilities
|
// return unormalized(!) probabilities
|
||||||
Ptr<DecoderState> nextState;
|
Ptr<DecoderState> nextState;
|
||||||
if (opt<std::string>("transformer-decoder-autoreg", "self-attention") == "rnn") {
|
if (opt<std::string>("transformer-decoder-autoreg", "self-attention") == "rnn") {
|
||||||
@ -909,9 +910,9 @@ public:
|
|||||||
output_->clear();
|
output_->clear();
|
||||||
cache_.clear();
|
cache_.clear();
|
||||||
alignments_.clear();
|
alignments_.clear();
|
||||||
perLayerRnn_.clear(); // this needs to be cleared between batches.
|
perLayerRnn_.clear(); // this needs to be cleared between batches.
|
||||||
// @TODO: figure out how to detect stale nodes i.e. nodes that are referenced,
|
// @TODO: figure out how to detect stale nodes i.e. nodes that are referenced,
|
||||||
// but where underlying memory has been deallocated by dropping all tensors
|
// but where underlying memory has been deallocated by dropping all tensors
|
||||||
// from a TensorAllocator object. This can happen during ExpressionGraph::clear()
|
// from a TensorAllocator object. This can happen during ExpressionGraph::clear()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user