2010-12-08 06:15:37 +03:00
/* This is where the trie is built. It's on-disk. */
2010-11-06 03:40:16 +03:00
# include "lm/search_trie.hh"
2010-10-27 21:50:40 +04:00
2011-07-14 00:53:18 +04:00
# include "lm/bhiksha.hh"
2011-09-21 20:06:48 +04:00
# include "lm/binary_format.hh"
2011-01-25 22:11:48 +03:00
# include "lm/blank.hh"
2010-10-27 21:50:40 +04:00
# include "lm/lm_exception.hh"
2012-09-28 18:04:48 +04:00
# include "lm/max_order.hh"
2011-06-27 02:21:44 +04:00
# include "lm/quantize.hh"
2010-10-27 21:50:40 +04:00
# include "lm/trie.hh"
2011-09-21 20:06:48 +04:00
# include "lm/trie_sort.hh"
2010-10-27 21:50:40 +04:00
# include "lm/vocab.hh"
# include "lm/weights.hh"
# include "lm/word_index.hh"
# include "util/ersatz_progress.hh"
2011-11-11 00:46:59 +04:00
# include "util/mmap.hh"
2010-12-08 06:15:37 +03:00
# include "util/proxy_iterator.hh"
2010-10-27 21:50:40 +04:00
# include "util/scoped.hh"
2011-09-21 20:06:48 +04:00
# include "util/sized_iterator.hh"
2010-10-27 21:50:40 +04:00
# include <algorithm>
# include <cstring>
# include <cstdio>
2011-11-11 00:46:59 +04:00
# include <cstdlib>
2011-09-21 20:06:48 +04:00
# include <queue>
2010-10-27 21:50:40 +04:00
# include <limits>
2011-06-27 02:21:44 +04:00
# include <numeric>
2010-10-27 21:50:40 +04:00
# include <vector>
2011-11-11 00:46:59 +04:00
# if defined(_WIN32) || defined(_WIN64)
# include <windows.h>
# endif
2010-10-27 21:50:40 +04:00
namespace lm {
namespace ngram {
namespace trie {
namespace {
void ReadOrThrow ( FILE * from , void * data , size_t size ) {
2011-09-21 20:06:48 +04:00
UTIL_THROW_IF ( 1 ! = std : : fread ( data , size , 1 , from ) , util : : ErrnoException , " Short read " ) ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
int Compare ( unsigned char order , const void * first_void , const void * second_void ) {
const WordIndex * first = reinterpret_cast < const WordIndex * > ( first_void ) , * second = reinterpret_cast < const WordIndex * > ( second_void ) ;
const WordIndex * end = first + order ;
for ( ; first ! = end ; + + first , + + second ) {
if ( * first < * second ) return - 1 ;
if ( * first > * second ) return 1 ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
return 0 ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
struct ProbPointer {
unsigned char array ;
uint64_t index ;
} ;
2010-10-27 21:50:40 +04:00
2013-01-05 01:02:47 +04:00
// Array of n-grams and float indices.
2011-09-21 20:06:48 +04:00
class BackoffMessages {
2010-10-27 21:50:40 +04:00
public :
2011-09-21 20:06:48 +04:00
void Init ( std : : size_t entry_size ) {
current_ = NULL ;
allocated_ = NULL ;
entry_size_ = entry_size ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
void Add ( const WordIndex * to , ProbPointer index ) {
while ( current_ + entry_size_ > allocated_ ) {
std : : size_t allocated_size = allocated_ - ( uint8_t * ) backing_ . get ( ) ;
Resize ( std : : max < std : : size_t > ( allocated_size * 2 , entry_size_ ) ) ;
}
memcpy ( current_ , to , entry_size_ - sizeof ( ProbPointer ) ) ;
* reinterpret_cast < ProbPointer * > ( current_ + entry_size_ - sizeof ( ProbPointer ) ) = index ;
current_ + = entry_size_ ;
}
void Apply ( float * const * const base , FILE * unigrams ) {
FinishedAdding ( ) ;
if ( current_ = = allocated_ ) return ;
rewind ( unigrams ) ;
ProbBackoff weights ;
WordIndex unigram = 0 ;
ReadOrThrow ( unigrams , & weights , sizeof ( weights ) ) ;
for ( ; current_ ! = allocated_ ; current_ + = entry_size_ ) {
const WordIndex & cur_word = * reinterpret_cast < const WordIndex * > ( current_ ) ;
for ( ; unigram < cur_word ; + + unigram ) {
ReadOrThrow ( unigrams , & weights , sizeof ( weights ) ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
if ( ! HasExtension ( weights . backoff ) ) {
weights . backoff = kExtensionBackoff ;
UTIL_THROW_IF ( fseek ( unigrams , - sizeof ( weights ) , SEEK_CUR ) , util : : ErrnoException , " Seeking backwards to denote unigram extension failed. " ) ;
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( unigrams , & weights , sizeof ( weights ) ) ;
2011-09-21 20:06:48 +04:00
}
const ProbPointer & write_to = * reinterpret_cast < const ProbPointer * > ( current_ + sizeof ( WordIndex ) ) ;
base [ write_to . array ] [ write_to . index ] + = weights . backoff ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
backing_ . reset ( ) ;
}
void Apply ( float * const * const base , RecordReader & reader ) {
FinishedAdding ( ) ;
if ( current_ = = allocated_ ) return ;
2013-01-05 01:02:47 +04:00
// We'll also use the same buffer to record messages to blanks that they extend.
2011-09-21 20:06:48 +04:00
WordIndex * extend_out = reinterpret_cast < WordIndex * > ( current_ ) ;
const unsigned char order = ( entry_size_ - sizeof ( ProbPointer ) ) / sizeof ( WordIndex ) ;
for ( reader . Rewind ( ) ; reader & & ( current_ ! = allocated_ ) ; ) {
switch ( Compare ( order , reader . Data ( ) , current_ ) ) {
case - 1 :
+ + reader ;
break ;
case 1 :
2013-01-05 01:02:47 +04:00
// Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
2011-09-21 20:06:48 +04:00
for ( const WordIndex * w = reinterpret_cast < const WordIndex * > ( current_ ) ; w ! = reinterpret_cast < const WordIndex * > ( current_ ) + order ; + + w , + + extend_out ) * extend_out = * w ;
current_ + = entry_size_ ;
break ;
case 0 :
float & backoff = reinterpret_cast < ProbBackoff * > ( ( uint8_t * ) reader . Data ( ) + order * sizeof ( WordIndex ) ) - > backoff ;
if ( ! HasExtension ( backoff ) ) {
backoff = kExtensionBackoff ;
reader . Overwrite ( & backoff , sizeof ( float ) ) ;
} else {
const ProbPointer & write_to = * reinterpret_cast < const ProbPointer * > ( current_ + entry_size_ - sizeof ( ProbPointer ) ) ;
base [ write_to . array ] [ write_to . index ] + = backoff ;
}
current_ + = entry_size_ ;
break ;
}
}
2013-01-05 01:02:47 +04:00
// Now this is a list of blanks that extend right.
2011-09-21 20:06:48 +04:00
entry_size_ = sizeof ( WordIndex ) * order ;
Resize ( sizeof ( WordIndex ) * ( extend_out - ( const WordIndex * ) backing_ . get ( ) ) ) ;
current_ = ( uint8_t * ) backing_ . get ( ) ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
// Call after Apply
bool Extends ( unsigned char order , const WordIndex * words ) {
if ( current_ = = allocated_ ) return false ;
assert ( order * sizeof ( WordIndex ) = = entry_size_ ) ;
while ( true ) {
switch ( Compare ( order , words , current_ ) ) {
case 1 :
current_ + = entry_size_ ;
if ( current_ = = allocated_ ) return false ;
break ;
case - 1 :
return false ;
case 0 :
return true ;
}
}
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
private :
void FinishedAdding ( ) {
Resize ( current_ - ( uint8_t * ) backing_ . get ( ) ) ;
2013-01-05 01:02:47 +04:00
// Sort requests in same order as files.
2011-10-11 22:27:36 +04:00
std : : sort (
util : : SizedIterator ( util : : SizedProxy ( backing_ . get ( ) , entry_size_ ) ) ,
util : : SizedIterator ( util : : SizedProxy ( current_ , entry_size_ ) ) ,
util : : SizedCompare < EntryCompare > ( EntryCompare ( ( entry_size_ - sizeof ( ProbPointer ) ) / sizeof ( WordIndex ) ) ) ) ;
2011-09-21 20:06:48 +04:00
current_ = ( uint8_t * ) backing_ . get ( ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
void Resize ( std : : size_t to ) {
std : : size_t current = current_ - ( uint8_t * ) backing_ . get ( ) ;
backing_ . call_realloc ( to ) ;
current_ = ( uint8_t * ) backing_ . get ( ) + current ;
allocated_ = ( uint8_t * ) backing_ . get ( ) + to ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
util : : scoped_malloc backing_ ;
2010-10-27 21:50:40 +04:00
2011-09-21 20:06:48 +04:00
uint8_t * current_ , * allocated_ ;
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
std : : size_t entry_size_ ;
2010-10-27 21:50:40 +04:00
} ;
2011-09-21 20:06:48 +04:00
const float kBadProb = std : : numeric_limits < float > : : infinity ( ) ;
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
class SRISucks {
2011-01-25 22:11:48 +03:00
public :
2011-09-21 20:06:48 +04:00
SRISucks ( ) {
2012-08-09 00:22:13 +04:00
for ( BackoffMessages * i = messages_ ; i ! = messages_ + KENLM_MAX_ORDER - 1 ; + + i )
2011-09-21 20:06:48 +04:00
i - > Init ( sizeof ( ProbPointer ) + sizeof ( WordIndex ) * ( i - messages_ + 1 ) ) ;
}
void Send ( unsigned char begin , unsigned char order , const WordIndex * to , float prob_basis ) {
assert ( prob_basis ! = kBadProb ) ;
ProbPointer pointer ;
pointer . array = order - 1 ;
pointer . index = values_ [ order - 1 ] . size ( ) ;
for ( unsigned char i = begin ; i < order ; + + i ) {
messages_ [ i - 1 ] . Add ( to , pointer ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
values_ [ order - 1 ] . push_back ( prob_basis ) ;
2011-01-25 22:11:48 +03:00
}
2010-12-08 06:15:37 +03:00
2011-09-21 20:06:48 +04:00
void ObtainBackoffs ( unsigned char total_order , FILE * unigram_file , RecordReader * reader ) {
2012-08-09 00:22:13 +04:00
for ( unsigned char i = 0 ; i < KENLM_MAX_ORDER - 1 ; + + i ) {
2012-02-28 22:58:00 +04:00
it_ [ i ] = values_ [ i ] . empty ( ) ? NULL : & * values_ [ i ] . begin ( ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
messages_ [ 0 ] . Apply ( it_ , unigram_file ) ;
BackoffMessages * messages = messages_ + 1 ;
const RecordReader * end = reader + total_order - 2 /* exclude unigrams and longest order */ ;
for ( ; reader ! = end ; + + messages , + + reader ) {
messages - > Apply ( it_ , * reader ) ;
2011-01-25 22:11:48 +03:00
}
}
2011-09-21 20:06:48 +04:00
ProbBackoff GetBlank ( unsigned char total_order , unsigned char order , const WordIndex * indices ) {
assert ( order > 1 ) ;
ProbBackoff ret ;
ret . prob = * ( it_ [ order - 1 ] + + ) ;
ret . backoff = ( ( order ! = total_order - 1 ) & & messages_ [ order - 1 ] . Extends ( order , indices ) ) ? kExtensionBackoff : kNoExtensionBackoff ;
return ret ;
2010-10-27 21:50:40 +04:00
}
2011-09-21 20:06:48 +04:00
const std : : vector < float > & Values ( unsigned char order ) const {
return values_ [ order - 1 ] ;
}
2010-10-27 21:50:40 +04:00
2011-09-21 20:06:48 +04:00
private :
2013-01-05 01:02:47 +04:00
// This used to be one array. Then I needed to separate it by order for quantization to work.
2012-08-09 00:22:13 +04:00
std : : vector < float > values_ [ KENLM_MAX_ORDER - 1 ] ;
BackoffMessages messages_ [ KENLM_MAX_ORDER - 1 ] ;
2011-01-25 22:11:48 +03:00
2012-08-09 00:22:13 +04:00
float * it_ [ KENLM_MAX_ORDER - 1 ] ;
2011-09-21 20:06:48 +04:00
} ;
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
class FindBlanks {
public :
2012-02-28 22:58:00 +04:00
FindBlanks ( unsigned char order , const ProbBackoff * unigrams , SRISucks & messages )
: counts_ ( order ) , unigrams_ ( unigrams ) , sri_ ( messages ) { }
2010-10-27 21:50:40 +04:00
2011-09-21 20:06:48 +04:00
float UnigramProb ( WordIndex index ) const {
return unigrams_ [ index ] . prob ;
2011-01-25 22:11:48 +03:00
}
2010-10-27 21:50:40 +04:00
2011-09-21 20:06:48 +04:00
void Unigram ( WordIndex /*index*/ ) {
+ + counts_ [ 0 ] ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
void MiddleBlank ( const unsigned char order , const WordIndex * indices , unsigned char lower , float prob_basis ) {
sri_ . Send ( lower , order , indices + 1 , prob_basis ) ;
+ + counts_ [ order - 1 ] ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
void Middle ( const unsigned char order , const void * /*data*/ ) {
+ + counts_ [ order - 1 ] ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
void Longest ( const void * /*data*/ ) {
2012-02-28 22:58:00 +04:00
+ + counts_ . back ( ) ;
2011-01-25 22:11:48 +03:00
}
2012-02-28 22:58:00 +04:00
const std : : vector < uint64_t > & Counts ( ) const {
return counts_ ;
}
2011-01-25 22:11:48 +03:00
private :
2012-02-28 22:58:00 +04:00
std : : vector < uint64_t > counts_ ;
2011-09-21 20:06:48 +04:00
const ProbBackoff * unigrams_ ;
SRISucks & sri_ ;
2011-01-25 22:11:48 +03:00
} ;
2013-01-05 01:02:47 +04:00
// Phase to actually write n-grams to the trie.
2011-07-14 00:53:18 +04:00
template < class Quant , class Bhiksha > class WriteEntries {
2011-01-25 22:11:48 +03:00
public :
2013-01-05 01:02:47 +04:00
WriteEntries ( RecordReader * contexts , const Quant & quant , UnigramValue * unigrams , BitPackedMiddle < Bhiksha > * middle , BitPackedLongest & longest , unsigned char order , SRISucks & sri ) :
2011-01-25 22:11:48 +03:00
contexts_ ( contexts ) ,
2012-06-28 18:58:59 +04:00
quant_ ( quant ) ,
2011-01-25 22:11:48 +03:00
unigrams_ ( unigrams ) ,
middle_ ( middle ) ,
2013-01-05 01:02:47 +04:00
longest_ ( longest ) ,
2011-09-21 20:06:48 +04:00
bigram_pack_ ( ( order = = 2 ) ? static_cast < BitPacked & > ( longest_ ) : static_cast < BitPacked & > ( * middle_ ) ) ,
order_ ( order ) ,
sri_ ( sri ) { }
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
float UnigramProb ( WordIndex index ) const { return unigrams_ [ index ] . weights . prob ; }
void Unigram ( WordIndex word ) {
unigrams_ [ word ] . next = bigram_pack_ . InsertIndex ( ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
void MiddleBlank ( const unsigned char order , const WordIndex * indices , unsigned char /*lower*/ , float /*prob_base*/ ) {
ProbBackoff weights = sri_ . GetBlank ( order_ , order , indices ) ;
2012-06-28 18:58:59 +04:00
typename Quant : : MiddlePointer ( quant_ , order - 2 , middle_ [ order - 2 ] . Insert ( indices [ order - 1 ] ) ) . Write ( weights . prob , weights . backoff ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
void Middle ( const unsigned char order , const void * data ) {
RecordReader & context = contexts_ [ order - 1 ] ;
const WordIndex * words = reinterpret_cast < const WordIndex * > ( data ) ;
ProbBackoff weights = * reinterpret_cast < const ProbBackoff * > ( words + order ) ;
if ( context & & ! memcmp ( data , context . Data ( ) , sizeof ( WordIndex ) * order ) ) {
2011-01-25 22:11:48 +03:00
SetExtension ( weights . backoff ) ;
+ + context ;
}
2012-06-28 18:58:59 +04:00
typename Quant : : MiddlePointer ( quant_ , order - 2 , middle_ [ order - 2 ] . Insert ( words [ order - 1 ] ) ) . Write ( weights . prob , weights . backoff ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
void Longest ( const void * data ) {
const WordIndex * words = reinterpret_cast < const WordIndex * > ( data ) ;
2012-06-28 18:58:59 +04:00
typename Quant : : LongestPointer ( quant_ , longest_ . Insert ( words [ order_ - 1 ] ) ) . Write ( reinterpret_cast < const Prob * > ( words + order_ ) - > prob ) ;
2011-01-25 22:11:48 +03:00
}
private :
2011-09-21 20:06:48 +04:00
RecordReader * contexts_ ;
2012-06-28 18:58:59 +04:00
const Quant & quant_ ;
2011-01-25 22:11:48 +03:00
UnigramValue * const unigrams_ ;
2012-06-28 18:58:59 +04:00
BitPackedMiddle < Bhiksha > * const middle_ ;
BitPackedLongest & longest_ ;
2011-01-25 22:11:48 +03:00
BitPacked & bigram_pack_ ;
2011-09-21 20:06:48 +04:00
const unsigned char order_ ;
SRISucks & sri_ ;
2011-01-25 22:11:48 +03:00
} ;
2011-09-21 20:06:48 +04:00
struct Gram {
Gram ( const WordIndex * in_begin , unsigned char order ) : begin ( in_begin ) , end ( in_begin + order ) { }
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
const WordIndex * begin , * end ;
2011-01-25 22:11:48 +03:00
2013-01-05 01:02:47 +04:00
// For queue, this is the direction we want.
2011-09-21 20:06:48 +04:00
bool operator < ( const Gram & other ) const {
return std : : lexicographical_compare ( other . begin , other . end , begin , end ) ;
}
} ;
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
template < class Doing > class BlankManager {
public :
BlankManager ( unsigned char total_order , Doing & doing ) : total_order_ ( total_order ) , been_length_ ( 0 ) , doing_ ( doing ) {
2012-08-09 00:22:13 +04:00
for ( float * i = basis_ ; i ! = basis_ + KENLM_MAX_ORDER - 1 ; + + i ) * i = kBadProb ;
2011-09-21 20:06:48 +04:00
}
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
void Visit ( const WordIndex * to , unsigned char length , float prob ) {
basis_ [ length - 1 ] = prob ;
unsigned char overlap = std : : min < unsigned char > ( length - 1 , been_length_ ) ;
const WordIndex * cur ;
WordIndex * pre ;
for ( cur = to , pre = been_ ; cur ! = to + overlap ; + + cur , + + pre ) {
if ( * pre ! = * cur ) break ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
if ( cur = = to + length - 1 ) {
* pre = * cur ;
been_length_ = length ;
return ;
2011-01-25 22:11:48 +03:00
}
2013-01-05 01:02:47 +04:00
// There are blanks to insert starting with order blank.
2011-09-21 20:06:48 +04:00
unsigned char blank = cur - to + 1 ;
UTIL_THROW_IF ( blank = = 1 , FormatLoadException , " Missing a unigram that appears as context. " ) ;
const float * lower_basis ;
for ( lower_basis = basis_ + blank - 2 ; * lower_basis = = kBadProb ; - - lower_basis ) { }
unsigned char based_on = lower_basis - basis_ + 1 ;
for ( ; cur ! = to + length - 1 ; + + blank , + + cur , + + pre ) {
2011-11-11 00:46:59 +04:00
assert ( * lower_basis ! = kBadProb ) ;
2011-09-21 20:06:48 +04:00
doing_ . MiddleBlank ( blank , to , based_on , * lower_basis ) ;
* pre = * cur ;
2013-01-05 01:02:47 +04:00
// Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
2011-09-21 20:06:48 +04:00
basis_ [ blank - 1 ] = kBadProb ;
2011-01-25 22:11:48 +03:00
}
2011-11-03 23:51:54 +04:00
* pre = * cur ;
2011-09-21 20:06:48 +04:00
been_length_ = length ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
private :
const unsigned char total_order_ ;
2011-01-25 22:11:48 +03:00
2012-08-09 00:22:13 +04:00
WordIndex been_ [ KENLM_MAX_ORDER ] ;
2011-09-21 20:06:48 +04:00
unsigned char been_length_ ;
2011-01-25 22:11:48 +03:00
2012-08-09 00:22:13 +04:00
float basis_ [ KENLM_MAX_ORDER ] ;
2013-01-05 01:02:47 +04:00
2011-09-21 20:06:48 +04:00
Doing & doing_ ;
} ;
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
template < class Doing > void RecursiveInsert ( const unsigned char total_order , const WordIndex unigram_count , RecordReader * input , std : : ostream * progress_out , const char * message , Doing & doing ) {
2012-06-28 18:58:59 +04:00
util : : ErsatzProgress progress ( unigram_count + 1 , progress_out , message ) ;
2011-11-17 23:12:19 +04:00
WordIndex unigram = 0 ;
2011-09-21 20:06:48 +04:00
std : : priority_queue < Gram > grams ;
2014-01-02 01:19:06 +04:00
if ( unigram_count ) grams . push ( Gram ( & unigram , 1 ) ) ;
2011-09-21 20:06:48 +04:00
for ( unsigned char i = 2 ; i < = total_order ; + + i ) {
if ( input [ i - 2 ] ) grams . push ( Gram ( reinterpret_cast < const WordIndex * > ( input [ i - 2 ] . Data ( ) ) , i ) ) ;
}
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
BlankManager < Doing > blank ( total_order , doing ) ;
2014-01-02 01:19:06 +04:00
while ( ! grams . empty ( ) ) {
2011-09-21 20:06:48 +04:00
Gram top = grams . top ( ) ;
grams . pop ( ) ;
unsigned char order = top . end - top . begin ;
if ( order = = 1 ) {
blank . Visit ( & unigram , 1 , doing . UnigramProb ( unigram ) ) ;
doing . Unigram ( unigram ) ;
progress . Set ( unigram ) ;
2014-01-02 01:19:06 +04:00
if ( + + unigram < unigram_count ) grams . push ( top ) ;
2011-09-21 20:06:48 +04:00
} else {
if ( order = = total_order ) {
blank . Visit ( top . begin , order , reinterpret_cast < const Prob * > ( top . end ) - > prob ) ;
doing . Longest ( top . begin ) ;
} else {
blank . Visit ( top . begin , order , reinterpret_cast < const ProbBackoff * > ( top . end ) - > prob ) ;
doing . Middle ( order , top . begin ) ;
}
RecordReader & reader = input [ order - 2 ] ;
if ( + + reader ) grams . push ( top ) ;
}
}
}
2011-01-25 22:11:48 +03:00
void SanityCheckCounts ( const std : : vector < uint64_t > & initial , const std : : vector < uint64_t > & fixed ) {
if ( fixed [ 0 ] ! = initial [ 0 ] ) UTIL_THROW ( util : : Exception , " Unigram count should be constant but initial is " < < initial [ 0 ] < < " and recounted is " < < fixed [ 0 ] ) ;
2011-03-21 22:51:08 +03:00
if ( fixed . back ( ) ! = initial . back ( ) ) UTIL_THROW ( util : : Exception , " Longest count should be constant but it changed from " < < initial . back ( ) < < " to " < < fixed . back ( ) ) ;
2011-01-25 22:11:48 +03:00
for ( unsigned char i = 0 ; i < initial . size ( ) ; + + i ) {
if ( fixed [ i ] < initial [ i ] ) UTIL_THROW ( util : : Exception , " Counts came out lower than expected. This shouldn't happen " ) ;
2010-10-27 21:50:40 +04:00
}
2011-01-25 22:11:48 +03:00
}
2010-10-27 21:50:40 +04:00
2011-09-21 20:06:48 +04:00
template < class Quant > void TrainQuantizer ( uint8_t order , uint64_t count , const std : : vector < float > & additional , RecordReader & reader , util : : ErsatzProgress & progress , Quant & quant ) {
std : : vector < float > probs ( additional ) , backoffs ;
probs . reserve ( count + additional . size ( ) ) ;
2011-06-27 02:21:44 +04:00
backoffs . reserve ( count ) ;
2011-09-21 20:06:48 +04:00
for ( reader . Rewind ( ) ; reader ; + + reader ) {
const ProbBackoff & weights = * reinterpret_cast < const ProbBackoff * > ( reinterpret_cast < const uint8_t * > ( reader . Data ( ) ) + sizeof ( WordIndex ) * order ) ;
probs . push_back ( weights . prob ) ;
if ( weights . backoff ! = 0.0 ) backoffs . push_back ( weights . backoff ) ;
+ + progress ;
2011-06-27 02:21:44 +04:00
}
quant . Train ( order , probs , backoffs ) ;
}
2011-09-21 20:06:48 +04:00
template < class Quant > void TrainProbQuantizer ( uint8_t order , uint64_t count , RecordReader & reader , util : : ErsatzProgress & progress , Quant & quant ) {
2011-06-27 02:21:44 +04:00
std : : vector < float > probs , backoffs ;
probs . reserve ( count ) ;
2011-09-21 20:06:48 +04:00
for ( reader . Rewind ( ) ; reader ; + + reader ) {
const Prob & weights = * reinterpret_cast < const Prob * > ( reinterpret_cast < const uint8_t * > ( reader . Data ( ) ) + sizeof ( WordIndex ) * order ) ;
probs . push_back ( weights . prob ) ;
+ + progress ;
2011-06-27 02:21:44 +04:00
}
quant . TrainProb ( order , probs ) ;
}
2011-09-21 20:06:48 +04:00
void PopulateUnigramWeights ( FILE * file , WordIndex unigram_count , RecordReader & contexts , UnigramValue * unigrams ) {
2013-01-05 01:02:47 +04:00
// Fill unigram probabilities.
2011-09-21 20:06:48 +04:00
try {
rewind ( file ) ;
for ( WordIndex i = 0 ; i < unigram_count ; + + i ) {
ReadOrThrow ( file , & unigrams [ i ] . weights , sizeof ( ProbBackoff ) ) ;
if ( contexts & & * reinterpret_cast < const WordIndex * > ( contexts . Data ( ) ) = = i ) {
SetExtension ( unigrams [ i ] . weights . backoff ) ;
+ + contexts ;
}
}
} catch ( util : : Exception & e ) {
e < < " while re-reading unigram probabilities " ;
throw ;
}
}
2011-06-27 02:21:44 +04:00
} // namespace
2011-11-11 00:46:59 +04:00
template < class Quant , class Bhiksha > void BuildTrie ( SortedFiles & files , std : : vector < uint64_t > & counts , const Config & config , TrieSearch < Quant , Bhiksha > & out , Quant & quant , const SortedVocabulary & vocab , Backing & backing ) {
2012-08-09 00:22:13 +04:00
RecordReader inputs [ KENLM_MAX_ORDER - 1 ] ;
RecordReader contexts [ KENLM_MAX_ORDER - 1 ] ;
2011-01-25 22:11:48 +03:00
2010-10-27 21:50:40 +04:00
for ( unsigned char i = 2 ; i < = counts . size ( ) ; + + i ) {
2011-11-11 00:46:59 +04:00
inputs [ i - 2 ] . Init ( files . Full ( i ) , i * sizeof ( WordIndex ) + ( i = = counts . size ( ) ? sizeof ( Prob ) : sizeof ( ProbBackoff ) ) ) ;
contexts [ i - 2 ] . Init ( files . Context ( i ) , ( i - 1 ) * sizeof ( WordIndex ) ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-21 20:06:48 +04:00
SRISucks sri ;
2012-02-28 22:58:00 +04:00
std : : vector < uint64_t > fixed_counts ;
2011-11-11 00:46:59 +04:00
util : : scoped_FILE unigram_file ;
util : : scoped_fd unigram_fd ( files . StealUnigram ( ) ) ;
2011-01-25 22:11:48 +03:00
{
2011-09-21 20:06:48 +04:00
util : : scoped_memory unigrams ;
2011-11-11 00:46:59 +04:00
MapRead ( util : : POPULATE_OR_READ , unigram_fd . get ( ) , 0 , counts [ 0 ] * sizeof ( ProbBackoff ) , unigrams ) ;
2012-02-28 22:58:00 +04:00
FindBlanks finder ( counts . size ( ) , reinterpret_cast < const ProbBackoff * > ( unigrams . get ( ) ) , sri ) ;
2013-01-05 01:02:47 +04:00
RecursiveInsert ( counts . size ( ) , counts [ 0 ] , inputs , config . ProgressMessages ( ) , " Identifying n-grams omitted by SRI " , finder ) ;
2012-02-28 22:58:00 +04:00
fixed_counts = finder . Counts ( ) ;
2010-10-27 21:50:40 +04:00
}
2011-11-11 00:46:59 +04:00
unigram_file . reset ( util : : FDOpenOrThrow ( unigram_fd ) ) ;
2011-09-21 20:06:48 +04:00
for ( const RecordReader * i = inputs ; i ! = inputs + counts . size ( ) - 2 ; + + i ) {
if ( * i ) UTIL_THROW ( FormatLoadException , " There's a bug in the trie implementation: the " < < ( i - inputs + 2 ) < < " -gram table did not complete reading " ) ;
2011-03-21 22:51:08 +03:00
}
2011-01-25 22:11:48 +03:00
SanityCheckCounts ( counts , fixed_counts ) ;
2011-02-24 20:11:53 +03:00
counts = fixed_counts ;
2011-01-25 22:11:48 +03:00
2011-09-21 20:06:48 +04:00
sri . ObtainBackoffs ( counts . size ( ) , unigram_file . get ( ) , inputs ) ;
2011-08-16 16:57:21 +04:00
out . SetupMemory ( GrowForSearch ( config , vocab . UnkCountChangePadding ( ) , TrieSearch < Quant , Bhiksha > : : Size ( fixed_counts , config ) , backing ) , fixed_counts , config ) ;
2011-06-27 02:21:44 +04:00
2011-09-21 20:06:48 +04:00
for ( unsigned char i = 2 ; i < = counts . size ( ) ; + + i ) {
inputs [ i - 2 ] . Rewind ( ) ;
}
2011-06-27 02:21:44 +04:00
if ( Quant : : kTrain ) {
2013-01-05 01:02:47 +04:00
util : : ErsatzProgress progress ( std : : accumulate ( counts . begin ( ) + 1 , counts . end ( ) , 0 ) ,
config . ProgressMessages ( ) , " Quantizing " ) ;
2011-06-27 02:21:44 +04:00
for ( unsigned char i = 2 ; i < counts . size ( ) ; + + i ) {
2011-09-21 20:06:48 +04:00
TrainQuantizer ( i , counts [ i - 1 ] , sri . Values ( i ) , inputs [ i - 2 ] , progress , quant ) ;
2011-06-27 02:21:44 +04:00
}
TrainProbQuantizer ( counts . size ( ) , counts . back ( ) , inputs [ counts . size ( ) - 2 ] , progress , quant ) ;
quant . FinishedLoading ( config ) ;
}
2010-10-27 21:50:40 +04:00
2012-06-28 18:58:59 +04:00
UnigramValue * unigrams = out . unigram_ . Raw ( ) ;
2011-09-21 20:06:48 +04:00
PopulateUnigramWeights ( unigram_file . get ( ) , counts [ 0 ] , contexts [ 0 ] , unigrams ) ;
unigram_file . reset ( ) ;
2011-01-25 22:11:48 +03:00
for ( unsigned char i = 2 ; i < = counts . size ( ) ; + + i ) {
inputs [ i - 2 ] . Rewind ( ) ;
}
2013-01-05 01:02:47 +04:00
// Fill entries except unigram probabilities.
2010-10-27 21:50:40 +04:00
{
2012-06-28 18:58:59 +04:00
WriteEntries < Quant , Bhiksha > writer ( contexts , quant , unigrams , out . middle_begin_ , out . longest_ , counts . size ( ) , sri ) ;
2013-01-05 01:02:47 +04:00
RecursiveInsert ( counts . size ( ) , counts [ 0 ] , inputs , config . ProgressMessages ( ) , " Writing trie " , writer ) ;
2014-01-02 01:19:06 +04:00
// Write the last unigram entry, which is the end pointer for the bigrams.
writer . Unigram ( counts [ 0 ] ) ;
2011-01-25 22:11:48 +03:00
}
2013-01-05 01:02:47 +04:00
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
2011-01-25 22:11:48 +03:00
for ( unsigned char order = 2 ; order < = counts . size ( ) ; + + order ) {
2011-09-21 20:06:48 +04:00
const RecordReader & context = contexts [ order - 2 ] ;
2011-01-25 22:11:48 +03:00
if ( context ) {
FormatLoadException e ;
2011-10-11 14:12:17 +04:00
e < < " A " < < static_cast < unsigned int > ( order ) < < " -gram has context " ;
2011-09-21 20:06:48 +04:00
const WordIndex * ctx = reinterpret_cast < const WordIndex * > ( context . Data ( ) ) ;
for ( const WordIndex * i = ctx ; i ! = ctx + order - 1 ; + + i ) {
2011-01-25 22:11:48 +03:00
e < < ' ' < < * i ;
}
2011-02-24 20:11:53 +03:00
e < < " so this context must appear in the model as a " < < static_cast < unsigned int > ( order - 1 ) < < " -gram but it does not " ;
2011-01-25 22:11:48 +03:00
throw e ;
2010-10-27 21:50:40 +04:00
}
}
/* Set ending offsets so the last entry will be sized properly */
2013-01-05 01:02:47 +04:00
// Last entry for unigrams was already set.
2011-06-27 02:21:44 +04:00
if ( out . middle_begin_ ! = out . middle_end_ ) {
2011-07-14 00:53:18 +04:00
for ( typename TrieSearch < Quant , Bhiksha > : : Middle * i = out . middle_begin_ ; i ! = out . middle_end_ - 1 ; + + i ) {
i - > FinishedLoading ( ( i + 1 ) - > InsertIndex ( ) , config ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
( out . middle_end_ - 1 ) - > FinishedLoading ( out . longest_ . InsertIndex ( ) , config ) ;
2013-01-05 01:02:47 +04:00
}
2010-10-27 21:50:40 +04:00
}
2011-07-14 00:53:18 +04:00
template < class Quant , class Bhiksha > uint8_t * TrieSearch < Quant , Bhiksha > : : SetupMemory ( uint8_t * start , const std : : vector < uint64_t > & counts , const Config & config ) {
2012-06-28 18:58:59 +04:00
quant_ . SetupMemory ( start , counts . size ( ) , config ) ;
2011-06-27 02:21:44 +04:00
start + = Quant : : Size ( counts . size ( ) , config ) ;
2012-06-28 18:58:59 +04:00
unigram_ . Init ( start ) ;
2011-06-27 02:21:44 +04:00
start + = Unigram : : Size ( counts [ 0 ] ) ;
FreeMiddles ( ) ;
middle_begin_ = static_cast < Middle * > ( malloc ( sizeof ( Middle ) * ( counts . size ( ) - 2 ) ) ) ;
middle_end_ = middle_begin_ + ( counts . size ( ) - 2 ) ;
2011-06-28 01:28:22 +04:00
std : : vector < uint8_t * > middle_starts ( counts . size ( ) - 2 ) ;
2011-06-28 01:20:42 +04:00
for ( unsigned char i = 2 ; i < counts . size ( ) ; + + i ) {
middle_starts [ i - 2 ] = start ;
2011-07-14 00:53:18 +04:00
start + = Middle : : Size ( Quant : : MiddleBits ( config ) , counts [ i - 1 ] , counts [ 0 ] , counts [ i ] , config ) ;
2011-06-28 01:20:42 +04:00
}
2011-07-14 00:53:18 +04:00
// Crazy backwards thing so we initialize using pointers to ones that have already been initialized
2011-06-27 02:21:44 +04:00
for ( unsigned char i = counts . size ( ) - 1 ; i > = 2 ; - - i ) {
new ( middle_begin_ + i - 2 ) Middle (
2011-06-28 01:20:42 +04:00
middle_starts [ i - 2 ] ,
2012-06-28 18:58:59 +04:00
quant_ . MiddleBits ( config ) ,
2011-07-14 00:53:18 +04:00
counts [ i - 1 ] ,
2011-06-27 02:21:44 +04:00
counts [ 0 ] ,
counts [ i ] ,
2012-06-28 18:58:59 +04:00
( i = = counts . size ( ) - 1 ) ? static_cast < const BitPacked & > ( longest_ ) : static_cast < const BitPacked & > ( middle_begin_ [ i - 1 ] ) ,
2011-07-14 00:53:18 +04:00
config ) ;
2011-06-27 02:21:44 +04:00
}
2012-06-28 18:58:59 +04:00
longest_ . Init ( start , quant_ . LongestBits ( config ) , counts [ 0 ] ) ;
2011-06-27 02:21:44 +04:00
return start + Longest : : Size ( Quant : : LongestBits ( config ) , counts . back ( ) , counts [ 0 ] ) ;
2011-01-25 22:11:48 +03:00
}
2011-07-14 00:53:18 +04:00
template < class Quant , class Bhiksha > void TrieSearch < Quant , Bhiksha > : : LoadedBinary ( ) {
2012-06-28 18:58:59 +04:00
unigram_ . LoadedBinary ( ) ;
2011-06-27 02:21:44 +04:00
for ( Middle * i = middle_begin_ ; i ! = middle_end_ ; + + i ) {
i - > LoadedBinary ( ) ;
}
2012-06-28 18:58:59 +04:00
longest_ . LoadedBinary ( ) ;
2011-06-27 02:21:44 +04:00
}
2010-10-27 21:50:40 +04:00
2011-07-14 00:53:18 +04:00
template < class Quant , class Bhiksha > void TrieSearch < Quant , Bhiksha > : : InitializeFromARPA ( const char * file , util : : FilePiece & f , std : : vector < uint64_t > & counts , const Config & config , SortedVocabulary & vocab , Backing & backing ) {
2011-11-11 00:46:59 +04:00
std : : string temporary_prefix ;
2010-10-27 21:50:40 +04:00
if ( config . temporary_directory_prefix ) {
2011-11-11 00:46:59 +04:00
temporary_prefix = config . temporary_directory_prefix ;
2010-10-27 21:50:40 +04:00
} else if ( config . write_mmap ) {
2011-11-11 00:46:59 +04:00
temporary_prefix = config . write_mmap ;
2010-10-27 21:50:40 +04:00
} else {
2011-11-11 00:46:59 +04:00
temporary_prefix = file ;
2010-10-27 21:50:40 +04:00
}
2013-01-05 01:02:47 +04:00
// At least 1MB sorting memory.
2011-11-11 00:46:59 +04:00
SortedFiles sorted ( config , f , counts , std : : max < size_t > ( config . building_memory , 1048576 ) , temporary_prefix , vocab ) ;
2011-01-25 22:11:48 +03:00
2011-11-11 00:46:59 +04:00
BuildTrie ( sorted , counts , config , * this , quant_ , vocab , backing ) ;
2010-10-27 21:50:40 +04:00
}
2011-07-14 00:53:18 +04:00
template class TrieSearch < DontQuantize , DontBhiksha > ;
template class TrieSearch < DontQuantize , ArrayBhiksha > ;
template class TrieSearch < SeparatelyQuantize , DontBhiksha > ;
template class TrieSearch < SeparatelyQuantize , ArrayBhiksha > ;
2011-06-27 02:21:44 +04:00
2010-10-27 21:50:40 +04:00
} // namespace trie
} // namespace ngram
} // namespace lm