Merged PR 13081: minor regression fix in ONNX expoeter

This fixes a silly little regression that snuck in to the last commit.
This commit is contained in:
Frank Seide 2020-05-27 06:07:28 +00:00
parent ae7cae2760
commit 6d2bfa68c0
3 changed files with 18 additions and 3 deletions

View File

@ -38,7 +38,7 @@ id2word = { id : word.rstrip() for id, word in enumerate(open('c:/work/marian-de
word2id = { word : id for id, word in id2word.items() }
unk_id = word2id["<unk>"]
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans"
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans-nounk"
encode_source = get_function(model_path_prefix + '.encode_source.onnx',
['encoder_context_0'])
decode_first = get_function(model_path_prefix + '.decode_first.onnx',

View File

@ -47,6 +47,21 @@ namespace marian {
}
setInference(true); // note: must also set "inference" parameter on options
// if we must suppress <unk>, we do that by patching the bias
const auto trgUnkId = vocabs.back()->getUnkId();
int unkColId = -1;
if (trgUnkId != Word::NONE && !modelOptions->get<bool>("allow-unk", false)) { // do we need to suppress unk?
unkColId = trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector?
// find the bias
const std::string outputBiasName = "decoder_ff_logit_out_b";
auto outputBias = graph->get(outputBiasName);
auto outputBiasVal = outputBias->val();
std::vector<float> outputBiasVec;
outputBiasVal->get(outputBiasVec);
outputBiasVec[unkColId] = -std::numeric_limits<float>::infinity();
outputBiasVal->set(outputBiasVec);
}
// the input length is represented by a value that hopefully is not used elsewhere
const size_t sentinelDim = 97; // who uses prime numbers as dimensions anyways!
size_t numEncoders = vocabs.size() - 1; // @TODO: test this exporter for >1 encoder

4
src/onnx/expression_graph_onnx_serialization.cpp Executable file → Normal file
View File

@ -356,16 +356,16 @@ namespace marian {
}
float extraScalar = 1.0f;
if (v->type() == "bdot") { // this maps to ONNX MatMul
scalar = 1.0f; // we cannot scale in ONNX MatMul
extraScalar = scalar; // must add extra scale operation at the end
scalar = 1.0f; // we cannot scale in ONNX MatMul
ABORT_IF(transA || transB || scalar != 1.0f, "Transposition and/or scalar not mapped away??");
n = bdot(a, b, transA, transB, scalar);
}
else { // dot, affine
// @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization.
//if (a->shape().size() != 2 || b->shape().size() != 2) { // not ONNX MatMul: must use explicit scale operation
scalar = 1.0f;
extraScalar = scalar;
scalar = 1.0f;
//}
n = dot(a, b, transA, transB, scalar);
//LOG(info, "{} {} x {} -> {}", v->type(), std::string(a->shape()), std::string(b->shape()), std::string(n->shape()));