diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e4599c40..3718807a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -30,6 +30,7 @@ set(MARIAN_SOURCES common/filesystem.cpp common/file_stream.cpp common/file_utils.cpp + common/hash.cpp common/signal_handling.cpp common/types.cpp diff --git a/src/common/hash.cpp b/src/common/hash.cpp new file mode 100644 index 00000000..57e5e914 --- /dev/null +++ b/src/common/hash.cpp @@ -0,0 +1,12 @@ +#include + +#include "hash.h" +#include "common/shape.h" + +namespace std { +size_t hash>::operator()(pair const& k) const { + size_t seed = hash{}(k.first); + marian::util::hash_combine(seed, k.second.hash()); + return seed; +} +} // namespace std diff --git a/src/common/hash.h b/src/common/hash.h index 7aca30de..37dab5e7 100644 --- a/src/common/hash.h +++ b/src/common/hash.h @@ -7,16 +7,18 @@ namespace util { template using hash = std::hash; -// This combinator is based on boost::hash_combine, but uses -// std::hash as the hash implementation. Used as a drop-in -// replacement for boost::hash_combine. +/** + * Combine hash values. + * This combinator is based on boost::hash_combine, but uses std::hash as the hash implementation. + * Used as a drop-in replacement for boost::hash_combine. + */ template inline void hash_combine(HashType& seed, T const& v) { hash hasher; seed ^= static_cast(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2); } -// Hash a whole chunk of memory, mostly used for diagnostics +/** Hash a whole chunk of memory. */ template inline HashType hashMem(const T* beg, size_t len) { HashType seed = 0; @@ -25,5 +27,17 @@ inline HashType hashMem(const T* beg, size_t len) { return seed; } -} +} // namespace util + +struct Shape; // Forward declaration +} // namespace marian + +namespace std { +/** + * std::hash specialization for the string-shape pair used as a cache key in transformer.h. + */ +template <> +struct hash> { + size_t operator()(pair const& k) const; +}; } diff --git a/src/models/transformer.h b/src/models/transformer.h index 7ec40dc5..af877600 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -5,6 +5,7 @@ #include "marian.h" +#include "common/hash.h" #include "layers/constructors.h" #include "models/decoder.h" #include "models/encoder.h" @@ -28,7 +29,7 @@ class Transformer : public EncoderOrDecoderBase { protected: using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_; - std::unordered_map cache_; // caching transformation of the encoder that should not be created again + std::unordered_map, Expr> cache_; // caching transformation of the encoder that should not be created again mutable/*lazy*/ std::vector sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth); @@ -40,16 +41,16 @@ protected: std::vector alignments_; // [max tgt len or 1][beam depth, max src length, batch size, 1] // @TODO: make this go away - template - T opt(const char* const key) const { Ptr options = options_; return options->get(key); } + template + T opt(const char* const key) const { Ptr options = options_; return options->get(key); } - template - T opt(const std::string& key) const { return opt(key.c_str()); } + template + T opt(const std::string& key) const { return opt(key.c_str()); } - template + template T opt(const char* const key, const T& def) const { Ptr options = options_; return options->get(key, def); } - template + template T opt(const std::string& key, const T& def) const { opt(key.c_str(), def); } public: @@ -256,7 +257,7 @@ public: // take softmax along src sequence axis (-1) auto weights = softmax(z); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length] - + if(saveAttentionWeights) collectOneHead(weights, dimBeam); @@ -289,26 +290,26 @@ public: // Caching transformation of the encoder that should not be created again. // @TODO: set this automatically by memoizing encoder context and // memoization propagation (short-term) - if (cache // if caching - && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen - && cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change - kh = cache_[prefix + "_keys"]; // then return cached tensor - } - else { + std::pair, Expr>::iterator, bool> cache_result; + if (cache + && !((cache_result = cache_.insert(std::pair, Expr>({prefix + "_keys", keys->shape()}, kh))).second) + ) { + kh = cache_result.first->second; + } else { int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation auto Wk = graph_->param(prefix + "_Wk", {dimKeys, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f)); auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros()); kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] - cache_[prefix + "_keys"] = kh; + if (cache) cache_result.first->second = kh; } Expr vh; - if (cache - && cache_.count(prefix + "_values") > 0 - && cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) { - vh = cache_[prefix + "_values"]; + if (cache + && !((cache_result = cache_.insert(std::pair, Expr>({prefix + "_values", values->shape()}, vh))).second) + ) { + vh = cache_result.first->second; } else { int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f)); @@ -316,7 +317,7 @@ public: vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] vh = SplitHeads(vh, dimHeads); - cache_[prefix + "_values"] = vh; + if (cache) cache_result.first->second = vh; } int dimBeam = q->shape()[-4]; @@ -377,7 +378,7 @@ public: // multi-head self-attention over previous input output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights); - + auto opsPost = opt("transformer-postprocess"); output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb); @@ -558,7 +559,7 @@ public: auto embeddingLayer = getEmbeddingLayer(opt("ulr", false)); std::tie(batchEmbeddings, batchMask) = embeddingLayer->apply((*batch)[batchIndex_]); batchEmbeddings = addSpecialEmbeddings(batchEmbeddings, /*start=*/0, batch); - + // reorganize batch and timestep batchEmbeddings = atleast_nd(batchEmbeddings, 4); // [beam depth=1, max length, batch size, vector dim] batchMask = atleast_nd(batchMask, 4); // [beam depth=1, max length, batch size, vector dim=1] @@ -593,7 +594,7 @@ public: } // this allows to run a final layernorm operation after going through the transformer layer stack. - // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) + // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) // it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested. auto opsTop = opt("transformer-postprocess-top", ""); layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb); @@ -622,14 +623,14 @@ public: int beamSize) const override { // @TODO: code duplication with DecoderState only because of isBatchMajor=true, should rather be a contructor argument of DecoderState? - + std::vector> newEncStates; - for(auto& es : encStates_) - // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries + for(auto& es : encStates_) + // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); // Create hypothesis-selected state based on current state and hyp indices - auto selectedState = New(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_); + auto selectedState = New(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_); // Set the same target token position as the current state // @TODO: This is the same as in base function. @@ -763,8 +764,8 @@ public: // This would happen if something goes wrong during batch pruning. ABORT_IF(encoderContext->shape()[-3] != dimBatch, - "Context and query batch dimension do not match {} != {}", - encoderContext->shape()[-3], + "Context and query batch dimension do not match {} != {}", + encoderContext->shape()[-3], dimBatch); // LayerAttention expects mask in a different layout @@ -871,7 +872,7 @@ public: } // This allows to run a final layernorm operation after going through the transformer layer stack. - // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) + // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) // it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested. auto opsTop = opt("transformer-postprocess-top", ""); query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb); @@ -884,7 +885,7 @@ public: if(shortlist_) output_->setShortlist(shortlist_); auto logits = output_->applyAsLogits(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab or shortlist dim] - + // return unormalized(!) probabilities Ptr nextState; if (opt("transformer-decoder-autoreg", "self-attention") == "rnn") { @@ -909,9 +910,9 @@ public: output_->clear(); cache_.clear(); alignments_.clear(); - perLayerRnn_.clear(); // this needs to be cleared between batches. - // @TODO: figure out how to detect stale nodes i.e. nodes that are referenced, - // but where underlying memory has been deallocated by dropping all tensors + perLayerRnn_.clear(); // this needs to be cleared between batches. + // @TODO: figure out how to detect stale nodes i.e. nodes that are referenced, + // but where underlying memory has been deallocated by dropping all tensors // from a TensorAllocator object. This can happen during ExpressionGraph::clear() } };