Merged PR 14593: bug fix in ONNX exporter

The number of layers was confused with the number of output states.
This commit is contained in:
Frank Seide 2020-08-06 03:35:53 +00:00
parent 83a0af231a
commit 00f10c2288
2 changed files with 78 additions and 20 deletions

View File

@ -45,7 +45,10 @@ _marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + "
sys.path.append(_marian_root_path + "/../onnxconverter-common")
from onnxconverter_common.onnx_fx import Graph
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
from onnxconverter_common import optimize_onnx_graph
import onnxruntime as _ort
from onnxruntime import quantization
def _ort_apply_model(model, inputs): # ORT execution is a callback so that Graph itself does not need to depend on ORT
sess = _ort.InferenceSession(model.SerializeToString())
return sess.run(None, inputs)
@ -53,6 +56,22 @@ Graph.inference_runtime = _ort_apply_model
Graph.opset = 11
def _optimize_graph_in_place(graph: Graph):
# @TODO: This should really be methods on onnx_fx.Graph.
g = graph._oxml.graph
g_opt = optimize_onnx_graph(
onnx_nodes=g.node, # the onnx node list in onnx model.
nchw_inputs=None, # the name list of the inputs needed to be transposed as NCHW
inputs=g.input, # the model input
outputs=g.output, # the model output
initializers=g.initializer, # the model initializers
stop_initializers=None, # 'stop' optimization on these initializers
model_value_info=g.value_info, # the model value_info
model_name=g.name, # the internal name of model
target_opset=graph.opset)
graph._oxml.graph.CopyFrom(g_opt)
def export_marian_model_components(marian_model_path: str, marian_vocab_paths: List[str],
marian_executable_path: Optional[str]=None) -> Dict[str,Graph]:
"""
@ -85,13 +104,33 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names
output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files
res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) }
# optimize the partial models in place, as Marian may not have used the most optimal way of expressing all operations
for graph_name in res.keys():
_optimize_graph_in_place(res[graph_name])
# clean up after ourselves
for output_path in output_paths:
os.unlink(output_path)
return res
def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
def quantize_models_in_place(partial_models: Dict[str,Graph], to_bits: int=8):
"""
Quantize the partial models in place.
Args:
partial_models: models returned from export_marian_model_components()
to_bits: number of bits to quantize to, currently only supports 8
"""
for graph_name in partial_models.keys(): # quantize each partial model
partial_models[graph_name]._oxml = quantization.quantize(
partial_models[graph_name]._oxml,
nbits=to_bits,
quantization_mode=quantization.QuantizationMode.IntegerOps,
symmetric_weight=True,
force_fusions=True)
def compose_model_components_with_greedy_search(partial_models: Dict[str,Graph], num_decoder_layers: int):
"""
Create an ONNX model that implements greedy search over the exported Marian pieces.
@ -102,6 +141,7 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
ONNX model that can be called as
result_ids = greedy_search_fn(np.array(source_ids, dtype=np.int64), np.array([target_eos_id], dtype=np.int64))[0]
"""
decoder_state_dim = num_decoder_layers * 2 # each decoder has two state variables
# load our partial functions
# ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here.
def define_parameter_order(graph, inputs, outputs):
@ -116,12 +156,12 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
decode_first = define_parameter_order(partial_models["decode_first"],
inputs=['data_1_posrange', 'encoder_context_0', 'data_0_mask'],
outputs=['first_logits'] +
[f"first_decoder_state_{i}" for i in range(num_decoder_layers)])
[f"first_decoder_state_{i}" for i in range(decoder_state_dim)])
decode_next = define_parameter_order(partial_models["decode_next"],
inputs=['prev_word', 'data_1_posrange', 'encoder_context_0', 'data_0_mask'] +
[f"decoder_state_{i}" for i in range(num_decoder_layers)],
[f"decoder_state_{i}" for i in range(decoder_state_dim)],
outputs=['next_logits'] +
[f"next_decoder_state_{i}" for i in range(num_decoder_layers)])
[f"next_decoder_state_{i}" for i in range(decoder_state_dim)])
# create an ONNX graph that implements full greedy search
# The greedy search is implemented via the @onnx_fx.Graph.trace decorator, which allows us to
@ -160,16 +200,18 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
eos_token = eos_id + 0
test_y_t = (y_t != eos_token)
@Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(num_decoder_layers)), 'y_t_o2'],
output_types=[_Ty.b, _Ty.i] + [_Ty.f] * 6 + [_Ty.i],
input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * num_decoder_layers)
@Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(decoder_state_dim)), 'y_t_o2'],
output_types=[_Ty.b, _Ty.i] + [_Ty.f] * decoder_state_dim + [_Ty.i],
input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * decoder_state_dim)
def loop_body(iteration_count, condition, # these are not actually used inside
y_t,
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5):
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5,
out_decoder_states_6, out_decoder_states_7, out_decoder_states_8, out_decoder_states_9, out_decoder_states_10, out_decoder_states_11):
# @BUGBUG: Currently, we do not support variable number of arguments to the callable.
# @TODO: We have the information from the type signature in Graph.trace(), so this should be possible.
assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change"
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5]
assert decoder_state_dim == 12, "Currently, decoder layers other than 6 require a manual code change"
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5,
out_decoder_states_6, out_decoder_states_7, out_decoder_states_8, out_decoder_states_9, out_decoder_states_10, out_decoder_states_11]
"""
Loop body follows the requirements of ONNX Loop:
@ -182,11 +224,11 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
Inputs:
iteration_num (not used by our function)
test_y_t: condition (not used as an input)
y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies
y_t, *out_decoder_states: N=(decoder_state_dim+1) loop-carried dependencies
Outputs:
test_y_t: condition, return True if there is more to decode
y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies (same as in the Inputs section)
y_t, *out_decoder_states: N=(decoder_state_dim+1) loop-carried dependencies (same as in the Inputs section)
y_t: K=1 outputs
"""
pos = iteration_count + 1
@ -209,6 +251,9 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
Y = ox.concat([ox.unsqueeze(y_t), y], axis=0) # note: y_t are rank-1 tensors, not scalars (ORT concat fails with scalars)
return ox.squeeze(Y, axes=[1])
greedy_search.to_model() # this triggers the model tracing (which is lazy)
# optimize the final model as well
# @BUGBUG: This leads to a malformed or hanging model.
#_optimize_graph_in_place(greedy_search)
return greedy_search

View File

@ -10,16 +10,21 @@ import os, sys
import marian_to_onnx as mo
# The following variables would normally be command-line arguments.
# We use constants here to keep it simple. Please just adjust these as needed.
my_dir = os.path.expanduser("~/")
marian_npz = my_dir + "model.npz.best-ce-mean-words.npz" # path to the Marian model to convert
num_decoder_layers = 6 # number of decoder layers
marian_vocs = [my_dir + "vocab_v1.wl"] * 2 # path to the vocabularies for source and target
onnx_model_path = my_dir + "model.npz.best-ce-mean-words.onnx" # resulting model gets written here
# We use constants here to keep it simple. They reflect an example use. You must adjust these.
my_dir = os.path.expanduser("~/young/wngt 2019/")
marian_npz = my_dir + "model.base.npz" # path to the Marian model to convert
num_decoder_layers = 6 # number of decoder layers
marian_vocs = [my_dir + "en-de.wl"] * 2 # path to the vocabularies for source and target
onnx_model_path = my_dir + "model.base.opt.onnx" # resulting model gets written here
quantize_to_bits = 8 # None for no quantization
# export Marian model as multiple ONNX models
partial_models = mo.export_marian_model_components(marian_npz, marian_vocs)
# quantize if desired
if quantize_to_bits:
mo.quantize_models_in_place(partial_models, to_bits=quantize_to_bits)
# use the ONNX models in a greedy-search
# The result is a fully self-contained model that implements greedy search.
onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_decoder_layers)
@ -28,7 +33,15 @@ onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_
onnx_model.save(onnx_model_path)
# run a test sentence
w2is = [{ word.rstrip(): id for id, word in enumerate(open(voc_path, "r").readlines()) } for voc_path in marian_vocs]
i2ws = [{ id: tok for tok, id in w2i.items() } for w2i in w2is]
src_tokens = "▁Republican ▁leaders ▁justifie d ▁their ▁policy ▁by ▁the ▁need ▁to ▁combat ▁electoral ▁fraud ▁.".split()
src_ids = [w2is[0][tok] for tok in src_tokens]
print(src_tokens)
print(src_ids)
Y = mo.apply_model(greedy_search_fn=onnx_model,
source_ids=[274, 35, 52, 791, 59, 4060, 6, 2688, 2, 7744, 9, 2128, 7, 2, 4695, 9, 950, 2561, 3, 0],
target_eos_id=0)
source_ids=src_ids + [w2is[0]["</s>"]],
target_eos_id=w2is[1]["</s>"])
print(Y.shape, Y)
tgt_tokens = [i2ws[1][y] for y in Y]
print(" ".join(tgt_tokens))