Merged PR 20729: Add top-k sampling

This adds Top-K sampling to Marian and extends the --output-sampling option to take arguments
This commit is contained in:
Marcin Junczys-Dowmunt 2021-11-22 03:32:54 +00:00
parent 1404201926
commit c85d060848
13 changed files with 226 additions and 53 deletions

@ -1 +1 @@
Subproject commit 7d612ca5e4b27a76f92584dad76d240e34f216d0
Subproject commit 0aa7b6b7632732d1f22f3d8169d3262a7e6b1e9d

View File

@ -695,9 +695,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Use softmax shortlist: path first best prune");
cli.add<std::vector<float>>("--weights",
"Scorer weights");
cli.add<bool>("--output-sampling",
"Noise output layer with gumbel noise",
false);
cli.add<std::vector<std::string>>("--output-sampling",
"Noise output layer with gumbel noise. Implicit default is 'full' for sampling from full distribution. "
" Also accepts 'topk num' (e.g. topk 100) for top-100 sampling.")
->implicit_val("full");
cli.add<std::vector<int>>("--output-approx-knn",
"Use approximate knn search in output layer (currently only in transformer)")
->implicit_val("100 1024");

View File

@ -357,6 +357,13 @@ Expr gather(Expr a, int axis, Expr indices) {
return Expression<GatherNodeOp>(a, axis, indices);
}
// scatter() -- scatter arbitrary elements along an axis; batched or non-batched
// This is the reverse operation to gather.
Expr scatter(Expr a, int axis, Expr indices, Expr source) {
return Expression<ScatterNodeOp>(a, axis, indices, source);
}
// index_select() -- gather arbitrary elements along an axis from an unbatched
// input 'a'. Indices are specified as a 1D vector.
// This is used e.g. for embedding lookup.

View File

@ -707,10 +707,23 @@ Expr stopGradient(Expr a);
* @param indices The indices to be gathered
* @returns Gathered expression with the same shape as @p indices
* @note @p a and @p indices must have the same rank
* @note The non-target axes of @p a and @p indicies must have the same size, or be broadcastable.
* @note The non-target axes of @p a and @p indices must have the same size, or be broadcastable.
*/
Expr gather(Expr a, int axis, Expr indices);
/**
* Scatter elements from source along an axis into a. Unindexed elements from a remain unchanged.
* This is the reverse operation to gather.
* @param a The input expression
* @param axis The axis along which to index
* @param indices The indices to be scattered
* @param source Expression with values to scatter.
* @returns Scattered expression with the same shape as @p a now containing values from @p source in positions @p indices
* @note @p source and @p indices must have the same rank
* @note In this version @p source and @p indicies must have the same shape
*/
Expr scatter(Expr a, int axis, Expr indices, Expr source);
#if 0
// reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis.
// with broadcasting

View File

