mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Merged PR 14593: bug fix in ONNX exporter
The number of layers was confused with the number of output states.
This commit is contained in:
parent
83a0af231a
commit
00f10c2288
@ -45,7 +45,10 @@ _marian_root_path = os.path.dirname(inspect.getfile(inspect.currentframe())) + "
|
|||||||
sys.path.append(_marian_root_path + "/../onnxconverter-common")
|
sys.path.append(_marian_root_path + "/../onnxconverter-common")
|
||||||
from onnxconverter_common.onnx_fx import Graph
|
from onnxconverter_common.onnx_fx import Graph
|
||||||
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
|
from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty
|
||||||
|
from onnxconverter_common import optimize_onnx_graph
|
||||||
import onnxruntime as _ort
|
import onnxruntime as _ort
|
||||||
|
from onnxruntime import quantization
|
||||||
|
|
||||||
def _ort_apply_model(model, inputs): # ORT execution is a callback so that Graph itself does not need to depend on ORT
|
def _ort_apply_model(model, inputs): # ORT execution is a callback so that Graph itself does not need to depend on ORT
|
||||||
sess = _ort.InferenceSession(model.SerializeToString())
|
sess = _ort.InferenceSession(model.SerializeToString())
|
||||||
return sess.run(None, inputs)
|
return sess.run(None, inputs)
|
||||||
@ -53,6 +56,22 @@ Graph.inference_runtime = _ort_apply_model
|
|||||||
Graph.opset = 11
|
Graph.opset = 11
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_graph_in_place(graph: Graph):
|
||||||
|
# @TODO: This should really be methods on onnx_fx.Graph.
|
||||||
|
g = graph._oxml.graph
|
||||||
|
g_opt = optimize_onnx_graph(
|
||||||
|
onnx_nodes=g.node, # the onnx node list in onnx model.
|
||||||
|
nchw_inputs=None, # the name list of the inputs needed to be transposed as NCHW
|
||||||
|
inputs=g.input, # the model input
|
||||||
|
outputs=g.output, # the model output
|
||||||
|
initializers=g.initializer, # the model initializers
|
||||||
|
stop_initializers=None, # 'stop' optimization on these initializers
|
||||||
|
model_value_info=g.value_info, # the model value_info
|
||||||
|
model_name=g.name, # the internal name of model
|
||||||
|
target_opset=graph.opset)
|
||||||
|
graph._oxml.graph.CopyFrom(g_opt)
|
||||||
|
|
||||||
|
|
||||||
def export_marian_model_components(marian_model_path: str, marian_vocab_paths: List[str],
|
def export_marian_model_components(marian_model_path: str, marian_vocab_paths: List[str],
|
||||||
marian_executable_path: Optional[str]=None) -> Dict[str,Graph]:
|
marian_executable_path: Optional[str]=None) -> Dict[str,Graph]:
|
||||||
"""
|
"""
|
||||||
@ -85,13 +104,33 @@ def export_marian_model_components(marian_model_path: str, marian_vocab_paths: L
|
|||||||
graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names
|
graph_names = ["encode_source", "decode_first", "decode_next"] # Marian generates graphs with these names
|
||||||
output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files
|
output_paths = [output_path_stem + "." + graph_name + ".onnx" for graph_name in graph_names] # form pathnames under which Marian wrote the files
|
||||||
res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) }
|
res = { graph_name: Graph.load(output_path) for graph_name, output_path in zip(graph_names, output_paths) }
|
||||||
|
# optimize the partial models in place, as Marian may not have used the most optimal way of expressing all operations
|
||||||
|
for graph_name in res.keys():
|
||||||
|
_optimize_graph_in_place(res[graph_name])
|
||||||
# clean up after ourselves
|
# clean up after ourselves
|
||||||
for output_path in output_paths:
|
for output_path in output_paths:
|
||||||
os.unlink(output_path)
|
os.unlink(output_path)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes], num_decoder_layers: int):
|
def quantize_models_in_place(partial_models: Dict[str,Graph], to_bits: int=8):
|
||||||
|
"""
|
||||||
|
Quantize the partial models in place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
partial_models: models returned from export_marian_model_components()
|
||||||
|
to_bits: number of bits to quantize to, currently only supports 8
|
||||||
|
"""
|
||||||
|
for graph_name in partial_models.keys(): # quantize each partial model
|
||||||
|
partial_models[graph_name]._oxml = quantization.quantize(
|
||||||
|
partial_models[graph_name]._oxml,
|
||||||
|
nbits=to_bits,
|
||||||
|
quantization_mode=quantization.QuantizationMode.IntegerOps,
|
||||||
|
symmetric_weight=True,
|
||||||
|
force_fusions=True)
|
||||||
|
|
||||||
|
|
||||||
|
def compose_model_components_with_greedy_search(partial_models: Dict[str,Graph], num_decoder_layers: int):
|
||||||
"""
|
"""
|
||||||
Create an ONNX model that implements greedy search over the exported Marian pieces.
|
Create an ONNX model that implements greedy search over the exported Marian pieces.
|
||||||
|
|
||||||
@ -102,6 +141,7 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
|||||||
ONNX model that can be called as
|
ONNX model that can be called as
|
||||||
result_ids = greedy_search_fn(np.array(source_ids, dtype=np.int64), np.array([target_eos_id], dtype=np.int64))[0]
|
result_ids = greedy_search_fn(np.array(source_ids, dtype=np.int64), np.array([target_eos_id], dtype=np.int64))[0]
|
||||||
"""
|
"""
|
||||||
|
decoder_state_dim = num_decoder_layers * 2 # each decoder has two state variables
|
||||||
# load our partial functions
|
# load our partial functions
|
||||||
# ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here.
|
# ONNX graph inputs and outputs are named but not ordered. Therefore, we must define the parameter order here.
|
||||||
def define_parameter_order(graph, inputs, outputs):
|
def define_parameter_order(graph, inputs, outputs):
|
||||||
@ -116,12 +156,12 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
|||||||
decode_first = define_parameter_order(partial_models["decode_first"],
|
decode_first = define_parameter_order(partial_models["decode_first"],
|
||||||
inputs=['data_1_posrange', 'encoder_context_0', 'data_0_mask'],
|
inputs=['data_1_posrange', 'encoder_context_0', 'data_0_mask'],
|
||||||
outputs=['first_logits'] +
|
outputs=['first_logits'] +
|
||||||
[f"first_decoder_state_{i}" for i in range(num_decoder_layers)])
|
[f"first_decoder_state_{i}" for i in range(decoder_state_dim)])
|
||||||
decode_next = define_parameter_order(partial_models["decode_next"],
|
decode_next = define_parameter_order(partial_models["decode_next"],
|
||||||
inputs=['prev_word', 'data_1_posrange', 'encoder_context_0', 'data_0_mask'] +
|
inputs=['prev_word', 'data_1_posrange', 'encoder_context_0', 'data_0_mask'] +
|
||||||
[f"decoder_state_{i}" for i in range(num_decoder_layers)],
|
[f"decoder_state_{i}" for i in range(decoder_state_dim)],
|
||||||
outputs=['next_logits'] +
|
outputs=['next_logits'] +
|
||||||
[f"next_decoder_state_{i}" for i in range(num_decoder_layers)])
|
[f"next_decoder_state_{i}" for i in range(decoder_state_dim)])
|
||||||
|
|
||||||
# create an ONNX graph that implements full greedy search
|
# create an ONNX graph that implements full greedy search
|
||||||
# The greedy search is implemented via the @onnx_fx.Graph.trace decorator, which allows us to
|
# The greedy search is implemented via the @onnx_fx.Graph.trace decorator, which allows us to
|
||||||
@ -160,16 +200,18 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
|||||||
eos_token = eos_id + 0
|
eos_token = eos_id + 0
|
||||||
test_y_t = (y_t != eos_token)
|
test_y_t = (y_t != eos_token)
|
||||||
|
|
||||||
@Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(num_decoder_layers)), 'y_t_o2'],
|
@Graph.trace(outputs=['ty_t', 'y_t_o', *(f'ods_{i}' for i in range(decoder_state_dim)), 'y_t_o2'],
|
||||||
output_types=[_Ty.b, _Ty.i] + [_Ty.f] * 6 + [_Ty.i],
|
output_types=[_Ty.b, _Ty.i] + [_Ty.f] * decoder_state_dim + [_Ty.i],
|
||||||
input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * num_decoder_layers)
|
input_types=[_Ty.I([1]), _Ty.b, _Ty.i] + [_Ty.f] * decoder_state_dim)
|
||||||
def loop_body(iteration_count, condition, # these are not actually used inside
|
def loop_body(iteration_count, condition, # these are not actually used inside
|
||||||
y_t,
|
y_t,
|
||||||
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5):
|
out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5,
|
||||||
|
out_decoder_states_6, out_decoder_states_7, out_decoder_states_8, out_decoder_states_9, out_decoder_states_10, out_decoder_states_11):
|
||||||
# @BUGBUG: Currently, we do not support variable number of arguments to the callable.
|
# @BUGBUG: Currently, we do not support variable number of arguments to the callable.
|
||||||
# @TODO: We have the information from the type signature in Graph.trace(), so this should be possible.
|
# @TODO: We have the information from the type signature in Graph.trace(), so this should be possible.
|
||||||
assert num_decoder_layers == 6, "Currently, decoder layers other than 6 require a manual code change"
|
assert decoder_state_dim == 12, "Currently, decoder layers other than 6 require a manual code change"
|
||||||
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5]
|
out_decoder_states = [out_decoder_states_0, out_decoder_states_1, out_decoder_states_2, out_decoder_states_3, out_decoder_states_4, out_decoder_states_5,
|
||||||
|
out_decoder_states_6, out_decoder_states_7, out_decoder_states_8, out_decoder_states_9, out_decoder_states_10, out_decoder_states_11]
|
||||||
"""
|
"""
|
||||||
Loop body follows the requirements of ONNX Loop:
|
Loop body follows the requirements of ONNX Loop:
|
||||||
|
|
||||||
@ -182,11 +224,11 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
|||||||
Inputs:
|
Inputs:
|
||||||
iteration_num (not used by our function)
|
iteration_num (not used by our function)
|
||||||
test_y_t: condition (not used as an input)
|
test_y_t: condition (not used as an input)
|
||||||
y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies
|
y_t, *out_decoder_states: N=(decoder_state_dim+1) loop-carried dependencies
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
test_y_t: condition, return True if there is more to decode
|
test_y_t: condition, return True if there is more to decode
|
||||||
y_t, *out_decoder_states: N=(num_decoder_layers+1) loop-carried dependencies (same as in the Inputs section)
|
y_t, *out_decoder_states: N=(decoder_state_dim+1) loop-carried dependencies (same as in the Inputs section)
|
||||||
y_t: K=1 outputs
|
y_t: K=1 outputs
|
||||||
"""
|
"""
|
||||||
pos = iteration_count + 1
|
pos = iteration_count + 1
|
||||||
@ -209,6 +251,9 @@ def compose_model_components_with_greedy_search(partial_models: Dict[str,bytes],
|
|||||||
Y = ox.concat([ox.unsqueeze(y_t), y], axis=0) # note: y_t are rank-1 tensors, not scalars (ORT concat fails with scalars)
|
Y = ox.concat([ox.unsqueeze(y_t), y], axis=0) # note: y_t are rank-1 tensors, not scalars (ORT concat fails with scalars)
|
||||||
return ox.squeeze(Y, axes=[1])
|
return ox.squeeze(Y, axes=[1])
|
||||||
greedy_search.to_model() # this triggers the model tracing (which is lazy)
|
greedy_search.to_model() # this triggers the model tracing (which is lazy)
|
||||||
|
# optimize the final model as well
|
||||||
|
# @BUGBUG: This leads to a malformed or hanging model.
|
||||||
|
#_optimize_graph_in_place(greedy_search)
|
||||||
return greedy_search
|
return greedy_search
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,16 +10,21 @@ import os, sys
|
|||||||
import marian_to_onnx as mo
|
import marian_to_onnx as mo
|
||||||
|
|
||||||
# The following variables would normally be command-line arguments.
|
# The following variables would normally be command-line arguments.
|
||||||
# We use constants here to keep it simple. Please just adjust these as needed.
|
# We use constants here to keep it simple. They reflect an example use. You must adjust these.
|
||||||
my_dir = os.path.expanduser("~/")
|
my_dir = os.path.expanduser("~/young/wngt 2019/")
|
||||||
marian_npz = my_dir + "model.npz.best-ce-mean-words.npz" # path to the Marian model to convert
|
marian_npz = my_dir + "model.base.npz" # path to the Marian model to convert
|
||||||
num_decoder_layers = 6 # number of decoder layers
|
num_decoder_layers = 6 # number of decoder layers
|
||||||
marian_vocs = [my_dir + "vocab_v1.wl"] * 2 # path to the vocabularies for source and target
|
marian_vocs = [my_dir + "en-de.wl"] * 2 # path to the vocabularies for source and target
|
||||||
onnx_model_path = my_dir + "model.npz.best-ce-mean-words.onnx" # resulting model gets written here
|
onnx_model_path = my_dir + "model.base.opt.onnx" # resulting model gets written here
|
||||||
|
quantize_to_bits = 8 # None for no quantization
|
||||||
|
|
||||||
# export Marian model as multiple ONNX models
|
# export Marian model as multiple ONNX models
|
||||||
partial_models = mo.export_marian_model_components(marian_npz, marian_vocs)
|
partial_models = mo.export_marian_model_components(marian_npz, marian_vocs)
|
||||||
|
|
||||||
|
# quantize if desired
|
||||||
|
if quantize_to_bits:
|
||||||
|
mo.quantize_models_in_place(partial_models, to_bits=quantize_to_bits)
|
||||||
|
|
||||||
# use the ONNX models in a greedy-search
|
# use the ONNX models in a greedy-search
|
||||||
# The result is a fully self-contained model that implements greedy search.
|
# The result is a fully self-contained model that implements greedy search.
|
||||||
onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_decoder_layers)
|
onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_decoder_layers)
|
||||||
@ -28,7 +33,15 @@ onnx_model = mo.compose_model_components_with_greedy_search(partial_models, num_
|
|||||||
onnx_model.save(onnx_model_path)
|
onnx_model.save(onnx_model_path)
|
||||||
|
|
||||||
# run a test sentence
|
# run a test sentence
|
||||||
|
w2is = [{ word.rstrip(): id for id, word in enumerate(open(voc_path, "r").readlines()) } for voc_path in marian_vocs]
|
||||||
|
i2ws = [{ id: tok for tok, id in w2i.items() } for w2i in w2is]
|
||||||
|
src_tokens = "▁Republican ▁leaders ▁justifie d ▁their ▁policy ▁by ▁the ▁need ▁to ▁combat ▁electoral ▁fraud ▁.".split()
|
||||||
|
src_ids = [w2is[0][tok] for tok in src_tokens]
|
||||||
|
print(src_tokens)
|
||||||
|
print(src_ids)
|
||||||
Y = mo.apply_model(greedy_search_fn=onnx_model,
|
Y = mo.apply_model(greedy_search_fn=onnx_model,
|
||||||
source_ids=[274, 35, 52, 791, 59, 4060, 6, 2688, 2, 7744, 9, 2128, 7, 2, 4695, 9, 950, 2561, 3, 0],
|
source_ids=src_ids + [w2is[0]["</s>"]],
|
||||||
target_eos_id=0)
|
target_eos_id=w2is[1]["</s>"])
|
||||||
print(Y.shape, Y)
|
print(Y.shape, Y)
|
||||||
|
tgt_tokens = [i2ws[1][y] for y in Y]
|
||||||
|
print(" ".join(tgt_tokens))
|
||||||
|
Loading…
Reference in New Issue
Block a user