Merged PR 10827: Sequential unlikelihood training and fixed gather operation

This implements Sequential Unlikelihood Training from https://arxiv.org/abs/1908.04319
* implementation as expensive multi-op, special node in-progress.
* fixed gather operator to work in batched cases
This commit is contained in:
Martin Junczys-Dowmunt 2019-12-13 18:55:36 +00:00
parent 5be8558c35
commit e0500b20b8
13 changed files with 157 additions and 64 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Sequence-level unliklihood training
- Allow file name templated valid-translation-output files
- Support for lexical shortlists in marian-server
- Support for 8-bit matrix multiplication with FBGEMM
@ -29,6 +30,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Gradient-checkpointing
### Fixed
- Gather-operation for all index sizes
- Fix word weighting with max length cropping
- Fixed compilation on CPUs without support for AVX
- FastOpt now reads "n" and "y" values as strings, not as boolean values

View File

@ -328,6 +328,8 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
cli.add<std::string>("--multi-loss-type",
"How to accumulate multi-objective losses: sum, scaled, mean", "sum");
cli.add<bool>("--unlikelihood-loss",
"Use word-level weights as indicators for sequence-level unlikelihood training");
cli.add<bool>("--overwrite",
"Do not create model checkpoints, only overwrite main model file with last checkpoint. "
"Reduces disk usage");

View File