@ -1033,12 +1033,14 @@ struct GatherNodeOp : public NaryNodeOp {
NodeOps forwardOps() override {
return {NodeOp(
// @TODO: rename to gather
Select(val_, child(0)->val(), child(1)->val(), axis_))};
}
NodeOps backwardOps() override {
return {NodeOp(
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
// @TODO: rename to scatter
Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))};
}
Shape newShape(Expr a, int axis, Expr indices) {
@ -1046,7 +1048,6 @@ struct GatherNodeOp : public NaryNodeOp {
axis = shape.axis(axis);
auto rank = shape.size();
ABORT_IF(rank != indices->shape().size(), "Mismatching ranks for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
axis = a->shape().axis(axis);
shape.set(axis, indices->shape()[axis]);
for (size_t i = 0; i < rank; ++i) {
if (i != axis) {
@ -1086,6 +1087,62 @@ private:
int axis_;
};
struct ScatterNodeOp : public NaryNodeOp {
ScatterNodeOp(Expr a, int axis, Expr indices, Expr source)
: NaryNodeOp({a, indices, source}, newShape(a, axis, indices, source), a->value_type()),
axis_(a->shape().axis(axis)) {
matchOrAbort<IndexType>(indices->value_type());
}
NodeOps forwardOps() override {
return {NodeOp(
CopyCast(val_, child(0)->val()); // @TODO: use normal copy
Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_)
)};
}
NodeOps backwardOps() override {
ABORT("backward for ScatterNodeOp not yet implemented");
}
Shape newShape(Expr a, int axis, Expr indices, Expr source) {
ABORT_IF(axis != -1, "only last dimensions");
ABORT_IF(indices->shape() != source->shape(), "Shapes must match");
Shape shape = a->shape();
// @TODO: do proper checking
return shape;
}
const std::string type() override { return "scatter"; }
const std::string color() override { return "orange"; }
virtual size_t hash() override {
if(!hash_) {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, axis_);
hash_ = seed;
}
return hash_;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<ScatterNodeOp>(node);
if(!cnode)
return false;
if(axis_ != cnode->axis_)
return false;
return true;
}
private:
friend class SerializationHelpers;
int axis_;
};
struct ColsNodeOp : public NaryNodeOp {
ColsNodeOp(Expr a, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {

View File

@ -133,7 +133,7 @@ public:
}
void backward() override {
Insert(/*out*/child(0)->grad(), adj_, val_, axis_);
Insert</*add=*/true>(/*out*/child(0)->grad(), adj_, val_, axis_);
}
const std::string type() override { return "topk"; }

View File

@ -10,5 +10,40 @@ Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
return state;
}
Ptr<DecoderState> GumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
[](Expr logits) { // lemma gets gumbelled
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
},
logsoftmax)); // factors don't
return state;
}
TopkGumbelSoftmaxStep::TopkGumbelSoftmaxStep(int k) : k_{k} {}
Ptr<DecoderState> TopkGumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
[=](Expr logits) { // lemma gets gumbelled
// create logits-sized tensor consisting only of invalid path scores
float invalidPathScore = NumericLimits<float>(logits->value_type()).lowest;
Expr invalidLogits = constant_like(logits, inits::fromValue(invalidPathScore));
// select top-k values
Expr val, idx;
std::tie(val, idx) = topk(logits, k_, /*axis=*/-1, /*descending=*/true);
// uncomment below to display probability mass in top-k selection
// debug(sum(gather(softmax(logits), -1, idx), -1), "sum");
// Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
// Scatter gumbelled values back into logits to fill with usable values
return scatter(invalidLogits, -1, idx, gumbelVal);
},
logsoftmax)); // factors don't
return state;
}
} // namespace models
} // namespace marian

View File

@ -297,20 +297,30 @@ public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// Gumbel-max noising for sampling during beam-search
// Seems to work well enough with beam-size=1. Turn on
// with --output-sampling during translation with marian-decoder
// Gumbel-max noising for sampling during translation.
// Produces accurate sampling with beam=1. Turn on
// with --output-sampling [full] during translation
// with marian-decoder for samnpling from the full
// softmax distribution.
class GumbelSoftmaxStep : public ILogProbStep {
public:
virtual ~GumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
[](Expr logits) { // lemma gets gumbelled
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
},
logsoftmax)); // factors don't
return state;
}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// Gumbel-max noising for top-k sampling during translation.
// Produces accurate sampling with beam=1. Turn on
// with --output-sampling topk [10] during translation
// with marian-decoder for top-10 sampling.
class TopkGumbelSoftmaxStep : public ILogProbStep {
private:
int k_{1};
public:
TopkGumbelSoftmaxStep(int k);
virtual ~TopkGumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,

View File

@ -370,10 +370,25 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
// add (log)softmax if requested
if (use == usage::translation) {
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
if(options->get<bool>("output-sampling", false))
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
else
if(options->hasAndNotEmpty("output-sampling")) {
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
std::string method = sampling.size() > 0 ? sampling[0] : "full";
if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
LOG(info, "Output sampling from the full softmax distribution");
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
} else if(method == "topk") {
int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10;
if(k == 1)
LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
LOG(info, "Output sampling via top-{} sampling", k);
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<TopkGumbelSoftmaxStep>(k));
} else {
ABORT("Unknown sampling method: {}", method);
}
} else {
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
}
}
#ifdef COMPILE_EXAMPLES
// note: 'usage::translation' here means 'inference'

