Merged PR 11103: Clear cache for RNN object between batches

* Clears cache for RNN object in transformer, otherwise stale tensor might be kept around.
* Add missing `hash()` and `equal` functions everywhere.
* Fixes bug from deployment test.
This commit is contained in:
Martin Junczys-Dowmunt 2020-01-11 20:29:43 +00:00
parent 0fab6ea850
commit af02867fb1
10 changed files with 180 additions and 25 deletions

@ -1 +1 @@
Subproject commit 336740065d9c23e53e912a1befff18981d9d27ab
Subproject commit c19b7814d71febf1053bd93af6ac314b46204092

View File

@ -87,8 +87,6 @@ if(CUDA_FOUND)
set(GENCODE "${GENCODE} -gencode=arch=compute_70,code=sm_70")
endif(COMPILE_CUDA_SM70)
message(${GENCODE})
# install nccl in ${CMAKE_BINARY_DIR}/local similar to /usr/local linux installation
ExternalProject_Add(nccl_install
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/nccl

View File

@ -63,7 +63,6 @@ public:
// df/dB += alpha * dot(op(A).T, D)
// beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
// to sum gradients from different graph parts
if(!transA_ && transB_)
return {NodeOp(Prod(child(0)->grad(),
adj_,
@ -130,6 +129,29 @@ public:
const std::string type() override { return "dot"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<DotNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
@ -274,6 +296,30 @@ public:
}
const std::string type() override { return "affine"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<AffineNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
};
class DotBatchedNodeOp : public NaryNodeOp {
@ -402,6 +448,29 @@ public:
const std::string type() override { return "bdot"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<DotBatchedNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
@ -443,18 +512,42 @@ public:
}
NodeOps backwardOps() override {
return {nullptr, // can't backprop into the sparse matrix (the gradient is dense)
nullptr,
nullptr,
NodeOp(CSRProd(child(3)->grad(), // child(3) = D
graph()->allocator(),
child(0)->val(), child(1)->val(), child(2)->val(), // children(0..2) = A
adj_,
/*transS=*/!transS_, /*swapOperands=*/swapOperands_, /*beta=*/1))};
return { nullptr, // can't backprop into the sparse matrix (the gradient is dense)
nullptr,
nullptr,
NodeOp(CSRProd(child(3)->grad(), // child(3) = D
graph()->allocator(),
child(0)->val(), child(1)->val(), child(2)->val(), // children(0..2) = A
adj_,
/*transS=*/!transS_, /*swapOperands=*/swapOperands_, /*beta=*/1))};
}
const std::string type() override { return "csr_dot"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
for(auto s : shape())
util::hash_combine(seed, s);
util::hash_combine(seed, transS_);
util::hash_combine(seed, swapOperands_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<CSRDotNodeOp>(node);
if(!cnode)
return false;
if(transS_ != cnode->transS_)
return false;
if(shape() != cnode->shape())
return false;
if(swapOperands_ != cnode->swapOperands_)
return false;
return true;
}
const std::string color() override { return "orange"; }
};
@ -883,6 +976,29 @@ struct CmpNodeOp : public ElementBinaryNodeOp {
ABORT("Should not get here??");
}
virtual size_t hash() override {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, cmp_);
util::hash_combine(seed, not_);
hash_ = seed;
}
return hash_;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<CmpNodeOp>(node);
if(!cnode)
return false;
if(cmp_ != cnode->cmp_)
return false;
if(not_ != cnode->not_)
return false;
return true;
}
private:
int cmp_; // -1: less; 0: equal; 1: greater
bool not_; // invert result if true
@ -1015,6 +1131,23 @@ public:
const std::string type() override { return "layer_normalization"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, eps_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<LayerNormalizationOp>(node);
if(!cnode)
return false;
if(eps_ != cnode->eps_)
return false;
return true;
}
private:
float eps_;
};

View File

@ -993,6 +993,8 @@ struct ShiftNodeOp : public UnaryNodeOp {
if(!cnode)
return false;
if(shift_ != cnode->shift_)
return false;
if(padValue_ != cnode->padValue_)
return false;
return true;
}

View File

@ -47,13 +47,12 @@ public:
ABORT_IF(shortlist_, "How did a shortlist make it into training?");
const Words& data = subBatch->data();
Expr yData = graph_->indices(toWordIndexVector(data));
auto yShifted = shift(y, {1, 0, 0});
state->setTargetHistoryEmbeddings(yShifted);
state->setTargetMask(yMask);
const Words& data = subBatch->data();
state->setTargetWords(data);
}

View File

@ -459,7 +459,7 @@ public:
int /*startPos*/) const {
float dropoutRnn = inference_ ? 0.f : opt<float>("dropout-rnn");
if(!perLayerRnn[prefix]) // lazily created and cache RNNs in the docoder to avoid costly recreation @TODO: turn this into class members
if(!perLayerRnn[prefix]) // lazily create and cache RNNs in the decoder to avoid costly recreation @TODO: turn this into class members
perLayerRnn[prefix] = rnn::rnn(
"type", opt<std::string>("dec-cell"),
"prefix", prefix,
@ -820,6 +820,10 @@ public:
output_->clear();
cache_.clear();
alignments_.clear();
perLayerRnn_.clear(); // this needs to be cleared between batches.
// @TODO: figure out how to detect stale nodes i.e. nodes that are referenced,
// but where underlying memory has been deallocated by dropping all tensors
// from a TensorAllocator object. This can happen during ExpressionGraph::clear()
}
};

