Hypergraph output

This commit is contained in:
mjdenkowski 2014-11-03 09:16:12 -05:00
parent c9a8238474
commit 40e8f2eca0

View File

@ -109,6 +109,8 @@ def main(argv):
n_best_out = None
n_best_size = None
n_best_distinct = False
hg_ext = None
hg_dir = None
tmp_dir = '/tmp'
xml_found = False
xml_input = 'exclusive'
@ -149,6 +151,15 @@ def main(argv):
cmd = cmd[:i] + cmd[i + 4:]
else:
cmd = cmd[:i] + cmd[i + 3:]
elif cmd[i] == '-output-search-graph-hypergraph':
# cmd[i + 1] == true
hg_ext = cmd[i + 2]
if i + 3 < len(cmd) and cmd[i + 3][0] != '-':
hg_dir = cmd[i + 3]
cmd = cmd[:i] + cmd[i + 4:]
else:
hg_dir = 'hypergraph'
cmd = cmd[:i] + cmd[i + 3:]
elif cmd[i] == '-tmp':
tmp_dir = cmd[i + 1]
cmd = cmd[:i] + cmd[i + 2:]
@ -251,6 +262,8 @@ def main(argv):
sys.stderr.write('Batch size: {}\n'.format(batch_size))
if n_best_out:
sys.stderr.write('N-best list: {} ({}{})\n'.format(n_best_out, n_best_size, ', distinct' if n_best_distinct else ''))
if hg_dir:
sys.stderr.write('Hypergraph dir: {} ({})\n'.format(hg_dir, hg_ext))
sys.stderr.write('Temp dir: {}\n'.format(work_dir))
# Accumulate seen lines
@ -315,6 +328,11 @@ def main(argv):
work_cmd.append(str(n_best_size))
if n_best_distinct:
work_cmd.append('distinct')
if hg_dir:
work_cmd.append('-output-search-graph-hypergraph')
work_cmd.append('true')
work_cmd.append(hg_ext)
work_cmd.append(os.path.join(work_dir, 'hg.{}'.format(i)))
in_file = os.path.join(work_dir, 'input.{}.xml'.format(i))
out_file = os.path.join(work_dir, 'out.{}'.format(i))
err_file = os.path.join(work_dir, 'err.{}'.format(i))
@ -333,6 +351,15 @@ def main(argv):
for line in open(os.path.join(work_dir, 'nbest.{}'.format(i)), 'r'):
entry = line.partition(' ')
out.write('{} {}'.format(int(entry[0]) + (i * batch_size), entry[2]))
# Gather hypergraphs
if hg_dir:
if not os.path.exists(hg_dir):
os.mkdir(hg_dir)
shutil.copy(os.path.join(work_dir, 'hg.0', 'weights'), os.path.join(hg_dir, 'weights'))
for i in range(threads):
for j in range(batch_size):
shutil.copy(os.path.join(work_dir, 'hg.{}'.format(i), '{}.{}'.format(j, hg_ext)), os.path.join(hg_dir, '{}.{}'.format((i * batch_size) + j, hg_ext)))
# Gather stdout
for i in range(threads):