2012-10-15 16:58:33 +04:00
# ifndef LM_VOCAB__
# define LM_VOCAB__
# include "lm/enumerate_vocab.hh"
# include "lm/lm_exception.hh"
# include "lm/virtual_interface.hh"
2012-11-05 00:36:42 +04:00
# include "util/pool.hh"
2012-10-15 16:58:33 +04:00
# include "util/probing_hash_table.hh"
# include "util/sorted_uniform.hh"
# include "util/string_piece.hh"
# include <limits>
# include <string>
# include <vector>
namespace lm {
struct ProbBackoff ;
class EnumerateVocab ;
namespace ngram {
struct Config ;
namespace detail {
uint64_t HashForVocab ( const char * str , std : : size_t len ) ;
inline uint64_t HashForVocab ( const StringPiece & str ) {
return HashForVocab ( str . data ( ) , str . length ( ) ) ;
}
class ProbingVocabularyHeader ;
} // namespace detail
class WriteWordsWrapper : public EnumerateVocab {
public :
WriteWordsWrapper ( EnumerateVocab * inner ) ;
~ WriteWordsWrapper ( ) ;
void Add ( WordIndex index , const StringPiece & str ) ;
2012-10-19 15:00:10 +04:00
void Write ( int fd , uint64_t start ) ;
2012-10-15 16:58:33 +04:00
private :
EnumerateVocab * inner_ ;
std : : string buffer_ ;
} ;
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
class SortedVocabulary : public base : : Vocabulary {
public :
SortedVocabulary ( ) ;
WordIndex Index ( const StringPiece & str ) const {
const uint64_t * found ;
if ( util : : BoundedSortedUniformFind < const uint64_t * , util : : IdentityAccessor < uint64_t > , util : : Pivot64 > (
util : : IdentityAccessor < uint64_t > ( ) ,
begin_ - 1 , 0 ,
end_ , std : : numeric_limits < uint64_t > : : max ( ) ,
detail : : HashForVocab ( str ) , found ) ) {
return found - begin_ + 1 ; // +1 because <unk> is 0 and does not appear in the lookup table.
} else {
return 0 ;
}
}
// Size for purposes of file writing
static uint64_t Size ( uint64_t entries , const Config & config ) ;
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound ( ) const { return bound_ ; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory ( void * start , std : : size_t allocated , std : : size_t entries , const Config & config ) ;
void ConfigureEnumerate ( EnumerateVocab * to , std : : size_t max_entries ) ;
WordIndex Insert ( const StringPiece & str ) ;
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading ( ProbBackoff * reorder_vocab ) ;
// Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
std : : size_t UnkCountChangePadding ( ) const { return SawUnk ( ) ? 0 : sizeof ( uint64_t ) ; }
bool SawUnk ( ) const { return saw_unk_ ; }
void LoadedBinary ( bool have_words , int fd , EnumerateVocab * to ) ;
private :
uint64_t * begin_ , * end_ ;
WordIndex bound_ ;
WordIndex highest_value_ ;
bool saw_unk_ ;
EnumerateVocab * enumerate_ ;
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
2012-11-05 00:36:42 +04:00
util : : Pool string_backing_ ;
std : : vector < StringPiece > strings_to_enumerate_ ;
2012-10-15 16:58:33 +04:00
} ;
# pragma pack(push)
# pragma pack(4)
struct ProbingVocabuaryEntry {
uint64_t key ;
WordIndex value ;
typedef uint64_t Key ;
uint64_t GetKey ( ) const {
return key ;
}
static ProbingVocabuaryEntry Make ( uint64_t key , WordIndex value ) {
ProbingVocabuaryEntry ret ;
ret . key = key ;
ret . value = value ;
return ret ;
}
} ;
# pragma pack(pop)
// Vocabulary storing a map from uint64_t to WordIndex.
class ProbingVocabulary : public base : : Vocabulary {
public :
ProbingVocabulary ( ) ;
WordIndex Index ( const StringPiece & str ) const {
Lookup : : ConstIterator i ;
return lookup_ . Find ( detail : : HashForVocab ( str ) , i ) ? i - > value : 0 ;
}
static uint64_t Size ( uint64_t entries , const Config & config ) ;
// Vocab words are [0, Bound()).
WordIndex Bound ( ) const { return bound_ ; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory ( void * start , std : : size_t allocated , std : : size_t entries , const Config & config ) ;
void ConfigureEnumerate ( EnumerateVocab * to , std : : size_t max_entries ) ;
WordIndex Insert ( const StringPiece & str ) ;
template < class Weights > void FinishedLoading ( Weights * /*reorder_vocab*/ ) {
InternalFinishedLoading ( ) ;
}
std : : size_t UnkCountChangePadding ( ) const { return 0 ; }
bool SawUnk ( ) const { return saw_unk_ ; }
void LoadedBinary ( bool have_words , int fd , EnumerateVocab * to ) ;
private :
void InternalFinishedLoading ( ) ;
typedef util : : ProbingHashTable < ProbingVocabuaryEntry , util : : IdentityHash > Lookup ;
Lookup lookup_ ;
WordIndex bound_ ;
bool saw_unk_ ;
EnumerateVocab * enumerate_ ;
detail : : ProbingVocabularyHeader * header_ ;
} ;
void MissingUnknown ( const Config & config ) throw ( SpecialWordMissingException ) ;
void MissingSentenceMarker ( const Config & config , const char * str ) throw ( SpecialWordMissingException ) ;
template < class Vocab > void CheckSpecials ( const Config & config , const Vocab & vocab ) throw ( SpecialWordMissingException ) {
if ( ! vocab . SawUnk ( ) ) MissingUnknown ( config ) ;
if ( vocab . BeginSentence ( ) = = vocab . NotFound ( ) ) MissingSentenceMarker ( config , " <s> " ) ;
if ( vocab . EndSentence ( ) = = vocab . NotFound ( ) ) MissingSentenceMarker ( config , " </s> " ) ;
}
} // namespace ngram
} // namespace lm
# endif // LM_VOCAB__