mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 13:23:25 +03:00
234 lines
7.3 KiB
C++
234 lines
7.3 KiB
C++
#ifndef LM_QUANTIZE_H
|
|
#define LM_QUANTIZE_H
|
|
|
|
#include "lm/blank.hh"
|
|
#include "lm/config.hh"
|
|
#include "lm/max_order.hh"
|
|
#include "lm/model_type.hh"
|
|
#include "util/bit_packing.hh"
|
|
|
|
#include <algorithm>
|
|
#include <vector>
|
|
|
|
#include <stdint.h>
|
|
|
|
#include <iostream>
|
|
|
|
namespace lm {
|
|
namespace ngram {
|
|
|
|
struct Config;
|
|
class BinaryFormat;
|
|
|
|
/* Store values directly and don't quantize. */
|
|
class DontQuantize {
|
|
public:
|
|
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
|
|
static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
|
|
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
|
|
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
|
|
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
|
|
|
|
class MiddlePointer {
|
|
public:
|
|
MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {}
|
|
|
|
MiddlePointer() : address_(NULL, 0) {}
|
|
|
|
bool Found() const {
|
|
return address_.base != NULL;
|
|
}
|
|
|
|
float Prob() const {
|
|
return util::ReadNonPositiveFloat31(address_.base, address_.offset);
|
|
}
|
|
|
|
float Backoff() const {
|
|
return util::ReadFloat32(address_.base, address_.offset + 31);
|
|
}
|
|
|
|
float Rest() const { return Prob(); }
|
|
|
|
void Write(float prob, float backoff) {
|
|
util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
|
|
util::WriteFloat32(address_.base, address_.offset + 31, backoff);
|
|
}
|
|
|
|
private:
|
|
util::BitAddress address_;
|
|
};
|
|
|
|
class LongestPointer {
|
|
public:
|
|
explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {}
|
|
|
|
LongestPointer() : address_(NULL, 0) {}
|
|
|
|
bool Found() const {
|
|
return address_.base != NULL;
|
|
}
|
|
|
|
float Prob() const {
|
|
return util::ReadNonPositiveFloat31(address_.base, address_.offset);
|
|
}
|
|
|
|
void Write(float prob) {
|
|
util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
|
|
}
|
|
|
|
private:
|
|
util::BitAddress address_;
|
|
};
|
|
|
|
DontQuantize() {}
|
|
|
|
void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {}
|
|
|
|
static const bool kTrain = false;
|
|
// These should never be called because kTrain is false.
|
|
void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {}
|
|
void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}
|
|
|
|
void FinishedLoading(const Config &) {}
|
|
};
|
|
|
|
class SeparatelyQuantize {
|
|
private:
|
|
class Bins {
|
|
public:
|
|
// Sigh C++ default constructor
|
|
Bins() {}
|
|
|
|
Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
|
|
|
|
float *Populate() { return begin_; }
|
|
|
|
uint64_t EncodeProb(float value) const {
|
|
return Encode(value, 0);
|
|
}
|
|
|
|
uint64_t EncodeBackoff(float value) const {
|
|
if (value == 0.0) {
|
|
return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant;
|
|
}
|
|
return Encode(value, 2);
|
|
}
|
|
|
|
float Decode(std::size_t off) const { return begin_[off]; }
|
|
|
|
uint8_t Bits() const { return bits_; }
|
|
|
|
uint64_t Mask() const { return mask_; }
|
|
|
|
private:
|
|
uint64_t Encode(float value, size_t reserved) const {
|
|
const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);
|
|
if (above == begin_ + reserved) return reserved;
|
|
if (above == end_) return end_ - begin_ - 1;
|
|
return above - begin_ - (value - *(above - 1) < *above - value);
|
|
}
|
|
|
|
float *begin_;
|
|
const float *end_;
|
|
uint8_t bits_;
|
|
uint64_t mask_;
|
|
};
|
|
|
|
public:
|
|
static const ModelType kModelTypeAdd = kQuantAdd;
|
|
|
|
static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
|
|
|
|
static uint64_t Size(uint8_t order, const Config &config) {
|
|
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
|
|
uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
|
|
// unigrams are currently not quantized so no need for a table.
|
|
return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
|
|
}
|
|
|
|
static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
|
|
static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
|
|
|
|
class MiddlePointer {
|
|
public:
|
|
MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {}
|
|
|
|
MiddlePointer() : address_(NULL, 0) {}
|
|
|
|
bool Found() const { return address_.base != NULL; }
|
|
|
|
float Prob() const {
|
|
return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));
|
|
}
|
|
|
|
float Backoff() const {
|
|
return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));
|
|
}
|
|
|
|
float Rest() const { return Prob(); }
|
|
|
|
void Write(float prob, float backoff) const {
|
|
util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),
|
|
(ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff));
|
|
}
|
|
|
|
private:
|
|
const Bins &ProbBins() const { return bins_[0]; }
|
|
const Bins &BackoffBins() const { return bins_[1]; }
|
|
const Bins *bins_;
|
|
|
|
util::BitAddress address_;
|
|
};
|
|
|
|
class LongestPointer {
|
|
public:
|
|
LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {}
|
|
|
|
LongestPointer() : address_(NULL, 0) {}
|
|
|
|
bool Found() const { return address_.base != NULL; }
|
|
|
|
void Write(float prob) const {
|
|
util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));
|
|
}
|
|
|
|
float Prob() const {
|
|
return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));
|
|
}
|
|
|
|
private:
|
|
const Bins *table_;
|
|
util::BitAddress address_;
|
|
};
|
|
|
|
SeparatelyQuantize() {}
|
|
|
|
void SetupMemory(void *start, unsigned char order, const Config &config);
|
|
|
|
static const bool kTrain = true;
|
|
// Assumes 0.0 is removed from backoff.
|
|
void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff);
|
|
// Train just probabilities (for longest order).
|
|
void TrainProb(uint8_t order, std::vector<float> &prob);
|
|
|
|
void FinishedLoading(const Config &config);
|
|
|
|
const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; }
|
|
|
|
const Bins &LongestTable() const { return longest_; }
|
|
|
|
private:
|
|
Bins tables_[KENLM_MAX_ORDER - 1][2];
|
|
|
|
Bins longest_;
|
|
|
|
uint8_t *actual_base_;
|
|
|
|
uint8_t prob_bits_, backoff_bits_;
|
|
};
|
|
|
|
} // namespace ngram
|
|
} // namespace lm
|
|
|
|
#endif // LM_QUANTIZE_H
|