@ -198,8 +198,6 @@ public:
subWidth = j + 1;
}
}
//if (subWidth < width_)
// LOG(info, "[data] sub-batch {} of {} wide batch has effective width of {}", pos / targetSize, width_, subWidth);
// create sub-batch
auto sb = New<SubBatch>(subSize, subWidth, vocab_);
@ -369,7 +367,7 @@ public:
// split each stream separately
for(auto batchStream : subBatches_) {
size_t i = 0; // index into split batch
for(auto splitSubBatch : batchStream->split(n, sizeLimit)) {
for(auto splitSubBatch : batchStream->split(n, sizeLimit)) { // splits a batch into pieces, can also change width
if(subs.size() <= i)
subs.resize(i + 1);
subs[i++].push_back(splitSubBatch); // this forms tuples across streams
@ -424,12 +422,11 @@ public:
if(!dataWeights_.empty()) {
size_t oldSize = size();
size_t width = 1;
// There are more weights than sentences, i.e. these are word weights.
if(dataWeights_.size() != oldSize)
width = subBatches_.back()->batchWidth();
for(auto split : splits) {
auto cb = std::static_pointer_cast<CorpusBatch>(split);
size_t width = 1; // One weight per sentence in case of sentence-level weights
if(dataWeights_.size() != oldSize) // if number of weights does not correspond to number of sentences we have word-level weights
width = cb->back()->batchWidth(); // splitting also affects width, hence we need to accomodate this here
std::vector<float> ws(width * split->size(), 1.0f);
// this needs to be split along the batch dimension

View File

@ -145,20 +145,20 @@ struct ConstantShape {
HOST_DEVICE_INLINE int elements() const { return (int)elements_; }
// The following functions iterate over shape dimensions and use resursive
// The following functions iterate over shape dimensions and use recursive
// templates. They unroll over a compile-time defined number of dimensions.
// Struct for recurrent template calls over shape dimensions,
// version for K > 0
template <const int K, const int D> struct I {
HOST_DEVICE_INLINE static int index(const Array<int, D>& dims,
const Array<int, D>& stride) {
const Array<int, D>& stride) {
return dims[K] * stride[K] + I<K-1, D>::index(dims, stride);
}
HOST_DEVICE_INLINE static int index(int si,
const Array<int, D>& shape,
const Array<int, D>& stride) {
const Array<int, D>& shape,
const Array<int, D>& stride) {
return (si % shape[K]) * stride[K] + I<K-1, D>::index(si / shape[K], shape, stride);
}
@ -175,19 +175,19 @@ struct ConstantShape {
// specialization for K == 0
template <const int D> struct I<0, D> {
HOST_DEVICE_INLINE static int index(const Array<int, D>& dims,
const Array<int, D>& stride) {
const Array<int, D>& stride) {
return dims[0] * stride[0];
}
HOST_DEVICE_INLINE static int index(int si,
const Array<int, D>& shape,
const Array<int, D>& stride) {
const Array<int, D>& shape,
const Array<int, D>& stride) {
return (si % shape[0]) * stride[0];
}
HOST_DEVICE_INLINE static void dims(int si,
Array<int, D>& dims,
const Array<int, D>& shape) {
Array<int, D>& dims,
const Array<int, D>& shape) {
dims[0] = si % shape[0];
}
};

View File

@ -317,6 +317,7 @@ Expr index_select(Expr a, int axis, Expr indices) {
indices = reshape(indices, shape); // move index to axis
return gather(a, axis, indices);
}
Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices) {
auto indexExpr = a->graph()->indices(indices);
return index_select(a, axis, indexExpr);
@ -612,8 +613,20 @@ Expr cast(Expr a, Type type) {
}
}
Expr cross_entropy(Expr a, Expr indices) {
return Expression<CrossEntropyNodeOp>(a, indices);
Expr cross_entropy(Expr logits, Expr indices) {
return Expression<CrossEntropyNodeOp>(logits, indices);
}
// Unlikelihood loss based on https://arxiv.org/abs/1908.04319
Expr unlikelihood(Expr logits, Expr indices) {
int dimBatch = logits->shape()[-2];
int dimTime = logits->shape()[-3];
// @TODO: fix the outside of this function in decoder.h etc.
auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});
// This is currently implemented with mutliple ops, might be worth doing a special operation like for cross_entropy
return -log(gather(1.f - softmax(logits), /*axis=*/-1, indicesWithLayout));
}
Expr plus(const std::vector<Expr>& nodes) {

View File

@ -200,6 +200,8 @@ Expr logsoftmax(Expr a);
Expr cross_entropy(Expr a, Expr b);
Expr unlikelihood(Expr a, Expr b);
Expr scalar_product(Expr a, Expr b, int ax = 0);
Expr weighted_average(Expr in, Expr weights, int ax = 0);

View File

@ -569,8 +569,6 @@ struct RowsNodeOp : public NaryNodeOp {
// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
// 'a' and 'indices' must have the same rank.
// @TODO: The current implementation does not support batched indices (third scenario above).
// I.e. all axes of 'indices' except 'axis' must have dimension 1.
struct GatherNodeOp : public NaryNodeOp {
GatherNodeOp(Expr a, int axis, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()),
@ -599,10 +597,6 @@ struct GatherNodeOp : public NaryNodeOp {
if (i != axis) {
ABORT_IF(indices->shape()[i] != shape[i] && indices->shape()[i] != 1,
"Dimensions must match or broadcast for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
#if 1 // presently, this implementation does not support batched indices
ABORT_IF(indices->shape()[i] != 1,
"Presently, gather() does not implement batched indices");
#endif
}
}
return shape;

View File

@ -7,9 +7,15 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
float smoothing = inference ? 0.f : options->get<float>("label-smoothing");
float factorWeight = options->get<float>("factor-weight", 1.0f);
std::string costType = options->get<std::string>("cost-type", "ce-mean");
bool unlikelihood = options->get<bool>("unlikelihood-loss", false);
if(costType == "ce-rescore") { // returns per-batch-item scores (while ce-mean reduces over batch)
return New<RescorerLoss>();
} else if(unlikelihood) {
ABORT_IF(!options->hasAndNotEmpty("data-weighting")
&& options->get<std::string>("data-weighting-type") != "word",
"Unlikelihood loss training requires error annotation in form of per-target-label scores");
return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
} else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
return New<CrossEntropyLoss>(smoothing, factorWeight);
}

View File

@ -366,13 +366,68 @@ protected:
if(mask)
ce = ce * cast(mask, Type::float32);
if(labelWeights)
if(labelWeights) {
// We currently do not know how to use target factors and word-level label weights together
bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors");
ce = ce * cast(labelWeights, Type::float32);
}
return ce;
}
};
/**
* @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an
* implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319.
* We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not
* zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going
* to flip over to use SUL for that sentence to penalize the selected word.
*
* SUL is implemented as:
* -log(gather(1 - softmax(logits), -1, indices))
*
* Factors are currently not supported.
*/
class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
: CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
: CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
protected:
virtual Expr compute(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
if(!labelWeights)
return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
// We currently do not know how to use target factors and word-level label weights together
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks, mask again to eliminate padding (might be obsolete)
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
return cast(unlikelihood(logits, indices), Type::float32);
});
// compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training
// schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL.
auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
ceUl = errorMask * ceUl; // don't use for correct label or padding
auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry
return cost;
}
};
/**
* @brief Cross entropy in rescorer used for computing sentences-level log probabilities
*/

View File

@ -15,8 +15,16 @@ Expr DataWeighting::getWeights(Ptr<ExpressionGraph> graph,
bool sentenceWeighting = weightingType_ == "sentence";
int dimBatch = (int)batch->size();
int dimWords = sentenceWeighting ? 1 : (int)batch->back()->batchWidth();
// This would abort anyway in fromVector(...), but has clearer error message
// here for this particular case
ABORT_IF(batch->getDataWeights().size() != dimWords * dimBatch,
"Number of sentence/word-level weights ({}) does not match tensor size ({})",
batch->getDataWeights().size(), dimWords * dimBatch);
auto weights = graph->constant({1, dimWords, dimBatch, 1},
inits::fromVector(batch->getDataWeights()));
return weights;
return weights; // [1, dimWords, dimBatch, 1] in case of word-level weights or
// [1, 1, dimBatch, 1] in case of sentence-level weights
}
} // namespace marian

View File

@ -686,20 +686,22 @@ void Select(Tensor out,
// @TODO: make this efficient
functional::Shape outShape = out->shape();
functional::Shape inShape = in->shape();
functional::Shape inShape = in->shape();
functional::Shape idxShape = indices->shape();
int length = outShape.elements();
functional::Array<int, functional::Shape::size()> dims;
int axisCPU = (int)(axis + functional::Shape::size() - out->shape().size());
if(axisCPU == 2) // specialization for axis==2, assuming N=4
if(axisCPU == 2 && outShape == idxShape) // specialization for axis==2 when there is no broadcasting, @TODO to be removed once we have a faster implementation below
return SelectAxis2(out, in, indices);
for(int index = 0; index < length; ++index) {
outShape.dims(index, dims);
dims[axisCPU] = (int)indices->data<IndexType>()[dims[axisCPU]];
int inIndex = inShape.index(dims);
out->data()[index] = in->data()[inIndex];
outShape.dims(index, dims); // compute dimension-based indices from global index;
int idxIndex = idxShape.bindex(dims); // return global index for indices based on dimension-specific indices from out, take broadcasting into account;
dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex]; // substitute index of out-tensor with corresponding axis-local position from in-tensor;
int inIndex = inShape.index(dims); // compute global index from dimension-specific indices, no broadcasting as out and in match in all dimensions apart from axis
out->data()[index] = in->data()[inIndex]; // assign corresponding values.
}
}
@ -712,7 +714,8 @@ void Insert(Tensor out,
// @TODO: make this efficient
functional::Shape outShape = out->shape();
functional::Shape inShape = in->shape();
functional::Shape inShape = in->shape();
functional::Shape idxShape = indices->shape();
int length = inShape.elements();
functional::Array<int, functional::Shape::size()> dims;
@ -720,7 +723,8 @@ void Insert(Tensor out,
for(int index = 0; index < length; ++index) {
inShape.dims(index, dims);
dims[axisCPU] = (int)indices->data<IndexType>()[dims[axisCPU]];
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex];
int outIndex = outShape.index(dims);
out->data()[outIndex] += in->data()[index];
}
@ -887,7 +891,7 @@ void CrossEntropyPick(Tensor out, Tensor in, Tensor labelIndices) {
// Groundtruth label index
IndexType i = labelIndices->data<IndexType>()[j];
// This appears to be safe i.e. that i >= 0 && i < cols is known
out->data()[j] = std::log(sum) - sp[i] + max;
out->data()[j] = std::log(sum) - sp[i] + max; // -log(p_i) = - logsoftmax(x_i - max) = - (x_i - max) - log(sum_j exp(x_j - max))
}
}
@ -920,7 +924,8 @@ void CrossEntropyPickBackward(Tensor out,
// cross-entropy
for(int i = 0; i < cols; ++i) {
float sub = (float)(i == (int)labelIndices->data<IndexType>()[j]); // delta, true if label index and column index match
so[i] += adj->data()[j] * (std::exp(sp[i] - max) / sum - sub);
auto softmax = std::exp(sp[i] - max) / sum;
so[i] += adj->data()[j] * (softmax - sub);
}
}
}

View File

@ -1132,7 +1132,8 @@ __global__ void gSelect(T* out,
const T* in,
const functional::Shape inShape,
int axis,
IndexType* d_indices) {
const IndexType* d_indices,
const functional::Shape idxShape) {
int length = outShape.elements();
functional::Array<int, functional::Shape::size()> dims;
@ -1140,7 +1141,8 @@ __global__ void gSelect(T* out,
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
outShape.dims(index, dims);
dims[axis] = d_indices[dims[axis]];
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axis] = (int)d_indices[idxIndex];
int inIndex = inShape.index(dims);
out[index] = in[inIndex];
}
@ -1153,7 +1155,8 @@ __global__ void gInsert(T* out,
const T* in,
const functional::Shape inShape,
int axis,
IndexType* d_indices) {
const IndexType* d_indices,
const functional::Shape idxShape) {
int length = inShape.elements();
functional::Array<int, functional::Shape::size()> dims;
@ -1161,7 +1164,8 @@ __global__ void gInsert(T* out,
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
inShape.dims(index, dims);
dims[axis] = d_indices[dims[axis]];
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axis] = (int)d_indices[idxIndex];
int outIndex = outShape.index(dims);
out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
}
@ -1189,7 +1193,8 @@ void Select(Tensor out,
in->data<float>(),
in->shape(),
axisGPU,
indices->data<IndexType>());
indices->data<IndexType>(),
indices->shape());
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gSelect<<<blocks, threads>>>(out->data<half>(),
@ -1197,7 +1202,8 @@ void Select(Tensor out,
in->data<half>(),
in->shape(),
axisGPU,
indices->data<IndexType>());
indices->data<IndexType>(),
indices->shape());
#endif
} else {
ABORT("Select not implemented for type {}", out->type());
@ -1224,7 +1230,8 @@ void Insert(Tensor out,
in->data<float>(),
in->shape(),
axisGPU,
indices->data<IndexType>());
indices->data<IndexType>(),
indices->shape());
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gInsert<<<blocks, threads>>>(out->data<half>(),
@ -1232,7 +1239,8 @@ void Insert(Tensor out,
in->data<half>(),
in->shape(),
axisGPU,
indices->data<IndexType>());
indices->data<IndexType>(),
indices->shape());
#endif
} else {
ABORT("Insert not implemented for type {}", out->type());
@ -1522,11 +1530,11 @@ __global__ void gCrossEntropyPick(T* out,
__syncthreads();
// cross-entropy
auto sum = _sum[0];
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id == (int)pick[j]) {
out[j] = (T)functional::Ops<AccType>::log(_sum[0]) - sp[id] + max;
}
if(id == (int)pick[j])
out[j] = (T)functional::Ops<AccType>::log(sum) - sp[id] + max;
}
}
__syncthreads();
@ -1628,7 +1636,8 @@ __global__ void gCrossEntropyPickBackward(T* out,
int id = tid + threadIdx.x;
if(id < cols) {
AccType sub = (AccType)(id == (int)pick[j]);
so[id] += (AccType)adj[j] * (functional::Ops<AccType>::exp(sp[id] - max) / _sum[0] - sub);
auto softmax = functional::Ops<AccType>::exp(sp[id] - max) / _sum[0];
so[id] += (AccType)adj[j] * (softmax - sub);
}
}
}

