Merged PR 14334: full ONNX conversion script

This PR adds a full ONNX conversion script that exports a Marian model and wraps it in a greedy-search implemented in ONNX.
This commit is contained in:
Frank Seide 2020-07-24 17:23:05 +00:00
parent c3fb60cbcd
commit 435aa9505e
8 changed files with 259 additions and 87 deletions

View File

@ -207,6 +207,7 @@ if(USE_ONNX)
message(STATUS "Enabling experimental ONNX support")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ONNX")
set(EXT_LIBS ${EXT_LIBS} protobuf)
include_directories(${Protobuf_INCLUDE_DIRS})
endif()
# Find packages

View File

@ -1,81 +0,0 @@
import onnxruntime as ort
import numpy as np
import onnx
import os, sys, time
os.environ['OMP_NUM_THREADS'] = '1'
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 1
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def get_function(path, output_vars):
print("Reading ONNX function from", path)
#model = onnx.load(path)
#print("Done", flush=True)
#print(model)
ort_sess = ort.InferenceSession(path, sess_options)
output_defs = ort_sess.get_outputs()
for input in ort_sess.get_inputs():
print(" input: ", input.name, input.shape, input.type)
for output in output_defs:
print(" output: ", output.name, output.shape, output.type)
def invoke_model(**kwargs):
def to_numpy(val):
arr = np.array(val)
if arr.dtype == np.double:
arr = arr.astype(np.float32)
elif arr.dtype == np.int64:
arr = arr.astype(np.int32)
return arr
kwargs = { name: to_numpy(val) for name, val in kwargs.items() }
output_vals = ort_sess.run(None, kwargs)
output_dict = { output_def.name : output_val for output_val, output_def in zip(output_vals, output_defs) }
return [output_dict[output_var] for output_var in output_vars]
return invoke_model
id2word = { id : word.rstrip() for id, word in enumerate(open('c:/work/marian-dev/local/model/vocab_v1.wl', encoding='utf-8').readlines()) }
word2id = { word : id for id, word in id2word.items() }
unk_id = word2id["<unk>"]
model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans-nounk"
encode_source = get_function(model_path_prefix + '.encode_source.onnx',
['encoder_context_0'])
decode_first = get_function(model_path_prefix + '.decode_first.onnx',
['first_logits', 'first_decoder_state_0', 'first_decoder_state_1', 'first_decoder_state_2', 'first_decoder_state_3', 'first_decoder_state_4', 'first_decoder_state_5'])
decode_next = get_function(model_path_prefix + '.decode_next.onnx',
['next_logits', 'next_decoder_state_0', 'next_decoder_state_1', 'next_decoder_state_2', 'next_decoder_state_3', 'next_decoder_state_4', 'next_decoder_state_5'])
def greedy_decode(data_0):
if len(data_0) == 1: # special handling for the empty sentence, like Marian
return data_0
data_0_mask = [[[1.]]] * len(data_0)
data_0_index_range = [[[float(t)]] for t in range(len(data_0))]
#print(data_0, data_0_mask, data_0_index_range)
max_len = len(data_0) * 3
Y = []
encoder_context_0, *_ = encode_source(data_0=data_0, data_0_mask=data_0_mask, data_0_posrange=data_0_index_range)
logp, *out_decoder_states = decode_first(data_1_posrange=[[[float(0)]]],
encoder_context_0=encoder_context_0, data_0_mask=data_0_mask)
logp[:,:,:,unk_id] = -1e8 # suppress <unk>, like Marian
Y.append(np.argmax(logp[0][0]))
while Y[-1] != 0 and len(Y) < max_len:
logp, *out_decoder_states = decode_next(prev_word=[Y[-1]], data_1_posrange=[[[float(len(Y))]]],
encoder_context_0=encoder_context_0, data_0_mask=data_0_mask,
decoder_state_0=out_decoder_states[0], decoder_state_1=out_decoder_states[1],
decoder_state_2=out_decoder_states[2], decoder_state_3=out_decoder_states[3],
decoder_state_4=out_decoder_states[4], decoder_state_5=out_decoder_states[5])
logp[:,:,:,unk_id] = -1e8
Y.append(np.argmax(logp[0][0]))
return Y
start_time = time.time()
with open("C:/work/marian-dev/local/model/predictions.out-onnx-debug-sin-notrans-first100-d.tok", 'wt', encoding='utf-8') as out_f:
for line in open("C:/work/marian-dev/local/model/predictions.in-first100.tok", encoding='utf-8').readlines():
data = [word2id.get(w, unk_id) for w in (line.rstrip() + " </s>").split(' ') if w]
Y = greedy_decode(data)
print("input: ", ' '.join(id2word[x] for x in data))
print("output:", ' '.join(id2word[y] for y in Y))
print(' '.join(id2word[y] for y in Y[:-1]), file=out_f, flush=True) # strip </s> for output to file
print("--- %s seconds ---" % (time.time() - start_time))

