2010-10-27 21:50:40 +04:00
# include "lm/trie.hh"
2011-07-14 00:53:18 +04:00
# include "lm/bhiksha.hh"
2010-10-27 21:50:40 +04:00
# include "util/bit_packing.hh"
# include "util/exception.hh"
# include "util/sorted_uniform.hh"
# include <assert.h>
namespace lm {
namespace ngram {
namespace trie {
namespace {
2011-06-27 02:21:44 +04:00
class KeyAccessor {
2010-10-27 21:50:40 +04:00
public :
2011-06-27 02:21:44 +04:00
KeyAccessor ( const void * base , uint64_t key_mask , uint8_t key_bits , uint8_t total_bits )
: base_ ( reinterpret_cast < const uint8_t * > ( base ) ) , key_mask_ ( key_mask ) , key_bits_ ( key_bits ) , total_bits_ ( total_bits ) { }
2010-10-27 21:50:40 +04:00
2011-06-27 02:21:44 +04:00
typedef uint64_t Key ;
2010-10-27 21:50:40 +04:00
2011-06-27 02:21:44 +04:00
Key operator ( ) ( uint64_t index ) const {
return util : : ReadInt57 ( base_ , index * static_cast < uint64_t > ( total_bits_ ) , key_bits_ , key_mask_ ) ;
2010-10-27 21:50:40 +04:00
}
private :
const uint8_t * const base_ ;
2011-06-27 02:21:44 +04:00
const WordIndex key_mask_ ;
2010-11-28 05:54:56 +03:00
const uint8_t key_bits_ , total_bits_ ;
2010-10-27 21:50:40 +04:00
} ;
2011-06-27 02:21:44 +04:00
bool FindBitPacked ( const void * base , uint64_t key_mask , uint8_t key_bits , uint8_t total_bits , uint64_t begin_index , uint64_t end_index , const uint64_t max_vocab , const uint64_t key , uint64_t & at_index ) {
KeyAccessor accessor ( base , key_mask , key_bits , total_bits ) ;
2011-07-14 00:53:18 +04:00
if ( ! util : : BoundedSortedUniformFind < uint64_t , KeyAccessor , util : : PivotSelect < sizeof ( WordIndex ) > : : T > ( accessor , begin_index - 1 , ( uint64_t ) 0 , end_index , max_vocab , key , at_index ) ) return false ;
return true ;
2010-10-27 21:50:40 +04:00
}
} // namespace
2012-09-28 18:04:48 +04:00
uint64_t BitPacked : : BaseSize ( uint64_t entries , uint64_t max_vocab , uint8_t remaining_bits ) {
2011-06-27 02:21:44 +04:00
uint8_t total_bits = util : : RequiredBits ( max_vocab ) + remaining_bits ;
2010-10-27 21:50:40 +04:00
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
// +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
// Note that this waste is O(order), not O(number of ngrams).
return ( ( 1 + entries ) * total_bits + 7 ) / 8 + sizeof ( uint64_t ) ;
}
void BitPacked : : BaseInit ( void * base , uint64_t max_vocab , uint8_t remaining_bits ) {
util : : BitPackingSanity ( ) ;
word_bits_ = util : : RequiredBits ( max_vocab ) ;
word_mask_ = ( 1ULL < < word_bits_ ) - 1ULL ;
if ( word_bits_ > 57 ) UTIL_THROW ( util : : Exception , " Sorry, word indices more than " < < ( 1ULL < < 57 ) < < " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions. " ) ;
2011-06-27 02:21:44 +04:00
total_bits_ = word_bits_ + remaining_bits ;
2010-10-27 21:50:40 +04:00
base_ = static_cast < uint8_t * > ( base ) ;
insert_index_ = 0 ;
2011-05-23 06:23:01 +04:00
max_vocab_ = max_vocab ;
2010-10-27 21:50:40 +04:00
}
2012-09-28 18:04:48 +04:00
template < class Bhiksha > uint64_t BitPackedMiddle < Bhiksha > : : Size ( uint8_t quant_bits , uint64_t entries , uint64_t max_vocab , uint64_t max_ptr , const Config & config ) {
2011-07-14 00:53:18 +04:00
return Bhiksha : : Size ( entries + 1 , max_ptr , config ) + BaseSize ( entries , max_vocab , quant_bits + Bhiksha : : InlineBits ( entries + 1 , max_ptr , config ) ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
template < class Bhiksha > BitPackedMiddle < Bhiksha > : : BitPackedMiddle ( void * base , uint8_t quant_bits , uint64_t entries , uint64_t max_vocab , uint64_t max_next , const BitPacked & next_source , const Config & config ) :
2011-07-14 00:53:18 +04:00
BitPacked ( ) ,
2012-06-28 18:58:59 +04:00
quant_bits_ ( quant_bits ) ,
2011-07-14 00:53:18 +04:00
// If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
bhiksha_ ( base , entries + 1 , max_next , config ) ,
next_source_ ( & next_source ) {
if ( entries + 1 > = ( 1ULL < < 57 ) | | ( max_next > = ( 1ULL < < 57 ) ) ) UTIL_THROW ( util : : Exception , " Sorry, this does not support more than " < < ( 1ULL < < 57 ) < < " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions. " ) ;
2012-06-28 18:58:59 +04:00
BaseInit ( reinterpret_cast < uint8_t * > ( base ) + Bhiksha : : Size ( entries + 1 , max_next , config ) , max_vocab , quant_bits_ + bhiksha_ . InlineBits ( ) ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
template < class Bhiksha > util : : BitAddress BitPackedMiddle < Bhiksha > : : Insert ( WordIndex word ) {
2010-10-27 21:50:40 +04:00
assert ( word < = word_mask_ ) ;
uint64_t at_pointer = insert_index_ * total_bits_ ;
2011-06-27 02:21:44 +04:00
util : : WriteInt57 ( base_ , at_pointer , word_bits_ , word ) ;
2010-10-27 21:50:40 +04:00
at_pointer + = word_bits_ ;
2012-06-28 18:58:59 +04:00
util : : BitAddress ret ( base_ , at_pointer ) ;
at_pointer + = quant_bits_ ;
2011-01-25 22:11:48 +03:00
uint64_t next = next_source_ - > InsertIndex ( ) ;
2011-07-14 00:53:18 +04:00
bhiksha_ . WriteNext ( base_ , at_pointer , insert_index_ , next ) ;
2010-10-27 21:50:40 +04:00
+ + insert_index_ ;
2012-06-28 18:58:59 +04:00
return ret ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
template < class Bhiksha > util : : BitAddress BitPackedMiddle < Bhiksha > : : Find ( WordIndex word , NodeRange & range , uint64_t & pointer ) const {
2010-10-27 21:50:40 +04:00
uint64_t at_pointer ;
2011-05-23 06:23:01 +04:00
if ( ! FindBitPacked ( base_ , word_mask_ , word_bits_ , total_bits_ , range . begin , range . end , max_vocab_ , word , at_pointer ) ) {
2012-06-28 18:58:59 +04:00
return util : : BitAddress ( NULL , 0 ) ;
2011-01-25 22:11:48 +03:00
}
2011-09-27 00:54:41 +04:00
pointer = at_pointer ;
2010-10-27 21:50:40 +04:00
at_pointer * = total_bits_ ;
at_pointer + = word_bits_ ;
2012-06-28 18:58:59 +04:00
bhiksha_ . ReadNext ( base_ , at_pointer + quant_bits_ , pointer , total_bits_ , range ) ;
2011-09-21 20:06:48 +04:00
2012-06-28 18:58:59 +04:00
return util : : BitAddress ( base_ , at_pointer ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
template < class Bhiksha > void BitPackedMiddle < Bhiksha > : : FinishedLoading ( uint64_t next_end , const Config & config ) {
2014-06-05 03:03:39 +04:00
// Write at insert_index. . .
uint64_t last_next_write = insert_index_ * total_bits_ +
// at the offset where the next pointers are stored.
( total_bits_ - bhiksha_ . InlineBits ( ) ) ;
bhiksha_ . WriteNext ( base_ , last_next_write , insert_index_ , next_end ) ;
2011-07-14 00:53:18 +04:00
bhiksha_ . FinishedLoading ( config ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
util : : BitAddress BitPackedLongest : : Insert ( WordIndex index ) {
2010-10-27 21:50:40 +04:00
assert ( index < = word_mask_ ) ;
uint64_t at_pointer = insert_index_ * total_bits_ ;
2011-06-27 02:21:44 +04:00
util : : WriteInt57 ( base_ , at_pointer , word_bits_ , index ) ;
2010-10-27 21:50:40 +04:00
at_pointer + = word_bits_ ;
+ + insert_index_ ;
2012-06-28 18:58:59 +04:00
return util : : BitAddress ( base_ , at_pointer ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
util : : BitAddress BitPackedLongest : : Find ( WordIndex word , const NodeRange & range ) const {
2010-10-27 21:50:40 +04:00
uint64_t at_pointer ;
2012-06-28 18:58:59 +04:00
if ( ! FindBitPacked ( base_ , word_mask_ , word_bits_ , total_bits_ , range . begin , range . end , max_vocab_ , word , at_pointer ) ) return util : : BitAddress ( NULL , 0 ) ;
2010-10-27 21:50:40 +04:00
at_pointer = at_pointer * total_bits_ + word_bits_ ;
2012-06-28 18:58:59 +04:00
return util : : BitAddress ( base_ , at_pointer ) ;
2010-10-27 21:50:40 +04:00
}
2012-06-28 18:58:59 +04:00
template class BitPackedMiddle < DontBhiksha > ;
template class BitPackedMiddle < ArrayBhiksha > ;
2011-06-27 02:21:44 +04:00
2010-10-27 21:50:40 +04:00
} // namespace trie
} // namespace ngram
} // namespace lm