View File

@ -670,16 +670,16 @@ void tests(DeviceType device, Type floatType = Type::float32) {
values.clear();
std::vector<T> vA({ 1, -2, 3,
-4, 5, -6,
7, -8, 9,
-10, 11, -12});
-4, 5, -6,
7, -8, 9,
-10, 11, -12});
std::vector<T> vC({ 1, -2, // C = np.array([1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12]).reshape((2, 3, 2))
3, -4,
5, -6,
3, -4,
5, -6,
7, -8,
9, -10,
11, -12 });
7, -8,
9, -10,
11, -12 });
std::vector<T> vB1({1, -2, 3});
std::vector<T> vB2({1, -4, 7, -10});
std::vector<T> vB3({-2, 5, -8, 11});
@ -687,7 +687,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
std::vector<T> vD1(vB4);
std::vector<T> vD2({5, -6, 11, -12});
std::vector<T> vD3({1, -2, 5, -6, 7, -8, 11, -12}); // C[:,(0,2),:]
//std::vector<float> vD4({5, -6, 3, -4, 7, -8, 11, -12}); // [C[0,(2,1),:],C[1,(0,2),:]]
std::vector<T> vD4({5, -6, 3, -4, 7, -8, 11, -12}); // [C[0,(2,1),:],C[1,(0,2),:]]
std::vector<T> vS1({7, -8, 9});
std::vector<T> vS2({-4, 5, -6, 7, -8, 9});
std::vector<T> vS3({7, -8, 9, -10, 11, -12});
@ -714,11 +714,11 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(D1->type() == "sliceView");
CHECK(D2->type() == "gather");
// enable this once gather() supports batched indices:
//auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
// inits::fromVector(std::vector<IndexType>{
// 2, 1,
// 0, 2 }),
// Type::uint32));
auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
inits::fromVector(std::vector<IndexType>{
2, 1,
0, 2 }),
Type::uint32));
auto S1 = slice(A, 0, 2);
auto S2 = narrow(A, 0, 1, 2);
@ -736,7 +736,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(D1->shape() == Shape({1, 3, 2})); D1->val()->get(values); CHECK( values == vD1 );
CHECK(D2->shape() == Shape({2, 1, 2})); D2->val()->get(values); CHECK( values == vD2 );
CHECK(D3->shape() == Shape({2, 2, 2})); D3->val()->get(values); CHECK( values == vD3 );
//CHECK(D4->shape() == Shape({2, 2, 2})); D4->val()->get(values); CHECK( values == vD4 );
CHECK(D4->shape() == Shape({2, 2, 2})); D4->val()->get(values); CHECK( values == vD4 );
CHECK(S1->shape() == Shape({1,3})); S1->val()->get(values); CHECK(values == vS1);
CHECK(S2->shape() == Shape({2,3})); S2->val()->get(values); CHECK(values == vS2);