mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 13081: minor regression fix in ONNX expoeter
This fixes a silly little regression that snuck in to the last commit.
This commit is contained in:
parent
ae7cae2760
commit
6d2bfa68c0
@ -38,7 +38,7 @@ id2word = { id : word.rstrip() for id, word in enumerate(open('c:/work/marian-de
|
||||
word2id = { word : id for id, word in id2word.items() }
|
||||
unk_id = word2id["<unk>"]
|
||||
|
||||
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans"
|
||||
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans-nounk"
|
||||
encode_source = get_function(model_path_prefix + '.encode_source.onnx',
|
||||
['encoder_context_0'])
|
||||
decode_first = get_function(model_path_prefix + '.decode_first.onnx',
|
||||
|
@ -47,6 +47,21 @@ namespace marian {
|
||||
}
|
||||
setInference(true); // note: must also set "inference" parameter on options
|
||||
|
||||
// if we must suppress <unk>, we do that by patching the bias
|
||||
const auto trgUnkId = vocabs.back()->getUnkId();
|
||||
int unkColId = -1;
|
||||
if (trgUnkId != Word::NONE && !modelOptions->get<bool>("allow-unk", false)) { // do we need to suppress unk?
|
||||
unkColId = trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector?
|
||||
// find the bias
|
||||
const std::string outputBiasName = "decoder_ff_logit_out_b";
|
||||
auto outputBias = graph->get(outputBiasName);
|
||||
auto outputBiasVal = outputBias->val();
|
||||
std::vector<float> outputBiasVec;
|
||||
outputBiasVal->get(outputBiasVec);
|
||||
outputBiasVec[unkColId] = -std::numeric_limits<float>::infinity();
|
||||
outputBiasVal->set(outputBiasVec);
|
||||
}
|
||||
|
||||
// the input length is represented by a value that hopefully is not used elsewhere
|
||||
const size_t sentinelDim = 97; // who uses prime numbers as dimensions anyways!
|
||||
size_t numEncoders = vocabs.size() - 1; // @TODO: test this exporter for >1 encoder
|
||||
|
4
src/onnx/expression_graph_onnx_serialization.cpp
Executable file → Normal file
4
src/onnx/expression_graph_onnx_serialization.cpp
Executable file → Normal file
@ -356,16 +356,16 @@ namespace marian {
|
||||
}
|
||||
float extraScalar = 1.0f;
|
||||
if (v->type() == "bdot") { // this maps to ONNX MatMul
|
||||
scalar = 1.0f; // we cannot scale in ONNX MatMul
|
||||
extraScalar = scalar; // must add extra scale operation at the end
|
||||
scalar = 1.0f; // we cannot scale in ONNX MatMul
|
||||
ABORT_IF(transA || transB || scalar != 1.0f, "Transposition and/or scalar not mapped away??");
|
||||
n = bdot(a, b, transA, transB, scalar);
|
||||
}
|
||||
else { // dot, affine
|
||||
// @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization.
|
||||
//if (a->shape().size() != 2 || b->shape().size() != 2) { // not ONNX MatMul: must use explicit scale operation
|
||||
scalar = 1.0f;
|
||||
extraScalar = scalar;
|
||||
scalar = 1.0f;
|
||||
//}
|
||||
n = dot(a, b, transA, transB, scalar);
|
||||
//LOG(info, "{} {} x {} -> {}", v->type(), std::string(a->shape()), std::string(b->shape()), std::string(n->shape()));
|
||||
|
Loading…
Reference in New Issue
Block a user