mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2025-01-01 08:21:47 +03:00
query results can now be sorted and pruned
This commit is contained in:
parent
be5678afb7
commit
ab191b562e
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,11 @@
|
||||
from libcpp.string cimport string
|
||||
from libcpp.vector cimport vector
|
||||
import os
|
||||
import cython
|
||||
|
||||
cpdef int fsign(float x):
|
||||
'''Simply returns the sign of float x (zero is assumed +), it's defined here just so one gains a little bit with static typing'''
|
||||
return 1 if x >= 0 else -1
|
||||
|
||||
cdef bytes as_str(data):
|
||||
if isinstance(data, bytes):
|
||||
@ -43,6 +48,13 @@ cdef class QueryResult(object):
|
||||
'''Word-alignment info (as string)'''
|
||||
return self._wa
|
||||
|
||||
@staticmethod
|
||||
def desc(x, y, keys = lambda r: r.scores[0]):
|
||||
'''Returns the sign of keys(y) - keys(x).
|
||||
Can only be used if scores is not an empty vector as
|
||||
keys defaults to scores[0]'''
|
||||
return fsign(keys(y) - keys(x))
|
||||
|
||||
def __str__(self):
|
||||
'''Returns a string such as: <words> ||| <scores> [||| word-alignment info]'''
|
||||
if self._wa:
|
||||
@ -126,14 +138,16 @@ cdef class BinaryPhraseTable(object):
|
||||
def delimiters(self):
|
||||
return self._delimiters
|
||||
|
||||
def query(self, line):
|
||||
def query(self, line, cmp = None, top = 0):
|
||||
'''Queries the phrase table and returns a list of matches.
|
||||
Each match is a QueryResult.'''
|
||||
Each match is a QueryResult.
|
||||
If 'cmp' is defined the return list is sorted.
|
||||
If 'top' is defined, onlye the top elements will be returned.'''
|
||||
cdef bytes text = as_str(line)
|
||||
cdef vector[string] fphrase = Tokenize(string(text), string(self._delimiters))
|
||||
cdef vector[StringTgtCand]* rv = new vector[StringTgtCand]()
|
||||
cdef vector[string]* wa = NULL
|
||||
|
||||
cdef list phrases
|
||||
if not self.__tree.UseWordAlignment():
|
||||
self.__tree.GetTargetCandidates(fphrase, rv[0])
|
||||
phrases = [get_query_result(rv[0][i]) for i in range(rv.size())]
|
||||
@ -143,5 +157,10 @@ cdef class BinaryPhraseTable(object):
|
||||
phrases = [get_query_result(rv[0][i], wa[0][i].c_str()) for i in range(rv.size())]
|
||||
del wa
|
||||
del rv
|
||||
return phrases
|
||||
if cmp:
|
||||
phrases.sort(cmp=cmp)
|
||||
if top > 0:
|
||||
return phrases[0:top]
|
||||
else:
|
||||
return phrases
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from binpt import BinaryPhraseTable
|
||||
import binpt
|
||||
#from binpt import QueryResult
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print "Usage: %s phrase-table nscores [wa] < query > result" % (sys.argv[0])
|
||||
sys.exit(0)
|
||||
@ -11,12 +11,12 @@ pt_file = sys.argv[1]
|
||||
nscores = int(sys.argv[2])
|
||||
wa = len(sys.argv) == 4
|
||||
|
||||
pt = BinaryPhraseTable(pt_file, nscores, wa)
|
||||
pt = binpt.BinaryPhraseTable(pt_file, nscores, wa)
|
||||
print >> sys.stderr, "-ttable %s -nscores %d -alignment-info %s -delimiter '%s'\n" %(pt.path, pt.nscores, str(pt.wa), pt.delimiters)
|
||||
|
||||
for line in sys.stdin:
|
||||
f = line.strip()
|
||||
matches = pt.query(f)
|
||||
matches = pt.query(f, cmp = binpt.QueryResult.desc, top = 20)
|
||||
print '\n'.join([' ||| '.join((f, str(e))) for e in matches])
|
||||
'''
|
||||
# This is how one would use the QueryResult object
|
||||
|
Loading…
Reference in New Issue
Block a user