2010-11-06 03:40:16 +03:00
# include "lm/model.hh"
2010-09-10 04:36:07 +04:00
2011-01-25 22:11:48 +03:00
# include "lm/blank.hh"
2010-09-28 20:26:55 +04:00
# include "lm/lm_exception.hh"
2010-11-06 03:40:16 +03:00
# include "lm/search_hashed.hh"
# include "lm/search_trie.hh"
2010-09-28 20:26:55 +04:00
# include "lm/read_arpa.hh"
# include "util/murmur_hash.hh"
2010-09-10 04:36:07 +04:00
# include <algorithm>
# include <functional>
# include <numeric>
# include <cmath>
namespace lm {
namespace ngram {
2010-09-16 23:53:33 +04:00
size_t hash_value ( const State & state ) {
return util : : MurmurHashNative ( state . history_ , sizeof ( WordIndex ) * state . valid_length_ ) ;
}
2010-09-27 07:46:44 +04:00
namespace detail {
2011-07-14 00:53:18 +04:00
template < class Search , class VocabularyT > const ModelType GenericModel < Search , VocabularyT > : : kModelType = Search : : kModelType ;
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > size_t GenericModel < Search , VocabularyT > : : Size ( const std : : vector < uint64_t > & counts , const Config & config ) {
return VocabularyT : : Size ( counts [ 0 ] , config ) + Search : : Size ( counts , config ) ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : SetupMemory ( void * base , const std : : vector < uint64_t > & counts , const Config & config ) {
uint8_t * start = static_cast < uint8_t * > ( base ) ;
size_t allocated = VocabularyT : : Size ( counts [ 0 ] , config ) ;
vocab_ . SetupMemory ( start , allocated , counts [ 0 ] , config ) ;
2010-09-15 01:33:11 +04:00
start + = allocated ;
2010-10-27 21:50:40 +04:00
start = search_ . SetupMemory ( start , counts , config ) ;
if ( static_cast < std : : size_t > ( start - static_cast < uint8_t * > ( base ) ) ! = Size ( counts , config ) ) UTIL_THROW ( FormatLoadException , " The data structures took " < < ( start - static_cast < uint8_t * > ( base ) ) < < " but Size says they should take " < < Size ( counts , config ) ) ;
2010-09-15 01:33:11 +04:00
}
2010-09-10 04:36:07 +04:00
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > GenericModel < Search , VocabularyT > : : GenericModel ( const char * file , const Config & config ) {
LoadLM ( file , config , * this ) ;
2010-09-10 04:36:07 +04:00
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State ( ) ;
begin_sentence . valid_length_ = 1 ;
begin_sentence . history_ [ 0 ] = vocab_ . BeginSentence ( ) ;
2010-10-27 21:50:40 +04:00
begin_sentence . backoff_ [ 0 ] = search_ . unigram . Lookup ( begin_sentence . history_ [ 0 ] ) . backoff ;
2010-09-10 04:36:07 +04:00
State null_context = State ( ) ;
null_context . valid_length_ = 0 ;
2011-06-27 02:21:44 +04:00
P : : Init ( begin_sentence , null_context , vocab_ , search_ . MiddleEnd ( ) - search_ . MiddleBegin ( ) + 2 ) ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : InitializeFromBinary ( void * start , const Parameters & params , const Config & config , int fd ) {
SetupMemory ( start , params . counts , config ) ;
vocab_ . LoadedBinary ( fd , config . enumerate_vocab ) ;
2011-06-27 02:21:44 +04:00
search_ . LoadedBinary ( ) ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : InitializeFromARPA ( const char * file , const Config & config ) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util : : FilePiece f ( backing_ . file . release ( ) , file , config . messages ) ;
2011-08-16 16:57:21 +04:00
try {
std : : vector < uint64_t > counts ;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts ( f , counts ) ;
if ( counts . size ( ) > kMaxOrder ) UTIL_THROW ( FormatLoadException , " This model has order " < < counts . size ( ) < < " . Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile. " ) ;
if ( counts . size ( ) < 2 ) UTIL_THROW ( FormatLoadException , " This ngram implementation assumes at least a bigram model. " ) ;
if ( config . probing_multiplier < = 1.0 ) UTIL_THROW ( ConfigException , " probing multiplier must be > 1.0 " ) ;
std : : size_t vocab_size = VocabularyT : : Size ( counts [ 0 ] , config ) ;
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_ . SetupMemory ( SetupJustVocab ( config , counts . size ( ) , vocab_size , backing_ ) , vocab_size , counts [ 0 ] , config ) ;
if ( config . write_mmap ) {
WriteWordsWrapper wrap ( config . enumerate_vocab ) ;
vocab_ . ConfigureEnumerate ( & wrap , counts [ 0 ] ) ;
search_ . InitializeFromARPA ( file , f , counts , config , vocab_ , backing_ ) ;
wrap . Write ( backing_ . file . get ( ) ) ;
} else {
vocab_ . ConfigureEnumerate ( config . enumerate_vocab , counts [ 0 ] ) ;
search_ . InitializeFromARPA ( file , f , counts , config , vocab_ , backing_ ) ;
}
if ( ! vocab_ . SawUnk ( ) ) {
assert ( config . unknown_missing ! = THROW_UP ) ;
// Default probabilities for unknown.
search_ . unigram . Unknown ( ) . backoff = 0.0 ;
search_ . unigram . Unknown ( ) . prob = config . unknown_missing_logprob ;
}
FinishFile ( config , kModelType , counts , backing_ ) ;
} catch ( util : : Exception & e ) {
e < < " Byte: " < < f . Offset ( ) ;
throw ;
2010-10-27 21:50:40 +04:00
}
}
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : FullScore ( const State & in_state , const WordIndex new_word , State & out_state ) const {
2011-01-25 22:11:48 +03:00
FullScoreReturn ret = ScoreExceptBackoff ( in_state . history_ , in_state . history_ + in_state . valid_length_ , new_word , out_state ) ;
if ( ret . ngram_length - 1 < in_state . valid_length_ ) {
ret . prob = std : : accumulate ( in_state . backoff_ + ret . ngram_length - 1 , in_state . backoff_ + in_state . valid_length_ , ret . prob ) ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
return ret ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : FullScoreForgotState ( const WordIndex * context_rbegin , const WordIndex * context_rend , const WordIndex new_word , State & out_state ) const {
context_rend = std : : min ( context_rend , context_rbegin + P : : Order ( ) - 1 ) ;
2011-01-25 22:11:48 +03:00
FullScoreReturn ret = ScoreExceptBackoff ( context_rbegin , context_rend , new_word , out_state ) ;
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
unsigned char start = ret . ngram_length ;
if ( context_rend - context_rbegin < static_cast < std : : ptrdiff_t > ( start ) ) return ret ;
if ( start < = 1 ) {
ret . prob + = search_ . unigram . Lookup ( * context_rbegin ) . backoff ;
start = 2 ;
}
typename Search : : Node node ;
if ( ! search_ . FastMakeNode ( context_rbegin , context_rbegin + start - 1 , node ) ) {
return ret ;
}
float backoff ;
// i is the order of the backoff we're looking for.
2011-06-27 02:21:44 +04:00
const Middle * mid_iter = search_ . MiddleBegin ( ) + start - 2 ;
for ( const WordIndex * i = context_rbegin + start - 1 ; i < context_rend ; + + i , + + mid_iter ) {
if ( ! search_ . LookupMiddleNoProb ( * mid_iter , * i , backoff , node ) ) break ;
2011-01-25 22:11:48 +03:00
ret . prob + = backoff ;
}
2010-10-27 21:50:40 +04:00
return ret ;
}
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : GetState ( const WordIndex * context_rbegin , const WordIndex * context_rend , State & out_state ) const {
2011-01-25 22:11:48 +03:00
// Generate a state from context.
2010-10-27 21:50:40 +04:00
context_rend = std : : min ( context_rend , context_rbegin + P : : Order ( ) - 1 ) ;
2011-01-25 22:11:48 +03:00
if ( context_rend = = context_rbegin ) {
2010-10-27 21:50:40 +04:00
out_state . valid_length_ = 0 ;
return ;
}
float ignored_prob ;
typename Search : : Node node ;
search_ . LookupUnigram ( * context_rbegin , ignored_prob , out_state . backoff_ [ 0 ] , node ) ;
2011-01-25 22:11:48 +03:00
out_state . valid_length_ = HasExtension ( out_state . backoff_ [ 0 ] ) ? 1 : 0 ;
2010-10-27 21:50:40 +04:00
float * backoff_out = out_state . backoff_ + 1 ;
2011-06-27 02:21:44 +04:00
const typename Search : : Middle * mid = search_ . MiddleBegin ( ) ;
2011-01-25 22:11:48 +03:00
for ( const WordIndex * i = context_rbegin + 1 ; i < context_rend ; + + i , + + backoff_out , + + mid ) {
if ( ! search_ . LookupMiddleNoProb ( * mid , * i , * backoff_out , node ) ) {
std : : copy ( context_rbegin , context_rbegin + out_state . valid_length_ , out_state . history_ ) ;
2010-10-27 21:50:40 +04:00
return ;
}
2011-01-25 22:11:48 +03:00
if ( HasExtension ( * backoff_out ) ) out_state . valid_length_ = i - context_rbegin + 1 ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
std : : copy ( context_rbegin , context_rbegin + out_state . valid_length_ , out_state . history_ ) ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.valid_length_ could be zero so I avoided using
// std::copy.
void CopyRemainingHistory ( const WordIndex * from , State & out_state ) {
WordIndex * out = out_state . history_ + 1 ;
const WordIndex * in_end = from + static_cast < ptrdiff_t > ( out_state . valid_length_ ) - 1 ;
for ( const WordIndex * in = from ; in < in_end ; + + in , + + out ) * out = * in ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
} // namespace
2010-10-27 21:50:40 +04:00
/* Ugly optimized function. Produce a score excluding backoff.
2010-09-10 04:36:07 +04:00
* The search goes in increasing order of ngram length .
2010-10-27 21:50:40 +04:00
* Context goes backward , so context_begin is the word immediately preceeding
* new_word .
2010-09-10 04:36:07 +04:00
*/
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : ScoreExceptBackoff (
const WordIndex * context_rbegin ,
const WordIndex * context_rend ,
2010-09-10 04:36:07 +04:00
const WordIndex new_word ,
State & out_state ) const {
FullScoreReturn ret ;
2011-01-25 22:11:48 +03:00
// ret.ngram_length contains the last known non-blank ngram length.
ret . ngram_length = 1 ;
2010-10-27 21:50:40 +04:00
typename Search : : Node node ;
float * backoff_out ( out_state . backoff_ ) ;
search_ . LookupUnigram ( new_word , ret . prob , * backoff_out , node ) ;
2011-01-25 22:11:48 +03:00
// This is the length of the context that should be used for continuation.
out_state . valid_length_ = HasExtension ( * backoff_out ) ? 1 : 0 ;
// We'll write the word anyway since it will probably be used and does no harm being there.
2010-09-10 04:36:07 +04:00
out_state . history_ [ 0 ] = new_word ;
2011-01-25 22:11:48 +03:00
if ( context_rbegin = = context_rend ) return ret ;
2010-09-10 04:36:07 +04:00
+ + backoff_out ;
// Ok now we now that the bigram contains known words. Start by looking it up.
2010-10-27 21:50:40 +04:00
const WordIndex * hist_iter = context_rbegin ;
2011-06-27 02:21:44 +04:00
const typename Search : : Middle * mid_iter = search_ . MiddleBegin ( ) ;
2010-09-10 04:36:07 +04:00
for ( ; ; + + mid_iter , + + hist_iter , + + backoff_out ) {
2010-10-27 21:50:40 +04:00
if ( hist_iter = = context_rend ) {
2011-01-25 22:11:48 +03:00
// Ran out of history. Typically no backoff, but this could be a blank.
CopyRemainingHistory ( context_rbegin , out_state ) ;
2010-09-10 04:36:07 +04:00
// ret.prob was already set.
return ret ;
}
2010-10-27 21:50:40 +04:00
2011-06-27 02:21:44 +04:00
if ( mid_iter = = search_ . MiddleEnd ( ) ) break ;
2010-10-27 21:50:40 +04:00
2011-01-25 22:11:48 +03:00
float revert = ret . prob ;
2010-10-27 21:50:40 +04:00
if ( ! search_ . LookupMiddle ( * mid_iter , * hist_iter , ret . prob , * backoff_out , node ) ) {
2010-09-10 04:36:07 +04:00
// Didn't find an ngram using hist_iter.
2011-01-25 22:11:48 +03:00
CopyRemainingHistory ( context_rbegin , out_state ) ;
2010-10-27 21:50:40 +04:00
// ret.prob was already set.
2010-09-10 04:36:07 +04:00
return ret ;
}
2011-01-25 22:11:48 +03:00
if ( ret . prob = = kBlankProb ) {
// It's a blank. Go back to the old probability.
ret . prob = revert ;
} else {
ret . ngram_length = hist_iter - context_rbegin + 2 ;
if ( HasExtension ( * backoff_out ) ) {
out_state . valid_length_ = ret . ngram_length ;
}
}
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
2011-01-25 22:11:48 +03:00
// It passed every lookup in search_.middle. All that's left is to check search_.longest.
2010-09-10 04:36:07 +04:00
2010-10-27 21:50:40 +04:00
if ( ! search_ . LookupLongest ( * hist_iter , ret . prob , node ) ) {
2011-01-25 22:11:48 +03:00
// Failed to find a longest n-gram. Fall back to the most recent non-blank.
CopyRemainingHistory ( context_rbegin , out_state ) ;
2010-10-27 21:50:40 +04:00
// ret.prob was already set.
2010-09-10 04:36:07 +04:00
return ret ;
}
2011-01-25 22:11:48 +03:00
// It's an P::Order()-gram.
CopyRemainingHistory ( context_rbegin , out_state ) ;
// There is no blank in longest_.
2010-09-10 04:36:07 +04:00
ret . ngram_length = P : : Order ( ) ;
2010-09-27 07:46:44 +04:00
return ret ;
}
2011-06-27 02:21:44 +04:00
template class GenericModel < ProbingHashedSearch , ProbingVocabulary > ; // HASH_PROBING
2011-07-14 00:53:18 +04:00
template class GenericModel < trie : : TrieSearch < DontQuantize , trie : : DontBhiksha > , SortedVocabulary > ; // TRIE_SORTED
template class GenericModel < trie : : TrieSearch < DontQuantize , trie : : ArrayBhiksha > , SortedVocabulary > ;
template class GenericModel < trie : : TrieSearch < SeparatelyQuantize , trie : : DontBhiksha > , SortedVocabulary > ; // TRIE_SORTED_QUANT
template class GenericModel < trie : : TrieSearch < SeparatelyQuantize , trie : : ArrayBhiksha > , SortedVocabulary > ;
2010-09-27 07:46:44 +04:00
2010-09-10 04:36:07 +04:00
} // namespace detail
} // namespace ngram
} // namespace lm