View File

@ -43,6 +43,23 @@ struct GRUFastNodeOp : public NaryNodeOp {
const std::string type() override { return "GRU-ops"; }
const std::string color() override { return "yellow"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, final_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<GRUFastNodeOp>(node);
if(!cnode)
return false;
if(final_ != cnode->final_)
return false;
return true;
}
};
Expr gruOps(const std::vector<Expr>& nodes, bool final) {

View File

@ -1050,7 +1050,7 @@ void LayerNormalization(Tensor out_,
sqSum += ex * ex;
}
float sigma = std::sqrt(eps + sqSum / cols);
float sigma = std::sqrt(sqSum / cols + eps);
#pragma omp simd
for(int i = 0; i < cols; ++i) {
@ -1112,7 +1112,7 @@ void LayerNormalizationGrad(Tensor gradX_,
sum_sqr += ex * ex;
}
float sigma = std::sqrt(eps + sum_sqr / cols);
float sigma = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
@ -1154,7 +1154,7 @@ void LayerNormalizationGrad(Tensor gradX_,
sum_sqr += ex * ex;
}
float sigma = std::sqrt(eps + sum_sqr / cols);
float sigma = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;

View File

@ -1925,7 +1925,7 @@ __global__ void gLNormalization(T* out,
len = (len + 1) >> 1;
}
__syncthreads();
AccType sigma = functional::Ops<AccType>::sqrt(_sqSum[0] / N); // all AccType
AccType sigma = functional::Ops<AccType>::sqrt(_sqSum[0] / N + eps); // all AccType
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x) {
@ -1934,7 +1934,7 @@ __global__ void gLNormalization(T* out,
AccType gammav = (AccType)gamma[id];
AccType xv = (AccType)xRow[id];
AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
AccType lv = (xv - mean) / (sigma + eps);
AccType lv = (xv - mean) / sigma;
AccType y = gammav * lv + betav;
yRow[id] = (T)y;
}
@ -2022,7 +2022,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
AccType gammav = (AccType)gamma[id];
AccType adjv = adjRow[id];
AccType lv = (yv - betav) / (gammav + eps); // go back to LN(x) from scaled and shifted version for accumulation
AccType lv = (yv - betav) / gammav; // go back to LN(x) from scaled and shifted version for accumulation
sum_x[threadIdx.x] += xv;
sum_adj_l[threadIdx.x] += adjv * lv;
@ -2064,7 +2064,7 @@ __global__ void gLayerNormalizationGrad(T* gradX,
len = (len + 1) >> 1;
}
__syncthreads();
AccType sigma = functional::Ops<AccType>::sqrt(sum_sqr[0] / N);
AccType sigma = functional::Ops<AccType>::sqrt(sum_sqr[0] / N + eps);
__syncthreads();
// Jacobian of layer norm
@ -2078,10 +2078,10 @@ __global__ void gLayerNormalizationGrad(T* gradX,
AccType xv = xRow[id];
AccType gammav = (AccType)gamma[id];
AccType adjv = adjRow[id];
AccType lv = (xv - mean) / (sigma + eps);
AccType lv = (xv - mean) / sigma;
AccType gradLv = N * adjv - lv * sum_adj_l[0] - sum_adj[0];
gradLv /= N * (sigma + eps); // eps has to be inside parentheses for correct gradient
gradLv /= N * sigma;
AccType gradXv = gammav * gradLv;

View File

@ -109,6 +109,8 @@ public:
auto cost = model->build(graph, batch);
fits = graph->fits();
LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits);
if(fits) {
stats->add(batch, multiplier);
start = current + 1;