query results can now be sorted and pruned

This commit is contained in:
Wilker Aziz 2012-09-18 11:39:39 +02:00
parent be5678afb7
commit ab191b562e
3 changed files with 1456 additions and 444 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,11 @@
from libcpp.string cimport string
from libcpp.vector cimport vector
import os
import cython
cpdef int fsign(float x):
'''Simply returns the sign of float x (zero is assumed +), it's defined here just so one gains a little bit with static typing'''
return 1 if x >= 0 else -1
cdef bytes as_str(data):
if isinstance(data, bytes):
@ -43,6 +48,13 @@ cdef class QueryResult(object):
'''Word-alignment info (as string)'''
return self._wa
@staticmethod
def desc(x, y, keys = lambda r: r.scores[0]):
'''Returns the sign of keys(y) - keys(x).
Can only be used if scores is not an empty vector as
keys defaults to scores[0]'''
return fsign(keys(y) - keys(x))
def __str__(self):
'''Returns a string such as: <words> ||| <scores> [||| word-alignment info]'''
if self._wa:
@ -126,14 +138,16 @@ cdef class BinaryPhraseTable(object):
def delimiters(self):
return self._delimiters
def query(self, line):
def query(self, line, cmp = None, top = 0):
'''Queries the phrase table and returns a list of matches.
Each match is a QueryResult.'''
Each match is a QueryResult.
If 'cmp' is defined the return list is sorted.
If 'top' is defined, onlye the top elements will be returned.'''
cdef bytes text = as_str(line)
cdef vector[string] fphrase = Tokenize(string(text), string(self._delimiters))
cdef vector[StringTgtCand]* rv = new vector[StringTgtCand]()
cdef vector[string]* wa = NULL
cdef list phrases
if not self.__tree.UseWordAlignment():
self.__tree.GetTargetCandidates(fphrase, rv[0])
phrases = [get_query_result(rv[0][i]) for i in range(rv.size())]
@ -143,5 +157,10 @@ cdef class BinaryPhraseTable(object):
phrases = [get_query_result(rv[0][i], wa[0][i].c_str()) for i in range(rv.size())]
del wa
del rv
return phrases
if cmp:
phrases.sort(cmp=cmp)
if top > 0:
return phrases[0:top]
else:
return phrases

View File

@ -1,8 +1,8 @@
from binpt import BinaryPhraseTable
import binpt
#from binpt import QueryResult
import sys
if len(sys.argv) < 3:
print "Usage: %s phrase-table nscores [wa] < query > result" % (sys.argv[0])
sys.exit(0)
@ -11,12 +11,12 @@ pt_file = sys.argv[1]
nscores = int(sys.argv[2])
wa = len(sys.argv) == 4
pt = BinaryPhraseTable(pt_file, nscores, wa)
pt = binpt.BinaryPhraseTable(pt_file, nscores, wa)
print >> sys.stderr, "-ttable %s -nscores %d -alignment-info %s -delimiter '%s'\n" %(pt.path, pt.nscores, str(pt.wa), pt.delimiters)
for line in sys.stdin:
f = line.strip()
matches = pt.query(f)
matches = pt.query(f, cmp = binpt.QueryResult.desc, top = 20)
print '\n'.join([' ||| '.join((f, str(e))) for e in matches])
'''
# This is how one would use the QueryResult object