enable use of both edt and sd; reduction of translation option cache

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1966 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
phkoehn 2009-01-01 18:16:54 +00:00
parent 2075f9dda1
commit 7507373b84
9 changed files with 120 additions and 40 deletions

View File

@ -4,6 +4,7 @@
#include <vector>
#include <set>
#include "Hypothesis.h"
#include "WordsBitmap.h"
namespace Moses
{
@ -22,6 +23,8 @@ public:
const_iterator end() const { return m_hypos.end(); }
size_t size() const { return m_hypos.size(); }
virtual inline float GetWorstScore() const { return -numeric_limits<float>::infinity(); };
virtual float GetWorstScoreForBitmap( WordsBitmapID ) { return -numeric_limits<float>::infinity(); };
virtual float GetWorstScoreForBitmap( WordsBitmap ) { return -numeric_limits<float>::infinity(); };
virtual ~HypothesisStack();
virtual bool AddPrune(Hypothesis *hypothesis) = 0;

View File

@ -60,9 +60,15 @@ pair<HypothesisStackNormal::iterator, bool> HypothesisStackNormal::Add(Hypothesi
VERBOSE(3,", best on stack");
m_bestScore = hypo->GetTotalScore();
// this may also affect the worst score
if ( m_bestScore + m_beamWidth > m_worstScore )
m_worstScore = m_bestScore + m_beamWidth;
}
if ( m_bestScore + m_beamWidth > m_worstScore )
m_worstScore = m_bestScore + m_beamWidth;
}
// update best/worst score for stack diversity 1
if ( m_minHypoStackDiversity == 1 &&
hypo->GetTotalScore() > GetWorstScoreForBitmap( hypo->GetWordsBitmap() ) )
{
SetWorstScoreForBitmap( hypo->GetWordsBitmap().GetID(), hypo->GetTotalScore() );
}
VERBOSE(3,", now size " << m_hypos.size());
@ -160,7 +166,7 @@ void HypothesisStackNormal::PruneToSize(size_t newSize)
}
// add best hyps for each coverage according to minStackDiversity
if ( m_minHypoStackDiversity > 0 )
if ( m_minHypoStackDiversity > 0 )
{
map< WordsBitmapID, size_t > diversityCount;
for(size_t i=0; i<hypos.size(); i++)

View File

@ -55,17 +55,20 @@ protected:
/** destroy all instances of Hypothesis in this collection */
void RemoveAll();
float GetWorstScoreForBitmap( const WordsBitmap &coverage ) {
WordsBitmapID id = coverage.GetID();
if (m_diversityWorstScore.find( id ) == m_diversityWorstScore.end())
return -numeric_limits<float>::infinity();
return m_diversityWorstScore[ id ];
}
void SetWorstScoreForBitmap( WordsBitmapID id, float worstScore ) {
m_diversityWorstScore[ id ] = worstScore;
}
public:
float GetWorstScoreForBitmap( WordsBitmapID id ) {
if (m_diversityWorstScore.find( id ) == m_diversityWorstScore.end())
return -numeric_limits<float>::infinity();
return m_diversityWorstScore[ id ];
}
float GetWorstScoreForBitmap( const WordsBitmap &coverage ) {
return GetWorstScoreForBitmap( coverage.GetID() );
}
HypothesisStackNormal();
/** adds the hypo, but only if within thresholds (beamThr, stackSize).

View File

@ -89,6 +89,7 @@ Parameter::Parameter()
AddParam("mbr-size", "number of translation candidates considered in MBR decoding (default 200)");
AddParam("mbr-scale", "scaling factor to convert log linear score probability in MBR decoding (default 1.0)");
AddParam("use-persistent-cache", "cache translation options across sentences (default true)");
AddParam("persistent-cache-size", "maximum size of cache for translation options (default 10,000 input phrases)");
AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation");
AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos");
AddParam("time-out", "seconds after which is interrupted (-1=no time-out, default is -1)");

View File

@ -288,8 +288,15 @@ void SearchNormal::ExpandHypothesis(const Hypothesis &hypothesis, const Translat
{
// worst possible score may have changed -> recompute
size_t wordsTranslated = hypothesis.GetWordsBitmap().GetNumWordsCovered() + transOpt.GetSize();
float allowedScore = m_hypoStackColl[wordsTranslated]->GetWorstScore() + staticData.GetEarlyDiscardingThreshold();
float allowedScore = m_hypoStackColl[wordsTranslated]->GetWorstScore();
if (staticData.GetMinHypoStackDiversity())
{
WordsBitmapID id = hypothesis.GetWordsBitmap().GetIDPlus(transOpt.GetStartPos(), transOpt.GetEndPos());
float allowedScoreForBitmap = m_hypoStackColl[wordsTranslated]->GetWorstScoreForBitmap( id );
allowedScore = std::min( allowedScore, allowedScoreForBitmap );
}
allowedScore += staticData.GetEarlyDiscardingThreshold();
// add expected score of translation option
expectedScore += transOpt.GetFutureScore();
// TRACE_ERR("EXPECTED diff: " << (newHypo->GetTotalScore()-expectedScore) << " (pre " << (newHypo->GetTotalScore()-expectedScorePre) << ") " << hypothesis.GetTargetPhrase() << " ... " << transOpt.GetTargetPhrase() << " [" << expectedScorePre << "," << expectedScore << "," << newHypo->GetTotalScore() << "]" << endl);

View File

@ -181,7 +181,7 @@ bool StaticData::LoadData(Parameter *parameter)
}
m_outputSearchGraphPB = true;
}
else
else
m_outputSearchGraphPB = false;
#endif
@ -201,6 +201,8 @@ bool StaticData::LoadData(Parameter *parameter)
if (m_inputType == SentenceInput)
{
SetBooleanParameter( &m_useTransOptCache, "use-persistent-cache", true );
m_transOptCacheMaxSize = (m_parameter->GetParam("persistent-cache-size").size() > 0)
? Scan<size_t>(m_parameter->GetParam("persistent-cache-size")[0]) : DEFAULT_MAX_TRANS_OPT_CACHE_SIZE;
}
else
{
@ -322,8 +324,6 @@ bool StaticData::LoadData(Parameter *parameter)
Scan<size_t>(m_parameter->GetParam("time-out")[0]) : -1;
m_timeout = (GetTimeoutThreshold() == -1) ? false : true;
// Read in constraint decoding file, if provided
if(m_parameter->GetParam("constraint").size())
m_constraintFileName = m_parameter->GetParam("constraint")[0];
@ -348,8 +348,7 @@ bool StaticData::LoadData(Parameter *parameter)
m_searchAlgorithm = (m_parameter->GetParam("search-algorithm").size() > 0) ?
(SearchAlgorithm) Scan<size_t>(m_parameter->GetParam("search-algorithm")[0]) : Normal;
//default case
// use of xml in input
if (m_parameter->GetParam("xml-input").size() == 0) m_xmlInputType = XmlPassThrough;
else if (m_parameter->GetParam("xml-input")[0]=="exclusive") m_xmlInputType = XmlExclusive;
else if (m_parameter->GetParam("xml-input")[0]=="inclusive") m_xmlInputType = XmlInclusive;
@ -412,10 +411,10 @@ StaticData::~StaticData()
RemoveAllInColl(m_reorderModels);
// delete trans opt
map<std::pair<const DecodeGraph*, Phrase>, TranslationOptionList* >::iterator iterCache;
map<std::pair<const DecodeGraph*, Phrase>, std::pair< TranslationOptionList*, clock_t > >::iterator iterCache;
for (iterCache = m_transOptCache.begin() ; iterCache != m_transOptCache.end() ; ++iterCache)
{
TranslationOptionList *transOptList = iterCache->second;
TranslationOptionList *transOptList = iterCache->second.first;
delete transOptList;
}
@ -969,18 +968,53 @@ const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGra
{
std::pair<const DecodeGraph*, Phrase> key(&decodeGraph, sourcePhrase);
std::map<std::pair<const DecodeGraph*, Phrase>, TranslationOptionList*>::const_iterator iter
std::map<std::pair<const DecodeGraph*, Phrase>, std::pair<TranslationOptionList*,clock_t> >::iterator iter
= m_transOptCache.find(key);
if (iter == m_transOptCache.end())
return NULL;
iter->second.second = clock(); // update last used time
return iter->second.first;
}
return iter->second;
void StaticData::ReduceTransOptCache() const
{
if (m_transOptCache.size() <= m_transOptCacheMaxSize) return; // not full
clock_t t = clock();
// find cutoff for last used time
priority_queue< clock_t > lastUsedTimes;
std::map<std::pair<const DecodeGraph*, Phrase>, std::pair<TranslationOptionList*,clock_t> >::iterator iter;
iter = m_transOptCache.begin();
while( iter != m_transOptCache.end() )
{
lastUsedTimes.push( iter->second.second );
iter++;
}
for( size_t i=0; i < lastUsedTimes.size()-m_transOptCacheMaxSize/2; i++ )
lastUsedTimes.pop();
clock_t cutoffLastUsedTime = lastUsedTimes.top();
// remove all old entries
iter = m_transOptCache.begin();
while( iter != m_transOptCache.end() )
{
if (iter->second.second < cutoffLastUsedTime)
{
std::map<std::pair<const DecodeGraph*, Phrase>, std::pair<TranslationOptionList*,clock_t> >::iterator iterRemove = iter++;
delete iterRemove->second.first;
m_transOptCache.erase(iterRemove);
}
else iter++;
}
VERBOSE(2,"Reduced persistent translation option cache in " << ((clock()-t)/(float)CLOCKS_PER_SEC) << " seconds." << std::endl);
}
void StaticData::AddTransOptListToCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase, const TranslationOptionList &transOptList) const
{
std::pair<const DecodeGraph*, Phrase> pair(&decodeGraph, sourcePhrase);
m_transOptCache[pair] = new TranslationOptionList(transOptList);
std::pair<const DecodeGraph*, Phrase> key(&decodeGraph, sourcePhrase);
TranslationOptionList* storedTransOptList = new TranslationOptionList(transOptList);
m_transOptCache[key] = make_pair( storedTransOptList, clock() );
ReduceTransOptCache();
}
}

View File

@ -140,8 +140,9 @@ protected:
bool m_timeout; //! use timeout
size_t m_timeout_threshold; //! seconds after which time out is activated
bool m_useTransOptCache;
mutable std::map<std::pair<const DecodeGraph*, Phrase>, TranslationOptionList*> m_transOptCache;
bool m_useTransOptCache; //! flag indicating, if the persistent translation option cache should be used
mutable std::map<std::pair<const DecodeGraph*, Phrase>, pair<TranslationOptionList*,clock_t> > m_transOptCache; //! persistent translation option cache
size_t m_transOptCacheMaxSize; //! maximum size for persistent translation option cache
mutable const InputType* m_input; //! holds reference to current sentence
bool m_isAlwaysCreateDirectTranslationOption;
@ -457,6 +458,7 @@ public:
bool GetUseTransOptCache() const { return m_useTransOptCache; }
void AddTransOptListToCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase, const TranslationOptionList &transOptList) const;
void StaticData::ReduceTransOptCache() const;
const TranslationOptionList* FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const;
};

View File

@ -47,6 +47,7 @@ namespace Moses
const size_t DEFAULT_CUBE_PRUNING_POP_LIMIT = 1000;
const size_t DEFAULT_CUBE_PRUNING_DIVERSITY = 0;
const size_t DEFAULT_MAX_HYPOSTACK_SIZE = 200;
const size_t DEFAULT_MAX_TRANS_OPT_CACHE_SIZE = 10000;
const size_t DEFAULT_MAX_TRANS_OPT_SIZE = 50;
const size_t DEFAULT_MAX_PART_TRANS_OPT_SIZE = 10000;
const size_t DEFAULT_MAX_PHRASE_LENGTH = 20;

View File

@ -194,25 +194,48 @@ public:
//! TODO - ??? no idea
int GetFutureCosts(int lastPos) const ;
//! converts bitmap into an integer ID: it consists of two parts: the first 16 bit are the pattern between the first gap and the last word-1, the second 16 bit are the number of filled positions. enforces a sentence length limit of 65535 and a max distortion of 16
WordsBitmapID GetID() const {
assert(m_size < (1<<16));
//! converts bitmap into an integer ID: it consists of two parts: the first 16 bit are the pattern between the first gap and the last word-1, the second 16 bit are the number of filled positions. enforces a sentence length limit of 65535 and a max distortion of 16
WordsBitmapID GetID() const {
assert(m_size < (1<<16));
size_t start = GetFirstGapPos();
if (start == NOT_FOUND) start = m_size; // nothing left
size_t start = GetFirstGapPos();
if (start == NOT_FOUND) start = m_size; // nothing left
size_t end = GetLastPos();
if (end == NOT_FOUND) end = 0; // nothing translated yet
size_t end = GetLastPos();
if (end == NOT_FOUND) end = 0; // nothing translated yet
assert(end < start || end-start <= 16);
WordsBitmapID id = 0;
for(size_t pos = end; pos > start; pos--) {
id = id*2 + (int) GetValue(pos);
}
return id + (1<<16) * start;
}
assert(end < start || end-start <= 16);
WordsBitmapID id = 0;
for(size_t pos = end; pos > start; pos--) {
id = id*2 + (int) GetValue(pos);
}
return id + (1<<16) * start;
}
TO_STRING();
//! converts bitmap into an integer ID, with an additional span covered
WordsBitmapID GetIDPlus( size_t startPos, size_t endPos ) const {
assert(m_size < (1<<16));
size_t start = GetFirstGapPos();
if (start == NOT_FOUND) start = m_size; // nothing left
size_t end = GetLastPos();
if (end == NOT_FOUND) end = 0; // nothing translated yet
if (start == startPos) start = endPos+1;
if (end < endPos) end = endPos;
assert(end < start || end-start <= 16);
WordsBitmapID id = 0;
for(size_t pos = end; pos > start; pos--) {
id = id*2;
if (GetValue(pos) || (startPos<=pos && pos<=endPos))
id++;
}
return id + (1<<16) * start;
}
TO_STRING();
};
// friend