Merged PR 14349: edited some comments in ONNX converter

edited some comments in ONNX converter
This commit is contained in:
Frank Seide 2020-07-24 18:12:28 +00:00
parent 435aa9505e
commit 38bd181937
2 changed files with 29 additions and 13 deletions

View File

@ -6,18 +6,33 @@ Library for converting certain types of Marian models to a standalone ONNX model
Because Marian and ONNX use very different philosophies, a conversion is not possible
for all possible Marian models. Specifically, currently we don't support recurrent
networks in the encoder.
networks in the encoder, and we can only decode with greedy search (not beam search).
This works by running a Marian decode for 2 output steps, and capturing pieces of
this graph that correspond to the encoder, the first decoding steps, and the second
decoding step. The graph of the second decoding step can be applied repeatedly in
This works by running a Marian decode for 2 output steps, and capturing three pieces of
Marian's internal graph that correspond to the encoder, the first decoding steps, and the
second decoding step. The graph of the second decoding step can be applied repeatedly in
order to decoder a variable-length sequence.
The three pieces are then composed with a greedy-search implementation, which is realized
directly via ONNX operators. This is facilitated by the onnx_fx library. As of this writing,
onnx_fx is still in experimental stage, and is not yet included in Release branches of
the onnxconverter-common distribution. Hence, you must use the latest master branch, not
the release.
The code below assumes that the onnxconverter_common repo is cloned next to the marian-dev
repo, and that you use the standard CMake build process on Linux. If not, please make sure
that the onnxconverter-common repo is included in PYTHONPATH, and you may need to pass the
binary path of Marian to export_marian_model_components() explicitly.
Prerequisites:
```
pip install onnxruntime
git clone https://github.com/microsoft/onnxconverter-common.git
```
You will also need to compile Marian with -DUSE_ONNX=ON.
Known issue: If the number of decoder layers is not 6, you need to manually adjust one
line of code in loop_body() below.
"""
import os, sys, inspect, subprocess
@ -26,7 +41,7 @@ from typing import List, Dict, Optional, Callable
# get the Marian root path
_marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + "/../.."
# we assume onnxconverter-common to be available next to the marian-dev repo; you must adjust this if needed
# we assume onnxconverter-common to be available next to the marian-dev repo; you may need to adjust this
sys.path.append(_marian_root_path + "/../onnxconverter-common")
from onnxconverter_common.onnx_fx import Graph
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
@ -48,7 +63,7 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
marian_vocab_paths: paths of vocab files (normally, this requires 2 entries, which may be identical)
marian_executable_path: path to Marian executable; will default to THIS_SCRIPT_PATH/../../build/marian
Returns:
Dict of ONNX Graph instances corresponding to pieces of Marian models.
Dict of onnx_fx.Graph instances corresponding to pieces of the Marian model.
"""
assert isinstance(marian_vocab_paths, list), "marian_vocab_paths must be a list of paths"
# default marian executable is found relative to location of this script (Linux/CMake only)
@ -66,7 +81,7 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
"--export-as", "onnx-encode"
]
subprocess.run([command] + args, check=True)
# load the tmp files into Python bytes objects
# load the tmp files into onnx_fx.Graph objects
graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names
output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files
res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) }
@ -76,7 +91,7 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
return res
def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
"""
Create an ONNX model that implements greedy search over the exported Marian pieces.
@ -91,7 +106,7 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
# ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here.
def define_parameter_order(graph, inputs, outputs):
tmppath = "/tmp/tmpmodel.onnx"
graph.save(tmppath) # unfortunately, Graph.load() cannot load from bytes, so use a tmp file
graph.save(tmppath) # unfortunately, Graph.load() cannot load from another Graph, so use a tmp file
graph = Graph.load(tmppath, inputs=inputs, outputs=outputs)
os.unlink(tmppath)
return graph
@ -109,7 +124,7 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
[f"next_decoder_state_{i}" for i in range(num_decoder_layers)])
# create an ONNX graph that implements full greedy search
# The greedy search is implemented via the @Graph.trace decorator, which allows us to
# The greedy search is implemented via the @onnx_fx.Graph.trace decorator, which allows us to
# author the greedy search in Python, similar to @CNTK.Function and PyTorch trace-based jit.
# The decorator executes greedy_search() below on a dummy input in order to generate an ONNX graph
# via invoking operators from the onnx.fx library.
@ -151,7 +166,7 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
def loop_body(iteration_count, condition, # these are not actually used inside
y_t,
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5):
# Currently, we do not support variable number of arguments to the callable.
# @BUGBUG: Currently, we do not support variable number of arguments to the callable.
# @TODO: We have the information from the type signature in Graph.trace(), so this should be possible.
assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change"
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5]
@ -197,9 +212,10 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
return greedy_search
def apply_model(greedy_search_fn: Callable, source_ids: List[int], target_eos_id: int) -> List[int]:
def apply_model(greedy_search_fn: Graph, source_ids: List[int], target_eos_id: int) -> List[int]:
"""
Apply model to an input sequence, e.g. run translation.
This function is meant for quick testing, and as an example of how to invoke the final graph.
Args:
greedy_search_fn: ONNX model created with combine_model_components_with_greedy_search()\

View File

@ -22,7 +22,7 @@ partial_models = mo.export_marian_model_components(marian_npz, marian_vocs)
# use the ONNX models in a greedy-search
# The result is a fully self-contained model that implements greedy search.
onnx_model = mo.combine_model_components_with_greedy_search(partial_models, num_decoder_layers)
onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_decoder_layers)
# save as ONNX file
onnx_model.save(onnx_model_path)