2010-11-17 17:06:21 +03:00
// $Id$
/***********************************************************************
Moses - factored phrase - based language decoder
Copyright ( C ) 2006 University of Edinburgh
This library is free software ; you can redistribute it and / or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation ; either
version 2.1 of the License , or ( at your option ) any later version .
This library is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the GNU
Lesser General Public License for more details .
You should have received a copy of the GNU Lesser General Public
License along with this library ; if not , write to the Free Software
Foundation , Inc . , 51 Franklin Street , Fifth Floor , Boston , MA 02110 - 1301 USA
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2011-11-18 16:07:41 +04:00
# include "util/check.hh"
2010-11-17 17:06:21 +03:00
# include <limits>
# include <iostream>
# include <memory>
# include <sstream>
# include "FFState.h"
2011-10-13 18:27:01 +04:00
# include "LM/Implementation.h"
2010-11-17 17:06:21 +03:00
# include "TypeDef.h"
# include "Util.h"
# include "Manager.h"
# include "FactorCollection.h"
# include "Phrase.h"
# include "StaticData.h"
2011-10-10 15:15:13 +04:00
# include "ChartManager.h"
2011-09-21 20:06:48 +04:00
# include "ChartHypothesis.h"
2010-11-17 17:06:21 +03:00
using namespace std ;
namespace Moses
{
2011-09-21 20:06:48 +04:00
void LanguageModelImplementation : : ShiftOrPush ( std : : vector < const Word * > & contextFactor , const Word & word ) const
{
if ( contextFactor . size ( ) < GetNGramOrder ( ) ) {
contextFactor . push_back ( & word ) ;
} else {
// shift
for ( size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder ( ) - 1 ; currNGramOrder + + ) {
contextFactor [ currNGramOrder ] = contextFactor [ currNGramOrder + 1 ] ;
}
contextFactor [ GetNGramOrder ( ) - 1 ] = & word ;
}
}
2011-03-08 02:21:09 +03:00
LMResult LanguageModelImplementation : : GetValueGivenState (
2011-02-24 16:14:42 +03:00
const std : : vector < const Word * > & contextFactor ,
FFState & state ) const
2010-11-17 17:06:21 +03:00
{
2011-02-24 16:14:42 +03:00
return GetValueForgotState ( contextFactor , state ) ;
2010-11-17 17:06:21 +03:00
}
void LanguageModelImplementation : : GetState (
2011-02-24 16:14:42 +03:00
const std : : vector < const Word * > & contextFactor ,
FFState & state ) const
2010-11-17 17:06:21 +03:00
{
2011-02-24 16:14:42 +03:00
GetValueForgotState ( contextFactor , state ) ;
2010-11-17 17:06:21 +03:00
}
2011-02-24 16:14:42 +03:00
2011-10-11 17:50:44 +04:00
// Calculate score of a phrase.
void LanguageModelImplementation : : CalcScore ( const Phrase & phrase , float & fullScore , float & ngramScore , size_t & oovCount ) const {
fullScore = 0 ;
ngramScore = 0 ;
oovCount = 0 ;
size_t phraseSize = phrase . GetSize ( ) ;
if ( ! phraseSize ) return ;
vector < const Word * > contextFactor ;
contextFactor . reserve ( GetNGramOrder ( ) ) ;
std : : auto_ptr < FFState > state ( NewState ( ( phrase . GetWord ( 0 ) = = GetSentenceStartArray ( ) ) ?
GetBeginSentenceState ( ) : GetNullContextState ( ) ) ) ;
size_t currPos = 0 ;
while ( currPos < phraseSize ) {
const Word & word = phrase . GetWord ( currPos ) ;
if ( word . IsNonTerminal ( ) ) {
// do nothing. reset ngram. needed to score target phrases during pt loading in chart decoding
if ( ! contextFactor . empty ( ) ) {
// TODO: state operator= ?
state . reset ( NewState ( GetNullContextState ( ) ) ) ;
contextFactor . clear ( ) ;
}
} else {
ShiftOrPush ( contextFactor , word ) ;
2011-11-18 16:07:41 +04:00
CHECK ( contextFactor . size ( ) < = GetNGramOrder ( ) ) ;
2011-10-11 17:50:44 +04:00
if ( word = = GetSentenceStartArray ( ) ) {
// do nothing, don't include prob for <s> unigram
2011-11-01 14:24:40 +04:00
if ( currPos ! = 0 ) {
2011-12-01 17:21:55 +04:00
std : : cerr < < " Either your data contains <s> in a position other than the first word or your language model is missing <s>. Did you build your ARPA using IRSTLM and forget to run add-start-end.sh? " < < std : : endl ;
2011-11-01 14:24:40 +04:00
abort ( ) ;
}
2011-10-11 17:50:44 +04:00
} else {
LMResult result = GetValueGivenState ( contextFactor , * state ) ;
fullScore + = result . score ;
if ( contextFactor . size ( ) = = GetNGramOrder ( ) )
ngramScore + = result . score ;
2011-10-12 14:22:45 +04:00
if ( result . unknown ) + + oovCount ;
2011-10-11 17:50:44 +04:00
}
}
currPos + + ;
}
}
FFState * LanguageModelImplementation : : Evaluate ( const Hypothesis & hypo , const FFState * ps , ScoreComponentCollection * out , const LanguageModel * feature ) const {
// In this function, we only compute the LM scores of n-grams that overlap a
// phrase boundary. Phrase-internal scores are taken directly from the
// translation option.
// In the case of unigram language models, there is no overlap, so we don't
// need to do anything.
if ( GetNGramOrder ( ) < = 1 )
return NULL ;
clock_t t = 0 ;
IFVERBOSE ( 2 ) {
t = clock ( ) ; // track time
}
// Empty phrase added? nothing to be done
if ( hypo . GetCurrTargetLength ( ) = = 0 )
return ps ? NewState ( ps ) : NULL ;
const size_t currEndPos = hypo . GetCurrTargetWordsRange ( ) . GetEndPos ( ) ;
const size_t startPos = hypo . GetCurrTargetWordsRange ( ) . GetStartPos ( ) ;
// 1st n-gram
vector < const Word * > contextFactor ( GetNGramOrder ( ) ) ;
size_t index = 0 ;
for ( int currPos = ( int ) startPos - ( int ) GetNGramOrder ( ) + 1 ; currPos < = ( int ) startPos ; currPos + + ) {
if ( currPos > = 0 )
contextFactor [ index + + ] = & hypo . GetWord ( currPos ) ;
else {
contextFactor [ index + + ] = & GetSentenceStartArray ( ) ;
}
}
FFState * res = NewState ( ps ) ;
float lmScore = ps ? GetValueGivenState ( contextFactor , * res ) . score : GetValueForgotState ( contextFactor , * res ) . score ;
// main loop
size_t endPos = std : : min ( startPos + GetNGramOrder ( ) - 2
, currEndPos ) ;
for ( size_t currPos = startPos + 1 ; currPos < = endPos ; currPos + + ) {
// shift all args down 1 place
for ( size_t i = 0 ; i < GetNGramOrder ( ) - 1 ; i + + )
contextFactor [ i ] = contextFactor [ i + 1 ] ;
// add last factor
contextFactor . back ( ) = & hypo . GetWord ( currPos ) ;
lmScore + = GetValueGivenState ( contextFactor , * res ) . score ;
}
// end of sentence
if ( hypo . IsSourceCompleted ( ) ) {
const size_t size = hypo . GetSize ( ) ;
contextFactor . back ( ) = & GetSentenceEndArray ( ) ;
for ( size_t i = 0 ; i < GetNGramOrder ( ) - 1 ; i + + ) {
int currPos = ( int ) ( size - GetNGramOrder ( ) + i + 1 ) ;
if ( currPos < 0 )
contextFactor [ i ] = & GetSentenceStartArray ( ) ;
else
contextFactor [ i ] = & hypo . GetWord ( ( size_t ) currPos ) ;
}
lmScore + = GetValueForgotState ( contextFactor , * res ) . score ;
}
else
{
if ( endPos < currEndPos ) {
//need to get the LM state (otherwise the last LM state is fine)
for ( size_t currPos = endPos + 1 ; currPos < = currEndPos ; currPos + + ) {
for ( size_t i = 0 ; i < GetNGramOrder ( ) - 1 ; i + + )
contextFactor [ i ] = contextFactor [ i + 1 ] ;
contextFactor . back ( ) = & hypo . GetWord ( currPos ) ;
}
GetState ( contextFactor , * res ) ;
}
}
if ( feature - > OOVFeatureEnabled ( ) ) {
vector < float > scores ( 2 ) ;
scores [ 0 ] = lmScore ;
scores [ 1 ] = 0 ;
out - > PlusEquals ( feature , scores ) ;
} else {
out - > PlusEquals ( feature , lmScore ) ;
}
IFVERBOSE ( 2 ) {
hypo . GetManager ( ) . GetSentenceStats ( ) . AddTimeCalcLM ( clock ( ) - t ) ;
}
return res ;
}
2011-10-10 15:15:13 +04:00
namespace {
// This is the FFState used by LanguageModelImplementation::EvaluateChart.
2011-10-10 20:25:56 +04:00
// Though svn blame goes back to heafield, don't blame me. I just moved this from LanguageModelChartState.cpp and ChartHypothesis.cpp.
2011-10-10 15:15:13 +04:00
class LanguageModelChartState : public FFState
{
private :
float m_prefixScore ;
FFState * m_lmRightContext ;
2011-10-10 20:25:56 +04:00
Phrase m_contextPrefix , m_contextSuffix ;
size_t m_numTargetTerminals ; // This isn't really correct except for the surviving hypothesis
const ChartHypothesis & m_hypo ;
/** Construct the prefix string of up to specified size
* \ param ret prefix string
* \ param size maximum size ( typically max lm context window )
*/
size_t CalcPrefix ( const ChartHypothesis & hypo , int featureID , Phrase & ret , size_t size ) const
{
const TargetPhrase & target = hypo . GetCurrTargetPhrase ( ) ;
const AlignmentInfo : : NonTermIndexMap & nonTermIndexMap =
2012-09-27 01:49:33 +04:00
target . GetAlignmentInfo ( ) . GetNonTermIndexMap ( ) ;
2012-09-22 19:09:34 +04:00
2011-10-10 20:25:56 +04:00
// loop over the rule that is being applied
for ( size_t pos = 0 ; pos < target . GetSize ( ) ; + + pos ) {
const Word & word = target . GetWord ( pos ) ;
// for non-terminals, retrieve it from underlying hypothesis
if ( word . IsNonTerminal ( ) ) {
size_t nonTermInd = nonTermIndexMap [ pos ] ;
const ChartHypothesis * prevHypo = hypo . GetPrevHypo ( nonTermInd ) ;
size = static_cast < const LanguageModelChartState * > ( prevHypo - > GetFFState ( featureID ) ) - > CalcPrefix ( * prevHypo , featureID , ret , size ) ;
}
// for words, add word
else {
ret . AddWord ( target . GetWord ( pos ) ) ;
size - - ;
}
// finish when maximum length reached
if ( size = = 0 )
break ;
}
return size ;
}
/** Construct the suffix phrase of up to specified size
* will always be called after the construction of prefix phrase
* \ param ret suffix phrase
* \ param size maximum size of suffix
*/
size_t CalcSuffix ( const ChartHypothesis & hypo , int featureID , Phrase & ret , size_t size ) const
{
2011-11-18 16:07:41 +04:00
CHECK ( m_contextPrefix . GetSize ( ) < = m_numTargetTerminals ) ;
2011-10-10 20:25:56 +04:00
// special handling for small hypotheses
// does the prefix match the entire hypothesis string? -> just copy prefix
if ( m_contextPrefix . GetSize ( ) = = m_numTargetTerminals ) {
size_t maxCount = std : : min ( m_contextPrefix . GetSize ( ) , size ) ;
size_t pos = m_contextPrefix . GetSize ( ) - 1 ;
for ( size_t ind = 0 ; ind < maxCount ; + + ind ) {
const Word & word = m_contextPrefix . GetWord ( pos ) ;
ret . PrependWord ( word ) ;
- - pos ;
}
size - = maxCount ;
return size ;
}
// construct suffix analogous to prefix
else {
2012-09-22 19:09:34 +04:00
const TargetPhrase & target = hypo . GetCurrTargetPhrase ( ) ;
2011-10-10 20:25:56 +04:00
const AlignmentInfo : : NonTermIndexMap & nonTermIndexMap =
2012-09-27 01:49:33 +04:00
target . GetAlignmentInfo ( ) . GetNonTermIndexMap ( ) ;
2012-09-22 19:09:34 +04:00
for ( int pos = ( int ) target . GetSize ( ) - 1 ; pos > = 0 ; - - pos ) {
const Word & word = target . GetWord ( pos ) ;
2011-10-10 20:25:56 +04:00
if ( word . IsNonTerminal ( ) ) {
size_t nonTermInd = nonTermIndexMap [ pos ] ;
const ChartHypothesis * prevHypo = hypo . GetPrevHypo ( nonTermInd ) ;
size = static_cast < const LanguageModelChartState * > ( prevHypo - > GetFFState ( featureID ) ) - > CalcSuffix ( * prevHypo , featureID , ret , size ) ;
}
else {
ret . PrependWord ( hypo . GetCurrTargetPhrase ( ) . GetWord ( pos ) ) ;
size - - ;
}
if ( size = = 0 )
break ;
}
return size ;
}
}
2011-10-10 15:15:13 +04:00
public :
2011-10-10 20:25:56 +04:00
LanguageModelChartState ( const ChartHypothesis & hypo , int featureID , size_t order )
: m_lmRightContext ( NULL )
2011-11-21 14:49:26 +04:00
, m_contextPrefix ( order - 1 )
, m_contextSuffix ( order - 1 )
2011-10-10 20:25:56 +04:00
, m_hypo ( hypo )
{
m_numTargetTerminals = hypo . GetCurrTargetPhrase ( ) . GetNumTerminals ( ) ;
for ( std : : vector < const ChartHypothesis * > : : const_iterator i = hypo . GetPrevHypos ( ) . begin ( ) ; i ! = hypo . GetPrevHypos ( ) . end ( ) ; + + i ) {
// keep count of words (= length of generated string)
m_numTargetTerminals + = static_cast < const LanguageModelChartState * > ( ( * i ) - > GetFFState ( featureID ) ) - > GetNumTargetTerminals ( ) ;
}
CalcPrefix ( hypo , featureID , m_contextPrefix , order - 1 ) ;
CalcSuffix ( hypo , featureID , m_contextSuffix , order - 1 ) ;
}
2011-10-10 15:15:13 +04:00
~ LanguageModelChartState ( ) {
delete m_lmRightContext ;
}
2011-10-10 20:25:56 +04:00
void Set ( float prefixScore , FFState * rightState ) {
m_prefixScore = prefixScore ;
m_lmRightContext = rightState ;
}
2011-10-10 15:15:13 +04:00
float GetPrefixScore ( ) const { return m_prefixScore ; }
FFState * GetRightContext ( ) const { return m_lmRightContext ; }
2011-10-10 20:25:56 +04:00
size_t GetNumTargetTerminals ( ) const {
return m_numTargetTerminals ;
}
const Phrase & GetPrefix ( ) const {
return m_contextPrefix ;
}
const Phrase & GetSuffix ( ) const {
return m_contextSuffix ;
}
2011-10-10 15:15:13 +04:00
int Compare ( const FFState & o ) const {
const LanguageModelChartState & other =
dynamic_cast < const LanguageModelChartState & > ( o ) ;
// prefix
2011-10-10 20:25:56 +04:00
if ( m_hypo . GetCurrSourceRange ( ) . GetStartPos ( ) > 0 ) // not for "<s> ..."
2011-10-10 15:15:13 +04:00
{
2011-10-10 20:25:56 +04:00
int ret = GetPrefix ( ) . Compare ( other . GetPrefix ( ) ) ;
2011-10-10 15:15:13 +04:00
if ( ret ! = 0 )
return ret ;
}
// suffix
2011-10-10 20:25:56 +04:00
size_t inputSize = m_hypo . GetManager ( ) . GetSource ( ) . GetSize ( ) ;
if ( m_hypo . GetCurrSourceRange ( ) . GetEndPos ( ) < inputSize - 1 ) // not for "... </s>"
2011-10-10 15:15:13 +04:00
{
2011-10-10 20:25:56 +04:00
int ret = other . GetRightContext ( ) - > Compare ( * m_lmRightContext ) ;
2011-10-10 15:15:13 +04:00
if ( ret ! = 0 )
return ret ;
}
return 0 ;
}
} ;
} // namespace
2011-09-21 20:06:48 +04:00
FFState * LanguageModelImplementation : : EvaluateChart ( const ChartHypothesis & hypo , int featureID , ScoreComponentCollection * out , const LanguageModel * scorer ) const {
2011-10-10 20:25:56 +04:00
LanguageModelChartState * ret = new LanguageModelChartState ( hypo , featureID , GetNGramOrder ( ) ) ;
2011-09-21 20:06:48 +04:00
// data structure for factored context phrase (history and predicted word)
vector < const Word * > contextFactor ;
contextFactor . reserve ( GetNGramOrder ( ) ) ;
// initialize language model context state
FFState * lmState = NewState ( GetNullContextState ( ) ) ;
// initial language model scores
float prefixScore = 0.0 ; // not yet final for initial words (lack context)
float finalizedScore = 0.0 ; // finalized, has sufficient context
// get index map for underlying hypotheses
2012-09-22 19:09:34 +04:00
const TargetPhrase & target = hypo . GetCurrTargetPhrase ( ) ;
2011-09-21 20:06:48 +04:00
const AlignmentInfo : : NonTermIndexMap & nonTermIndexMap =
2012-09-27 01:49:33 +04:00
hypo . GetCurrTargetPhrase ( ) . GetAlignmentInfo ( ) . GetNonTermIndexMap ( ) ;
2011-09-21 20:06:48 +04:00
// loop over rule
for ( size_t phrasePos = 0 , wordPos = 0 ;
2012-10-05 22:12:29 +04:00
phrasePos < hypo . GetCurrTargetPhrase ( ) . GetSize ( ) ;
2011-09-21 20:06:48 +04:00
phrasePos + + )
{
// consult rule for either word or non-terminal
2012-10-05 22:12:29 +04:00
const Word & word = hypo . GetCurrTargetPhrase ( ) . GetWord ( phrasePos ) ;
2011-09-21 20:06:48 +04:00
// regular word
if ( ! word . IsNonTerminal ( ) )
{
ShiftOrPush ( contextFactor , word ) ;
// beginning of sentence symbol <s>? -> just update state
if ( word = = GetSentenceStartArray ( ) )
2012-06-13 18:56:53 +04:00
{
2011-11-18 16:07:41 +04:00
CHECK ( phrasePos = = 0 ) ;
2011-09-21 20:06:48 +04:00
delete lmState ;
lmState = NewState ( GetBeginSentenceState ( ) ) ;
}
// score a regular word added by the rule
else
{
updateChartScore ( & prefixScore , & finalizedScore , UntransformLMScore ( GetValueGivenState ( contextFactor , * lmState ) . score ) , + + wordPos ) ;
}
}
// non-terminal, add phrase from underlying hypothesis
else
{
// look up underlying hypothesis
size_t nonTermIndex = nonTermIndexMap [ phrasePos ] ;
const ChartHypothesis * prevHypo = hypo . GetPrevHypo ( nonTermIndex ) ;
2011-10-10 20:25:56 +04:00
const LanguageModelChartState * prevState =
static_cast < const LanguageModelChartState * > ( prevHypo - > GetFFState ( featureID ) ) ;
size_t subPhraseLength = prevState - > GetNumTargetTerminals ( ) ;
2011-09-21 20:06:48 +04:00
// special case: rule starts with non-terminal -> copy everything
if ( phrasePos = = 0 ) {
// get prefixScore and finalizedScore
prefixScore = prevState - > GetPrefixScore ( ) ;
finalizedScore = prevHypo - > GetScoreBreakdown ( ) . GetScoresForProducer ( scorer ) [ 0 ] - prefixScore ;
// get language model state
delete lmState ;
lmState = NewState ( prevState - > GetRightContext ( ) ) ;
// push suffix
2011-10-10 20:25:56 +04:00
int suffixPos = prevState - > GetSuffix ( ) . GetSize ( ) - ( GetNGramOrder ( ) - 1 ) ;
2011-09-21 20:06:48 +04:00
if ( suffixPos < 0 ) suffixPos = 0 ; // push all words if less than order
2011-10-10 20:25:56 +04:00
for ( ; ( size_t ) suffixPos < prevState - > GetSuffix ( ) . GetSize ( ) ; suffixPos + + )
2011-09-21 20:06:48 +04:00
{
2011-10-10 20:25:56 +04:00
const Word & word = prevState - > GetSuffix ( ) . GetWord ( suffixPos ) ;
2011-09-21 20:06:48 +04:00
ShiftOrPush ( contextFactor , word ) ;
wordPos + + ;
}
}
// internal non-terminal
else
{
// score its prefix
for ( size_t prefixPos = 0 ;
prefixPos < GetNGramOrder ( ) - 1 // up to LM order window
& & prefixPos < subPhraseLength ; // up to length
prefixPos + + )
{
2011-10-10 20:25:56 +04:00
const Word & word = prevState - > GetPrefix ( ) . GetWord ( prefixPos ) ;
2011-09-21 20:06:48 +04:00
ShiftOrPush ( contextFactor , word ) ;
updateChartScore ( & prefixScore , & finalizedScore , UntransformLMScore ( GetValueGivenState ( contextFactor , * lmState ) . score ) , + + wordPos ) ;
}
// check if we are dealing with a large sub-phrase
if ( subPhraseLength > GetNGramOrder ( ) - 1 )
{
// add its finalized language model score
finalizedScore + =
prevHypo - > GetScoreBreakdown ( ) . GetScoresForProducer ( scorer ) [ 0 ] // full score
- prevState - > GetPrefixScore ( ) ; // - prefix score
// copy language model state
delete lmState ;
lmState = NewState ( prevState - > GetRightContext ( ) ) ;
// push its suffix
size_t remainingWords = subPhraseLength - ( GetNGramOrder ( ) - 1 ) ;
if ( remainingWords > GetNGramOrder ( ) - 1 ) {
// only what is needed for the history window
remainingWords = GetNGramOrder ( ) - 1 ;
}
2011-10-10 20:25:56 +04:00
for ( size_t suffixPos = prevState - > GetSuffix ( ) . GetSize ( ) - remainingWords ;
suffixPos < prevState - > GetSuffix ( ) . GetSize ( ) ;
2011-09-21 20:06:48 +04:00
suffixPos + + ) {
2011-10-10 20:25:56 +04:00
const Word & word = prevState - > GetSuffix ( ) . GetWord ( suffixPos ) ;
2011-09-21 20:06:48 +04:00
ShiftOrPush ( contextFactor , word ) ;
}
wordPos + = subPhraseLength ;
}
}
}
}
// assign combined score to score breakdown
out - > Assign ( scorer , prefixScore + finalizedScore ) ;
2011-10-10 20:25:56 +04:00
ret - > Set ( prefixScore , lmState ) ;
return ret ;
2011-09-21 20:06:48 +04:00
}
2011-10-11 17:50:44 +04:00
void LanguageModelImplementation : : updateChartScore ( float * prefixScore , float * finalizedScore , float score , size_t wordPos ) const {
2011-09-21 20:06:48 +04:00
if ( wordPos < GetNGramOrder ( ) ) {
* prefixScore + = score ;
}
else {
* finalizedScore + = score ;
}
}
2010-11-17 17:06:21 +03:00
}