mosesdecoder/moses/LM/DALMWrapper.h

85 lines
2.2 KiB
C
Raw Normal View History

2013-11-05 18:37:56 +04:00
// $Id$
#pragma once
#include <vector>
2013-11-11 23:49:00 +04:00
#include <boost/bimap.hpp>
2013-12-16 18:17:56 +04:00
#include "Implementation.h"
#include "moses/Hypothesis.h"
2013-11-05 18:37:56 +04:00
namespace DALM
{
class Logger;
class Vocabulary;
2013-12-16 18:17:56 +04:00
class State;
class LM;
2013-11-11 23:49:00 +04:00
typedef unsigned int VocabId;
}
2013-11-05 18:37:56 +04:00
namespace Moses
{
2013-11-11 23:49:00 +04:00
class Factor;
2013-11-05 18:37:56 +04:00
2013-12-16 18:17:56 +04:00
class LanguageModelDALM : public LanguageModel
2013-11-05 18:37:56 +04:00
{
2013-12-16 18:17:56 +04:00
public:
LanguageModelDALM(const std::string &line);
virtual ~LanguageModelDALM();
void Load();
virtual const FFState *EmptyHypothesisState(const InputType &/*input*/) const;
virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
virtual FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const;
virtual FFState *EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const;
virtual bool IsUseable(const FactorMask &mask) const;
virtual void SetParameter(const std::string& key, const std::string& value);
2013-11-05 18:37:56 +04:00
protected:
2013-12-16 18:17:56 +04:00
const Factor *m_beginSentenceFactor;
FactorType m_factorType;
std::string m_filePath;
size_t m_nGramOrder; //! max n-gram length contained in this LM
DALM::Logger *m_logger;
DALM::Vocabulary *m_vocab;
DALM::LM *m_lm;
2013-11-11 23:49:00 +04:00
DALM::VocabId wid_start, wid_end;
2013-11-18 17:54:40 +04:00
typedef boost::bimap<const Factor *, DALM::VocabId> VocabMap;
mutable VocabMap m_vocabMap;
void CreateVocabMapping(const std::string &wordstxt);
2013-11-11 23:49:00 +04:00
DALM::VocabId GetVocabId(const Factor *factor) const;
2013-12-16 18:17:56 +04:00
private:
LMResult GetValue(DALM::VocabId wid, DALM::State* finalState) const;
LMResult GetValue(const Word &word, DALM::State* finalState) const;
void updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const;
// Convert last words of hypothesis into vocab ids, returning an end pointer.
DALM::VocabId *LastIDs(const Hypothesis &hypo, DALM::VocabId *indices) const {
DALM::VocabId *index = indices;
DALM::VocabId *end = indices + m_nGramOrder - 1;
int position = hypo.GetCurrTargetWordsRange().GetEndPos();
for (; ; ++index, --position) {
if (index == end) return index;
if (position == -1) {
*index = wid_start;
return index + 1;
}
*index = GetVocabId(hypo.GetWord(position).GetFactor(m_factorType));
}
}
2013-11-05 18:37:56 +04:00
};
}
2013-12-16 18:17:56 +04:00