diff --git a/moses/LM/IRST.cpp b/moses/LM/IRST.cpp index 19a5f2c82..386eb09ec 100644 --- a/moses/LM/IRST.cpp +++ b/moses/LM/IRST.cpp @@ -29,6 +29,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA using namespace irstlm; #include "IRST.h" +#include "moses/LM/PointerState.h" #include "moses/TypeDef.h" #include "moses/Util.h" #include "moses/FactorCollection.h" @@ -42,9 +43,28 @@ using namespace std; namespace Moses { +class IRSTLMState : public PointerState +{ +public: + IRSTLMState():PointerState(NULL) {} + IRSTLMState(const void* lms):PointerState(lms) {} + IRSTLMState(const IRSTLMState& copy_from):PointerState(copy_from.lmstate) {} + + IRSTLMState& operator=( const IRSTLMState& rhs ) + { + lmstate = rhs.lmstate; + return *this; + } + + const void* GetState() const + { + return lmstate; + } +}; + LanguageModelIRST::LanguageModelIRST(const std::string &line) :LanguageModelSingleFactor(line) - ,m_lmtb_dub(0) + ,m_lmtb_dub(0), m_lmtb_size(0) { const StaticData &staticData = StaticData::Instance(); int threadCount = staticData.ThreadCount(); @@ -54,6 +74,10 @@ LanguageModelIRST::LanguageModelIRST(const std::string &line) ReadParameters(); + VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_lmtb_dub:|" << m_lmtb_dub << "|" << std::endl); + VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_filePath:|" << m_filePath << "|" << std::endl); + VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_factorType:|" << m_factorType << "|" << std::endl); + VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_lmtb_size:|" << m_lmtb_size << "|" << std::endl); } LanguageModelIRST::~LanguageModelIRST() @@ -70,17 +94,15 @@ LanguageModelIRST::~LanguageModelIRST() void LanguageModelIRST::Load() { - cerr << "In LanguageModelIRST::Load: nGramOrder = " << m_nGramOrder << "\n"; - FactorCollection &factorCollection = FactorCollection::Instance(); m_lmtb = m_lmtb->CreateLanguageModel(m_filePath); - m_lmtb->setMaxLoadedLevel(1000); + if (m_lmtb_size > 0) m_lmtb->setMaxLoadedLevel(m_lmtb_size); m_lmtb->load(m_filePath); d=m_lmtb->getDict(); d->incflag(1); - m_lmtb_size=m_lmtb->maxlevel(); + m_nGramOrder = m_lmtb_size = m_lmtb->maxlevel(); // LM can be ok, just outputs warnings // Mauro: in the original, the following two instructions are wrongly switched: @@ -89,7 +111,7 @@ void LanguageModelIRST::Load() CreateFactors(factorCollection); - VERBOSE(1, "IRST: m_unknownId=" << m_unknownId << std::endl); + VERBOSE(1, GetScoreProducerDescription() << " LanguageModelIRST::Load() m_unknownId=" << m_unknownId << std::endl); //install caches to save time (only if PS_CACHE_ENABLE is defined through compilation flags) m_lmtb->init_caches(m_lmtb_size>2?m_lmtb_size-1:2); @@ -117,6 +139,8 @@ void LanguageModelIRST::CreateFactors(FactorCollection &factorCollection) m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_); factorId = m_sentenceStart->GetId(); + const std::string bs = BOS_; + const std::string es = EOS_; m_lmtb_sentenceStart=lmIdMap[factorId] = GetLmID(BOS_); maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId; m_sentenceStartWord[m_factorType] = m_sentenceStart; @@ -142,6 +166,11 @@ int LanguageModelIRST::GetLmID( const std::string &str ) const return d->encode( str.c_str() ); // at the level of micro tags } +int LanguageModelIRST::GetLmID( const Word &word ) const +{ + return GetLmID( word.GetFactor(m_factorType) ); +} + int LanguageModelIRST::GetLmID( const Factor *factor ) const { size_t factorId = factor->GetId(); @@ -157,7 +186,7 @@ int LanguageModelIRST::GetLmID( const Factor *factor ) const ///di cui non sia stato ancora calcolato il suo codice target abbia ///comunque un factorID noto (e quindi minore di m_lmIdLookup.size()) ///E' necessario dunque identificare questi casi di indeterminatezza - ///del codice target. Attualamente, questo controllo e' stato implementato + ///del codice target. Attualmente, questo controllo e' stato implementato ///impostando a m_empty tutti i termini che non hanno ancora //ricevuto un codice target effettivo /////////// @@ -198,6 +227,102 @@ int LanguageModelIRST::GetLmID( const Factor *factor ) const } } +void LanguageModelIRST::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const +{ + fullScore = 0; + ngramScore = 0; + oovCount = 0; + + if ( !phrase.GetSize() ) return; + + if ( m_lmtb_size > (int) phrase.GetSize()) return; + + int codes[m_lmtb_size]; + int idx = 0; + int position = 0; + + for (; position < m_lmtb_size; ++position) + { + codes[idx] = GetLmID(phrase.GetWord(position)); + if (codes[idx] == m_unknownId) ++oovCount; + ++idx; + } + + char* msp = NULL; + + ngramScore = m_lmtb->clprob(codes,idx,NULL,NULL,&msp); + int end_loop = (int) phrase.GetSize(); + + for (; position < end_loop; ++position) { + for (idx = 1; idx < m_lmtb_size; ++idx) + { + codes[idx-1] = codes[idx]; + } + codes[idx-1] = GetLmID(phrase.GetWord(position)); + if (codes[idx-1] == m_unknownId) ++oovCount; + ngramScore += m_lmtb->clprob(codes,idx,NULL,NULL,&msp); + } + ngramScore = TransformLMScore(ngramScore); + fullScore = ngramScore; +} + +FFState* LanguageModelIRST::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const +{ + if (!hypo.GetCurrTargetLength()) { + std::auto_ptr ret(new IRSTLMState(ps)); + return ret.release(); + } + + const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos(); + //[begin, end) in STL-like fashion. + const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1; + const std::size_t adjust_end = std::min(end, begin + m_lmtb_size - 1); + + // set up context + int codes[m_lmtb_size]; + int idx=m_lmtb_size-1; + int position = adjust_end-1; + + //fill the farthest positions with at most ONE sentenceEnd symbol and at most ONE sentenceEnd symbol, if "empty" positions are available + //so that the vector looks like = " context_word context_word" for a two-word context and a LM of order 5 + while (position >= (const int) begin) { + codes[idx] = GetLmID(hypo.GetWord(position)); + --idx; + --position; + } + if (idx == 1){ + codes[1] = m_lmtb_sentenceStart; + codes[0] = m_lmtb_sentenceStart; + } + else if (idx == 0) + { + codes[0] = m_lmtb_sentenceEnd; + } + + char* msp = NULL; + float score = m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp); + score = TransformLMScore(score); + out->PlusEquals(this, score); + + if (adjust_end < end) + { + idx=m_lmtb_size-1; + position = end-1; + + while (idx>=0) + { + codes[idx] = GetLmID(hypo.GetWord(position)); + --idx; + --position; + } + msp = (char *) m_lmtb->cmaxsuffptr(codes,m_lmtb_size); + } + + std::auto_ptr ret(new IRSTLMState(msp)); + + return ret.release(); +} + LMResult LanguageModelIRST::GetValue(const vector &contextFactor, State* finalState) const { FactorType factorType = GetFactorType(); @@ -219,14 +344,16 @@ LMResult LanguageModelIRST::GetValue(const vector &contextFactor, S if (count < (size_t) m_lmtb_size) codes[idx++] = m_lmtb_sentenceStart; for (size_t i = 0 ; i < count ; i++) { - codes[idx++] = GetLmID((*contextFactor[i])[factorType]); + //codes[idx] = GetLmID((*contextFactor[i])[factorType]); + codes[idx] = GetLmID(*contextFactor[i]); + ++idx; } + LMResult result; result.unknown = (codes[idx - 1] == m_unknownId); char* msp = NULL; - unsigned int ilen; - result.score = m_lmtb->clprob(codes,idx,NULL,NULL,&msp,&ilen); + result.score = m_lmtb->clprob(codes,idx,NULL,NULL,&msp); if (finalState) *finalState=(State *) msp; @@ -235,7 +362,7 @@ LMResult LanguageModelIRST::GetValue(const vector &contextFactor, S } -bool LMCacheCleanup(size_t sentences_done, size_t m_lmcache_cleanup_threshold) +bool LMCacheCleanup(const int sentences_done, const size_t m_lmcache_cleanup_threshold) { if (sentences_done==-1) return true; if (m_lmcache_cleanup_threshold) @@ -266,5 +393,15 @@ void LanguageModelIRST::CleanUpAfterSentenceProcessing(const InputType& source) } } +void LanguageModelIRST::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "dub") { + m_lmtb_dub = Scan(value); + } else { + LanguageModelSingleFactor::SetParameter(key, value); + } + m_lmtb_size = m_nGramOrder; +} + } diff --git a/moses/LM/IRST.h b/moses/LM/IRST.h index 9d46fc759..35235160b 100644 --- a/moses/LM/IRST.h +++ b/moses/LM/IRST.h @@ -24,10 +24,13 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include #include + #include "moses/Factor.h" +#include "moses/LM/SingleFactor.h" +#include "moses/Hypothesis.h" #include "moses/TypeDef.h" + #include "moses/Util.h" -#include "SingleFactor.h" //this is required because: //- IRSTLM package uses the namespace irstlm @@ -44,6 +47,9 @@ class dictionary; namespace Moses { + +//class LanguageModel; +class FFState; class Phrase; /** Implementation of single factor LM using IRST's code. @@ -59,31 +65,47 @@ protected: int m_empty; //code of an empty position int m_lmtb_sentenceStart; //lmtb symbols to initialize ngram with int m_lmtb_sentenceEnd; //lmt symbol to initialize ngram with - int m_lmtb_size; //max ngram stored in the table int m_lmtb_dub; //dictionary upperboud + int m_lmtb_size; //max ngram stored in the table + + dictionary* d; std::string m_mapFilePath; void CreateFactors(FactorCollection &factorCollection); + + int GetLmID( const Word &word ) const; int GetLmID( const std::string &str ) const; int GetLmID( const Factor *factor ) const; - dictionary* d; public: LanguageModelIRST(const std::string &line); ~LanguageModelIRST(); + + void SetParameter(const std::string& key, const std::string& value); + void Load(); virtual LMResult GetValue(const std::vector &contextFactor, State* finalState = NULL) const; + + virtual FFState *EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; + + virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const; + +/* + virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const; + + virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const; +*/ + void InitializeForInput(InputType const& source); void CleanUpAfterSentenceProcessing(const InputType& source); void set_dictionary_upperbound(int dub) { m_lmtb_size=dub ; -//m_lmtb->set_dictionary_upperbound(dub); }; };