2010-11-06 03:40:16 +03:00
# include "lm/model.hh"
2010-09-10 04:36:07 +04:00
2011-01-25 22:11:48 +03:00
# include "lm/blank.hh"
2010-09-28 20:26:55 +04:00
# include "lm/lm_exception.hh"
2010-11-06 03:40:16 +03:00
# include "lm/search_hashed.hh"
# include "lm/search_trie.hh"
2010-09-28 20:26:55 +04:00
# include "lm/read_arpa.hh"
# include "util/murmur_hash.hh"
2010-09-10 04:36:07 +04:00
# include <algorithm>
# include <functional>
# include <numeric>
# include <cmath>
namespace lm {
namespace ngram {
2010-09-27 07:46:44 +04:00
namespace detail {
2011-07-14 00:53:18 +04:00
template < class Search , class VocabularyT > const ModelType GenericModel < Search , VocabularyT > : : kModelType = Search : : kModelType ;
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > size_t GenericModel < Search , VocabularyT > : : Size ( const std : : vector < uint64_t > & counts , const Config & config ) {
return VocabularyT : : Size ( counts [ 0 ] , config ) + Search : : Size ( counts , config ) ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : SetupMemory ( void * base , const std : : vector < uint64_t > & counts , const Config & config ) {
uint8_t * start = static_cast < uint8_t * > ( base ) ;
size_t allocated = VocabularyT : : Size ( counts [ 0 ] , config ) ;
vocab_ . SetupMemory ( start , allocated , counts [ 0 ] , config ) ;
2010-09-15 01:33:11 +04:00
start + = allocated ;
2010-10-27 21:50:40 +04:00
start = search_ . SetupMemory ( start , counts , config ) ;
if ( static_cast < std : : size_t > ( start - static_cast < uint8_t * > ( base ) ) ! = Size ( counts , config ) ) UTIL_THROW ( FormatLoadException , " The data structures took " < < ( start - static_cast < uint8_t * > ( base ) ) < < " but Size says they should take " < < Size ( counts , config ) ) ;
2010-09-15 01:33:11 +04:00
}
2010-09-10 04:36:07 +04:00
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > GenericModel < Search , VocabularyT > : : GenericModel ( const char * file , const Config & config ) {
LoadLM ( file , config , * this ) ;
2010-09-10 04:36:07 +04:00
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State ( ) ;
2011-09-21 20:06:48 +04:00
begin_sentence . length = 1 ;
begin_sentence . words [ 0 ] = vocab_ . BeginSentence ( ) ;
begin_sentence . backoff [ 0 ] = search_ . unigram . Lookup ( begin_sentence . words [ 0 ] ) . backoff ;
2010-09-10 04:36:07 +04:00
State null_context = State ( ) ;
2011-09-21 20:06:48 +04:00
null_context . length = 0 ;
2011-06-27 02:21:44 +04:00
P : : Init ( begin_sentence , null_context , vocab_ , search_ . MiddleEnd ( ) - search_ . MiddleBegin ( ) + 2 ) ;
2010-09-10 04:36:07 +04:00
}
2011-11-11 00:46:59 +04:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : InitializeFromBinary ( void * start , const Parameters & params , const Config & config , int fd ) {
2010-10-27 21:50:40 +04:00
SetupMemory ( start , params . counts , config ) ;
vocab_ . LoadedBinary ( fd , config . enumerate_vocab ) ;
2011-06-27 02:21:44 +04:00
search_ . LoadedBinary ( ) ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : InitializeFromARPA ( const char * file , const Config & config ) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util : : FilePiece f ( backing_ . file . release ( ) , file , config . messages ) ;
2011-08-16 16:57:21 +04:00
try {
std : : vector < uint64_t > counts ;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts ( f , counts ) ;
if ( counts . size ( ) > kMaxOrder ) UTIL_THROW ( FormatLoadException , " This model has order " < < counts . size ( ) < < " . Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile. " ) ;
if ( counts . size ( ) < 2 ) UTIL_THROW ( FormatLoadException , " This ngram implementation assumes at least a bigram model. " ) ;
if ( config . probing_multiplier < = 1.0 ) UTIL_THROW ( ConfigException , " probing multiplier must be > 1.0 " ) ;
std : : size_t vocab_size = VocabularyT : : Size ( counts [ 0 ] , config ) ;
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_ . SetupMemory ( SetupJustVocab ( config , counts . size ( ) , vocab_size , backing_ ) , vocab_size , counts [ 0 ] , config ) ;
if ( config . write_mmap ) {
WriteWordsWrapper wrap ( config . enumerate_vocab ) ;
vocab_ . ConfigureEnumerate ( & wrap , counts [ 0 ] ) ;
search_ . InitializeFromARPA ( file , f , counts , config , vocab_ , backing_ ) ;
wrap . Write ( backing_ . file . get ( ) ) ;
} else {
vocab_ . ConfigureEnumerate ( config . enumerate_vocab , counts [ 0 ] ) ;
search_ . InitializeFromARPA ( file , f , counts , config , vocab_ , backing_ ) ;
}
if ( ! vocab_ . SawUnk ( ) ) {
assert ( config . unknown_missing ! = THROW_UP ) ;
// Default probabilities for unknown.
search_ . unigram . Unknown ( ) . backoff = 0.0 ;
search_ . unigram . Unknown ( ) . prob = config . unknown_missing_logprob ;
}
2011-09-21 20:06:48 +04:00
FinishFile ( config , kModelType , kVersion , counts , backing_ ) ;
2011-08-16 16:57:21 +04:00
} catch ( util : : Exception & e ) {
e < < " Byte: " < < f . Offset ( ) ;
throw ;
2010-10-27 21:50:40 +04:00
}
}
2011-11-11 00:46:59 +04:00
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : UpdateConfigFromBinary ( int fd , const std : : vector < uint64_t > & counts , Config & config ) {
util : : AdvanceOrThrow ( fd , VocabularyT : : Size ( counts [ 0 ] , config ) ) ;
Search : : UpdateConfigFromBinary ( fd , counts , config ) ;
}
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : FullScore ( const State & in_state , const WordIndex new_word , State & out_state ) const {
2011-09-21 20:06:48 +04:00
FullScoreReturn ret = ScoreExceptBackoff ( in_state . words , in_state . words + in_state . length , new_word , out_state ) ;
2011-11-11 00:46:59 +04:00
for ( const float * i = in_state . backoff + ret . ngram_length - 1 ; i < in_state . backoff + in_state . length ; + + i ) {
ret . prob + = * i ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
return ret ;
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : FullScoreForgotState ( const WordIndex * context_rbegin , const WordIndex * context_rend , const WordIndex new_word , State & out_state ) const {
context_rend = std : : min ( context_rend , context_rbegin + P : : Order ( ) - 1 ) ;
2011-01-25 22:11:48 +03:00
FullScoreReturn ret = ScoreExceptBackoff ( context_rbegin , context_rend , new_word , out_state ) ;
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
unsigned char start = ret . ngram_length ;
if ( context_rend - context_rbegin < static_cast < std : : ptrdiff_t > ( start ) ) return ret ;
if ( start < = 1 ) {
ret . prob + = search_ . unigram . Lookup ( * context_rbegin ) . backoff ;
start = 2 ;
}
typename Search : : Node node ;
if ( ! search_ . FastMakeNode ( context_rbegin , context_rbegin + start - 1 , node ) ) {
return ret ;
}
float backoff ;
// i is the order of the backoff we're looking for.
2011-06-27 02:21:44 +04:00
const Middle * mid_iter = search_ . MiddleBegin ( ) + start - 2 ;
for ( const WordIndex * i = context_rbegin + start - 1 ; i < context_rend ; + + i , + + mid_iter ) {
if ( ! search_ . LookupMiddleNoProb ( * mid_iter , * i , backoff , node ) ) break ;
2011-01-25 22:11:48 +03:00
ret . prob + = backoff ;
}
2010-10-27 21:50:40 +04:00
return ret ;
}
template < class Search , class VocabularyT > void GenericModel < Search , VocabularyT > : : GetState ( const WordIndex * context_rbegin , const WordIndex * context_rend , State & out_state ) const {
2011-01-25 22:11:48 +03:00
// Generate a state from context.
2010-10-27 21:50:40 +04:00
context_rend = std : : min ( context_rend , context_rbegin + P : : Order ( ) - 1 ) ;
2011-01-25 22:11:48 +03:00
if ( context_rend = = context_rbegin ) {
2011-09-21 20:06:48 +04:00
out_state . length = 0 ;
2010-10-27 21:50:40 +04:00
return ;
}
2011-09-21 20:06:48 +04:00
FullScoreReturn ignored ;
2010-10-27 21:50:40 +04:00
typename Search : : Node node ;
2011-09-21 20:06:48 +04:00
search_ . LookupUnigram ( * context_rbegin , out_state . backoff [ 0 ] , node , ignored ) ;
out_state . length = HasExtension ( out_state . backoff [ 0 ] ) ? 1 : 0 ;
float * backoff_out = out_state . backoff + 1 ;
2011-06-27 02:21:44 +04:00
const typename Search : : Middle * mid = search_ . MiddleBegin ( ) ;
2011-01-25 22:11:48 +03:00
for ( const WordIndex * i = context_rbegin + 1 ; i < context_rend ; + + i , + + backoff_out , + + mid ) {
if ( ! search_ . LookupMiddleNoProb ( * mid , * i , * backoff_out , node ) ) {
2011-09-21 20:06:48 +04:00
std : : copy ( context_rbegin , context_rbegin + out_state . length , out_state . words ) ;
2010-10-27 21:50:40 +04:00
return ;
}
2011-09-21 20:06:48 +04:00
if ( HasExtension ( * backoff_out ) ) out_state . length = i - context_rbegin + 1 ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
std : : copy ( context_rbegin , context_rbegin + out_state . length , out_state . words ) ;
}
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : ExtendLeft (
const WordIndex * add_rbegin , const WordIndex * add_rend ,
const float * backoff_in ,
uint64_t extend_pointer ,
unsigned char extend_length ,
float * backoff_out ,
unsigned char & next_use ) const {
FullScoreReturn ret ;
float subtract_me ;
typename Search : : Node node ( search_ . Unpack ( extend_pointer , extend_length , subtract_me ) ) ;
ret . prob = subtract_me ;
ret . ngram_length = extend_length ;
next_use = 0 ;
// If this function is called, then it does depend on left words.
ret . independent_left = false ;
ret . extend_left = extend_pointer ;
const typename Search : : Middle * mid_iter = search_ . MiddleBegin ( ) + extend_length - 1 ;
const WordIndex * i = add_rbegin ;
for ( ; ; + + i , + + backoff_out , + + mid_iter ) {
if ( i = = add_rend ) {
// Ran out of words.
for ( const float * b = backoff_in + ret . ngram_length - extend_length ; b < backoff_in + ( add_rend - add_rbegin ) ; + + b ) ret . prob + = * b ;
ret . prob - = subtract_me ;
return ret ;
}
if ( mid_iter = = search_ . MiddleEnd ( ) ) break ;
if ( ret . independent_left | | ! search_ . LookupMiddle ( * mid_iter , * i , * backoff_out , node , ret ) ) {
// Didn't match a word.
ret . independent_left = true ;
for ( const float * b = backoff_in + ret . ngram_length - extend_length ; b < backoff_in + ( add_rend - add_rbegin ) ; + + b ) ret . prob + = * b ;
ret . prob - = subtract_me ;
return ret ;
}
ret . ngram_length = mid_iter - search_ . MiddleBegin ( ) + 2 ;
if ( HasExtension ( * backoff_out ) ) next_use = i - add_rbegin + 1 ;
}
if ( ret . independent_left | | ! search_ . LookupLongest ( * i , ret . prob , node ) ) {
// The last backoff weight, for Order() - 1.
ret . prob + = backoff_in [ i - add_rbegin ] ;
} else {
ret . ngram_length = P : : Order ( ) ;
}
ret . independent_left = true ;
ret . prob - = subtract_me ;
return ret ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
2011-09-21 20:06:48 +04:00
// (hence the -1). out_state.length could be zero so I avoided using
2011-01-25 22:11:48 +03:00
// std::copy.
void CopyRemainingHistory ( const WordIndex * from , State & out_state ) {
2011-09-21 20:06:48 +04:00
WordIndex * out = out_state . words + 1 ;
const WordIndex * in_end = from + static_cast < ptrdiff_t > ( out_state . length ) - 1 ;
2011-01-25 22:11:48 +03:00
for ( const WordIndex * in = from ; in < in_end ; + + in , + + out ) * out = * in ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
} // namespace
2010-10-27 21:50:40 +04:00
/* Ugly optimized function. Produce a score excluding backoff.
2010-09-10 04:36:07 +04:00
* The search goes in increasing order of ngram length .
2010-10-27 21:50:40 +04:00
* Context goes backward , so context_begin is the word immediately preceeding
* new_word .
2010-09-10 04:36:07 +04:00
*/
2010-10-27 21:50:40 +04:00
template < class Search , class VocabularyT > FullScoreReturn GenericModel < Search , VocabularyT > : : ScoreExceptBackoff (
const WordIndex * context_rbegin ,
const WordIndex * context_rend ,
2010-09-10 04:36:07 +04:00
const WordIndex new_word ,
State & out_state ) const {
FullScoreReturn ret ;
2011-01-25 22:11:48 +03:00
// ret.ngram_length contains the last known non-blank ngram length.
ret . ngram_length = 1 ;
2011-09-21 20:06:48 +04:00
float * backoff_out ( out_state . backoff ) ;
2010-10-27 21:50:40 +04:00
typename Search : : Node node ;
2011-09-21 20:06:48 +04:00
search_ . LookupUnigram ( new_word , * backoff_out , node , ret ) ;
// This is the length of the context that should be used for continuation to the right.
out_state . length = HasExtension ( * backoff_out ) ? 1 : 0 ;
2011-01-25 22:11:48 +03:00
// We'll write the word anyway since it will probably be used and does no harm being there.
2011-09-21 20:06:48 +04:00
out_state . words [ 0 ] = new_word ;
2011-01-25 22:11:48 +03:00
if ( context_rbegin = = context_rend ) return ret ;
2010-09-10 04:36:07 +04:00
+ + backoff_out ;
2011-09-21 20:06:48 +04:00
// Ok start by looking up the bigram.
2010-10-27 21:50:40 +04:00
const WordIndex * hist_iter = context_rbegin ;
2011-06-27 02:21:44 +04:00
const typename Search : : Middle * mid_iter = search_ . MiddleBegin ( ) ;
2010-09-10 04:36:07 +04:00
for ( ; ; + + mid_iter , + + hist_iter , + + backoff_out ) {
2010-10-27 21:50:40 +04:00
if ( hist_iter = = context_rend ) {
2011-01-25 22:11:48 +03:00
// Ran out of history. Typically no backoff, but this could be a blank.
CopyRemainingHistory ( context_rbegin , out_state ) ;
2010-09-10 04:36:07 +04:00
// ret.prob was already set.
return ret ;
}
2010-10-27 21:50:40 +04:00
2011-06-27 02:21:44 +04:00
if ( mid_iter = = search_ . MiddleEnd ( ) ) break ;
2010-10-27 21:50:40 +04:00
2011-09-21 20:06:48 +04:00
if ( ret . independent_left | | ! search_ . LookupMiddle ( * mid_iter , * hist_iter , * backoff_out , node , ret ) ) {
2010-09-10 04:36:07 +04:00
// Didn't find an ngram using hist_iter.
2011-01-25 22:11:48 +03:00
CopyRemainingHistory ( context_rbegin , out_state ) ;
2011-09-21 20:06:48 +04:00
// ret.prob was already set.
ret . independent_left = true ;
2010-09-10 04:36:07 +04:00
return ret ;
}
2011-09-21 20:06:48 +04:00
ret . ngram_length = hist_iter - context_rbegin + 2 ;
if ( HasExtension ( * backoff_out ) ) {
out_state . length = ret . ngram_length ;
2011-01-25 22:11:48 +03:00
}
2010-09-10 04:36:07 +04:00
}
2010-10-27 21:50:40 +04:00
2011-01-25 22:11:48 +03:00
// It passed every lookup in search_.middle. All that's left is to check search_.longest.
2011-09-21 20:06:48 +04:00
if ( ! ret . independent_left & & search_ . LookupLongest ( * hist_iter , ret . prob , node ) ) {
// It's an P::Order()-gram.
// There is no blank in longest_.
ret . ngram_length = P : : Order ( ) ;
2010-09-10 04:36:07 +04:00
}
2011-09-21 20:06:48 +04:00
// This handles (N-1)-grams and N-grams.
2011-01-25 22:11:48 +03:00
CopyRemainingHistory ( context_rbegin , out_state ) ;
2011-09-21 20:06:48 +04:00
ret . independent_left = true ;
2010-09-27 07:46:44 +04:00
return ret ;
}
2011-06-27 02:21:44 +04:00
template class GenericModel < ProbingHashedSearch , ProbingVocabulary > ; // HASH_PROBING
2011-07-14 00:53:18 +04:00
template class GenericModel < trie : : TrieSearch < DontQuantize , trie : : DontBhiksha > , SortedVocabulary > ; // TRIE_SORTED
template class GenericModel < trie : : TrieSearch < DontQuantize , trie : : ArrayBhiksha > , SortedVocabulary > ;
template class GenericModel < trie : : TrieSearch < SeparatelyQuantize , trie : : DontBhiksha > , SortedVocabulary > ; // TRIE_SORTED_QUANT
template class GenericModel < trie : : TrieSearch < SeparatelyQuantize , trie : : ArrayBhiksha > , SortedVocabulary > ;
2010-09-27 07:46:44 +04:00
2010-09-10 04:36:07 +04:00
} // namespace detail
} // namespace ngram
} // namespace lm