2010-09-22 02:43:29 +04:00
// $Id$
/***********************************************************************
Moses - factored phrase - based language decoder
Copyright ( C ) 2006 University of Edinburgh
This library is free software ; you can redistribute it and / or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation ; either
version 2.1 of the License , or ( at your option ) any later version .
This library is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the GNU
Lesser General Public License for more details .
You should have received a copy of the GNU Lesser General Public
License along with this library ; if not , write to the Free Software
Foundation , Inc . , 51 Franklin Street , Fifth Floor , Boston , MA 02110 - 1301 USA
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2010-10-27 21:50:40 +04:00
# include <cstring>
2010-10-28 05:05:04 +04:00
# include <iostream>
2011-10-12 23:49:27 +04:00
# include <memory>
2011-09-06 13:47:08 +04:00
# include <stdlib.h>
2010-10-28 05:05:04 +04:00
# include "lm/binary_format.hh"
2010-10-27 21:50:40 +04:00
# include "lm/enumerate_vocab.hh"
2011-10-13 16:33:05 +04:00
# include "lm/left.hh"
2010-11-06 03:40:16 +03:00
# include "lm/model.hh"
2010-09-22 02:43:29 +04:00
2011-10-13 18:27:01 +04:00
# include "LM/Ken.h"
# include "LM/Base.h"
2010-10-28 05:05:04 +04:00
# include "FFState.h"
2010-09-22 02:43:29 +04:00
# include "TypeDef.h"
# include "Util.h"
# include "FactorCollection.h"
# include "Phrase.h"
# include "InputFileStream.h"
# include "StaticData.h"
2011-09-21 20:06:48 +04:00
# include "ChartHypothesis.h"
2010-09-22 02:43:29 +04:00
2011-10-17 13:30:30 +04:00
# include <boost/shared_ptr.hpp>
2011-10-13 16:33:05 +04:00
2010-09-22 02:43:29 +04:00
using namespace std ;
2011-10-12 23:49:27 +04:00
namespace Moses {
namespace {
2010-11-06 03:40:16 +03:00
struct KenLMState : public FFState {
2011-02-24 16:14:42 +03:00
lm : : ngram : : State state ;
int Compare ( const FFState & o ) const {
const KenLMState & other = static_cast < const KenLMState & > ( o ) ;
2011-09-21 20:06:48 +04:00
if ( state . length < other . state . length ) return - 1 ;
if ( state . length > other . state . length ) return 1 ;
return std : : memcmp ( state . words , other . state . words , sizeof ( lm : : WordIndex ) * state . length ) ;
2011-02-24 16:14:42 +03:00
}
2010-11-06 03:40:16 +03:00
} ;
2011-09-21 20:06:48 +04:00
/*
* An implementation of single factor LM using Ken ' s code .
*/
2011-10-13 16:33:05 +04:00
template < class Model > class LanguageModelKen : public LanguageModel {
2011-10-12 23:49:27 +04:00
public :
2011-10-13 16:33:05 +04:00
LanguageModelKen ( const std : : string & file , ScoreIndexManager & manager , FactorType factorType , bool lazy ) ;
2011-08-24 14:45:41 +04:00
2011-10-13 16:33:05 +04:00
LanguageModel * Duplicate ( ScoreIndexManager & scoreIndexManager ) const ;
2010-10-28 05:05:04 +04:00
2011-10-13 16:33:05 +04:00
bool Useable ( const Phrase & phrase ) const {
return ( phrase . GetSize ( ) > 0 & & phrase . GetFactor ( 0 , m_factorType ) ! = NULL ) ;
}
2010-10-28 05:05:04 +04:00
2011-10-13 16:33:05 +04:00
std : : string GetScoreProducerDescription ( unsigned ) const {
std : : ostringstream oss ;
2012-01-13 19:07:30 +04:00
oss < < " LM_ " < < ( unsigned ) m_ngram - > Order ( ) < < " gram " ;
2011-10-13 16:33:05 +04:00
return oss . str ( ) ;
}
2010-10-28 05:05:04 +04:00
2011-10-13 16:33:05 +04:00
const FFState * EmptyHypothesisState ( const InputType & /*input*/ ) const {
KenLMState * ret = new KenLMState ( ) ;
ret - > state = m_ngram - > BeginSentenceState ( ) ;
return ret ;
}
2010-10-28 05:05:04 +04:00
2011-10-13 16:33:05 +04:00
void CalcScore ( const Phrase & phrase , float & fullScore , float & ngramScore , size_t & oovCount ) const ;
2011-10-12 17:04:12 +04:00
2011-10-13 16:33:05 +04:00
FFState * Evaluate ( const Hypothesis & hypo , const FFState * ps , ScoreComponentCollection * out ) const ;
2010-10-28 05:05:04 +04:00
2011-10-13 16:33:05 +04:00
FFState * EvaluateChart ( const ChartHypothesis & cur_hypo , int featureID , ScoreComponentCollection * accumulator ) const ;
2011-09-21 20:06:48 +04:00
2011-10-12 23:49:27 +04:00
private :
2011-10-13 16:33:05 +04:00
LanguageModelKen ( ScoreIndexManager & manager , const LanguageModelKen < Model > & copy_from ) ;
2011-09-21 20:06:48 +04:00
2011-10-12 23:49:27 +04:00
lm : : WordIndex TranslateID ( const Word & word ) const {
2011-10-13 16:33:05 +04:00
std : : size_t factor = word . GetFactor ( m_factorType ) - > GetId ( ) ;
2011-10-12 23:49:27 +04:00
return ( factor > = m_lmIdLookup . size ( ) ? 0 : m_lmIdLookup [ factor ] ) ;
2011-09-21 20:06:48 +04:00
}
2011-10-12 17:04:12 +04:00
2011-10-12 23:49:27 +04:00
// Convert last words of hypothesis into vocab ids, returning an end pointer.
lm : : WordIndex * LastIDs ( const Hypothesis & hypo , lm : : WordIndex * indices ) const {
lm : : WordIndex * index = indices ;
lm : : WordIndex * end = indices + m_ngram - > Order ( ) - 1 ;
int position = hypo . GetCurrTargetWordsRange ( ) . GetEndPos ( ) ;
for ( ; ; + + index , - - position ) {
2011-12-21 08:12:35 +04:00
if ( index = = end ) return index ;
2011-10-12 23:49:27 +04:00
if ( position = = - 1 ) {
* index = m_ngram - > GetVocabulary ( ) . BeginSentence ( ) ;
return index + 1 ;
}
* index = TranslateID ( hypo . GetWord ( position ) ) ;
}
}
2010-10-27 21:50:40 +04:00
2011-10-13 16:33:05 +04:00
boost : : shared_ptr < Model > m_ngram ;
2011-10-17 13:30:30 +04:00
2011-10-13 16:33:05 +04:00
std : : vector < lm : : WordIndex > m_lmIdLookup ;
2010-09-22 02:43:29 +04:00
2011-10-13 16:33:05 +04:00
FactorType m_factorType ;
const Factor * m_beginSentenceFactor ;
} ;
2010-09-22 02:43:29 +04:00
2011-10-12 23:49:27 +04:00
class MappingBuilder : public lm : : EnumerateVocab {
public :
MappingBuilder ( FactorCollection & factorCollection , std : : vector < lm : : WordIndex > & mapping )
: m_factorCollection ( factorCollection ) , m_mapping ( mapping ) { }
void Add ( lm : : WordIndex index , const StringPiece & str ) {
2011-10-13 17:32:14 +04:00
std : : size_t factorId = m_factorCollection . AddFactor ( str ) - > GetId ( ) ;
2011-10-12 23:49:27 +04:00
if ( m_mapping . size ( ) < = factorId ) {
// 0 is <unk> :-)
m_mapping . resize ( factorId + 1 ) ;
}
m_mapping [ factorId ] = index ;
}
private :
FactorCollection & m_factorCollection ;
std : : vector < lm : : WordIndex > & m_mapping ;
} ;
2011-10-13 16:33:05 +04:00
template < class Model > LanguageModelKen < Model > : : LanguageModelKen ( const std : : string & file , ScoreIndexManager & manager , FactorType factorType , bool lazy ) : m_factorType ( factorType ) {
2011-02-24 16:14:42 +03:00
lm : : ngram : : Config config ;
IFVERBOSE ( 1 ) {
config . messages = & std : : cerr ;
2011-10-13 16:33:05 +04:00
} else {
2011-02-24 16:14:42 +03:00
config . messages = NULL ;
}
2011-10-13 16:33:05 +04:00
FactorCollection & collection = FactorCollection : : Instance ( ) ;
MappingBuilder builder ( collection , m_lmIdLookup ) ;
2011-02-24 16:14:42 +03:00
config . enumerate_vocab = & builder ;
2011-10-13 16:33:05 +04:00
config . load_method = lazy ? util : : LAZY : util : : POPULATE_OR_READ ;
2011-02-24 16:14:42 +03:00
2011-10-17 13:30:30 +04:00
m_ngram . reset ( new Model ( file . c_str ( ) , config ) ) ;
2011-02-24 16:14:42 +03:00
2011-10-13 16:33:05 +04:00
m_beginSentenceFactor = collection . AddFactor ( BOS_ ) ;
Init ( manager ) ;
2010-09-22 02:43:29 +04:00
}
2011-10-13 16:33:05 +04:00
template < class Model > LanguageModel * LanguageModelKen < Model > : : Duplicate ( ScoreIndexManager & manager ) const {
return new LanguageModelKen < Model > ( manager , * this ) ;
2010-10-27 21:50:40 +04:00
}
2010-09-22 02:43:29 +04:00
2011-10-13 16:33:05 +04:00
template < class Model > LanguageModelKen < Model > : : LanguageModelKen ( ScoreIndexManager & manager , const LanguageModelKen < Model > & copy_from ) :
m_ngram ( copy_from . m_ngram ) ,
// TODO: don't copy this.
m_lmIdLookup ( copy_from . m_lmIdLookup ) ,
m_factorType ( copy_from . m_factorType ) ,
m_beginSentenceFactor ( copy_from . m_beginSentenceFactor ) {
Init ( manager ) ;
2010-09-22 02:43:29 +04:00
}
2011-10-13 16:33:05 +04:00
template < class Model > void LanguageModelKen < Model > : : CalcScore ( const Phrase & phrase , float & fullScore , float & ngramScore , size_t & oovCount ) const {
fullScore = 0 ;
ngramScore = 0 ;
oovCount = 0 ;
2010-10-27 21:50:40 +04:00
2011-10-13 16:33:05 +04:00
if ( ! phrase . GetSize ( ) ) return ;
2010-11-17 17:06:21 +03:00
2011-10-13 16:33:05 +04:00
typename Model : : State state_backing [ 2 ] ;
typename Model : : State * state0 = & state_backing [ 0 ] , * state1 = & state_backing [ 1 ] ;
size_t position ;
if ( m_beginSentenceFactor = = phrase . GetWord ( 0 ) . GetFactor ( m_factorType ) ) {
* state0 = m_ngram - > BeginSentenceState ( ) ;
position = 1 ;
} else {
* state0 = m_ngram - > NullContextState ( ) ;
position = 0 ;
}
size_t ngramBoundary = m_ngram - > Order ( ) - 1 ;
2010-11-17 17:06:21 +03:00
2011-10-13 16:33:05 +04:00
for ( ; position < phrase . GetSize ( ) ; + + position ) {
const Word & word = phrase . GetWord ( position ) ;
if ( word . IsNonTerminal ( ) ) {
2011-10-30 21:56:42 +04:00
// If there's a non-terminal at 1 and we have a 5-gram LM, then positions 2 3 4 and 5 will be incomplete while position 6 is complete.
ngramBoundary = m_ngram - > Order ( ) + position ;
2011-10-13 16:33:05 +04:00
* state0 = m_ngram - > NullContextState ( ) ;
} else {
lm : : WordIndex index = TranslateID ( word ) ;
2011-11-01 14:24:40 +04:00
if ( index = = m_ngram - > GetVocabulary ( ) . BeginSentence ( ) ) {
2011-12-01 17:21:55 +04:00
std : : cerr < < " Either your data contains <s> in a position other than the first word or your language model is missing <s>. Did you build your ARPA using IRSTLM and forget to run add-start-end.sh? " < < std : : endl ;
2011-11-01 14:24:40 +04:00
abort ( ) ;
}
2011-10-13 16:33:05 +04:00
float score = TransformLMScore ( m_ngram - > Score ( * state0 , index , * state1 ) ) ;
std : : swap ( state0 , state1 ) ;
if ( position > = ngramBoundary ) ngramScore + = score ;
fullScore + = score ;
if ( ! index ) + + oovCount ;
}
2011-02-24 16:14:42 +03:00
}
2010-10-27 21:50:40 +04:00
}
2011-10-13 16:33:05 +04:00
template < class Model > FFState * LanguageModelKen < Model > : : Evaluate ( const Hypothesis & hypo , const FFState * ps , ScoreComponentCollection * out ) const {
2011-10-12 23:49:27 +04:00
const lm : : ngram : : State & in_state = static_cast < const KenLMState & > ( * ps ) . state ;
std : : auto_ptr < KenLMState > ret ( new KenLMState ( ) ) ;
if ( ! hypo . GetCurrTargetLength ( ) ) {
ret - > state = in_state ;
return ret . release ( ) ;
}
const std : : size_t begin = hypo . GetCurrTargetWordsRange ( ) . GetStartPos ( ) ;
//[begin, end) in STL-like fashion.
const std : : size_t end = hypo . GetCurrTargetWordsRange ( ) . GetEndPos ( ) + 1 ;
const std : : size_t adjust_end = std : : min ( end , begin + m_ngram - > Order ( ) - 1 ) ;
std : : size_t position = begin ;
typename Model : : State aux_state ;
typename Model : : State * state0 = & ret - > state , * state1 = & aux_state ;
float score = m_ngram - > Score ( in_state , TranslateID ( hypo . GetWord ( position ) ) , * state0 ) ;
+ + position ;
for ( ; position < adjust_end ; + + position ) {
score + = m_ngram - > Score ( * state0 , TranslateID ( hypo . GetWord ( position ) ) , * state1 ) ;
std : : swap ( state0 , state1 ) ;
}
if ( hypo . IsSourceCompleted ( ) ) {
// Score end of sentence.
std : : vector < lm : : WordIndex > indices ( m_ngram - > Order ( ) - 1 ) ;
const lm : : WordIndex * last = LastIDs ( hypo , & indices . front ( ) ) ;
score + = m_ngram - > FullScoreForgotState ( & indices . front ( ) , last , m_ngram - > GetVocabulary ( ) . EndSentence ( ) , ret - > state ) . prob ;
} else if ( adjust_end < end ) {
// Get state after adding a long phrase.
std : : vector < lm : : WordIndex > indices ( m_ngram - > Order ( ) - 1 ) ;
const lm : : WordIndex * last = LastIDs ( hypo , & indices . front ( ) ) ;
m_ngram - > GetState ( & indices . front ( ) , last , ret - > state ) ;
} else if ( state0 ! = & ret - > state ) {
// Short enough phrase that we can just reuse the state.
ret - > state = * state0 ;
}
score = TransformLMScore ( score ) ;
2011-10-13 16:33:05 +04:00
if ( OOVFeatureEnabled ( ) ) {
2011-10-12 23:49:27 +04:00
std : : vector < float > scores ( 2 ) ;
scores [ 0 ] = score ;
scores [ 1 ] = 0.0 ;
2011-10-13 16:33:05 +04:00
out - > PlusEquals ( this , scores ) ;
2011-10-12 23:49:27 +04:00
} else {
2011-10-13 16:33:05 +04:00
out - > PlusEquals ( this , score ) ;
2011-10-12 23:49:27 +04:00
}
return ret . release ( ) ;
}
class LanguageModelChartStateKenLM : public FFState {
public :
LanguageModelChartStateKenLM ( ) { }
const lm : : ngram : : ChartState & GetChartState ( ) const { return m_state ; }
lm : : ngram : : ChartState & GetChartState ( ) { return m_state ; }
int Compare ( const FFState & o ) const
{
const LanguageModelChartStateKenLM & other = static_cast < const LanguageModelChartStateKenLM & > ( o ) ;
int ret = m_state . Compare ( other . m_state ) ;
return ret ;
}
private :
lm : : ngram : : ChartState m_state ;
} ;
2011-10-13 16:33:05 +04:00
template < class Model > FFState * LanguageModelKen < Model > : : EvaluateChart ( const ChartHypothesis & hypo , int featureID , ScoreComponentCollection * accumulator ) const {
2011-10-12 23:49:27 +04:00
LanguageModelChartStateKenLM * newState = new LanguageModelChartStateKenLM ( ) ;
lm : : ngram : : RuleScore < Model > ruleScore ( * m_ngram , newState - > GetChartState ( ) ) ;
2012-10-05 22:12:29 +04:00
const AlignmentInfo : : NonTermIndexMap & nonTermIndexMap = hypo . GetCurrTargetPhrase ( ) . GetAlignmentInfo ( ) . GetNonTermIndexMap ( ) ;
2011-10-12 23:49:27 +04:00
const size_t size = hypo . GetCurrTargetPhrase ( ) . GetSize ( ) ;
size_t phrasePos = 0 ;
// Special cases for first word.
if ( size ) {
const Word & word = hypo . GetCurrTargetPhrase ( ) . GetWord ( 0 ) ;
2011-10-13 16:33:05 +04:00
if ( word . GetFactor ( m_factorType ) = = m_beginSentenceFactor ) {
2011-10-12 23:49:27 +04:00
// Begin of sentence
ruleScore . BeginSentence ( ) ;
phrasePos + + ;
} else if ( word . IsNonTerminal ( ) ) {
// Non-terminal is first so we can copy instead of rescoring.
2012-10-05 22:12:29 +04:00
const ChartHypothesis * prevHypo = hypo . GetPrevHypo ( nonTermIndexMap [ phrasePos ] ) ;
2011-10-12 23:49:27 +04:00
const lm : : ngram : : ChartState & prevState = static_cast < const LanguageModelChartStateKenLM * > ( prevHypo - > GetFFState ( featureID ) ) - > GetChartState ( ) ;
2011-10-13 16:33:05 +04:00
ruleScore . BeginNonTerminal ( prevState , prevHypo - > GetScoreBreakdown ( ) . GetScoresForProducer ( this ) [ 0 ] ) ;
2011-10-12 23:49:27 +04:00
phrasePos + + ;
}
}
for ( ; phrasePos < size ; phrasePos + + ) {
const Word & word = hypo . GetCurrTargetPhrase ( ) . GetWord ( phrasePos ) ;
if ( word . IsNonTerminal ( ) ) {
2012-10-05 22:12:29 +04:00
const ChartHypothesis * prevHypo = hypo . GetPrevHypo ( nonTermIndexMap [ phrasePos ] ) ;
2011-10-12 23:49:27 +04:00
const lm : : ngram : : ChartState & prevState = static_cast < const LanguageModelChartStateKenLM * > ( prevHypo - > GetFFState ( featureID ) ) - > GetChartState ( ) ;
2011-10-13 16:33:05 +04:00
ruleScore . NonTerminal ( prevState , prevHypo - > GetScoreBreakdown ( ) . GetScoresForProducer ( this ) [ 0 ] ) ;
2011-10-12 23:49:27 +04:00
} else {
ruleScore . Terminal ( TranslateID ( word ) ) ;
}
}
2011-10-13 16:33:05 +04:00
accumulator - > Assign ( this , ruleScore . Finish ( ) ) ;
2011-10-12 23:49:27 +04:00
return newState ;
}
2010-10-28 05:05:04 +04:00
} // namespace
2011-10-13 16:33:05 +04:00
LanguageModel * ConstructKenLM ( const std : : string & file , ScoreIndexManager & manager , FactorType factorType , bool lazy ) {
2011-10-17 13:30:30 +04:00
try {
lm : : ngram : : ModelType model_type ;
if ( lm : : ngram : : RecognizeBinary ( file . c_str ( ) , model_type ) ) {
switch ( model_type ) {
2012-06-28 18:58:59 +04:00
case lm : : ngram : : PROBING :
2011-10-17 13:30:30 +04:00
return new LanguageModelKen < lm : : ngram : : ProbingModel > ( file , manager , factorType , lazy ) ;
2012-06-28 18:58:59 +04:00
case lm : : ngram : : REST_PROBING :
return new LanguageModelKen < lm : : ngram : : RestProbingModel > ( file , manager , factorType , lazy ) ;
case lm : : ngram : : TRIE :
2011-10-17 13:30:30 +04:00
return new LanguageModelKen < lm : : ngram : : TrieModel > ( file , manager , factorType , lazy ) ;
2012-06-28 18:58:59 +04:00
case lm : : ngram : : QUANT_TRIE :
2011-10-17 13:30:30 +04:00
return new LanguageModelKen < lm : : ngram : : QuantTrieModel > ( file , manager , factorType , lazy ) ;
2012-06-28 18:58:59 +04:00
case lm : : ngram : : ARRAY_TRIE :
2011-10-17 13:30:30 +04:00
return new LanguageModelKen < lm : : ngram : : ArrayTrieModel > ( file , manager , factorType , lazy ) ;
2012-06-28 18:58:59 +04:00
case lm : : ngram : : QUANT_ARRAY_TRIE :
2011-10-17 13:30:30 +04:00
return new LanguageModelKen < lm : : ngram : : QuantArrayTrieModel > ( file , manager , factorType , lazy ) ;
default :
std : : cerr < < " Unrecognized kenlm model type " < < model_type < < std : : endl ;
abort ( ) ;
}
} else {
2011-10-13 16:33:05 +04:00
return new LanguageModelKen < lm : : ngram : : ProbingModel > ( file , manager , factorType , lazy ) ;
2011-02-24 16:14:42 +03:00
}
2011-10-17 13:30:30 +04:00
} catch ( std : : exception & e ) {
std : : cerr < < e . what ( ) < < std : : endl ;
abort ( ) ;
2011-02-24 16:14:42 +03:00
}
2010-10-28 05:05:04 +04:00
}
2010-09-22 02:43:29 +04:00
}