mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2025-01-06 11:38:34 +03:00
ea8e19f286
TODO: kill istream
280 lines
8.5 KiB
C++
280 lines
8.5 KiB
C++
#ifndef LM_VOCAB_H
|
|
#define LM_VOCAB_H
|
|
|
|
#include "lm/enumerate_vocab.hh"
|
|
#include "lm/lm_exception.hh"
|
|
#include "lm/virtual_interface.hh"
|
|
#include "util/file_stream.hh"
|
|
#include "util/murmur_hash.hh"
|
|
#include "util/pool.hh"
|
|
#include "util/probing_hash_table.hh"
|
|
#include "util/sorted_uniform.hh"
|
|
#include "util/string_piece.hh"
|
|
|
|
#include <limits>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace lm {
|
|
struct ProbBackoff;
|
|
class EnumerateVocab;
|
|
|
|
namespace ngram {
|
|
struct Config;
|
|
|
|
namespace detail {
|
|
uint64_t HashForVocab(const char *str, std::size_t len);
|
|
inline uint64_t HashForVocab(const StringPiece &str) {
|
|
return HashForVocab(str.data(), str.length());
|
|
}
|
|
struct ProbingVocabularyHeader;
|
|
} // namespace detail
|
|
|
|
// Writes words immediately to a file instead of buffering, because we know
|
|
// where in the file to put them.
|
|
class ImmediateWriteWordsWrapper : public EnumerateVocab {
|
|
public:
|
|
ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start);
|
|
|
|
void Add(WordIndex index, const StringPiece &str) {
|
|
stream_ << str << '\0';
|
|
if (inner_) inner_->Add(index, str);
|
|
}
|
|
|
|
private:
|
|
EnumerateVocab *inner_;
|
|
|
|
util::FileStream stream_;
|
|
};
|
|
|
|
// When the binary size isn't known yet.
|
|
class WriteWordsWrapper : public EnumerateVocab {
|
|
public:
|
|
WriteWordsWrapper(EnumerateVocab *inner);
|
|
|
|
void Add(WordIndex index, const StringPiece &str);
|
|
|
|
const std::string &Buffer() const { return buffer_; }
|
|
void Write(int fd, uint64_t start);
|
|
|
|
private:
|
|
EnumerateVocab *inner_;
|
|
|
|
std::string buffer_;
|
|
};
|
|
|
|
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
|
|
class SortedVocabulary : public base::Vocabulary {
|
|
public:
|
|
SortedVocabulary();
|
|
|
|
WordIndex Index(const StringPiece &str) const {
|
|
const uint64_t *found;
|
|
if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
|
|
util::IdentityAccessor<uint64_t>(),
|
|
begin_ - 1, 0,
|
|
end_, std::numeric_limits<uint64_t>::max(),
|
|
detail::HashForVocab(str), found)) {
|
|
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
|
|
} else {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
// Size for purposes of file writing
|
|
static uint64_t Size(uint64_t entries, const Config &config);
|
|
|
|
/* Read null-delimited words from file from_words, renumber according to
|
|
* hash order, write null-delimited words to to_words, and create a mapping
|
|
* from old id to new id. The 0th vocab word must be <unk>.
|
|
*/
|
|
static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping);
|
|
|
|
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
|
|
WordIndex Bound() const { return bound_; }
|
|
|
|
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
|
|
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
|
|
|
|
void Relocate(void *new_start);
|
|
|
|
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
|
|
|
|
// Insert and FinishedLoading go together.
|
|
WordIndex Insert(const StringPiece &str);
|
|
// Reorders reorder_vocab so that the IDs are sorted.
|
|
void FinishedLoading(ProbBackoff *reorder_vocab);
|
|
|
|
// Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
|
|
std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
|
|
|
|
bool SawUnk() const { return saw_unk_; }
|
|
|
|
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
|
|
|
|
uint64_t *&EndHack() { return end_; }
|
|
|
|
void Populated();
|
|
|
|
private:
|
|
template <class T> void GenericFinished(T *reorder);
|
|
|
|
uint64_t *begin_, *end_;
|
|
|
|
WordIndex bound_;
|
|
|
|
bool saw_unk_;
|
|
|
|
EnumerateVocab *enumerate_;
|
|
|
|
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
|
|
util::Pool string_backing_;
|
|
|
|
std::vector<StringPiece> strings_to_enumerate_;
|
|
};
|
|
|
|
#pragma pack(push)
|
|
#pragma pack(4)
|
|
struct ProbingVocabularyEntry {
|
|
uint64_t key;
|
|
WordIndex value;
|
|
|
|
typedef uint64_t Key;
|
|
uint64_t GetKey() const { return key; }
|
|
void SetKey(uint64_t to) { key = to; }
|
|
|
|
static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) {
|
|
ProbingVocabularyEntry ret;
|
|
ret.key = key;
|
|
ret.value = value;
|
|
return ret;
|
|
}
|
|
};
|
|
#pragma pack(pop)
|
|
|
|
// Vocabulary storing a map from uint64_t to WordIndex.
|
|
class ProbingVocabulary : public base::Vocabulary {
|
|
public:
|
|
ProbingVocabulary();
|
|
|
|
WordIndex Index(const StringPiece &str) const {
|
|
Lookup::ConstIterator i;
|
|
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
|
|
}
|
|
|
|
static uint64_t Size(uint64_t entries, float probing_multiplier);
|
|
// This just unwraps Config to get the probing_multiplier.
|
|
static uint64_t Size(uint64_t entries, const Config &config);
|
|
|
|
// Vocab words are [0, Bound()).
|
|
WordIndex Bound() const { return bound_; }
|
|
|
|
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
|
|
void SetupMemory(void *start, std::size_t allocated);
|
|
void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
|
|
SetupMemory(start, allocated);
|
|
}
|
|
|
|
void Relocate(void *new_start);
|
|
|
|
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
|
|
|
|
WordIndex Insert(const StringPiece &str);
|
|
|
|
template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
|
|
InternalFinishedLoading();
|
|
}
|
|
|
|
std::size_t UnkCountChangePadding() const { return 0; }
|
|
|
|
bool SawUnk() const { return saw_unk_; }
|
|
|
|
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
|
|
|
|
private:
|
|
void InternalFinishedLoading();
|
|
|
|
typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
|
|
|
|
Lookup lookup_;
|
|
|
|
WordIndex bound_;
|
|
|
|
bool saw_unk_;
|
|
|
|
EnumerateVocab *enumerate_;
|
|
|
|
detail::ProbingVocabularyHeader *header_;
|
|
};
|
|
|
|
void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
|
|
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
|
|
|
|
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
|
|
if (!vocab.SawUnk()) MissingUnknown(config);
|
|
if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
|
|
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
|
|
}
|
|
|
|
class WriteUniqueWords {
|
|
public:
|
|
explicit WriteUniqueWords(int fd) : word_list_(fd) {}
|
|
|
|
void operator()(const StringPiece &word) {
|
|
word_list_ << word << '\0';
|
|
}
|
|
|
|
private:
|
|
util::FileStream word_list_;
|
|
};
|
|
|
|
class NoOpUniqueWords {
|
|
public:
|
|
NoOpUniqueWords() {}
|
|
void operator()(const StringPiece &word) {}
|
|
};
|
|
|
|
template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
|
|
public:
|
|
static std::size_t MemUsage(WordIndex content) {
|
|
return Lookup::MemUsage(content > 2 ? content : 2);
|
|
}
|
|
|
|
// Does not take ownership of write_wordi
|
|
template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
|
|
: lookup_(initial_size), new_word_(new_word_construct) {
|
|
FindOrInsert("<unk>"); // Force 0
|
|
FindOrInsert("<s>"); // Force 1
|
|
FindOrInsert("</s>"); // Force 2
|
|
}
|
|
|
|
WordIndex Index(const StringPiece &str) const {
|
|
Lookup::ConstIterator i;
|
|
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
|
|
}
|
|
|
|
WordIndex FindOrInsert(const StringPiece &word) {
|
|
ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
|
|
Lookup::MutableIterator it;
|
|
if (!lookup_.FindOrInsert(entry, it)) {
|
|
new_word_(word);
|
|
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
|
|
}
|
|
return it->value;
|
|
}
|
|
|
|
WordIndex Size() const { return lookup_.Size(); }
|
|
|
|
private:
|
|
typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;
|
|
|
|
Lookup lookup_;
|
|
|
|
NewWordAction new_word_;
|
|
};
|
|
|
|
} // namespace ngram
|
|
} // namespace lm
|
|
|
|
#endif // LM_VOCAB_H
|