hacky Nan handling

This commit is contained in:
Marcin Junczys-Dowmunt 2019-01-27 23:03:46 -08:00
parent 83fbd248d0
commit abe9467471
6 changed files with 62 additions and 17 deletions

View File

@ -113,6 +113,7 @@ for layer in range(config["enc-depth"]):
marianModel[marianPrefix + "_ffn_ffn_ln_scale"] = tfModel[tfPrefix + "/output/LayerNorm/gamma:0"]
marianModel[marianPrefix + "_ffn_ffn_ln_bias"] = tfModel[tfPrefix + "/output/LayerNorm/beta:0"]
# Training objectives
# Masked-LM output layer
marianModel["masked-lm_ff_logit_l1_W"] = tfModel["cls/predictions/transform/dense/kernel:0"]
marianModel["masked-lm_ff_logit_l1_b"] = tfModel["cls/predictions/transform/dense/bias:0"]

View File

@ -14,6 +14,11 @@ namespace marian {
namespace cpu {
void IsNan(const Tensor in, Ptr<Allocator> allocator, bool& isNan, bool& isInf) {
ABORT("Not implemented");
}
inline float stableSigmoid(float x) {
if(x >= 0) {
float z = expf(-x);
@ -394,7 +399,7 @@ void CopyRows(Tensor out_,
for(size_t j = 0; j < rows; ++j) {
size_t dst = j;
// @TODO: consider moving type checking to this function
// @TODO: consider moving type checking to this function
// instead of matchOrAbort above
size_t src = (size_t)indices->data<IndexType>()[j];
@ -494,7 +499,7 @@ void Select(Tensor out,
functional::Array<int, functional::Shape::size()> dims;
int axisCPU = (int)(axis + functional::Shape::size() - out->shape().size());
for(int index = 0; index < length; ++index) {
outShape.dims(index, dims);
dims[axisCPU] = (int)indices->data<IndexType>()[dims[axisCPU]];

View File

@ -55,6 +55,7 @@ void fill(Ptr<Backend> backend, T* begin, T* end, T value) {
CUDA_CHECK(cudaStreamSynchronize(0));
}
template void fill<bool>(Ptr<Backend>, bool*, bool*, bool);
template void fill<int8_t>(Ptr<Backend>, int8_t*, int8_t*, int8_t);
template void fill<int16_t>(Ptr<Backend>, int16_t*, int16_t*, int16_t);
template void fill<int32_t>(Ptr<Backend>, int32_t*, int32_t*, int32_t);
@ -84,7 +85,7 @@ __global__ void gSwap(T* d_v1, T* d_v2, int size) {
if(index < size) {
T temp = d_v1[index];
d_v1[index] = d_v2[index];
d_v2[index] = temp;
d_v2[index] = temp;
}
}
@ -93,7 +94,7 @@ void swap_ranges(Ptr<Backend> backend, T* begin, T* end, T* dest) {
int size = end - begin;
if (size == 0)
return;
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
int threadsPerBlock = std::min(MAX_THREADS, size);
int blocks = (size / threadsPerBlock) + (size % threadsPerBlock != 0); // @TODO: (size+threadsPerBlock-1)/threadsPerBlock or CeilDiv(a,b)

View File

@ -27,14 +27,43 @@ __device__ inline float stableSigmoid(float x) {
}
}
bool IsNan(Tensor in) {
// cudaSetDevice(in->getDeviceId().no);
// thrust::device_ptr<float> begin = thrust::device_pointer_cast(in->data());
// thrust::device_ptr<float> end
// = thrust::device_pointer_cast(in->data() + in->size());
// return thrust::transform_reduce(
// begin, end, isnan_test(), 0, thrust::plus<bool>());
return false;
template <typename T>
__global__ void gIsNan(const T* in, int length, bool* isNan, bool* isInf) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
if(isnan((float)in[index])) *isNan = true;
if(isinf((float)in[index])) *isInf = true;
//if(isinf2(in[index])) *isInf = true;
}
}
}
void IsNan(const Tensor in, Ptr<Allocator> allocator, bool& isNan, bool& isInf) {
cudaSetDevice(in->getDeviceId().no);
int length = in->size();
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
auto mem = allocator->alloc<bool>(2);
bool* dIsNan = &mem->data<bool>()[0];
bool* dIsInf = &mem->data<bool>()[1];
fill(in->getBackend(), dIsNan, dIsNan + 2, false);
if(in->type() == Type::float32) {
gIsNan<<<blocks, threads>>>(in->data<float>(), length, dIsNan, dIsInf);
} else {
ABORT("IsNan for type {} not implemented", in->type());
}
CudaCopy(dIsNan, dIsNan + 1, &isNan);
CudaCopy(dIsInf, dIsInf + 1, &isInf);
allocator->free(mem);
cudaStreamSynchronize(0);
}
void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) {
@ -1176,9 +1205,9 @@ __global__ void gCrossEntropyPick(float* out,
}
// In each j-th row, take the corresponding j-th label index i from indices and compute:
// For each vocabulary item v, the only non-zero element in a row in the sum is the item
// that matches the label indexed by i (the picked element).
// C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
// For each vocabulary item v, the only non-zero element in a row in the sum is the item
// that matches the label indexed by i (the picked element).
// C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
void CrossEntropyPick(Tensor out, Tensor in, Tensor indices) {
matchOrAbort<IndexType>(indices->type());

View File

@ -34,6 +34,8 @@ void copy(Ptr<Backend> backend, const InIt beg, const InIt end, OutIt it) {
std::copy(beg, end, it);
}
DISPATCH4(IsNan, const Tensor, Ptr<Allocator>, bool&, bool&);
template <class Functor, class... Tensors>
void Element(Functor functor, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND

View File

@ -356,12 +356,19 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
auto rationalLoss = builders_[localDeviceIndex]->build(graph, subBatch);
graph->forward();
StaticLoss tempLoss = *rationalLoss; // needed for overstuff
tempLoss.loss /= (float)overstuff; // @TODO: @fseide: scale only loss? should this scale labels too?
localDeviceLosses[localDeviceIndex] += tempLoss;
graph->backward(/*zero=*/false); // (gradients are reset before we get here)
bool hasNan = false, hasInf = false;
IsNan(graph->params()->grads(), graph->allocator(), hasNan, hasInf);
if(hasNan || hasInf) {
LOG(warn, "Seen Nan ({}) or Inf ({}) in gradient, zeroing gradient", hasNan, hasInf);
graph->params()->grads()->set(0.f);
}
}
});
// At this point, each device on each MPI process has a gradient aggregated over a subset of the sub-batches.
@ -521,7 +528,7 @@ void SyncGraphGroup::save(bool final) /*override*/ {
return comm_->gatherState(getShardFn);
},
isMainProcess());
barrier(); // (for better grouping of log messages)
}