mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Improve checks on transformer cache (#881)
* Fix caching in transformer attention * Move hash specialization * Swap comments to doxygen * Include string header
This commit is contained in:
parent
b64e258bda
commit
894a07ad5b
@ -30,6 +30,7 @@ set(MARIAN_SOURCES
|
||||
common/filesystem.cpp
|
||||
common/file_stream.cpp
|
||||
common/file_utils.cpp
|
||||
common/hash.cpp
|
||||
common/signal_handling.cpp
|
||||
common/types.cpp
|
||||
|
||||
|
12
src/common/hash.cpp
Normal file
12
src/common/hash.cpp
Normal file
@ -0,0 +1,12 @@
|
||||
#include <string>
|
||||
|
||||
#include "hash.h"
|
||||
#include "common/shape.h"
|
||||
|
||||
namespace std {
|
||||
size_t hash<pair<string, marian::Shape>>::operator()(pair<string, marian::Shape> const& k) const {
|
||||
size_t seed = hash<string>{}(k.first);
|
||||
marian::util::hash_combine(seed, k.second.hash());
|
||||
return seed;
|
||||
}
|
||||
} // namespace std
|
@ -7,16 +7,18 @@ namespace util {
|
||||
|
||||
template <class T> using hash = std::hash<T>;
|
||||
|
||||
// This combinator is based on boost::hash_combine, but uses
|
||||
// std::hash as the hash implementation. Used as a drop-in
|
||||
// replacement for boost::hash_combine.
|
||||
/**
|
||||
* Combine hash values.
|
||||
* This combinator is based on boost::hash_combine, but uses std::hash as the hash implementation.
|
||||
* Used as a drop-in replacement for boost::hash_combine.
|
||||
*/
|
||||
template <class T, class HashType = std::size_t>
|
||||
inline void hash_combine(HashType& seed, T const& v) {
|
||||
hash<T> hasher;
|
||||
seed ^= static_cast<HashType>(hasher(v)) + 0x9e3779b9 + (seed<<6) + (seed>>2);
|
||||
}
|
||||
|
||||
// Hash a whole chunk of memory, mostly used for diagnostics
|
||||
/** Hash a whole chunk of memory. */
|
||||
template <class T, class HashType = std::size_t>
|
||||
inline HashType hashMem(const T* beg, size_t len) {
|
||||
HashType seed = 0;
|
||||
@ -25,5 +27,17 @@ inline HashType hashMem(const T* beg, size_t len) {
|
||||
return seed;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace util
|
||||
|
||||
struct Shape; // Forward declaration
|
||||
} // namespace marian
|
||||
|
||||
namespace std {
|
||||
/**
|
||||
* std::hash specialization for the string-shape pair used as a cache key in transformer.h.
|
||||
*/
|
||||
template <>
|
||||
struct hash<pair<string, marian::Shape>> {
|
||||
size_t operator()(pair<string, marian::Shape> const& k) const;
|
||||
};
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
#include "marian.h"
|
||||
|
||||
#include "common/hash.h"
|
||||
#include "layers/constructors.h"
|
||||
#include "models/decoder.h"
|
||||
#include "models/encoder.h"
|
||||
@ -28,7 +29,7 @@ class Transformer : public EncoderOrDecoderBase {
|
||||
|
||||
protected:
|
||||
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
|
||||
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
|
||||
std::unordered_map<std::pair<std::string, Shape>, Expr> cache_; // caching transformation of the encoder that should not be created again
|
||||
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings
|
||||
|
||||
bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
|
||||
@ -289,26 +290,26 @@ public:
|
||||
// Caching transformation of the encoder that should not be created again.
|
||||
// @TODO: set this automatically by memoizing encoder context and
|
||||
// memoization propagation (short-term)
|
||||
if (cache // if caching
|
||||
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
|
||||
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
|
||||
kh = cache_[prefix + "_keys"]; // then return cached tensor
|
||||
}
|
||||
else {
|
||||
std::pair<std::unordered_map<std::pair<std::string, Shape>, Expr>::iterator, bool> cache_result;
|
||||
if (cache
|
||||
&& !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_keys", keys->shape()}, kh))).second)
|
||||
) {
|
||||
kh = cache_result.first->second;
|
||||
} else {
|
||||
int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
|
||||
auto Wk = graph_->param(prefix + "_Wk", {dimKeys, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
|
||||
auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros());
|
||||
|
||||
kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
|
||||
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
|
||||
cache_[prefix + "_keys"] = kh;
|
||||
if (cache) cache_result.first->second = kh;
|
||||
}
|
||||
|
||||
Expr vh;
|
||||
if (cache
|
||||
&& cache_.count(prefix + "_values") > 0
|
||||
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
|
||||
vh = cache_[prefix + "_values"];
|
||||
&& !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_values", values->shape()}, vh))).second)
|
||||
) {
|
||||
vh = cache_result.first->second;
|
||||
} else {
|
||||
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
|
||||
auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
|
||||
@ -316,7 +317,7 @@ public:
|
||||
|
||||
vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
|
||||
vh = SplitHeads(vh, dimHeads);
|
||||
cache_[prefix + "_values"] = vh;
|
||||
if (cache) cache_result.first->second = vh;
|
||||
}
|
||||
|
||||
int dimBeam = q->shape()[-4];
|
||||
|
Loading…
Reference in New Issue
Block a user