mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
add fallback option for sampling, for back-compat
This commit is contained in:
parent
1a74358277
commit
e4f3d0f740
@ -374,7 +374,10 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
|
|||||||
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
|
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
|
||||||
std::string method = sampling.size() > 0 ? sampling[0] : "full";
|
std::string method = sampling.size() > 0 ? sampling[0] : "full";
|
||||||
|
|
||||||
if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
|
if(method == "0") { /*for backwards-compat when output-sampling: false in yaml file*/
|
||||||
|
// do normal decoding
|
||||||
|
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
|
||||||
|
} else if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
|
||||||
LOG(info, "Output sampling from the full softmax distribution");
|
LOG(info, "Output sampling from the full softmax distribution");
|
||||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
|
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
|
||||||
} else if(method == "topk") {
|
} else if(method == "topk") {
|
||||||
|
Loading…
Reference in New Issue
Block a user