View File

@ -0,0 +1,215 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Library for converting certain types of Marian models to a standalone ONNX model.
Because Marian and ONNX use very different philosophies, a conversion is not possible
for all possible Marian models. Specifically, currently we don't support recurrent
networks in the encoder.
This works by running a Marian decode for 2 output steps, and capturing pieces of
this graph that correspond to the encoder, the first decoding steps, and the second
decoding step. The graph of the second decoding step can be applied repeatedly in
order to decoder a variable-length sequence.
Prerequisites:
```
pip install onnxruntime
git clone https://github.com/microsoft/onnxconverter-common.git
```
"""
import os, sys, inspect, subprocess
from typing import List, Dict, Optional, Callable
# get the Marian root path
_marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + "/../.."
# we assume onnxconverter-common to be available next to the marian-dev repo; you must adjust this if needed
sys.path.append(_marian_root_path + "/../onnxconverter-common")
from onnxconverter_common.onnx_fx import Graph
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
import onnxruntime as _ort
def _ort_apply_model(model, inputs): # ORT execution is a callback so that Graph itself does not need to depend on ORT
sess = _ort.InferenceSession(model.SerializeToString())
return sess.run(None, inputs)
Graph.inference_runtime = _ort_apply_model
Graph.opset = 11
def export_marian_model_components(marian_model_path: str, marian_vocab_paths: List[str],
marian_executable_path: Optional[str]=None) -> Dict[str,Graph]:
"""
Export the Marian graph to a set of models.
Args:
marian_model_path: path to Marian model to convert
marian_vocab_paths: paths of vocab files (normally, this requires 2 entries, which may be identical)
marian_executable_path: path to Marian executable; will default to THIS_SCRIPT_PATH/../../build/marian
Returns:
Dict of ONNX Graph instances corresponding to pieces of Marian models.
"""
assert isinstance(marian_vocab_paths, list), "marian_vocab_paths must be a list of paths"
# default marian executable is found relative to location of this script (Linux/CMake only)
if marian_executable_path is None:
marian_executable_path = _marian_root_path + "/build/marian"
# partial models are written to /tmp
output_path_stem = "/tmp/" + os.path.basename(marian_model_path)
# exporting is done via invoking Marian via its command-line interface; models are written to tmp files
command = marian_executable_path
args = [
"convert",
"--from", marian_model_path,
"--vocabs", *marian_vocab_paths,
"--to", output_path_stem,
"--export-as", "onnx-encode"
]
subprocess.run([command] + args, check=True)
# load the tmp files into Python bytes objects
graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names
output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files
res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) }
# clean up after ourselves
for output_path in output_paths:
os.unlink(output_path)
return res
def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
"""
Create an ONNX model that implements greedy search over the exported Marian pieces.
Args:
partial_models: models returned from export_marian_model_components()
num_decoder_layers: must be specified, since it cannot be inferred from the model files presently (e.g. 6)
Returns:
ONNX model that can be called as
result_ids = greedy_search_fn(np.array(source_ids, dtype=np.int64), np.array([target_eos_id], dtype=np.int64))[0]
"""
# load our partial functions
# ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here.
def define_parameter_order(graph, inputs, outputs):
tmppath = "/tmp/tmpmodel.onnx"
graph.save(tmppath) # unfortunately, Graph.load() cannot load from bytes, so use a tmp file
graph = Graph.load(tmppath, inputs=inputs, outputs=outputs)
os.unlink(tmppath)
return graph
encode_source = define_parameter_order(partial_models["encode_source"],
inputs=['data_0', 'data_0_mask', 'data_0_posrange'], # define the order of arguments
outputs=['encoder_context_0'])
decode_first = define_parameter_order(partial_models["decode_first"],
inputs=['data_1_posrange', 'encoder_context_0', 'data_0_mask'],
outputs=['first_logits'] +
[f"first_decoder_state_{i}" for i in range(num_decoder_layers)])
decode_next = define_parameter_order(partial_models["decode_next"],
inputs=['prev_word', 'data_1_posrange', 'encoder_context_0', 'data_0_mask'] +
[f"decoder_state_{i}" for i in range(num_decoder_layers)],
outputs=['next_logits'] +
[f"next_decoder_state_{i}" for i in range(num_decoder_layers)])
# create an ONNX graph that implements full greedy search
# The greedy search is implemented via the @Graph.trace decorator, which allows us to
# author the greedy search in Python, similar to @CNTK.Function and PyTorch trace-based jit.
# The decorator executes greedy_search() below on a dummy input in order to generate an ONNX graph
# via invoking operators from the onnx.fx library.
# The partial functions exported from Marian are invoked (=inlined) by this.
# The result is a full ONNX graph that implements greedy search using the Marian model.
@Graph.trace(
input_types=[_Ty.I(shape=['N']), _Ty.I([1])],
output_types=[_Ty.I(shape=['T'])],
outputs="Y")
def greedy_search(X, eos_id):
"""
Args:
X: sequence of input tokens, including EOS symbol, as integer indices into the input vocabulary
eos_id: id of the EOS symbol in the output vocabulary
"""
ox = X.ox
data_0 = X
data_0_shape = data_0.shape()
data_0_mask = ox.constant_of_shape(data_0_shape, value=1.0)
seq_len = data_0_shape[-1]
data_0_index_range = ox.range([ox.constant(value=0), seq_len, ox.constant(value=1)]).cast(to=ox.float)
data_0_index_range = ox.unsqueeze(data_0_index_range, axes=[1, 2])
max_len = seq_len * 3
encoder_context_0 = encode_source(data_0=data_0, data_0_mask=data_0_mask,
data_0_posrange=data_0_index_range)
y_len_0 = ox.constant(value=0.0)
logp, *out_decoder_states = decode_first(data_1_posrange=y_len_0,
encoder_context_0=encoder_context_0, data_0_mask=data_0_mask)
y_t = logp[0, 0, 0].argmax(axis=-1, keepdims=True) # note: rank-1 tensor, not a scalar
eos_token = eos_id + 0
test_y_t = (y_t != eos_token)
@Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(num_decoder_layers)), 'y_t_o2'],
output_types=[_Ty.b, _Ty.i] + [_Ty.f] * 6 + [_Ty.i],
input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * num_decoder_layers)
def loop_body(iteration_count, condition, # these are not actually used inside
y_t,
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5):
# Currently, we do not support variable number of arguments to the callable.
# @TODO: We have the information from the type signature in Graph.trace(), so this should be possible.
assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change"
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5]
"""
Loop body follows the requirements of ONNX Loop:
"The graph run each iteration.
It has 2+N inputs: (iteration_num, condition, loop carried dependencies...).
It has 1+N+K outputs: (condition, loop carried dependencies..., scan_outputs...).
Each scan_output is created by concatenating the value of the specified output value at the end of each iteration of the loop.
It is an error if the dimensions or data type of these scan_outputs change across loop iterations."
Inputs:
iteration_num (not used by our function)
test_y_t: condition (not used as an input)
y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies
Outputs:
test_y_t: condition, return True if there is more to decode
y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies (same as in the Inputs section)
y_t: K=1 outputs
"""
pos = iteration_count + 1
data_1_posrange = pos.cast(to=1).unsqueeze(axes=[0, 1, 2])
logp, *out_decoder_states = decode_next(
prev_word=y_t, data_1_posrange=data_1_posrange,
encoder_context_0=encoder_context_0, data_0_mask=data_0_mask,
**{f"decoder_state_{i}": out_decoder_states[i] for i in range(len(out_decoder_states))})
y_t = logp[0, 0, 0].argmax(axis=-1, keepdims=True)
test_y_t = (y_t != eos_token)
return [test_y_t, y_t] + out_decoder_states + [y_t]
# "Final N loop carried dependency values then K scan_outputs"
ret_vals = ox.loop(max_len, test_y_t, loop_body,
inputs=[y_t] + out_decoder_states,
outputs=['gy_t_o', *[f"gods_{i}" for i in range(len(out_decoder_states))], 'greedy_out'])
y = ret_vals[-1] # scan_output
# we must prepend the very first token
Y = ox.concat([ox.unsqueeze(y_t), y], axis=0) # note: y_t are rank-1 tensors, not scalars (ORT concat fails with scalars)
return ox.squeeze(Y, axes=[1])
greedy_search.to_model() # this triggers the model tracing (which is lazy)
return greedy_search
def apply_model(greedy_search_fn: Callable, source_ids: List[int], target_eos_id: int) -> List[int]:
"""
Apply model to an input sequence, e.g. run translation.
Args:
greedy_search_fn: ONNX model created with combine_model_components_with_greedy_search()\
source_ids: list of source tokens, as indices into soure vocabulary, ending in EOS symbol
target_eos_id: id of EOS symbol in target vocabulary
Returns:
Result as list of ids into target vocabulary
"""
import numpy as np
Y = greedy_search_fn(
np.array(source_ids, dtype=np.int64),
np.array([target_eos_id], dtype=np.int64))[0]
return Y

View File

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Example program demonstrating how to convert a Marian model using the marian_to_onnx library
to a self-contained ONNX model that implements greedy search.
"""
import os, sys
import marian_to_onnx as mo
# The following variables would normally be command-line arguments.
# We use constants here to keep it simple. Please just adjust these as needed.
my_dir = os.path.expanduser("~/")
marian_npz = my_dir + "model.npz.best-ce-mean-words.npz" # path to the Marian model to convert
num_decoder_layers = 6 # number of decoder layers
marian_vocs = [my_dir + "vocab_v1.wl"] * 2 # path to the vocabularies for source and target
onnx_model_path = my_dir + "model.npz.best-ce-mean-words.onnx" # resulting model gets written here
# export Marian model as multiple ONNX models
partial_models = mo.export_marian_model_components(marian_npz, marian_vocs)
# use the ONNX models in a greedy-search
# The result is a fully self-contained model that implements greedy search.
onnx_model = mo.combine_model_components_with_greedy_search(partial_models, num_decoder_layers)
# save as ONNX file
onnx_model.save(onnx_model_path)
# run a test sentence
Y = mo.apply_model(greedy_search_fn=onnx_model,
source_ids=[274, 35, 52, 791, 59, 4060, 6, 2688, 2, 7744, 9, 2128, 7, 2, 4695, 9, 950, 2561, 3, 0],
target_eos_id=0)
print(Y.shape, Y)

