mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 01:31:46 +03:00
Merged PR 14349: edited some comments in ONNX converter
edited some comments in ONNX converter
This commit is contained in:
parent
435aa9505e
commit
38bd181937
@ -6,18 +6,33 @@ Library for converting certain types of Marian models to a standalone ONNX model
|
||||
|
||||
Because Marian and ONNX use very different philosophies, a conversion is not possible
|
||||
for all possible Marian models. Specifically, currently we don't support recurrent
|
||||
networks in the encoder.
|
||||
networks in the encoder, and we can only decode with greedy search (not beam search).
|
||||
|
||||
This works by running a Marian decode for 2 output steps, and capturing pieces of
|
||||
this graph that correspond to the encoder, the first decoding steps, and the second
|
||||
decoding step. The graph of the second decoding step can be applied repeatedly in
|
||||
This works by running a Marian decode for 2 output steps, and capturing three pieces of
|
||||
Marian's internal graph that correspond to the encoder, the first decoding steps, and the
|
||||
second decoding step. The graph of the second decoding step can be applied repeatedly in
|
||||
order to decoder a variable-length sequence.
|
||||
|
||||
The three pieces are then composed with a greedy-search implementation, which is realized
|
||||
directly via ONNX operators. This is facilitated by the onnx_fx library. As of this writing,
|
||||
onnx_fx is still in experimental stage, and is not yet included in Release branches of
|
||||
the onnxconverter-common distribution. Hence, you must use the latest master branch, not
|
||||
the release.
|
||||
|
||||
The code below assumes that the onnxconverter_common repo is cloned next to the marian-dev
|
||||
repo, and that you use the standard CMake build process on Linux. If not, please make sure
|
||||
that the onnxconverter-common repo is included in PYTHONPATH, and you may need to pass the
|
||||
binary path of Marian to export_marian_model_components() explicitly.
|
||||
|
||||
Prerequisites:
|
||||
```
|
||||
pip install onnxruntime
|
||||
git clone https://github.com/microsoft/onnxconverter-common.git
|
||||
```
|
||||
You will also need to compile Marian with -DUSE_ONNX=ON.
|
||||
|
||||
Known issue: If the number of decoder layers is not 6, you need to manually adjust one
|
||||
line of code in loop_body() below.
|
||||
"""
|
||||
|
||||
import os, sys, inspect, subprocess
|
||||
@ -26,7 +41,7 @@ from typing import List, Dict, Optional, Callable
|
||||
# get the Marian root path
|
||||
_marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + "/../.."
|
||||
|
||||
# we assume onnxconverter-common to be available next to the marian-dev repo; you must adjust this if needed
|
||||
# we assume onnxconverter-common to be available next to the marian-dev repo; you may need to adjust this
|
||||
sys.path.append(_marian_root_path + "/../onnxconverter-common")
|
||||
from onnxconverter_common.onnx_fx import Graph
|
||||
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
|
||||
@ -48,7 +63,7 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
|
||||
marian_vocab_paths: paths of vocab files (normally, this requires 2 entries, which may be identical)
|
||||
marian_executable_path: path to Marian executable; will default to THIS_SCRIPT_PATH/../../build/marian
|
||||
Returns:
|
||||
Dict of ONNX Graph instances corresponding to pieces of Marian models.
|
||||
Dict of onnx_fx.Graph instances corresponding to pieces of the Marian model.
|
||||
"""
|
||||
assert isinstance(marian_vocab_paths, list), "marian_vocab_paths must be a list of paths"
|
||||
# default marian executable is found relative to location of this script (Linux/CMake only)
|
||||
@ -66,7 +81,7 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
|
||||
"--export-as", "onnx-encode"
|
||||
]
|
||||
subprocess.run([command] + args, check=True)
|
||||
# load the tmp files into Python bytes objects
|
||||
# load the tmp files into onnx_fx.Graph objects
|
||||
graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names
|
||||
output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files
|
||||
res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) }
|
||||
@ -76,7 +91,7 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
|
||||
return res
|
||||
|
||||
|
||||
def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
|
||||
def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
|
||||
"""
|
||||
Create an ONNX model that implements greedy search over the exported Marian pieces.
|
||||
|
||||
@ -91,7 +106,7 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
||||
# ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here.
|
||||
def define_parameter_order(graph, inputs, outputs):
|
||||
tmppath = "/tmp/tmpmodel.onnx"
|
||||
graph.save(tmppath) # unfortunately, Graph.load() cannot load from bytes, so use a tmp file
|
||||
graph.save(tmppath) # unfortunately, Graph.load() cannot load from another Graph, so use a tmp file
|
||||
graph = Graph.load(tmppath, inputs=inputs, outputs=outputs)
|
||||
os.unlink(tmppath)
|
||||
return graph
|
||||
@ -109,7 +124,7 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
||||
[f"next_decoder_state_{i}" for i in range(num_decoder_layers)])
|
||||
|
||||
# create an ONNX graph that implements full greedy search
|
||||
# The greedy search is implemented via the @Graph.trace decorator, which allows us to
|
||||
# The greedy search is implemented via the @onnx_fx.Graph.trace decorator, which allows us to
|
||||
# author the greedy search in Python, similar to @CNTK.Function and PyTorch trace-based jit.
|
||||
# The decorator executes greedy_search() below on a dummy input in order to generate an ONNX graph
|
||||
# via invoking operators from the onnx.fx library.
|
||||
@ -151,7 +166,7 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
||||
def loop_body(iteration_count, condition, # these are not actually used inside
|
||||
y_t,
|
||||
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5):
|
||||
# Currently, we do not support variable number of arguments to the callable.
|
||||
# @BUGBUG: Currently, we do not support variable number of arguments to the callable.
|
||||
# @TODO: We have the information from the type signature in Graph.trace(), so this should be possible.
|
||||
assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change"
|
||||
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5]
|
||||
@ -197,9 +212,10 @@ def combine_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
||||
return greedy_search
|
||||
|
||||
|
||||
def apply_model(greedy_search_fn: Callable, source_ids: List[int], target_eos_id: int) -> List[int]:
|
||||
def apply_model(greedy_search_fn: Graph, source_ids: List[int], target_eos_id: int) -> List[int]:
|
||||
"""
|
||||
Apply model to an input sequence, e.g. run translation.
|
||||
This function is meant for quick testing, and as an example of how to invoke the final graph.
|
||||
|
||||
Args:
|
||||
greedy_search_fn: ONNX model created with combine_model_components_with_greedy_search()\
|
||||
|
@ -22,7 +22,7 @@ partial_models = mo.export_marian_model_components(marian_npz, marian_vocs)
|
||||
|
||||
# use the ONNX models in a greedy-search
|
||||
# The result is a fully self-contained model that implements greedy search.
|
||||
onnx_model = mo.combine_model_components_with_greedy_search(partial_models, num_decoder_layers)
|
||||
onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_decoder_layers)
|
||||
|
||||
# save as ONNX file
|
||||
onnx_model.save(onnx_model_path)
|
||||
|
Loading…
Reference in New Issue
Block a user