mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 13053: minor updates to ONNX code
will not affect anything else
This commit is contained in:
parent
0dc318e993
commit
c8a62dd2c8
@ -12,6 +12,7 @@ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_AL
|
||||
def get_function(path, output_vars):
|
||||
print("Reading ONNX function from", path)
|
||||
#model = onnx.load(path)
|
||||
#print("Done", flush=True)
|
||||
#print(model)
|
||||
ort_sess = ort.InferenceSession(path, sess_options)
|
||||
output_defs = ort_sess.get_outputs()
|
||||
@ -37,19 +38,20 @@ id2word = { id : word.rstrip() for id, word in enumerate(open('c:/work/marian-de
|
||||
word2id = { word : id for id, word in id2word.items() }
|
||||
unk_id = word2id["<unk>"]
|
||||
|
||||
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin"
|
||||
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans"
|
||||
encode_source = get_function(model_path_prefix + '.encode_source.onnx',
|
||||
['encoder_context_0'])
|
||||
decode_first = get_function(model_path_prefix + '.decode_first.onnx',
|
||||
['logits', 'out_decoder_state_0', 'out_decoder_state_1', 'out_decoder_state_2', 'out_decoder_state_3', 'out_decoder_state_4', 'out_decoder_state_5'])
|
||||
['first_logits', 'first_decoder_state_0', 'first_decoder_state_1', 'first_decoder_state_2', 'first_decoder_state_3', 'first_decoder_state_4', 'first_decoder_state_5'])
|
||||
decode_next = get_function(model_path_prefix + '.decode_next.onnx',
|
||||
['logits', 'out_decoder_state_0', 'out_decoder_state_1', 'out_decoder_state_2', 'out_decoder_state_3', 'out_decoder_state_4', 'out_decoder_state_5'])
|
||||
['next_logits', 'next_decoder_state_0', 'next_decoder_state_1', 'next_decoder_state_2', 'next_decoder_state_3', 'next_decoder_state_4', 'next_decoder_state_5'])
|
||||
|
||||
def greedy_decode(data_0):
|
||||
if len(data_0) == 1: # special handling for the empty sentence, like Marian
|
||||
return data_0
|
||||
data_0_mask = [[[1.]]] * len(data_0)
|
||||
data_0_index_range = [[[float(t)]] for t in range(len(data_0))]
|
||||
#print(data_0, data_0_mask, data_0_index_range)
|
||||
|
||||
max_len = len(data_0) * 3
|
||||
Y = []
|
||||
@ -69,7 +71,7 @@ def greedy_decode(data_0):
|
||||
return Y
|
||||
|
||||
start_time = time.time()
|
||||
with open("C:/work/marian-dev/local/model/predictions.out-onnx-debug-sin-3-first100.tok", 'wt', encoding='utf-8') as out_f:
|
||||
with open("C:/work/marian-dev/local/model/predictions.out-onnx-debug-sin-notrans-first100-d.tok", 'wt', encoding='utf-8') as out_f:
|
||||
for line in open("C:/work/marian-dev/local/model/predictions.in-first100.tok", encoding='utf-8').readlines():
|
||||
data = [word2id.get(w, unk_id) for w in (line.rstrip() + " </s>").split(' ') if w]
|
||||
Y = greedy_decode(data)
|
||||
|
@ -137,9 +137,9 @@ namespace marian {
|
||||
inputs.emplace_back(encoderContexts[i]);
|
||||
inputs.emplace_back(encoderEmbeddingInputs[1+2*i]);
|
||||
}
|
||||
outputs.emplace_back(std::make_pair("logits", decodeFirstState->getLogProbs().getLogits()));
|
||||
outputs.emplace_back(std::make_pair("first_logits", decodeFirstState->getLogProbs().getLogits()));
|
||||
for (const auto& dss : extractStates(decodeFirstState))
|
||||
outputs.emplace_back(std::make_pair("out_decoder_state_" + std::to_string(outputs.size()-1), dss));
|
||||
outputs.emplace_back(std::make_pair("first_decoder_state_" + std::to_string(outputs.size()-1), dss));
|
||||
functionDefs["decode_first"] = std::make_pair(std::move(inputs), std::move(outputs));
|
||||
|
||||
// descriptor for decode_next(prev_word, data_1_posrange, encoder_context_0, data_0_mask, decoder_state_0, decoder_state_1, ...) -> logits, decoder_state_0, decoder_state_1, ...
|
||||
@ -151,9 +151,9 @@ namespace marian {
|
||||
}
|
||||
for (const auto& dss : extractStates(decodeFirstState))
|
||||
inputs.emplace_back(std::make_pair("decoder_state_" + std::to_string(inputs.size() - (numEncoders*2 + 2)), dss));
|
||||
outputs.emplace_back(std::make_pair("logits", decodeNextState->getLogProbs().getLogits()));
|
||||
outputs.emplace_back(std::make_pair("next_logits", decodeNextState->getLogProbs().getLogits()));
|
||||
for (const auto& dss : extractStates(decodeNextState))
|
||||
outputs.emplace_back(std::make_pair("out_decoder_state_" + std::to_string(outputs.size() - 1), dss));
|
||||
outputs.emplace_back(std::make_pair("next_decoder_state_" + std::to_string(outputs.size() - 1), dss));
|
||||
functionDefs["decode_next"] = std::make_pair(std::move(inputs), std::move(outputs));
|
||||
|
||||
// now export the sub-graph as given by the function descriptor
|
||||
|
@ -322,36 +322,62 @@ namespace marian {
|
||||
auto oneExpr = newConstant(v, {}, 1.0f, "one");
|
||||
n = s * y + (oneExpr - s) * x;
|
||||
}
|
||||
else if (v->type() == "dot" ||
|
||||
v->type() == "bdot" ||
|
||||
else if ( v->type() == "bdot" ||
|
||||
(v->type() == "dot" /* && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2)*/) ||
|
||||
(v->type() == "affine" && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2 || v->child(2)->shape().size() > 2))) {
|
||||
// ONNX MatMul behaves like Numpy matmul, and therefore implements batched semantics.
|
||||
// ONNX MatMul has no transA/B/scale parameters, so we must handle those as explicit operations.
|
||||
// affine() could also be ONNX Gemm, but that does not support outer ranks, so we just expand it into dot().
|
||||
// @TODO: ^^ we can just reshape(). Code is already below, but ONNX Gemm always crashes, so this is disabled for now.
|
||||
auto a = v->child(0);
|
||||
auto b = v->child(1);
|
||||
bool transA{}, transB{}; float scalar{}; // (gcc complains without the initializers, which I think is a compiler bug)
|
||||
E::tryGetMatMulAttributes<DotNodeOp> (v, transA, transB, scalar) ||
|
||||
E::tryGetMatMulAttributes<DotBatchedNodeOp>(v, transA, transB, scalar) ||
|
||||
E::tryGetMatMulAttributes<AffineNodeOp> (v, transA, transB, scalar) || E::fail();
|
||||
//LOG(info, "{} {}={}x{} trans = {}, {} and scalar = {}",
|
||||
// v->type(), std::string(v->shape()), std::string(a->shape()), std::string(b->shape()), transA, transB, scalar);
|
||||
if (transA || transB || scalar != 1.0f ||
|
||||
(v->type() == "affine" && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2 || v->child(2)->shape().size() > 2))) {
|
||||
//LOG(info, "patching {} due to trans = {}, {} and scalar = {}", v->type(), transA, transB, scalar);
|
||||
if (transA)
|
||||
(v->type() == "affine" && (a->shape().size() != 2 || b->shape().size() != 2 || v->child(2)->shape().size() > 2))) {
|
||||
//LOG(info, "patching {} {}={}x{} due to trans = {}, {} and scalar = {}",
|
||||
// v->type(), std::string(v->shape()), std::string(a->shape()), std::string(b->shape()), transA, transB, scalar);
|
||||
if (transA) { // note: we don't optimize for this since it does not happen in present models
|
||||
a = swapAxes(a, -1, -2);
|
||||
if (transB)
|
||||
transA = false;
|
||||
}
|
||||
// @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization.
|
||||
//if (v->type() != "bdot" && b->shape().size() == 2) { // [A,B,C,I,J] x [J,K] --> reshape into regular matrix product
|
||||
// ABORT_IF(transA, "Transposition not mapped away??");
|
||||
// a = reshape(a, Shape({ a->shape().elements() / a->shape()[-1], a->shape()[-1] })); // now it's a regular matrix product, can use Gemm
|
||||
//}
|
||||
/*else*/ if (transB) { // not a regular matrix product: cannot use Gemm, so must transpose manually
|
||||
b = swapAxes(b, -1, -2);
|
||||
if (v->type() == "bdot")
|
||||
n = bdot(a, b);
|
||||
else // dot, affine
|
||||
n = dot(a, b);
|
||||
//LOG(info, "{} {} x {} -> {}", v->type(), std::string(a->shape()), std::string(b->shape()), std::string(n->shape()));
|
||||
if (v->type() == "affine")
|
||||
n = n + v->child(2);
|
||||
transB = false;
|
||||
}
|
||||
float extraScalar = 1.0f;
|
||||
if (v->type() == "bdot") { // this maps to ONNX MatMul
|
||||
scalar = 1.0f; // we cannot scale in ONNX MatMul
|
||||
extraScalar = scalar; // must add extra scale operation at the end
|
||||
ABORT_IF(transA || transB || scalar != 1.0f, "Transposition and/or scalar not mapped away??");
|
||||
n = bdot(a, b, transA, transB, scalar);
|
||||
}
|
||||
else { // dot, affine
|
||||
// @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization.
|
||||
//if (a->shape().size() != 2 || b->shape().size() != 2) { // not ONNX MatMul: must use explicit scale operation
|
||||
scalar = 1.0f;
|
||||
extraScalar = scalar;
|
||||
//}
|
||||
n = dot(a, b, transA, transB, scalar);
|
||||
//LOG(info, "{} {} x {} -> {}", v->type(), std::string(a->shape()), std::string(b->shape()), std::string(n->shape()));
|
||||
if (v->type() == "affine")
|
||||
n = n + v->child(2);
|
||||
}
|
||||
//if (v->type() == "affine")
|
||||
// LOG(info, "{} + {} -> {}", v->type(), std::string(v->child(2)->shape()), std::string(n->shape()));
|
||||
if (scalar != 1.0f)
|
||||
n = n * newConstant(v, {}, scalar, "scalar");
|
||||
if (extraScalar != 1.0f)
|
||||
n = n * newConstant(v, {}, extraScalar, "scalar");
|
||||
if (n->shape() != v->shape())
|
||||
n = reshape(n, v->shape()); // if we did some shaping to get a regular matrix product, reshape it back
|
||||
}
|
||||
}
|
||||
else if (v->type() == "affine" && v->children().size() > 3) {
|
||||
@ -596,7 +622,7 @@ namespace marian {
|
||||
model.set_ir_version(IR_VERSION);
|
||||
model.set_producer_name(producerName);
|
||||
model.mutable_graph()->CopyFrom(graph);
|
||||
#define OPSET_IMPORT_VERSION 8 // Azure supports 8. @TODO: Avoid hard-coded value. Is there a header we need to include?
|
||||
#define OPSET_IMPORT_VERSION 9 // 9 is needed for some newer ops
|
||||
model.add_opset_import()->set_version(OPSET_IMPORT_VERSION);
|
||||
return model;
|
||||
}
|
||||
@ -742,6 +768,10 @@ namespace marian {
|
||||
auto name = getExprName(expr, nameOverrides); // node name is used as both output name and node name
|
||||
auto op = mapExprOp(expr);
|
||||
|
||||
//if (op == "MatMul" && expr->child(0)->shape().size() == 2 && expr->child(1)->shape().size() == 2) {
|
||||
// op = "Gemm";
|
||||
//}
|
||||
|
||||
#if 1 // workaround for onnxruntime which does not handle Pad correctly
|
||||
if (op == "Pad") {
|
||||
// Implement Pad as Slice >> Concat
|
||||
@ -855,14 +885,18 @@ namespace marian {
|
||||
// matmul attributes
|
||||
bool transA, transB;
|
||||
float scalar;
|
||||
if (node.op_type() == "Gemm") {
|
||||
// Note: Not all affine() calls get mapped to Gemm.
|
||||
ABORT_IF(children[0]->shape().size() != 2 || children[1]->shape().size() != 2 || children[2]->shape().size() > 2,
|
||||
// @BUGBUG: I cannot get Gemm to work, ONNX runtime always crashes. So we will NEVER get here.
|
||||
if (node.op_type() == "Gemm") { // we get here for affine() or dot()
|
||||
// Note: We only get here if Gemm can implement this configuration.
|
||||
ABORT_IF(children[0]->shape().size() != 2 || children[1]->shape().size() != 2 ||
|
||||
(children.size() > 2 && children[2]->shape().size() > 2),
|
||||
"Gemm unexpectedly used for non-matrix inputs");
|
||||
E::tryGetMatMulAttributes<AffineNodeOp>(expr, transA, transB, scalar) || E::fail();
|
||||
if (transA) addAttribute(node, "transA", transA);
|
||||
if (transB) addAttribute(node, "transB", transB);
|
||||
if (scalar != 1.0f) addAttribute(node, "alpha", scalar);
|
||||
E::tryGetMatMulAttributes<AffineNodeOp>(expr, transA, transB, scalar) ||
|
||||
E::tryGetMatMulAttributes<DotNodeOp> (expr, transA, transB, scalar) || E::fail();
|
||||
/*if (transA) */ addAttribute(node, "transA", transA ? 1 : 0);
|
||||
/*if (transB) */ addAttribute(node, "transB", transB ? 1 : 0);
|
||||
/*if (scalar != 1.0f)*/ addAttribute(node, "alpha", scalar);
|
||||
//addAttribute(node, "beta", 0.0f);
|
||||
}
|
||||
else if (E::tryGetMatMulAttributes<DotNodeOp> (expr, transA, transB, scalar) ||
|
||||
E::tryGetMatMulAttributes<DotBatchedNodeOp>(expr, transA, transB, scalar)) {
|
||||
|
Loading…
Reference in New Issue
Block a user