mosesdecoder/contrib/promix/nbest.py
2013-04-17 21:34:20 +01:00

231 lines
6.9 KiB
Python

#!/usr/bin/env python
import gzip
import os
import re
import numpy as np
import sys
from bleu import BleuScorer
from coll import OrderedDict
# Edit to set moses python path
sys.path.append(os.path.dirname(__file__) + "/../python")
import moses.dictree as binpt
class DataFormatException(Exception):
pass
class Hypothesis:
def __init__(self,text,fv,segments=False):
self.alignment = [] #only stored for segmented hypos
self.tokens = [] #only stored for segmented hypos
if not segments:
self.text = text
# Triples of (source-start, source-end, target-end) where segments end positions
# are 1 beyond the last token
else:
# recover segmentation
self.tokens = []
align_re = re.compile("\|(\d+)-(\d+)\|")
for token in text.split():
match = align_re.match(token)
if match:
self.alignment.append\
((int(match.group(1)), 1+int(match.group(2)), len(self.tokens)))
else:
self.tokens.append(token)
self.text = " ".join(self.tokens)
if not self.alignment:
raise DataFormatException("Expected segmentation information not found in nbest")
self.fv = np.array(fv)
self.score = 0
def __str__(self):
return "{text=%s fv=%s score=%5.4f}" % (self.text, str(self.fv), self.score)
class NBestList:
def __init__(self,id):
self.id = id
self.hyps = []
# Maps feature ids (short feature names) to their values
_feature_index = {}
def set_feature_start(name,index):
indexes = _feature_index.get(name, [index,0])
indexes[0] = index
_feature_index[name] = indexes
def set_feature_end(name,index):
indexes = _feature_index.get(name, [0,index])
indexes[1] = index
_feature_index[name] = indexes
def get_feature_index(name):
return _feature_index.get(name, [0,0])
def get_nbests(nbest_file, segments=False):
"""Iterate through nbest lists"""
if nbest_file.endswith("gz"):
fh = gzip.GzipFile(nbest_file)
else:
fh = open(nbest_file)
lineno = 0
nbest = None
for line in fh:
fields = line.split(" ||| ")
if len(fields) != 4:
raise DataFormatException("nbest(%d): %s" % (lineno,line))
(id, text, scores, total) = fields
if nbest and nbest.id != id:
yield nbest
nbest = None
if not nbest:
nbest = NBestList(id)
fv = []
score_name = None
for score in scores.split():
if score.endswith(":"):
score = score[:-1]
if score_name:
set_feature_end(score_name,len(fv))
score_name = score
set_feature_start(score_name,len(fv))
else:
fv.append(float(score))
if score_name: set_feature_end(score_name,len(fv))
hyp = Hypothesis(text[:-1],fv,segments)
nbest.hyps.append(hyp)
if nbest:
yield nbest
def get_scores(score_data_file):
"""Iterate through the score data, returning a set of scores for each sentence"""
scorer = BleuScorer()
fh = open(score_data_file)
lineno = 0
score_vectors = None
for line in fh:
if line.startswith("SCORES_TXT_BEGIN"):
score_vectors = []
elif line.startswith("SCORES_TXT_END"):
scores = [scorer.score(score_vector) for score_vector in score_vectors]
yield scores
else:
score_vectors.append([float(i) for i in line[:-1].split()])
def get_scored_nbests(nbest_file, score_data_file, input_file, segments=False):
score_gen = get_scores(score_data_file)
input_gen = None
if input_file: input_gen = open(input_file)
try:
for nbest in get_nbests(nbest_file, segments=segments):
scores = score_gen.next()
if len(scores) != len(nbest.hyps):
raise DataFormatException("Length of nbest %s does not match score list (%d != %d)" %
(nbest.id,len(nbest.hyps), len(scores)))
input_line = None
if input_gen:
input_line = input_gen.next()[:-1]
for hyp,score in zip(nbest.hyps, scores):
hyp.score = score
hyp.input_line = input_line
yield nbest
except StopIteration:
raise DataFormatException("Score file shorter than nbest list file")
class PhraseCache:
"""An LRU cache for ttable lookups"""
def __init__(self, max_size):
self.max_size = max_size
self.pairs_to_scores = OrderedDict()
def get(self, source, target):
key = (source,target)
scores = self.pairs_to_scores.get(key,None)
if scores:
# cache hit - update access time
del self.pairs_to_scores[key]
self.pairs_to_scores[key] = scores
return scores
def add(self,source,target,scores):
key = (source,target)
self.pairs_to_scores[key] = scores
while len(self.pairs_to_scores) > self.max_size:
self.pairs_to_scores.popitem(last=False)
#
# Should I store full lists of options, or just phrase pairs?
# Should probably store phrase-pairs, but may want to add
# high scoring pairs (say, 20?) when I load the translations
# of a given phrase
#
class CachedPhraseTable:
def __init__(self,ttable_file,nscores=5,cache_size=20000):
wa = False
if binpt.PhraseDictionaryTree.canLoad(ttable_file,True):
# assume word alignment is included
wa = True
self.ttable = binpt.PhraseDictionaryTree(ttable_file,nscores = nscores,wa = wa, tableLimit=0)
self.cache = PhraseCache(cache_size)
self.nscores = nscores
def get_scores(self,phrase):
source = " ".join(phrase[0])
target_tuple = tuple(phrase[1])
target = " ".join(target_tuple)
scores = self.cache.get(source,target)
if not scores:
# cache miss
scores = [0] * (self.nscores-1) # ignore penalty
entries = self.ttable.query(source, converter=None)
# find correct target
for entry in entries:
if entry.rhs == target_tuple:
scores = entry.scores[:-1]
break
#print "QUERY",source,"|||",target,"|||",scores
self.cache.add(source,target,scores)
#else:
# print "CACHE",source,"|||",target,"|||",scores
return scores
class MosesPhraseScorer:
def __init__(self,ttable_files, cache_size=20000):
self.ttables = []
for ttable_file in ttable_files:
self.ttables.append(CachedPhraseTable(ttable_file, cache_size=cache_size))
def add_scores(self, hyp):
"""Add the phrase scores to a hypothesis"""
# Collect up the phrase pairs
phrases = []
source_tokens = hyp.input_line.split()
tgt_st = 0
if not hyp.alignment:
raise DataFormatException("Alignments missing from: " + str(hyp))
for src_st,src_end,tgt_end in hyp.alignment:
phrases.append((source_tokens[src_st:src_end], hyp.tokens[tgt_st:tgt_end]))
tgt_st = tgt_end
# Look up the scores
phrase_scores = []
for ttable in self.ttables:
phrase_scores.append([])
for phrase in phrases:
phrase_scores[-1].append(ttable.get_scores(phrase))
# phrase_scores = np.array(phrase_scores)
# eps = np.exp(-100)
# phrase_scores[phrase_scores<eps]=eps
floor = np.exp(-100)
phrase_scores = np.clip(np.array(phrase_scores), floor, np.inf)
hyp.phrase_scores = phrase_scores