add fallback option for sampling, for back-compat

This commit is contained in:
Marcin Junczys-Dowmunt 2022-05-09 13:28:28 -07:00
parent 1a74358277
commit e4f3d0f740

View File

@ -374,7 +374,10 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
std::string method = sampling.size() > 0 ? sampling[0] : "full";
if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
if(method == "0") { /*for backwards-compat when output-sampling: false in yaml file*/
// do normal decoding
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
} else if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
LOG(info, "Output sampling from the full softmax distribution");
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
} else if(method == "topk") {