mosesdecoder/contrib/other-builds/moses2/LM/KENLM.cpp

325 lines
8.6 KiB
C++
Raw Normal View History

2015-11-04 16:03:26 +03:00
/*
* KENLM.cpp
*
* Created on: 4 Nov 2015
* Author: hieu
*/
2015-12-02 15:05:18 +03:00
#include <sstream>
2015-11-04 17:54:20 +03:00
#include <vector>
2015-11-04 16:03:26 +03:00
#include "KENLM.h"
2015-11-04 17:54:20 +03:00
#include "../TargetPhrase.h"
#include "../Scores.h"
2015-11-04 16:09:53 +03:00
#include "../System.h"
2015-11-04 19:11:56 +03:00
#include "../Search/Hypothesis.h"
#include "../Search/Manager.h"
2015-11-04 16:03:26 +03:00
#include "lm/state.hh"
2015-11-04 17:54:20 +03:00
#include "lm/left.hh"
2015-11-13 13:40:55 +03:00
#include "../legacy/FactorCollection.h"
2015-11-04 16:03:26 +03:00
2015-11-04 17:54:20 +03:00
using namespace std;
2015-12-10 23:49:30 +03:00
namespace Moses2
{
2015-11-12 23:34:58 +03:00
struct KenLMState : public FFState {
lm::ngram::State state;
2015-11-04 16:03:26 +03:00
virtual size_t hash() const {
size_t ret = hash_value(state);
return ret;
2015-11-04 16:03:26 +03:00
}
2015-11-12 23:34:58 +03:00
virtual bool operator==(const FFState& o) const {
2015-11-04 16:03:26 +03:00
const KenLMState &other = static_cast<const KenLMState &>(o);
bool ret = state == other.state;
2015-11-04 16:03:26 +03:00
return ret;
}
2015-12-02 15:05:18 +03:00
virtual std::string ToString() const
{
stringstream ss;
for (size_t i = 0; i < state.Length(); ++i) {
ss << state.words[i] << " ";
2015-12-02 15:05:18 +03:00
}
return ss.str();
}
2015-11-04 16:03:26 +03:00
};
/////////////////////////////////////////////////////////////////
class MappingBuilder : public lm::EnumerateVocab
{
public:
MappingBuilder(FactorCollection &factorCollection, System &system, std::vector<lm::WordIndex> &mapping)
2015-11-18 14:08:32 +03:00
: m_factorCollection(factorCollection)
, m_system(system)
, m_mapping(mapping)
2015-11-18 14:08:32 +03:00
{}
2015-11-04 16:03:26 +03:00
void Add(lm::WordIndex index, const StringPiece &str) {
std::size_t factorId = m_factorCollection.AddFactor(str, m_system)->GetId();
if (m_mapping.size() <= factorId) {
// 0 is <unk> :-)
m_mapping.resize(factorId + 1);
}
m_mapping[factorId] = index;
2015-11-04 16:03:26 +03:00
}
private:
2015-11-13 01:51:13 +03:00
FactorCollection &m_factorCollection;
std::vector<lm::WordIndex> &m_mapping;
2015-11-18 14:08:32 +03:00
System &m_system;
2015-11-04 16:03:26 +03:00
};
/////////////////////////////////////////////////////////////////
KENLM::KENLM(size_t startInd, const std::string &line)
:StatefulFeatureFunction(startInd, line)
2015-11-06 12:04:19 +03:00
,m_lazy(false)
2015-11-04 16:03:26 +03:00
{
ReadParameters();
}
KENLM::~KENLM()
{
// TODO Auto-generated destructor stub
}
void KENLM::Load(System &system)
{
2015-11-18 18:33:42 +03:00
FactorCollection &fc = system.GetVocab();
2015-11-04 17:54:20 +03:00
2015-12-07 00:12:44 +03:00
m_bos = fc.AddFactor(BOS_, system, false);
m_eos = fc.AddFactor(EOS_, system, false);
2015-11-04 16:03:26 +03:00
2015-11-04 17:54:20 +03:00
lm::ngram::Config config;
config.messages = NULL;
2015-11-04 16:03:26 +03:00
2015-11-18 18:33:42 +03:00
FactorCollection &collection = system.GetVocab();
MappingBuilder builder(collection, system, m_lmIdLookup);
2015-11-04 17:54:20 +03:00
config.enumerate_vocab = &builder;
2015-11-06 12:04:19 +03:00
config.load_method = m_lazy ? util::LAZY : util::POPULATE_OR_READ;
2015-11-04 17:54:20 +03:00
m_ngram.reset(new Model(m_path.c_str(), config));
2015-11-04 16:03:26 +03:00
}
void KENLM::InitializeForInput(const Manager &mgr) const
{
}
// clean up temporary memory, called after processing each sentence
void KENLM::CleanUpAfterSentenceProcessing(const Manager &mgr) const
{
}
2015-11-04 16:03:26 +03:00
void KENLM::SetParameter(const std::string& key, const std::string& value)
{
if (key == "path") {
m_path = value;
}
else if (key == "factor") {
2015-11-13 01:51:13 +03:00
m_factorType = Scan<FactorType>(value);
2015-11-04 16:03:26 +03:00
}
2015-11-06 12:04:19 +03:00
else if (key == "lazyken") {
2015-11-12 02:29:58 +03:00
m_lazy = Scan<bool>(value);
2015-11-06 12:04:19 +03:00
}
2015-11-04 16:03:26 +03:00
else if (key == "order") {
// don't need to store it
}
else {
StatefulFeatureFunction::SetParameter(key, value);
}
}
2016-01-05 17:34:59 +03:00
FFState* KENLM::BlankState(MemPool &pool) const
2015-11-05 18:34:24 +03:00
{
2015-11-05 18:59:09 +03:00
KenLMState *ret = new (pool.Allocate<KenLMState>()) KenLMState();
return ret;
2015-11-05 18:34:24 +03:00
}
2015-11-04 16:03:26 +03:00
//! return the state associated with the empty hypothesis for a given sentence
2015-12-15 18:24:57 +03:00
void KENLM::EmptyHypothesisState(FFState &state,
const Manager &mgr,
const InputType &input,
const Hypothesis &hypo) const
2015-11-04 16:03:26 +03:00
{
2015-11-05 19:16:55 +03:00
KenLMState &stateCast = static_cast<KenLMState&>(state);
stateCast.state = m_ngram->BeginSentenceState();
2015-11-04 16:03:26 +03:00
}
void
KENLM::EvaluateInIsolation(MemPool &pool,
const System &system,
const Phrase &source,
const TargetPhrase &targetPhrase,
2015-11-04 16:03:26 +03:00
Scores &scores,
SCORE *estimatedScore) const
2015-11-04 16:03:26 +03:00
{
2015-11-04 18:25:09 +03:00
// contains factors used by this LM
float fullScore, nGramScore;
size_t oovCount;
2015-11-04 16:03:26 +03:00
CalcScore(targetPhrase, fullScore, nGramScore, oovCount);
2015-11-04 18:25:09 +03:00
float estimateScore = fullScore - nGramScore;
bool GetLMEnableOOVFeature = false;
if (GetLMEnableOOVFeature) {
float scoresVec[2], estimateScoresVec[2];
2015-11-04 18:25:09 +03:00
scoresVec[0] = nGramScore;
scoresVec[1] = oovCount;
scores.PlusEquals(system, *this, scoresVec);
2015-11-04 18:25:09 +03:00
estimateScoresVec[0] = estimateScore;
estimateScoresVec[1] = 0;
SCORE weightedScore = Scores::CalcWeightedScore(system, *this, estimateScoresVec);
(*estimatedScore) += weightedScore;
2015-11-04 18:25:09 +03:00
}
else {
scores.PlusEquals(system, *this, nGramScore);
SCORE weightedScore = Scores::CalcWeightedScore(system, *this, estimateScore);
(*estimatedScore) += weightedScore;
2015-11-04 18:25:09 +03:00
}
2015-11-04 16:03:26 +03:00
}
2015-11-05 19:35:31 +03:00
void KENLM::EvaluateWhenApplied(const Manager &mgr,
2015-11-04 16:03:26 +03:00
const Hypothesis &hypo,
2015-11-12 23:34:58 +03:00
const FFState &prevState,
2015-11-05 19:35:31 +03:00
Scores &scores,
2015-11-12 23:34:58 +03:00
FFState &state) const
2015-11-04 16:03:26 +03:00
{
2015-11-05 19:35:31 +03:00
KenLMState &stateCast = static_cast<KenLMState&>(state);
const System &system = mgr.system;
const lm::ngram::State &in_state = static_cast<const KenLMState&>(prevState).state;
2015-11-04 19:11:56 +03:00
if (!hypo.GetTargetPhrase().GetSize()) {
stateCast.state = in_state;
2015-11-05 19:35:31 +03:00
return;
2015-11-04 19:11:56 +03:00
}
const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
//[begin, end) in STL-like fashion.
const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1);
std::size_t position = begin;
typename Model::State aux_state;
typename Model::State *state0 = &stateCast.state, *state1 = &aux_state;
float score = ScoreAndCache(mgr, in_state, TranslateID(hypo.GetWord(position)), *state0);
2016-02-17 16:45:05 +03:00
++position;
for (; position < adjust_end; ++position) {
score += ScoreAndCache(mgr, *state0, TranslateID(hypo.GetWord(position)), *state1);
std::swap(state0, state1);
}
2016-02-16 19:03:37 +03:00
if (hypo.GetBitmap().IsComplete()) {
// Score end of sentence.
std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
const lm::WordIndex *last = LastIDs(hypo, &indices.front());
score += m_ngram->FullScoreForgotState(&indices.front(), last, m_ngram->GetVocabulary().EndSentence(), stateCast.state).prob;
} else if (adjust_end < end) {
// Get state after adding a long phrase.
std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
const lm::WordIndex *last = LastIDs(hypo, &indices.front());
m_ngram->GetState(&indices.front(), last, stateCast.state);
} else if (state0 != &stateCast.state) {
// Short enough phrase that we can just reuse the state.
stateCast.state = *state0;
2016-02-16 19:03:37 +03:00
}
score = TransformLMScore(score);
2015-11-04 19:11:56 +03:00
bool OOVFeatureEnabled = false;
if (OOVFeatureEnabled) {
std::vector<float> scoresVec(2);
scoresVec[0] = score;
scoresVec[1] = 0.0;
scores.PlusEquals(system, *this, scoresVec);
} else {
scores.PlusEquals(system, *this, score);
}
2015-11-04 16:03:26 +03:00
}
void KENLM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, std::size_t &oovCount) const
2015-11-04 17:54:20 +03:00
{
fullScore = 0;
ngramScore = 0;
oovCount = 0;
if (!phrase.GetSize()) return;
lm::ngram::ChartState discarded_sadly;
lm::ngram::RuleScore<Model> scorer(*m_ngram, discarded_sadly);
2015-11-04 17:54:20 +03:00
size_t position;
2015-11-04 17:55:18 +03:00
if (m_bos == phrase[0][m_factorType]) {
2015-11-04 17:54:20 +03:00
scorer.BeginSentence();
position = 1;
} else {
position = 0;
}
size_t ngramBoundary = m_ngram->Order() - 1;
size_t end_loop = std::min(ngramBoundary, phrase.GetSize());
for (; position < end_loop; ++position) {
const Word &word = phrase[position];
lm::WordIndex index = TranslateID(word);
scorer.Terminal(index);
if (!index) ++oovCount;
}
float before_boundary = fullScore + scorer.Finish();
for (; position < phrase.GetSize(); ++position) {
const Word &word = phrase[position];
lm::WordIndex index = TranslateID(word);
scorer.Terminal(index);
if (!index) ++oovCount;
}
fullScore += scorer.Finish();
2015-11-13 13:40:55 +03:00
ngramScore = TransformLMScore(fullScore - before_boundary);
fullScore = TransformLMScore(fullScore);
2015-11-04 17:54:20 +03:00
}
2015-11-04 19:11:56 +03:00
// Convert last words of hypothesis into vocab ids, returning an end pointer.
lm::WordIndex *KENLM::LastIDs(const Hypothesis &hypo, lm::WordIndex *indices) const {
lm::WordIndex *index = indices;
lm::WordIndex *end = indices + m_ngram->Order() - 1;
int position = hypo.GetCurrTargetWordsRange().GetEndPos();
for (; ; ++index, --position) {
if (index == end) return index;
if (position == -1) {
*index = m_ngram->GetVocabulary().BeginSentence();
return index + 1;
}
*index = TranslateID(hypo.GetWord(position));
}
}
float KENLM::ScoreAndCache(const Manager &mgr, const lm::ngram::State &in_state, const lm::WordIndex new_word, lm::ngram::State &out_state) const
2016-02-17 16:45:05 +03:00
{
//cerr << "score=";
float score;
if (mgr.FindLMCache(in_state, new_word, score, out_state)) {
// found in cache. score & set set in the call
//cerr << "in cache ";
}
else {
//cerr << "not cache ";
score = m_ngram->Score(in_state, new_word, out_state);
mgr.AddLMCache(in_state, new_word, score, out_state);
}
//score = m_ngram->Score(in_state, new_word, out_state);
2016-02-18 15:05:38 +03:00
//cerr << score << " " << (int) out_state.length << endl;
return score;
}
2015-12-10 23:49:30 +03:00
}