2011-09-21 20:06:48 +04:00
# include "lm/trie_sort.hh"
# include "lm/config.hh"
# include "lm/lm_exception.hh"
# include "lm/read_arpa.hh"
# include "lm/vocab.hh"
# include "lm/weights.hh"
# include "lm/word_index.hh"
# include "util/file_piece.hh"
# include "util/mmap.hh"
# include "util/proxy_iterator.hh"
# include "util/sized_iterator.hh"
# include <algorithm>
# include <cstring>
# include <cstdio>
2011-11-11 00:46:59 +04:00
# include <cstdlib>
2011-09-21 20:06:48 +04:00
# include <deque>
2014-06-02 21:28:02 +04:00
# include <iterator>
2011-09-21 20:06:48 +04:00
# include <limits>
# include <vector>
namespace lm {
namespace ngram {
namespace trie {
namespace {
typedef util : : SizedIterator NGramIter ;
// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams.
class PartialViewProxy {
public :
PartialViewProxy ( ) : attention_size_ ( 0 ) , inner_ ( ) { }
PartialViewProxy ( void * ptr , std : : size_t block_size , std : : size_t attention_size ) : attention_size_ ( attention_size ) , inner_ ( ptr , block_size ) { }
operator std : : string ( ) const {
return std : : string ( reinterpret_cast < const char * > ( inner_ . Data ( ) ) , attention_size_ ) ;
}
PartialViewProxy & operator = ( const PartialViewProxy & from ) {
memcpy ( inner_ . Data ( ) , from . inner_ . Data ( ) , attention_size_ ) ;
return * this ;
}
PartialViewProxy & operator = ( const std : : string & from ) {
memcpy ( inner_ . Data ( ) , from . data ( ) , attention_size_ ) ;
return * this ;
}
const void * Data ( ) const { return inner_ . Data ( ) ; }
void * Data ( ) { return inner_ . Data ( ) ; }
2014-01-28 04:51:35 +04:00
friend void swap ( PartialViewProxy first , PartialViewProxy second ) {
std : : swap_ranges ( reinterpret_cast < char * > ( first . Data ( ) ) , reinterpret_cast < char * > ( first . Data ( ) ) + first . attention_size_ , reinterpret_cast < char * > ( second . Data ( ) ) ) ;
}
2011-09-21 20:06:48 +04:00
private :
friend class util : : ProxyIterator < PartialViewProxy > ;
typedef std : : string value_type ;
const std : : size_t attention_size_ ;
typedef util : : SizedInnerIterator InnerIterator ;
InnerIterator & Inner ( ) { return inner_ ; }
const InnerIterator & Inner ( ) const { return inner_ ; }
InnerIterator inner_ ;
} ;
typedef util : : ProxyIterator < PartialViewProxy > PartialIter ;
2013-01-17 15:58:58 +04:00
FILE * DiskFlush ( const void * mem_begin , const void * mem_end , const std : : string & temp_prefix ) {
util : : scoped_fd file ( util : : MakeTemp ( temp_prefix ) ) ;
2011-11-11 00:46:59 +04:00
util : : WriteOrThrow ( file . get ( ) , mem_begin , ( uint8_t * ) mem_end - ( uint8_t * ) mem_begin ) ;
return util : : FDOpenOrThrow ( file ) ;
2011-09-21 20:06:48 +04:00
}
2013-01-17 15:58:58 +04:00
FILE * WriteContextFile ( uint8_t * begin , uint8_t * end , const std : : string & temp_prefix , std : : size_t entry_size , unsigned char order ) {
2011-09-21 20:06:48 +04:00
const size_t context_size = sizeof ( WordIndex ) * ( order - 1 ) ;
// Sort just the contexts using the same memory.
PartialIter context_begin ( PartialViewProxy ( begin + sizeof ( WordIndex ) , entry_size , context_size ) ) ;
PartialIter context_end ( PartialViewProxy ( end + sizeof ( WordIndex ) , entry_size , context_size ) ) ;
2012-02-28 22:58:00 +04:00
# if defined(_WIN32) || defined(_WIN64)
std : : stable_sort
# else
std : : sort
# endif
( context_begin , context_end , util : : SizedCompare < EntryCompare , PartialViewProxy > ( EntryCompare ( order - 1 ) ) ) ;
2011-09-21 20:06:48 +04:00
2013-01-17 15:58:58 +04:00
util : : scoped_FILE out ( util : : FMakeTemp ( temp_prefix ) ) ;
2011-09-21 20:06:48 +04:00
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
2011-11-11 00:46:59 +04:00
if ( context_begin = = context_end ) return out . release ( ) ;
2011-09-21 20:06:48 +04:00
PartialIter i ( context_begin ) ;
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( out . get ( ) , i - > Data ( ) , context_size ) ;
2011-09-21 20:06:48 +04:00
const void * previous = i - > Data ( ) ;
+ + i ;
for ( ; i ! = context_end ; + + i ) {
if ( memcmp ( previous , i - > Data ( ) , context_size ) ) {
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( out . get ( ) , i - > Data ( ) , context_size ) ;
2011-09-21 20:06:48 +04:00
previous = i - > Data ( ) ;
}
}
2011-11-11 00:46:59 +04:00
return out . release ( ) ;
2011-09-21 20:06:48 +04:00
}
struct ThrowCombine {
2014-08-28 07:23:39 +04:00
void operator ( ) ( std : : size_t entry_size , unsigned char order , const void * first , const void * second , FILE * /*out*/ ) const {
const WordIndex * base = reinterpret_cast < const WordIndex * > ( first ) ;
FormatLoadException e ;
e < < " Duplicate n-gram detected with vocab ids " ;
for ( const WordIndex * i = base ; i ! = base + order ; + + i ) {
e < < ' ' < < * i ;
}
throw e ;
2011-09-21 20:06:48 +04:00
}
} ;
// Useful for context files that just contain records with no value.
struct FirstCombine {
2014-08-28 07:23:39 +04:00
void operator ( ) ( std : : size_t entry_size , unsigned char /*order*/ , const void * first , const void * /*second*/ , FILE * out ) const {
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( out , first , entry_size ) ;
2011-09-21 20:06:48 +04:00
}
} ;
2013-01-17 15:58:58 +04:00
template < class Combine > FILE * MergeSortedFiles ( FILE * first_file , FILE * second_file , const std : : string & temp_prefix , std : : size_t weights_size , unsigned char order , const Combine & combine ) {
2011-09-21 20:06:48 +04:00
std : : size_t entry_size = sizeof ( WordIndex ) * order + weights_size ;
RecordReader first , second ;
2011-11-11 00:46:59 +04:00
first . Init ( first_file , entry_size ) ;
second . Init ( second_file , entry_size ) ;
2013-01-17 15:58:58 +04:00
util : : scoped_FILE out_file ( util : : FMakeTemp ( temp_prefix ) ) ;
2011-09-21 20:06:48 +04:00
EntryCompare less ( order ) ;
while ( first & & second ) {
if ( less ( first . Data ( ) , second . Data ( ) ) ) {
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( out_file . get ( ) , first . Data ( ) , entry_size ) ;
2011-09-21 20:06:48 +04:00
+ + first ;
} else if ( less ( second . Data ( ) , first . Data ( ) ) ) {
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( out_file . get ( ) , second . Data ( ) , entry_size ) ;
2011-09-21 20:06:48 +04:00
+ + second ;
} else {
2014-08-28 07:23:39 +04:00
combine ( entry_size , order , first . Data ( ) , second . Data ( ) , out_file . get ( ) ) ;
2011-09-21 20:06:48 +04:00
+ + first ; + + second ;
}
}
2011-10-08 14:59:54 +04:00
for ( RecordReader & remains = ( first ? first : second ) ; remains ; + + remains ) {
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( out_file . get ( ) , remains . Data ( ) , entry_size ) ;
2011-09-21 20:06:48 +04:00
}
2011-11-11 00:46:59 +04:00
return out_file . release ( ) ;
2011-09-21 20:06:48 +04:00
}
} // namespace
2011-11-11 00:46:59 +04:00
void RecordReader : : Init ( FILE * file , std : : size_t entry_size ) {
2012-08-17 00:01:43 +04:00
entry_size_ = entry_size ;
2011-09-21 20:06:48 +04:00
data_ . reset ( malloc ( entry_size ) ) ;
UTIL_THROW_IF ( ! data_ . get ( ) , util : : ErrnoException , " Failed to malloc read buffer " ) ;
2012-08-17 00:01:43 +04:00
file_ = file ;
if ( file ) {
rewind ( file ) ;
remains_ = true ;
+ + * this ;
} else {
remains_ = false ;
}
2011-09-21 20:06:48 +04:00
}
void RecordReader : : Overwrite ( const void * start , std : : size_t amount ) {
long internal = ( uint8_t * ) start - ( uint8_t * ) data_ . get ( ) ;
2011-11-11 00:46:59 +04:00
UTIL_THROW_IF ( fseek ( file_ , internal - entry_size_ , SEEK_CUR ) , util : : ErrnoException , " Couldn't seek backwards for revision " ) ;
2012-09-28 18:04:48 +04:00
util : : WriteOrThrow ( file_ , start , amount ) ;
2011-09-21 20:06:48 +04:00
long forward = entry_size_ - internal - amount ;
2012-02-28 22:58:00 +04:00
# if !defined(_WIN32) && !defined(_WIN64)
if ( forward )
# endif
UTIL_THROW_IF ( fseek ( file_ , forward , SEEK_CUR ) , util : : ErrnoException , " Couldn't seek forwards past revision " ) ;
2011-09-21 20:06:48 +04:00
}
2011-11-11 00:46:59 +04:00
void RecordReader : : Rewind ( ) {
2012-08-17 00:01:43 +04:00
if ( file_ ) {
rewind ( file_ ) ;
remains_ = true ;
+ + * this ;
} else {
remains_ = false ;
}
2011-11-11 00:46:59 +04:00
}
SortedFiles : : SortedFiles ( const Config & config , util : : FilePiece & f , std : : vector < uint64_t > & counts , size_t buffer , const std : : string & file_prefix , SortedVocabulary & vocab ) {
2011-09-21 20:06:48 +04:00
PositiveProbWarn warn ( config . positive_log_probability ) ;
2013-01-17 15:58:58 +04:00
unigram_ . reset ( util : : MakeTemp ( file_prefix ) ) ;
2011-09-21 20:06:48 +04:00
{
// In case <unk> appears.
2011-11-11 00:46:59 +04:00
size_t size_out = ( counts [ 0 ] + 1 ) * sizeof ( ProbBackoff ) ;
util : : scoped_mmap unigram_mmap ( util : : MapZeroedWrite ( unigram_ . get ( ) , size_out ) , size_out ) ;
2011-09-21 20:06:48 +04:00
Read1Grams ( f , counts [ 0 ] , vocab , reinterpret_cast < ProbBackoff * > ( unigram_mmap . get ( ) ) , warn ) ;
CheckSpecials ( config , vocab ) ;
if ( ! vocab . SawUnk ( ) ) + + counts [ 0 ] ;
}
// Only use as much buffer as we need.
size_t buffer_use = 0 ;
for ( unsigned int order = 2 ; order < counts . size ( ) ; + + order ) {
buffer_use = std : : max < size_t > ( buffer_use , static_cast < size_t > ( ( sizeof ( WordIndex ) * order + 2 * sizeof ( float ) ) * counts [ order - 1 ] ) ) ;
}
buffer_use = std : : max < size_t > ( buffer_use , static_cast < size_t > ( ( sizeof ( WordIndex ) * counts . size ( ) + sizeof ( float ) ) * counts . back ( ) ) ) ;
buffer = std : : min < size_t > ( buffer , buffer_use ) ;
2011-11-11 00:46:59 +04:00
util : : scoped_malloc mem ;
mem . reset ( malloc ( buffer ) ) ;
2011-09-21 20:06:48 +04:00
if ( ! mem . get ( ) ) UTIL_THROW ( util : : ErrnoException , " malloc failed for sort buffer size " < < buffer ) ;
for ( unsigned char order = 2 ; order < = counts . size ( ) ; + + order ) {
2013-01-17 15:58:58 +04:00
ConvertToSorted ( f , vocab , counts , file_prefix , order , warn , mem . get ( ) , buffer ) ;
2011-09-21 20:06:48 +04:00
}
ReadEnd ( f ) ;
}
2011-11-11 00:46:59 +04:00
namespace {
class Closer {
public :
explicit Closer ( std : : deque < FILE * > & files ) : files_ ( files ) { }
~ Closer ( ) {
for ( std : : deque < FILE * > : : iterator i = files_ . begin ( ) ; i ! = files_ . end ( ) ; + + i ) {
util : : scoped_FILE deleter ( * i ) ;
}
}
void PopFront ( ) {
util : : scoped_FILE deleter ( files_ . front ( ) ) ;
files_ . pop_front ( ) ;
}
private :
std : : deque < FILE * > & files_ ;
} ;
} // namespace
2013-01-17 15:58:58 +04:00
void SortedFiles : : ConvertToSorted ( util : : FilePiece & f , const SortedVocabulary & vocab , const std : : vector < uint64_t > & counts , const std : : string & file_prefix , unsigned char order , PositiveProbWarn & warn , void * mem , std : : size_t mem_size ) {
2011-11-11 00:46:59 +04:00
ReadNGramHeader ( f , order ) ;
const size_t count = counts [ order - 1 ] ;
// Size of weights. Does it include backoff?
const size_t words_size = sizeof ( WordIndex ) * order ;
const size_t weights_size = sizeof ( float ) + ( ( order = = counts . size ( ) ) ? 0 : sizeof ( float ) ) ;
const size_t entry_size = words_size + weights_size ;
const size_t batch_size = std : : min ( count , mem_size / entry_size ) ;
uint8_t * const begin = reinterpret_cast < uint8_t * > ( mem ) ;
std : : deque < FILE * > files , contexts ;
Closer files_closer ( files ) , contexts_closer ( contexts ) ;
for ( std : : size_t batch = 0 , done = 0 ; done < count ; + + batch ) {
uint8_t * out = begin ;
uint8_t * out_end = out + std : : min ( count - done , batch_size ) * entry_size ;
if ( order = = counts . size ( ) ) {
for ( ; out ! = out_end ; out + = entry_size ) {
2014-06-02 21:28:02 +04:00
std : : reverse_iterator < WordIndex * > it ( reinterpret_cast < WordIndex * > ( out ) + order ) ;
ReadNGram ( f , order , vocab , it , * reinterpret_cast < Prob * > ( out + words_size ) , warn ) ;
2011-11-11 00:46:59 +04:00
}
} else {
for ( ; out ! = out_end ; out + = entry_size ) {
2014-06-02 21:28:02 +04:00
std : : reverse_iterator < WordIndex * > it ( reinterpret_cast < WordIndex * > ( out ) + order ) ;
ReadNGram ( f , order , vocab , it , * reinterpret_cast < ProbBackoff * > ( out + words_size ) , warn ) ;
2011-11-11 00:46:59 +04:00
}
}
// Sort full records by full n-gram.
util : : SizedProxy proxy_begin ( begin , entry_size ) , proxy_end ( out_end , entry_size ) ;
2012-02-28 22:58:00 +04:00
// parallel_sort uses too much RAM. TODO: figure out why windows sort doesn't like my proxies.
# if defined(_WIN32) || defined(_WIN64)
std : : stable_sort
# else
std : : sort
# endif
( NGramIter ( proxy_begin ) , NGramIter ( proxy_end ) , util : : SizedCompare < EntryCompare > ( EntryCompare ( order ) ) ) ;
2013-01-17 15:58:58 +04:00
files . push_back ( DiskFlush ( begin , out_end , file_prefix ) ) ;
contexts . push_back ( WriteContextFile ( begin , out_end , file_prefix , entry_size , order ) ) ;
2011-11-11 00:46:59 +04:00
done + = ( out_end - begin ) / entry_size ;
}
// All individual files created. Merge them.
while ( files . size ( ) > 1 ) {
2013-01-17 15:58:58 +04:00
files . push_back ( MergeSortedFiles ( files [ 0 ] , files [ 1 ] , file_prefix , weights_size , order , ThrowCombine ( ) ) ) ;
2011-11-11 00:46:59 +04:00
files_closer . PopFront ( ) ;
files_closer . PopFront ( ) ;
2013-01-17 15:58:58 +04:00
contexts . push_back ( MergeSortedFiles ( contexts [ 0 ] , contexts [ 1 ] , file_prefix , 0 , order - 1 , FirstCombine ( ) ) ) ;
2011-11-11 00:46:59 +04:00
contexts_closer . PopFront ( ) ;
contexts_closer . PopFront ( ) ;
}
if ( ! files . empty ( ) ) {
// Steal from closers.
full_ [ order - 2 ] . reset ( files . front ( ) ) ;
files . pop_front ( ) ;
context_ [ order - 2 ] . reset ( contexts . front ( ) ) ;
contexts . pop_front ( ) ;
}
}
2011-09-21 20:06:48 +04:00
} // namespace trie
} // namespace ngram
} // namespace lm