mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 21:42:19 +03:00
Hypergraph output
This commit is contained in:
parent
c9a8238474
commit
40e8f2eca0
@ -109,6 +109,8 @@ def main(argv):
|
||||
n_best_out = None
|
||||
n_best_size = None
|
||||
n_best_distinct = False
|
||||
hg_ext = None
|
||||
hg_dir = None
|
||||
tmp_dir = '/tmp'
|
||||
xml_found = False
|
||||
xml_input = 'exclusive'
|
||||
@ -149,6 +151,15 @@ def main(argv):
|
||||
cmd = cmd[:i] + cmd[i + 4:]
|
||||
else:
|
||||
cmd = cmd[:i] + cmd[i + 3:]
|
||||
elif cmd[i] == '-output-search-graph-hypergraph':
|
||||
# cmd[i + 1] == true
|
||||
hg_ext = cmd[i + 2]
|
||||
if i + 3 < len(cmd) and cmd[i + 3][0] != '-':
|
||||
hg_dir = cmd[i + 3]
|
||||
cmd = cmd[:i] + cmd[i + 4:]
|
||||
else:
|
||||
hg_dir = 'hypergraph'
|
||||
cmd = cmd[:i] + cmd[i + 3:]
|
||||
elif cmd[i] == '-tmp':
|
||||
tmp_dir = cmd[i + 1]
|
||||
cmd = cmd[:i] + cmd[i + 2:]
|
||||
@ -251,6 +262,8 @@ def main(argv):
|
||||
sys.stderr.write('Batch size: {}\n'.format(batch_size))
|
||||
if n_best_out:
|
||||
sys.stderr.write('N-best list: {} ({}{})\n'.format(n_best_out, n_best_size, ', distinct' if n_best_distinct else ''))
|
||||
if hg_dir:
|
||||
sys.stderr.write('Hypergraph dir: {} ({})\n'.format(hg_dir, hg_ext))
|
||||
sys.stderr.write('Temp dir: {}\n'.format(work_dir))
|
||||
|
||||
# Accumulate seen lines
|
||||
@ -315,6 +328,11 @@ def main(argv):
|
||||
work_cmd.append(str(n_best_size))
|
||||
if n_best_distinct:
|
||||
work_cmd.append('distinct')
|
||||
if hg_dir:
|
||||
work_cmd.append('-output-search-graph-hypergraph')
|
||||
work_cmd.append('true')
|
||||
work_cmd.append(hg_ext)
|
||||
work_cmd.append(os.path.join(work_dir, 'hg.{}'.format(i)))
|
||||
in_file = os.path.join(work_dir, 'input.{}.xml'.format(i))
|
||||
out_file = os.path.join(work_dir, 'out.{}'.format(i))
|
||||
err_file = os.path.join(work_dir, 'err.{}'.format(i))
|
||||
@ -333,6 +351,15 @@ def main(argv):
|
||||
for line in open(os.path.join(work_dir, 'nbest.{}'.format(i)), 'r'):
|
||||
entry = line.partition(' ')
|
||||
out.write('{} {}'.format(int(entry[0]) + (i * batch_size), entry[2]))
|
||||
|
||||
# Gather hypergraphs
|
||||
if hg_dir:
|
||||
if not os.path.exists(hg_dir):
|
||||
os.mkdir(hg_dir)
|
||||
shutil.copy(os.path.join(work_dir, 'hg.0', 'weights'), os.path.join(hg_dir, 'weights'))
|
||||
for i in range(threads):
|
||||
for j in range(batch_size):
|
||||
shutil.copy(os.path.join(work_dir, 'hg.{}'.format(i), '{}.{}'.format(j, hg_ext)), os.path.join(hg_dir, '{}.{}'.format((i * batch_size) + j, hg_ext)))
|
||||
|
||||
# Gather stdout
|
||||
for i in range(threads):
|
||||
|
Loading…
Reference in New Issue
Block a user