diff --git a/scripts/onnx/marian_to_onnx.py b/scripts/onnx/marian_to_onnx.py index 6d7aed81..f015279f 100644 --- a/scripts/onnx/marian_to_onnx.py +++ b/scripts/onnx/marian_to_onnx.py @@ -45,7 +45,10 @@ _marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + " sys.path.append(_marian_root_path + "/../onnxconverter-common") from onnxconverter_common.onnx_fx import Graph from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty +from onnxconverter_common import optimize_onnx_graph import onnxruntime as _ort +from onnxruntime import quantization + def _ort_apply_model(model, inputs): # ORT execution is a callback so that Graph itself does not need to depend on ORT sess = _ort.InferenceSession(model.SerializeToString()) return sess.run(None, inputs) @@ -53,6 +56,22 @@ Graph.inference_runtime = _ort_apply_model Graph.opset = 11 +def _optimize_graph_in_place(graph: Graph): + # @TODO: This should really be methods on onnx_fx.Graph. + g = graph._oxml.graph + g_opt = optimize_onnx_graph( + onnx_nodes=g.node, # the onnx node list in onnx model. + nchw_inputs=None, # the name list of the inputs needed to be transposed as NCHW + inputs=g.input, # the model input + outputs=g.output, # the model output + initializers=g.initializer, # the model initializers + stop_initializers=None, # 'stop' optimization on these initializers + model_value_info=g.value_info, # the model value_info + model_name=g.name, # the internal name of model + target_opset=graph.opset) + graph._oxml.graph.CopyFrom(g_opt) + + def export_marian_model_components(marian_model_path: str, marian_vocab_paths: List[str], marian_executable_path: Optional[str]=None) -> Dict[str,Graph]: """ @@ -85,13 +104,33 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) } + # optimize the partial models in place, as Marian may not have used the most optimal way of expressing all operations + for graph_name in res.keys(): + _optimize_graph_in_place(res[graph_name]) # clean up after ourselves for output_path in output_paths: os.unlink(output_path) return res -def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int): +def quantize_models_in_place(partial_models: Dict[str,Graph], to_bits: int=8): + """ + Quantize the partial models in place. + + Args: + partial_models: models returned from export_marian_model_components() + to_bits: number of bits to quantize to, currently only supports 8 + """ + for graph_name in partial_models.keys(): # quantize each partial model + partial_models[graph_name]._oxml = quantization.quantize( + partial_models[graph_name]._oxml, + nbits=to_bits, + quantization_mode=quantization.QuantizationMode.IntegerOps, + symmetric_weight=True, + force_fusions=True) + + +def compose_model_components_with_greedy_search(partial_models: Dict[str,Graph], num_decoder_layers: int): """ Create an ONNX model that implements greedy search over the exported Marian pieces. @@ -102,6 +141,7 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], ONNX model that can be called as result_ids = greedy_search_fn(np.array(source_ids, dtype=np.int64), np.array([target_eos_id], dtype=np.int64))[0] """ + decoder_state_dim = num_decoder_layers * 2 # each decoder has two state variables # load our partial functions # ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here. def define_parameter_order(graph, inputs, outputs): @@ -116,12 +156,12 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], decode_first = define_parameter_order(partial_models["decode_first"], inputs=['data_1_posrange', 'encoder_context_0', 'data_0_mask'], outputs=['first_logits'] + - [f"first_decoder_state_{i}" for i in range(num_decoder_layers)]) + [f"first_decoder_state_{i}" for i in range(decoder_state_dim)]) decode_next = define_parameter_order(partial_models["decode_next"], inputs=['prev_word', 'data_1_posrange', 'encoder_context_0', 'data_0_mask'] + - [f"decoder_state_{i}" for i in range(num_decoder_layers)], + [f"decoder_state_{i}" for i in range(decoder_state_dim)], outputs=['next_logits'] + - [f"next_decoder_state_{i}" for i in range(num_decoder_layers)]) + [f"next_decoder_state_{i}" for i in range(decoder_state_dim)]) # create an ONNX graph that implements full greedy search # The greedy search is implemented via the @onnx_fx.Graph.trace decorator, which allows us to @@ -160,16 +200,18 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], eos_token = eos_id + 0 test_y_t = (y_t != eos_token) - @Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(num_decoder_layers)), 'y_t_o2'], - output_types=[_Ty.b, _Ty.i] + [_Ty.f] * 6 + [_Ty.i], - input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * num_decoder_layers) + @Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(decoder_state_dim)), 'y_t_o2'], + output_types=[_Ty.b, _Ty.i] + [_Ty.f] * decoder_state_dim + [_Ty.i], + input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * decoder_state_dim) def loop_body(iteration_count, condition, # these are not actually used inside y_t, - out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5): + out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5, + out_decoder_states_6, out_decoder_states_7, out_decoder_states_8, out_decoder_states_9, out_decoder_states_10, out_decoder_states_11): # @BUGBUG: Currently, we do not support variable number of arguments to the callable. # @TODO: We have the information from the type signature in Graph.trace(), so this should be possible. - assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change" - out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5] + assert decoder_state_dim == 12, "Currently, decoder layers other than 6 require a manual code change" + out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5, + out_decoder_states_6, out_decoder_states_7, out_decoder_states_8, out_decoder_states_9, out_decoder_states_10, out_decoder_states_11] """ Loop body follows the requirements of ONNX Loop: @@ -182,11 +224,11 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], Inputs: iteration_num (not used by our function) test_y_t: condition (not used as an input) - y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies + y_t, *out_decoder_states: N=(decoder_state_dim+1) loop-carried dependencies Outputs: test_y_t: condition, return True if there is more to decode - y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies (same as in the Inputs section) + y_t, *out_decoder_states: N=(decoder_state_dim+1) loop-carried dependencies (same as in the Inputs section) y_t: K=1 outputs """ pos = iteration_count + 1 @@ -209,6 +251,9 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], Y = ox.concat([ox.unsqueeze(y_t), y], axis=0) # note: y_t are rank-1 tensors, not scalars (ORT concat fails with scalars) return ox.squeeze(Y, axes=[1]) greedy_search.to_model() # this triggers the model tracing (which is lazy) + # optimize the final model as well + # @BUGBUG: This leads to a malformed or hanging model. + #_optimize_graph_in_place(greedy_search) return greedy_search diff --git a/scripts/onnx/marian_to_onnx_example.py b/scripts/onnx/marian_to_onnx_example.py index d0912bb3..d247fa8d 100644 --- a/scripts/onnx/marian_to_onnx_example.py +++ b/scripts/onnx/marian_to_onnx_example.py @@ -10,16 +10,21 @@ import os, sys import marian_to_onnx as mo # The following variables would normally be command-line arguments. -# We use constants here to keep it simple. Please just adjust these as needed. -my_dir = os.path.expanduser("~/") -marian_npz = my_dir + "model.npz.best-ce-mean-words.npz" # path to the Marian model to convert -num_decoder_layers = 6 # number of decoder layers -marian_vocs = [my_dir + "vocab_v1.wl"] * 2 # path to the vocabularies for source and target -onnx_model_path = my_dir + "model.npz.best-ce-mean-words.onnx" # resulting model gets written here +# We use constants here to keep it simple. They reflect an example use. You must adjust these. +my_dir = os.path.expanduser("~/young/wngt 2019/") +marian_npz = my_dir + "model.base.npz" # path to the Marian model to convert +num_decoder_layers = 6 # number of decoder layers +marian_vocs = [my_dir + "en-de.wl"] * 2 # path to the vocabularies for source and target +onnx_model_path = my_dir + "model.base.opt.onnx" # resulting model gets written here +quantize_to_bits = 8 # None for no quantization # export Marian model as multiple ONNX models partial_models = mo.export_marian_model_components(marian_npz, marian_vocs) +# quantize if desired +if quantize_to_bits: + mo.quantize_models_in_place(partial_models, to_bits=quantize_to_bits) + # use the ONNX models in a greedy-search # The result is a fully self-contained model that implements greedy search. onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_decoder_layers) @@ -28,7 +33,15 @@ onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_ onnx_model.save(onnx_model_path) # run a test sentence +w2is = [{ word.rstrip(): id for id, word in enumerate(open(voc_path, "r").readlines()) } for voc_path in marian_vocs] +i2ws = [{ id: tok for tok, id in w2i.items() } for w2i in w2is] +src_tokens = "▁Republican ▁leaders ▁justifie d ▁their ▁policy ▁by ▁the ▁need ▁to ▁combat ▁electoral ▁fraud ▁.".split() +src_ids = [w2is[0][tok] for tok in src_tokens] +print(src_tokens) +print(src_ids) Y = mo.apply_model(greedy_search_fn=onnx_model, - source_ids=[274, 35, 52, 791, 59, 4060, 6, 2688, 2, 7744, 9, 2128, 7, 2, 4695, 9, 950, 2561, 3, 0], - target_eos_id=0) + source_ids=src_ids + [w2is[0][""]], + target_eos_id=w2is[1][""]) print(Y.shape, Y) +tgt_tokens = [i2ws[1][y] for y in Y] +print(" ".join(tgt_tokens))