View File

@ -739,6 +739,7 @@ void Select(Tensor out,
}
}
template <bool add>
void Insert(Tensor out,
const Tensor in,
const Tensor indices,
@ -760,10 +761,16 @@ void Insert(Tensor out,
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex];
int outIndex = outShape.index(dims);
out->data()[outIndex] += in->data()[index];
if(add)
out->data()[outIndex] += in->data()[index];
else
out->data()[outIndex] = in->data()[index];
}
}
template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) {
int rows = out_->shape().elements() / out_->shape().back();
int cols = out_->shape().back();

View File

@ -1309,7 +1309,7 @@ __global__ void gSelect(T* out,
}
}
template <typename T>
template <bool add, typename T>
__global__ void gInsert(T* out,
functional::Shape outShape,
const T* in,
@ -1327,7 +1327,10 @@ __global__ void gInsert(T* out,
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
dims[axis] = (int)d_indices[idxIndex];
int outIndex = outShape.index(dims);
out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
if(add)
out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
else
out[outIndex] = in[index];
}
}
}
@ -1349,21 +1352,21 @@ void Select(Tensor out,
if(out->type() == Type::float32) {
gSelect<<<blocks, threads>>>(out->data<float>(),
out->shape(),
in->data<float>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
out->shape(),
in->data<float>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gSelect<<<blocks, threads>>>(out->data<half>(),
out->shape(),
in->data<half>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
out->shape(),
in->data<half>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
#endif
} else if(out->type() == Type::uint32) {
gSelect<<<blocks, threads>>>(out->data<IndexType>(),
@ -1378,6 +1381,7 @@ void Select(Tensor out,
}
}
template <bool add>
void Insert(Tensor out,
const Tensor in,
const Tensor indices,
@ -1393,28 +1397,31 @@ void Insert(Tensor out,
int axisGPU = axis + functional::Shape::size() - out->shape().size();
if(out->type() == Type::float32) {
gInsert<<<blocks, threads>>>(out->data<float>(),
out->shape(),
in->data<float>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
gInsert<add><<<blocks, threads>>>(out->data<float>(),
out->shape(),
in->data<float>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gInsert<<<blocks, threads>>>(out->data<half>(),
out->shape(),
in->data<half>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
gInsert<add><<<blocks, threads>>>(out->data<half>(),
out->shape(),
in->data<half>(),
in->shape(),
axisGPU,
indices->data<IndexType>(),
indices->shape());
#endif
} else {
ABORT("Insert not implemented for type {}", out->type());
}
}
template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
template <typename T>
__global__ void gGRUFastForward(T* out,
const T* state,

View File

@ -297,7 +297,28 @@ DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH4(Select, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
DISPATCH4(Insert, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
#ifdef CUDA_FOUND
namespace gpu {
template <bool add>
void Insert(Tensor out, const Tensor in, const Tensor indices, int axis);
}
#endif
namespace cpu {
template <bool add>
void Insert(Tensor out, const Tensor in, const Tensor indices, int axis);
}
template <bool add>
static inline void Insert(Tensor out, const Tensor in, const Tensor indices, int axis) {
#ifdef CUDA_FOUND
if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
gpu::Insert<add>(out, in, indices, axis);
else
#endif
cpu::Insert<add>(out, in, indices, axis);
}
DISPATCH7(TopK, marian::Tensor, marian::Tensor, Ptr<Allocator>, const marian::Tensor, int, int, bool);

View File

@ -119,7 +119,7 @@ public:
threadPool.enqueue(task, device, id++);
}
if(options_->get<bool>("output-sampling", false)) {
if(options_->hasAndNotEmpty("output-sampling")) {
if(options_->get<size_t>("beam-size") > 1)
LOG(warn,
"[warning] Output sampling and beam search (beam-size > 1) are contradictory methods "