View File

@ -18,9 +18,10 @@
#pragma warning(disable : 4100 4125 4127 4244 4267 4512 4456 4510 4610 4800)
#endif
#ifdef __GNUC__
#pragma GCC diagnostic ignored "-Wunused-variable" // note: GCC <6.0 ignores this when inside push/pop
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
#pragma GCC diagnostic ignored "-Wsuggest-override"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#define AuxillaryParseTableField AuxiliaryParseTableField // in protobuf 3.12, the generated source has a spelling error

View File

@ -13,7 +13,7 @@
#error incompatible with your Protocol Buffer headers. Please update
#error your headers.
#endif
#if 3012000 < PROTOBUF_MIN_PROTOC_VERSION
#if 3012003 < PROTOBUF_MIN_PROTOC_VERSION
#error This file was generated by an older version of protoc which is
#error incompatible with your Protocol Buffer headers. Please
#error regenerate this file with a newer version of protoc.

View File

@ -70,15 +70,17 @@ int main(int argc, char** argv) {
// added a flag if the weights needs to be packed or not
graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32);
}
#ifdef USE_ONNX
else if (exportAs == "onnx-encode") {
#ifdef USE_ONNX
auto graph = New<ExpressionGraphONNXExporter>();
load(graph);
auto modelOptions = New<Options>(config)->with("vocabs", vocabPaths, "inference", true);
graph->exportToONNX(modelTo, modelOptions, vocabPaths);
}
#else
ABORT("--export-as onnx-encode requires Marian to be built with USE_ONNX=ON");
#endif // USE_ONNX
}
else
ABORT("Unknown --export-as value: {}", exportAs);

View File

@ -622,7 +622,7 @@ namespace marian {
model.set_ir_version(IR_VERSION);
model.set_producer_name(producerName);
model.mutable_graph()->CopyFrom(graph);
#define OPSET_IMPORT_VERSION 9 // 9 is needed for some newer ops
#define OPSET_IMPORT_VERSION 11
model.add_opset_import()->set_version(OPSET_IMPORT_VERSION);
return model;
}
@ -833,7 +833,7 @@ namespace marian {
LOG(info, s);
}
// axis attribute
size_t axis;
size_t axis{};
std::vector<size_t> axes;
if (E::tryGetAxisAttribute<ConcatenateNodeOp>(expr, axis)// ||
//E::tryGetAxisAttribute<SelectNodeOp>(expr, axis)