mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 20729: Add top-k sampling
This adds Top-K sampling to Marian and extends the --output-sampling option to take arguments
This commit is contained in:
parent
1404201926
commit
c85d060848
@ -1 +1 @@
|
||||
Subproject commit 7d612ca5e4b27a76f92584dad76d240e34f216d0
|
||||
Subproject commit 0aa7b6b7632732d1f22f3d8169d3262a7e6b1e9d
|
@ -695,9 +695,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
||||
"Use softmax shortlist: path first best prune");
|
||||
cli.add<std::vector<float>>("--weights",
|
||||
"Scorer weights");
|
||||
cli.add<bool>("--output-sampling",
|
||||
"Noise output layer with gumbel noise",
|
||||
false);
|
||||
cli.add<std::vector<std::string>>("--output-sampling",
|
||||
"Noise output layer with gumbel noise. Implicit default is 'full' for sampling from full distribution. "
|
||||
" Also accepts 'topk num' (e.g. topk 100) for top-100 sampling.")
|
||||
->implicit_val("full");
|
||||
cli.add<std::vector<int>>("--output-approx-knn",
|
||||
"Use approximate knn search in output layer (currently only in transformer)")
|
||||
->implicit_val("100 1024");
|
||||
|
@ -357,6 +357,13 @@ Expr gather(Expr a, int axis, Expr indices) {
|
||||
return Expression<GatherNodeOp>(a, axis, indices);
|
||||
}
|
||||
|
||||
// scatter() -- scatter arbitrary elements along an axis; batched or non-batched
|
||||
// This is the reverse operation to gather.
|
||||
Expr scatter(Expr a, int axis, Expr indices, Expr source) {
|
||||
return Expression<ScatterNodeOp>(a, axis, indices, source);
|
||||
}
|
||||
|
||||
|
||||
// index_select() -- gather arbitrary elements along an axis from an unbatched
|
||||
// input 'a'. Indices are specified as a 1D vector.
|
||||
// This is used e.g. for embedding lookup.
|
||||
|
@ -707,10 +707,23 @@ Expr stopGradient(Expr a);
|
||||
* @param indices The indices to be gathered
|
||||
* @returns Gathered expression with the same shape as @p indices
|
||||
* @note @p a and @p indices must have the same rank
|
||||
* @note The non-target axes of @p a and @p indicies must have the same size, or be broadcastable.
|
||||
* @note The non-target axes of @p a and @p indices must have the same size, or be broadcastable.
|
||||
*/
|
||||
Expr gather(Expr a, int axis, Expr indices);
|
||||
|
||||
/**
|
||||
* Scatter elements from source along an axis into a. Unindexed elements from a remain unchanged.
|
||||
* This is the reverse operation to gather.
|
||||
* @param a The input expression
|
||||
* @param axis The axis along which to index
|
||||
* @param indices The indices to be scattered
|
||||
* @param source Expression with values to scatter.
|
||||
* @returns Scattered expression with the same shape as @p a now containing values from @p source in positions @p indices
|
||||
* @note @p source and @p indices must have the same rank
|
||||
* @note In this version @p source and @p indicies must have the same shape
|
||||
*/
|
||||
Expr scatter(Expr a, int axis, Expr indices, Expr source);
|
||||
|
||||
#if 0
|
||||
// reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis.
|
||||
// with broadcasting
|
||||
|
@ -1033,12 +1033,14 @@ struct GatherNodeOp : public NaryNodeOp {
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
return {NodeOp(
|
||||
// @TODO: rename to gather
|
||||
Select(val_, child(0)->val(), child(1)->val(), axis_))};
|
||||
}
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
return {NodeOp(
|
||||
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
|
||||
// @TODO: rename to scatter
|
||||
Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))};
|
||||
}
|
||||
|
||||
Shape newShape(Expr a, int axis, Expr indices) {
|
||||
@ -1046,7 +1048,6 @@ struct GatherNodeOp : public NaryNodeOp {
|
||||
axis = shape.axis(axis);
|
||||
auto rank = shape.size();
|
||||
ABORT_IF(rank != indices->shape().size(), "Mismatching ranks for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
|
||||
axis = a->shape().axis(axis);
|
||||
shape.set(axis, indices->shape()[axis]);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
if (i != axis) {
|
||||
@ -1086,6 +1087,62 @@ private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
struct ScatterNodeOp : public NaryNodeOp {
|
||||
ScatterNodeOp(Expr a, int axis, Expr indices, Expr source)
|
||||
: NaryNodeOp({a, indices, source}, newShape(a, axis, indices, source), a->value_type()),
|
||||
axis_(a->shape().axis(axis)) {
|
||||
matchOrAbort<IndexType>(indices->value_type());
|
||||
}
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
return {NodeOp(
|
||||
CopyCast(val_, child(0)->val()); // @TODO: use normal copy
|
||||
Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_)
|
||||
)};
|
||||
}
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
ABORT("backward for ScatterNodeOp not yet implemented");
|
||||
}
|
||||
|
||||
Shape newShape(Expr a, int axis, Expr indices, Expr source) {
|
||||
ABORT_IF(axis != -1, "only last dimensions");
|
||||
ABORT_IF(indices->shape() != source->shape(), "Shapes must match");
|
||||
|
||||
Shape shape = a->shape();
|
||||
// @TODO: do proper checking
|
||||
return shape;
|
||||
}
|
||||
|
||||
const std::string type() override { return "scatter"; }
|
||||
|
||||
const std::string color() override { return "orange"; }
|
||||
|
||||
virtual size_t hash() override {
|
||||
if(!hash_) {
|
||||
size_t seed = NaryNodeOp::hash();
|
||||
util::hash_combine(seed, axis_);
|
||||
hash_ = seed;
|
||||
}
|
||||
return hash_;
|
||||
}
|
||||
|
||||
virtual bool equal(Expr node) override {
|
||||
if(!NaryNodeOp::equal(node))
|
||||
return false;
|
||||
auto cnode = std::dynamic_pointer_cast<ScatterNodeOp>(node);
|
||||
if(!cnode)
|
||||
return false;
|
||||
if(axis_ != cnode->axis_)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class SerializationHelpers;
|
||||
int axis_;
|
||||
};
|
||||
|
||||
struct ColsNodeOp : public NaryNodeOp {
|
||||
ColsNodeOp(Expr a, Expr indices)
|
||||
: NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
|
||||
|
@ -133,7 +133,7 @@ public:
|
||||
}
|
||||
|
||||
void backward() override {
|
||||
Insert(/*out*/child(0)->grad(), adj_, val_, axis_);
|
||||
Insert</*add=*/true>(/*out*/child(0)->grad(), adj_, val_, axis_);
|
||||
}
|
||||
|
||||
const std::string type() override { return "topk"; }
|
||||
|
@ -10,5 +10,40 @@ Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
return state;
|
||||
}
|
||||
|
||||
Ptr<DecoderState> GumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
|
||||
[](Expr logits) { // lemma gets gumbelled
|
||||
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
return state;
|
||||
}
|
||||
|
||||
TopkGumbelSoftmaxStep::TopkGumbelSoftmaxStep(int k) : k_{k} {}
|
||||
|
||||
Ptr<DecoderState> TopkGumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
|
||||
[=](Expr logits) { // lemma gets gumbelled
|
||||
// create logits-sized tensor consisting only of invalid path scores
|
||||
float invalidPathScore = NumericLimits<float>(logits->value_type()).lowest;
|
||||
Expr invalidLogits = constant_like(logits, inits::fromValue(invalidPathScore));
|
||||
|
||||
// select top-k values
|
||||
Expr val, idx;
|
||||
std::tie(val, idx) = topk(logits, k_, /*axis=*/-1, /*descending=*/true);
|
||||
|
||||
// uncomment below to display probability mass in top-k selection
|
||||
// debug(sum(gather(softmax(logits), -1, idx), -1), "sum");
|
||||
|
||||
// Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
|
||||
Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
|
||||
|
||||
// Scatter gumbelled values back into logits to fill with usable values
|
||||
return scatter(invalidLogits, -1, idx, gumbelVal);
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
return state;
|
||||
}
|
||||
|
||||
} // namespace models
|
||||
} // namespace marian
|
||||
|
@ -297,20 +297,30 @@ public:
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
|
||||
};
|
||||
|
||||
// Gumbel-max noising for sampling during beam-search
|
||||
// Seems to work well enough with beam-size=1. Turn on
|
||||
// with --output-sampling during translation with marian-decoder
|
||||
// Gumbel-max noising for sampling during translation.
|
||||
// Produces accurate sampling with beam=1. Turn on
|
||||
// with --output-sampling [full] during translation
|
||||
// with marian-decoder for samnpling from the full
|
||||
// softmax distribution.
|
||||
class GumbelSoftmaxStep : public ILogProbStep {
|
||||
public:
|
||||
virtual ~GumbelSoftmaxStep() {}
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
|
||||
[](Expr logits) { // lemma gets gumbelled
|
||||
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
return state;
|
||||
}
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
|
||||
};
|
||||
|
||||
|
||||
// Gumbel-max noising for top-k sampling during translation.
|
||||
// Produces accurate sampling with beam=1. Turn on
|
||||
// with --output-sampling topk [10] during translation
|
||||
// with marian-decoder for top-10 sampling.
|
||||
class TopkGumbelSoftmaxStep : public ILogProbStep {
|
||||
private:
|
||||
int k_{1};
|
||||
|
||||
public:
|
||||
TopkGumbelSoftmaxStep(int k);
|
||||
virtual ~TopkGumbelSoftmaxStep() {}
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
|
||||
};
|
||||
|
||||
// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
|
||||
|
@ -370,10 +370,25 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
|
||||
// add (log)softmax if requested
|
||||
if (use == usage::translation) {
|
||||
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
|
||||
if(options->get<bool>("output-sampling", false))
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
|
||||
else
|
||||
if(options->hasAndNotEmpty("output-sampling")) {
|
||||
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
|
||||
std::string method = sampling.size() > 0 ? sampling[0] : "full";
|
||||
|
||||
if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
|
||||
LOG(info, "Output sampling from the full softmax distribution");
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
|
||||
} else if(method == "topk") {
|
||||
int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10;
|
||||
if(k == 1)
|
||||
LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
|
||||
LOG(info, "Output sampling via top-{} sampling", k);
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<TopkGumbelSoftmaxStep>(k));
|
||||
} else {
|
||||
ABORT("Unknown sampling method: {}", method);
|
||||
}
|
||||
} else {
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
|
||||
}
|
||||
}
|
||||
#ifdef COMPILE_EXAMPLES
|
||||
// note: 'usage::translation' here means 'inference'
|
||||
|
@ -739,6 +739,7 @@ void Select(Tensor out,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool add>
|
||||
void Insert(Tensor out,
|
||||
const Tensor in,
|
||||
const Tensor indices,
|
||||
@ -760,10 +761,16 @@ void Insert(Tensor out,
|
||||
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
|
||||
dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex];
|
||||
int outIndex = outShape.index(dims);
|
||||
out->data()[outIndex] += in->data()[index];
|
||||
if(add)
|
||||
out->data()[outIndex] += in->data()[index];
|
||||
else
|
||||
out->data()[outIndex] = in->data()[index];
|
||||
}
|
||||
}
|
||||
|
||||
template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
|
||||
template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
|
||||
|
||||
void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) {
|
||||
int rows = out_->shape().elements() / out_->shape().back();
|
||||
int cols = out_->shape().back();
|
||||
|
@ -1309,7 +1309,7 @@ __global__ void gSelect(T* out,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <bool add, typename T>
|
||||
__global__ void gInsert(T* out,
|
||||
functional::Shape outShape,
|
||||
const T* in,
|
||||
@ -1327,7 +1327,10 @@ __global__ void gInsert(T* out,
|
||||
int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
|
||||
dims[axis] = (int)d_indices[idxIndex];
|
||||
int outIndex = outShape.index(dims);
|
||||
out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
|
||||
if(add)
|
||||
out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
|
||||
else
|
||||
out[outIndex] = in[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1349,21 +1352,21 @@ void Select(Tensor out,
|
||||
|
||||
if(out->type() == Type::float32) {
|
||||
gSelect<<<blocks, threads>>>(out->data<float>(),
|
||||
out->shape(),
|
||||
in->data<float>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
out->shape(),
|
||||
in->data<float>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
#if COMPILE_FP16
|
||||
} else if (out->type() == Type::float16) {
|
||||
gSelect<<<blocks, threads>>>(out->data<half>(),
|
||||
out->shape(),
|
||||
in->data<half>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
out->shape(),
|
||||
in->data<half>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
#endif
|
||||
} else if(out->type() == Type::uint32) {
|
||||
gSelect<<<blocks, threads>>>(out->data<IndexType>(),
|
||||
@ -1378,6 +1381,7 @@ void Select(Tensor out,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool add>
|
||||
void Insert(Tensor out,
|
||||
const Tensor in,
|
||||
const Tensor indices,
|
||||
@ -1393,28 +1397,31 @@ void Insert(Tensor out,
|
||||
int axisGPU = axis + functional::Shape::size() - out->shape().size();
|
||||
|
||||
if(out->type() == Type::float32) {
|
||||
gInsert<<<blocks, threads>>>(out->data<float>(),
|
||||
out->shape(),
|
||||
in->data<float>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
gInsert<add><<<blocks, threads>>>(out->data<float>(),
|
||||
out->shape(),
|
||||
in->data<float>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
#if COMPILE_FP16
|
||||
} else if (out->type() == Type::float16) {
|
||||
gInsert<<<blocks, threads>>>(out->data<half>(),
|
||||
out->shape(),
|
||||
in->data<half>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
gInsert<add><<<blocks, threads>>>(out->data<half>(),
|
||||
out->shape(),
|
||||
in->data<half>(),
|
||||
in->shape(),
|
||||
axisGPU,
|
||||
indices->data<IndexType>(),
|
||||
indices->shape());
|
||||
#endif
|
||||
} else {
|
||||
ABORT("Insert not implemented for type {}", out->type());
|
||||
}
|
||||
}
|
||||
|
||||
template void Insert<true>(Tensor out, const Tensor in, const Tensor indices, int axis);
|
||||
template void Insert<false>(Tensor out, const Tensor in, const Tensor indices, int axis);
|
||||
|
||||
template <typename T>
|
||||
__global__ void gGRUFastForward(T* out,
|
||||
const T* state,
|
||||
|
@ -297,7 +297,28 @@ DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
|
||||
DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
|
||||
|
||||
DISPATCH4(Select, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
|
||||
DISPATCH4(Insert, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
|
||||
|
||||
#ifdef CUDA_FOUND
|
||||
namespace gpu {
|
||||
template <bool add>
|
||||
void Insert(Tensor out, const Tensor in, const Tensor indices, int axis);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace cpu {
|
||||
template <bool add>
|
||||
void Insert(Tensor out, const Tensor in, const Tensor indices, int axis);
|
||||
}
|
||||
|
||||
template <bool add>
|
||||
static inline void Insert(Tensor out, const Tensor in, const Tensor indices, int axis) {
|
||||
#ifdef CUDA_FOUND
|
||||
if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
|
||||
gpu::Insert<add>(out, in, indices, axis);
|
||||
else
|
||||
#endif
|
||||
cpu::Insert<add>(out, in, indices, axis);
|
||||
}
|
||||
|
||||
DISPATCH7(TopK, marian::Tensor, marian::Tensor, Ptr<Allocator>, const marian::Tensor, int, int, bool);
|
||||
|
||||
|
@ -119,7 +119,7 @@ public:
|
||||
threadPool.enqueue(task, device, id++);
|
||||
}
|
||||
|
||||
if(options_->get<bool>("output-sampling", false)) {
|
||||
if(options_->hasAndNotEmpty("output-sampling")) {
|
||||
if(options_->get<size_t>("beam-size") > 1)
|
||||
LOG(warn,
|
||||
"[warning] Output sampling and beam search (beam-size > 1) are contradictory methods "
|
||||
|
Loading…
Reference in New Issue
Block a user