mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 22:14:57 +03:00
enable use of both edt and sd; reduction of translation option cache
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1966 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
2075f9dda1
commit
7507373b84
@ -4,6 +4,7 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "Hypothesis.h"
|
||||
#include "WordsBitmap.h"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
@ -22,6 +23,8 @@ public:
|
||||
const_iterator end() const { return m_hypos.end(); }
|
||||
size_t size() const { return m_hypos.size(); }
|
||||
virtual inline float GetWorstScore() const { return -numeric_limits<float>::infinity(); };
|
||||
virtual float GetWorstScoreForBitmap( WordsBitmapID ) { return -numeric_limits<float>::infinity(); };
|
||||
virtual float GetWorstScoreForBitmap( WordsBitmap ) { return -numeric_limits<float>::infinity(); };
|
||||
|
||||
virtual ~HypothesisStack();
|
||||
virtual bool AddPrune(Hypothesis *hypothesis) = 0;
|
||||
|
@ -60,9 +60,15 @@ pair<HypothesisStackNormal::iterator, bool> HypothesisStackNormal::Add(Hypothesi
|
||||
VERBOSE(3,", best on stack");
|
||||
m_bestScore = hypo->GetTotalScore();
|
||||
// this may also affect the worst score
|
||||
if ( m_bestScore + m_beamWidth > m_worstScore )
|
||||
m_worstScore = m_bestScore + m_beamWidth;
|
||||
}
|
||||
if ( m_bestScore + m_beamWidth > m_worstScore )
|
||||
m_worstScore = m_bestScore + m_beamWidth;
|
||||
}
|
||||
// update best/worst score for stack diversity 1
|
||||
if ( m_minHypoStackDiversity == 1 &&
|
||||
hypo->GetTotalScore() > GetWorstScoreForBitmap( hypo->GetWordsBitmap() ) )
|
||||
{
|
||||
SetWorstScoreForBitmap( hypo->GetWordsBitmap().GetID(), hypo->GetTotalScore() );
|
||||
}
|
||||
|
||||
VERBOSE(3,", now size " << m_hypos.size());
|
||||
|
||||
@ -160,7 +166,7 @@ void HypothesisStackNormal::PruneToSize(size_t newSize)
|
||||
}
|
||||
|
||||
// add best hyps for each coverage according to minStackDiversity
|
||||
if ( m_minHypoStackDiversity > 0 )
|
||||
if ( m_minHypoStackDiversity > 0 )
|
||||
{
|
||||
map< WordsBitmapID, size_t > diversityCount;
|
||||
for(size_t i=0; i<hypos.size(); i++)
|
||||
|
@ -55,17 +55,20 @@ protected:
|
||||
/** destroy all instances of Hypothesis in this collection */
|
||||
void RemoveAll();
|
||||
|
||||
float GetWorstScoreForBitmap( const WordsBitmap &coverage ) {
|
||||
WordsBitmapID id = coverage.GetID();
|
||||
if (m_diversityWorstScore.find( id ) == m_diversityWorstScore.end())
|
||||
return -numeric_limits<float>::infinity();
|
||||
return m_diversityWorstScore[ id ];
|
||||
}
|
||||
void SetWorstScoreForBitmap( WordsBitmapID id, float worstScore ) {
|
||||
m_diversityWorstScore[ id ] = worstScore;
|
||||
}
|
||||
|
||||
public:
|
||||
float GetWorstScoreForBitmap( WordsBitmapID id ) {
|
||||
if (m_diversityWorstScore.find( id ) == m_diversityWorstScore.end())
|
||||
return -numeric_limits<float>::infinity();
|
||||
return m_diversityWorstScore[ id ];
|
||||
}
|
||||
float GetWorstScoreForBitmap( const WordsBitmap &coverage ) {
|
||||
return GetWorstScoreForBitmap( coverage.GetID() );
|
||||
}
|
||||
|
||||
HypothesisStackNormal();
|
||||
|
||||
/** adds the hypo, but only if within thresholds (beamThr, stackSize).
|
||||
|
@ -89,6 +89,7 @@ Parameter::Parameter()
|
||||
AddParam("mbr-size", "number of translation candidates considered in MBR decoding (default 200)");
|
||||
AddParam("mbr-scale", "scaling factor to convert log linear score probability in MBR decoding (default 1.0)");
|
||||
AddParam("use-persistent-cache", "cache translation options across sentences (default true)");
|
||||
AddParam("persistent-cache-size", "maximum size of cache for translation options (default 10,000 input phrases)");
|
||||
AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation");
|
||||
AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos");
|
||||
AddParam("time-out", "seconds after which is interrupted (-1=no time-out, default is -1)");
|
||||
|
@ -288,8 +288,15 @@ void SearchNormal::ExpandHypothesis(const Hypothesis &hypothesis, const Translat
|
||||
{
|
||||
// worst possible score may have changed -> recompute
|
||||
size_t wordsTranslated = hypothesis.GetWordsBitmap().GetNumWordsCovered() + transOpt.GetSize();
|
||||
float allowedScore = m_hypoStackColl[wordsTranslated]->GetWorstScore() + staticData.GetEarlyDiscardingThreshold();
|
||||
|
||||
float allowedScore = m_hypoStackColl[wordsTranslated]->GetWorstScore();
|
||||
if (staticData.GetMinHypoStackDiversity())
|
||||
{
|
||||
WordsBitmapID id = hypothesis.GetWordsBitmap().GetIDPlus(transOpt.GetStartPos(), transOpt.GetEndPos());
|
||||
float allowedScoreForBitmap = m_hypoStackColl[wordsTranslated]->GetWorstScoreForBitmap( id );
|
||||
allowedScore = std::min( allowedScore, allowedScoreForBitmap );
|
||||
}
|
||||
allowedScore += staticData.GetEarlyDiscardingThreshold();
|
||||
|
||||
// add expected score of translation option
|
||||
expectedScore += transOpt.GetFutureScore();
|
||||
// TRACE_ERR("EXPECTED diff: " << (newHypo->GetTotalScore()-expectedScore) << " (pre " << (newHypo->GetTotalScore()-expectedScorePre) << ") " << hypothesis.GetTargetPhrase() << " ... " << transOpt.GetTargetPhrase() << " [" << expectedScorePre << "," << expectedScore << "," << newHypo->GetTotalScore() << "]" << endl);
|
||||
|
@ -181,7 +181,7 @@ bool StaticData::LoadData(Parameter *parameter)
|
||||
}
|
||||
m_outputSearchGraphPB = true;
|
||||
}
|
||||
else
|
||||
else
|
||||
m_outputSearchGraphPB = false;
|
||||
#endif
|
||||
|
||||
@ -201,6 +201,8 @@ bool StaticData::LoadData(Parameter *parameter)
|
||||
if (m_inputType == SentenceInput)
|
||||
{
|
||||
SetBooleanParameter( &m_useTransOptCache, "use-persistent-cache", true );
|
||||
m_transOptCacheMaxSize = (m_parameter->GetParam("persistent-cache-size").size() > 0)
|
||||
? Scan<size_t>(m_parameter->GetParam("persistent-cache-size")[0]) : DEFAULT_MAX_TRANS_OPT_CACHE_SIZE;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -322,8 +324,6 @@ bool StaticData::LoadData(Parameter *parameter)
|
||||
Scan<size_t>(m_parameter->GetParam("time-out")[0]) : -1;
|
||||
m_timeout = (GetTimeoutThreshold() == -1) ? false : true;
|
||||
|
||||
|
||||
|
||||
// Read in constraint decoding file, if provided
|
||||
if(m_parameter->GetParam("constraint").size())
|
||||
m_constraintFileName = m_parameter->GetParam("constraint")[0];
|
||||
@ -348,8 +348,7 @@ bool StaticData::LoadData(Parameter *parameter)
|
||||
m_searchAlgorithm = (m_parameter->GetParam("search-algorithm").size() > 0) ?
|
||||
(SearchAlgorithm) Scan<size_t>(m_parameter->GetParam("search-algorithm")[0]) : Normal;
|
||||
|
||||
//default case
|
||||
|
||||
// use of xml in input
|
||||
if (m_parameter->GetParam("xml-input").size() == 0) m_xmlInputType = XmlPassThrough;
|
||||
else if (m_parameter->GetParam("xml-input")[0]=="exclusive") m_xmlInputType = XmlExclusive;
|
||||
else if (m_parameter->GetParam("xml-input")[0]=="inclusive") m_xmlInputType = XmlInclusive;
|
||||
@ -412,10 +411,10 @@ StaticData::~StaticData()
|
||||
RemoveAllInColl(m_reorderModels);
|
||||
|
||||
// delete trans opt
|
||||
map<std::pair<const DecodeGraph*, Phrase>, TranslationOptionList* >::iterator iterCache;
|
||||
map<std::pair<const DecodeGraph*, Phrase>, std::pair< TranslationOptionList*, clock_t > >::iterator iterCache;
|
||||
for (iterCache = m_transOptCache.begin() ; iterCache != m_transOptCache.end() ; ++iterCache)
|
||||
{
|
||||
TranslationOptionList *transOptList = iterCache->second;
|
||||
TranslationOptionList *transOptList = iterCache->second.first;
|
||||
delete transOptList;
|
||||
}
|
||||
|
||||
@ -969,18 +968,53 @@ const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGra
|
||||
{
|
||||
std::pair<const DecodeGraph*, Phrase> key(&decodeGraph, sourcePhrase);
|
||||
|
||||
std::map<std::pair<const DecodeGraph*, Phrase>, TranslationOptionList*>::const_iterator iter
|
||||
std::map<std::pair<const DecodeGraph*, Phrase>, std::pair<TranslationOptionList*,clock_t> >::iterator iter
|
||||
= m_transOptCache.find(key);
|
||||
if (iter == m_transOptCache.end())
|
||||
return NULL;
|
||||
iter->second.second = clock(); // update last used time
|
||||
return iter->second.first;
|
||||
}
|
||||
|
||||
return iter->second;
|
||||
void StaticData::ReduceTransOptCache() const
|
||||
{
|
||||
if (m_transOptCache.size() <= m_transOptCacheMaxSize) return; // not full
|
||||
clock_t t = clock();
|
||||
|
||||
// find cutoff for last used time
|
||||
priority_queue< clock_t > lastUsedTimes;
|
||||
std::map<std::pair<const DecodeGraph*, Phrase>, std::pair<TranslationOptionList*,clock_t> >::iterator iter;
|
||||
iter = m_transOptCache.begin();
|
||||
while( iter != m_transOptCache.end() )
|
||||
{
|
||||
lastUsedTimes.push( iter->second.second );
|
||||
iter++;
|
||||
}
|
||||
for( size_t i=0; i < lastUsedTimes.size()-m_transOptCacheMaxSize/2; i++ )
|
||||
lastUsedTimes.pop();
|
||||
clock_t cutoffLastUsedTime = lastUsedTimes.top();
|
||||
|
||||
// remove all old entries
|
||||
iter = m_transOptCache.begin();
|
||||
while( iter != m_transOptCache.end() )
|
||||
{
|
||||
if (iter->second.second < cutoffLastUsedTime)
|
||||
{
|
||||
std::map<std::pair<const DecodeGraph*, Phrase>, std::pair<TranslationOptionList*,clock_t> >::iterator iterRemove = iter++;
|
||||
delete iterRemove->second.first;
|
||||
m_transOptCache.erase(iterRemove);
|
||||
}
|
||||
else iter++;
|
||||
}
|
||||
VERBOSE(2,"Reduced persistent translation option cache in " << ((clock()-t)/(float)CLOCKS_PER_SEC) << " seconds." << std::endl);
|
||||
}
|
||||
|
||||
void StaticData::AddTransOptListToCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase, const TranslationOptionList &transOptList) const
|
||||
{
|
||||
std::pair<const DecodeGraph*, Phrase> pair(&decodeGraph, sourcePhrase);
|
||||
m_transOptCache[pair] = new TranslationOptionList(transOptList);
|
||||
std::pair<const DecodeGraph*, Phrase> key(&decodeGraph, sourcePhrase);
|
||||
TranslationOptionList* storedTransOptList = new TranslationOptionList(transOptList);
|
||||
m_transOptCache[key] = make_pair( storedTransOptList, clock() );
|
||||
ReduceTransOptCache();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -140,8 +140,9 @@ protected:
|
||||
bool m_timeout; //! use timeout
|
||||
size_t m_timeout_threshold; //! seconds after which time out is activated
|
||||
|
||||
bool m_useTransOptCache;
|
||||
mutable std::map<std::pair<const DecodeGraph*, Phrase>, TranslationOptionList*> m_transOptCache;
|
||||
bool m_useTransOptCache; //! flag indicating, if the persistent translation option cache should be used
|
||||
mutable std::map<std::pair<const DecodeGraph*, Phrase>, pair<TranslationOptionList*,clock_t> > m_transOptCache; //! persistent translation option cache
|
||||
size_t m_transOptCacheMaxSize; //! maximum size for persistent translation option cache
|
||||
|
||||
mutable const InputType* m_input; //! holds reference to current sentence
|
||||
bool m_isAlwaysCreateDirectTranslationOption;
|
||||
@ -457,6 +458,7 @@ public:
|
||||
bool GetUseTransOptCache() const { return m_useTransOptCache; }
|
||||
|
||||
void AddTransOptListToCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase, const TranslationOptionList &transOptList) const;
|
||||
void StaticData::ReduceTransOptCache() const;
|
||||
|
||||
const TranslationOptionList* FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const;
|
||||
};
|
||||
|
@ -47,6 +47,7 @@ namespace Moses
|
||||
const size_t DEFAULT_CUBE_PRUNING_POP_LIMIT = 1000;
|
||||
const size_t DEFAULT_CUBE_PRUNING_DIVERSITY = 0;
|
||||
const size_t DEFAULT_MAX_HYPOSTACK_SIZE = 200;
|
||||
const size_t DEFAULT_MAX_TRANS_OPT_CACHE_SIZE = 10000;
|
||||
const size_t DEFAULT_MAX_TRANS_OPT_SIZE = 50;
|
||||
const size_t DEFAULT_MAX_PART_TRANS_OPT_SIZE = 10000;
|
||||
const size_t DEFAULT_MAX_PHRASE_LENGTH = 20;
|
||||
|
@ -194,25 +194,48 @@ public:
|
||||
//! TODO - ??? no idea
|
||||
int GetFutureCosts(int lastPos) const ;
|
||||
|
||||
//! converts bitmap into an integer ID: it consists of two parts: the first 16 bit are the pattern between the first gap and the last word-1, the second 16 bit are the number of filled positions. enforces a sentence length limit of 65535 and a max distortion of 16
|
||||
WordsBitmapID GetID() const {
|
||||
assert(m_size < (1<<16));
|
||||
//! converts bitmap into an integer ID: it consists of two parts: the first 16 bit are the pattern between the first gap and the last word-1, the second 16 bit are the number of filled positions. enforces a sentence length limit of 65535 and a max distortion of 16
|
||||
WordsBitmapID GetID() const {
|
||||
assert(m_size < (1<<16));
|
||||
|
||||
size_t start = GetFirstGapPos();
|
||||
if (start == NOT_FOUND) start = m_size; // nothing left
|
||||
size_t start = GetFirstGapPos();
|
||||
if (start == NOT_FOUND) start = m_size; // nothing left
|
||||
|
||||
size_t end = GetLastPos();
|
||||
if (end == NOT_FOUND) end = 0; // nothing translated yet
|
||||
size_t end = GetLastPos();
|
||||
if (end == NOT_FOUND) end = 0; // nothing translated yet
|
||||
|
||||
assert(end < start || end-start <= 16);
|
||||
WordsBitmapID id = 0;
|
||||
for(size_t pos = end; pos > start; pos--) {
|
||||
id = id*2 + (int) GetValue(pos);
|
||||
}
|
||||
return id + (1<<16) * start;
|
||||
}
|
||||
assert(end < start || end-start <= 16);
|
||||
WordsBitmapID id = 0;
|
||||
for(size_t pos = end; pos > start; pos--) {
|
||||
id = id*2 + (int) GetValue(pos);
|
||||
}
|
||||
return id + (1<<16) * start;
|
||||
}
|
||||
|
||||
TO_STRING();
|
||||
//! converts bitmap into an integer ID, with an additional span covered
|
||||
WordsBitmapID GetIDPlus( size_t startPos, size_t endPos ) const {
|
||||
assert(m_size < (1<<16));
|
||||
|
||||
size_t start = GetFirstGapPos();
|
||||
if (start == NOT_FOUND) start = m_size; // nothing left
|
||||
|
||||
size_t end = GetLastPos();
|
||||
if (end == NOT_FOUND) end = 0; // nothing translated yet
|
||||
|
||||
if (start == startPos) start = endPos+1;
|
||||
if (end < endPos) end = endPos;
|
||||
|
||||
assert(end < start || end-start <= 16);
|
||||
WordsBitmapID id = 0;
|
||||
for(size_t pos = end; pos > start; pos--) {
|
||||
id = id*2;
|
||||
if (GetValue(pos) || (startPos<=pos && pos<=endPos))
|
||||
id++;
|
||||
}
|
||||
return id + (1<<16) * start;
|
||||
}
|
||||
|
||||
TO_STRING();
|
||||
};
|
||||
|
||||
// friend
|
||||
|
Loading…
Reference in New Issue
Block a user