mosesdecoder/scripts/training/create_count_tables.py
Rico Sennrich 908c006e32 online combination of multiple phrase tables
- creates a virtual phrase table at decoding time based on a vector of component models and a combination algorithm
  - linear interpolation or instance weighting
  - two possible component model types supported so far: 0 (in-memory) or 12 (compact)
  - weights can be set in config, and overriden on a sentence-level through mosesserver API
  - online optimization (perplexity minimization) using dlib and xmlrpc-c call
2013-04-22 13:21:59 +02:00

154 lines
5.2 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Rico Sennrich <sennrich [AT] cl.uzh.ch>
# This script creates tables that store phrase pair frequencies rather than probabilities.
# These count tables can be used for a delayed, online computation of the original phrase translation features
# The benefit is that models can be combined quickly, with the same results as if we trained a model on the concatenation of all data (excepting differences in word alignment).
# Also, each model can be given a weight, which is applied to all frequencies of the model for the combination.
# Note: the input phrase table must have alignment information (--phrase-word-alignment in train-model.perl);
# it must be unsmoothed;
# additionally, the online model type requires the lexical counts files lex.counts.e2f and lex.counts.f2e to be in the same folder (--write-lexical-counts in train-model.perl)
# The results may differ from training on the concatenation of all data due to differences in word alignment, and rounding errors.
from __future__ import unicode_literals
import sys
import os
import gzip
from tempfile import NamedTemporaryFile
from subprocess import Popen, PIPE
if len(sys.argv) < 3 or len(sys.argv) > 4:
sys.stderr.write('Usage: ' + sys.argv[0] + ' in_file out_path [prune_count]\nThis script will create the files out_path/count-table.gz and out_path/count-table-target.gz\n')
exit()
def handle_file(filename,action,fileobj=None,mode='r'):
"""support reading either from stdin, plain file or gzipped file"""
if action == 'open':
if mode == 'r':
mode = 'rb'
if mode == 'rb' and not filename == '-' and not os.path.exists(filename):
if os.path.exists(filename+'.gz'):
filename = filename+'.gz'
else:
sys.stderr.write('Error: unable to open file. ' + filename + ' - aborting.\n')
exit()
if filename.endswith('.gz'):
fileobj = gzip.open(filename,mode)
elif filename == '-':
fileobj = sys.stdin
else:
fileobj = open(filename,mode)
return fileobj
elif action == 'close' and filename != '-':
fileobj.close()
def sort_and_uniq(infile, outfile):
cmd = ['sort', infile]
fobj = handle_file(outfile, 'open', mode='w')
sys.stderr.write('Executing: LC_ALL=C ' + ' '.join(cmd) + ' | uniq | gzip -c > ' + outfile + '\n')
p_sort = Popen(cmd, env={'LC_ALL':'C'}, stdout=PIPE)
p_uniq = Popen(['uniq'], stdin = p_sort.stdout, stdout=PIPE)
p_compress = Popen(['gzip', '-c'], stdin = p_uniq.stdout, stdout=fobj)
p_compress.wait()
fobj.close()
def create_count_lines(fobj, countobj, countobj_target, prune=0):
i = 0
original_pos = 0
source = ""
store_lines = set()
for line in fobj:
if not i % 100000:
sys.stderr.write('.')
i += 1
line = line.split(b' ||| ')
current_source = line[0]
scores = line[2].split()
comments = line[4].split()
fs = comments[1]
ft = comments[0]
try:
fst = comments[2]
except IndexError:
fst = str(int(round(float(scores[0])*float(ft)))).encode()
line[2] = b' '.join([fst,ft,fs])
if prune:
if current_source != source:
write_batch(store_lines, countobj, prune)
source = current_source
store_lines = set()
original_pos = 0
store_lines.add((float(fst), original_pos, b' ||| '.join(line)))
original_pos += 1
else:
countobj.write(b' ||| '.join(line))
# target count file
tline = b' ||| '.join([line[1], b'X', ft]) + b' ||| |||\n' # if you use string formatting to make this look nicer, you may break Python 3 compatibility.
countobj_target.write(tline)
if prune:
write_batch(store_lines, countobj, prune)
countobj.close()
countobj_target.close()
def write_batch(store_lines, outfile, prune):
top20 = sorted(store_lines, reverse=True)[:prune]
for score, original_pos, store_line in sorted(top20, key = lambda x: x[1]): #write in original_order
outfile.write(store_line)
if __name__ == '__main__':
if len(sys.argv) == 4:
prune = int(sys.argv[3])
else:
prune = 0
fileobj = handle_file(sys.argv[1],'open')
out_path = sys.argv[2]
count_table_file = gzip.open(os.path.join(out_path,'count-table.gz'), 'w')
count_table_target_file = os.path.join(out_path,'count-table-target.gz')
count_table_target_file_temp = NamedTemporaryFile(delete=False)
try:
sys.stderr.write('Creating temporary file for unsorted target counts file: ' + count_table_target_file_temp.name + '\n')
create_count_lines(fileobj, count_table_file, count_table_target_file_temp, prune)
count_table_target_file_temp.close()
sys.stderr.write('Finished writing, now re-sorting and compressing target count file\n')
sort_and_uniq(count_table_target_file_temp.name, count_table_target_file)
os.remove(count_table_target_file_temp.name)
sys.stderr.write('Done\n')
except BaseException:
os.remove(count_table_target_file_temp.name)
raise