mosesdecoder/moses/LM/DALMWrapper.cpp

598 lines
16 KiB
C++
Raw Normal View History

2013-11-05 18:37:56 +04:00
2014-06-03 11:13:32 +04:00
//#include <boost/functional/hash.hpp>
#include <algorithm>
2013-12-16 18:17:56 +04:00
#include "moses/FF/FFState.h"
2013-11-21 23:19:34 +04:00
#include "DALMWrapper.h"
2013-11-11 22:27:15 +04:00
#include "logger.h"
#include "dalm.h"
2013-11-11 23:49:00 +04:00
#include "vocabulary.h"
2013-12-16 18:17:56 +04:00
#include "moses/FactorTypeSet.h"
2013-11-14 22:31:46 +04:00
#include "moses/FactorCollection.h"
2013-11-18 17:54:40 +04:00
#include "moses/InputFileStream.h"
#include "util/exception.hh"
2014-02-17 19:00:28 +04:00
#include "moses/ChartHypothesis.h"
#include "moses/ChartManager.h"
2013-11-05 18:37:56 +04:00
using namespace std;
2013-11-11 18:39:53 +04:00
/////////////////////////
2015-01-14 14:07:42 +03:00
void read_ini(const char *inifile, string &model, string &words, string &wordstxt)
{
ifstream ifs(inifile);
string line;
getline(ifs, line);
while(ifs) {
unsigned int pos = line.find("=");
string key = line.substr(0, pos);
string value = line.substr(pos+1, line.size()-pos);
if(key=="MODEL") {
model = value;
} else if(key=="WORDS") {
words = value;
} else if(key=="WORDSTXT") {
wordstxt = value;
}
getline(ifs, line);
}
2013-11-11 18:39:53 +04:00
}
/////////////////////////
2013-11-11 23:49:00 +04:00
namespace Moses
{
2013-12-16 18:17:56 +04:00
class DALMState : public FFState
{
private:
2015-01-14 14:07:42 +03:00
DALM::State state;
2013-12-16 18:17:56 +04:00
public:
2015-01-14 14:07:42 +03:00
DALMState() {
}
DALMState(const DALMState &from) {
state = from.state;
}
virtual ~DALMState() {
}
void reset(const DALMState &from) {
state = from.state;
}
virtual int Compare(const FFState& other) const {
const DALMState &o = static_cast<const DALMState &>(other);
if(state.get_count() < o.state.get_count()) return -1;
else if(state.get_count() > o.state.get_count()) return 1;
else return state.compare(o.state);
}
DALM::State &get_state() {
return state;
}
void refresh() {
state.refresh();
}
2013-12-16 18:17:56 +04:00
};
2014-02-17 19:00:28 +04:00
class DALMChartState : public FFState
{
private:
2015-01-14 14:07:42 +03:00
DALM::Fragment prefixFragments[DALM_MAX_ORDER-1];
unsigned char prefixLength;
DALM::State rightContext;
bool isLarge;
size_t hypoSize;
2014-02-17 19:00:28 +04:00
public:
2015-01-14 14:07:42 +03:00
DALMChartState()
: prefixLength(0),
isLarge(false) {
}
/*
DALMChartState(const DALMChartState &other)
: prefixLength(other.prefixLength),
rightContext(other.rightContext),
isLarge(other.isLarge)
{
std::copy(
other.prefixFragments,
other.prefixFragments+other.prefixLength,
prefixFragments
);
}
*/
virtual ~DALMChartState() {
}
/*
DALMChartState &operator=(const DALMChartState &other){
prefixLength = other.prefixLength;
std::copy(
other.prefixFragments,
other.prefixFragments+other.prefixLength,
prefixFragments
);
rightContext = other.rightContext;
isLarge=other.isLarge;
return *this;
}
*/
inline unsigned char GetPrefixLength() const {
return prefixLength;
}
inline unsigned char &GetPrefixLength() {
return prefixLength;
}
inline const DALM::Fragment *GetPrefixFragments() const {
return prefixFragments;
}
inline DALM::Fragment *GetPrefixFragments() {
return prefixFragments;
}
inline const DALM::State &GetRightContext() const {
return rightContext;
}
inline DALM::State &GetRightContext() {
return rightContext;
}
inline bool LargeEnough() const {
return isLarge;
}
inline void SetAsLarge() {
isLarge=true;
}
inline size_t &GetHypoSize() {
return hypoSize;
}
inline size_t GetHypoSize() const {
return hypoSize;
}
virtual int Compare(const FFState& other) const {
const DALMChartState &o = static_cast<const DALMChartState &>(other);
if(prefixLength < o.prefixLength) return -1;
if(prefixLength > o.prefixLength) return 1;
if(prefixLength!=0) {
const DALM::Fragment &f = prefixFragments[prefixLength-1];
const DALM::Fragment &of = o.prefixFragments[prefixLength-1];
int ret = DALM::compare_fragments(f,of);
if(ret != 0) return ret;
}
if(isLarge != o.isLarge) return (int)isLarge - (int)o.isLarge;
if(rightContext.get_count() < o.rightContext.get_count()) return -1;
if(rightContext.get_count() > o.rightContext.get_count()) return 1;
return rightContext.compare(o.rightContext);
}
2014-02-17 19:00:28 +04:00
};
2013-11-05 18:37:56 +04:00
LanguageModelDALM::LanguageModelDALM(const std::string &line)
2013-12-16 18:17:56 +04:00
:LanguageModel(line)
2013-11-05 18:37:56 +04:00
{
ReadParameters();
if (m_factorType == NOT_FOUND) {
m_factorType = 0;
}
}
LanguageModelDALM::~LanguageModelDALM()
{
2015-01-14 14:07:42 +03:00
delete m_logger;
delete m_vocab;
delete m_lm;
2013-11-05 18:37:56 +04:00
}
2013-11-11 18:39:53 +04:00
void LanguageModelDALM::Load()
{
2015-01-14 14:07:42 +03:00
/////////////////////
// READING INIFILE //
/////////////////////
string inifile= m_filePath + "/dalm.ini";
string model; // Path to the double-array file.
string words; // Path to the vocabulary file.
string wordstxt; //Path to the vocabulary file in text format.
read_ini(inifile.c_str(), model, words, wordstxt);
model = m_filePath + "/" + model;
words = m_filePath + "/" + words;
wordstxt = m_filePath + "/" + wordstxt;
UTIL_THROW_IF(model.empty() || words.empty() || wordstxt.empty(),
util::FileOpenException,
"Failed to read DALM ini file " << m_filePath << ". Probably doesn't exist");
////////////////
// LOADING LM //
////////////////
// Preparing a logger object.
m_logger = new DALM::Logger(stderr);
m_logger->setLevel(DALM::LOGGER_INFO);
// Load the vocabulary file.
m_vocab = new DALM::Vocabulary(words, *m_logger);
// Load the language model.
m_lm = new DALM::LM(model, *m_vocab, m_nGramOrder, *m_logger);
wid_start = m_vocab->lookup(BOS_);
wid_end = m_vocab->lookup(EOS_);
// vocab mapping
CreateVocabMapping(wordstxt);
FactorCollection &collection = FactorCollection::Instance();
m_beginSentenceFactor = collection.AddFactor(BOS_);
2013-11-11 18:39:53 +04:00
}
2015-01-14 14:07:42 +03:00
const FFState *LanguageModelDALM::EmptyHypothesisState(const InputType &/*input*/) const
{
DALMState *s = new DALMState();
m_lm->init_state(s->get_state());
return s;
2013-12-16 18:17:56 +04:00
}
2013-11-11 23:49:00 +04:00
2015-01-14 14:07:42 +03:00
void LanguageModelDALM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
{
2013-12-16 18:17:56 +04:00
fullScore = 0;
ngramScore = 0;
oovCount = 0;
size_t phraseSize = phrase.GetSize();
if (!phraseSize) return;
2015-01-14 14:07:42 +03:00
2013-12-16 18:17:56 +04:00
size_t currPos = 0;
size_t hist_count = 0;
2015-01-14 14:07:42 +03:00
DALM::State state;
if(phrase.GetWord(0).GetFactor(m_factorType) == m_beginSentenceFactor) {
m_lm->init_state(state);
currPos++;
hist_count++;
}
float score;
2013-12-16 18:17:56 +04:00
while (currPos < phraseSize) {
const Word &word = phrase.GetWord(currPos);
hist_count++;
if (word.IsNonTerminal()) {
2014-06-03 11:13:32 +04:00
state.refresh();
2013-12-16 18:17:56 +04:00
hist_count = 0;
} else {
2015-01-14 14:07:42 +03:00
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
score = m_lm->query(wid, state);
fullScore += score;
2014-02-14 15:22:53 +04:00
if (hist_count >= m_nGramOrder) ngramScore += score;
if (wid==m_vocab->unk()) ++oovCount;
2013-12-16 18:17:56 +04:00
}
currPos++;
2013-11-11 23:49:00 +04:00
}
2015-01-14 14:07:42 +03:00
fullScore = TransformLMScore(fullScore);
ngramScore = TransformLMScore(ngramScore);
2013-12-16 18:17:56 +04:00
}
2015-01-14 14:07:42 +03:00
FFState *LanguageModelDALM::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
{
2013-12-16 18:17:56 +04:00
// In this function, we only compute the LM scores of n-grams that overlap a
// phrase boundary. Phrase-internal scores are taken directly from the
// translation option.
2015-01-14 14:07:42 +03:00
const DALMState *dalm_ps = static_cast<const DALMState *>(ps);
2013-12-16 18:17:56 +04:00
// Empty phrase added? nothing to be done
2015-01-14 14:07:42 +03:00
if (hypo.GetCurrTargetLength() == 0) {
2013-12-16 18:17:56 +04:00
return dalm_ps ? new DALMState(*dalm_ps) : NULL;
}
2015-01-14 14:07:42 +03:00
2013-12-16 18:17:56 +04:00
const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
//[begin, end) in STL-like fashion.
const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
const std::size_t adjust_end = std::min(end, begin + m_nGramOrder - 1);
2015-01-14 14:07:42 +03:00
2013-12-16 18:17:56 +04:00
DALMState *dalm_state = new DALMState(*dalm_ps);
2015-01-14 14:07:42 +03:00
DALM::State &state = dalm_state->get_state();
2013-12-16 18:17:56 +04:00
float score = 0.0;
2015-01-14 14:07:42 +03:00
for(std::size_t position=begin; position < adjust_end; position++) {
score += m_lm->query(GetVocabId(hypo.GetWord(position).GetFactor(m_factorType)), state);
2013-12-16 18:17:56 +04:00
}
2015-01-14 14:07:42 +03:00
2013-12-16 18:17:56 +04:00
if (hypo.IsSourceCompleted()) {
// Score end of sentence.
std::vector<DALM::VocabId> indices(m_nGramOrder-1);
const DALM::VocabId *last = LastIDs(hypo, &indices.front());
2014-06-03 11:13:32 +04:00
m_lm->set_state(&indices.front(), (last-&indices.front()), state);
2015-01-14 14:07:42 +03:00
score += m_lm->query(wid_end, state);
2013-12-16 18:17:56 +04:00
} else if (adjust_end < end) {
// Get state after adding a long phrase.
std::vector<DALM::VocabId> indices(m_nGramOrder-1);
const DALM::VocabId *last = LastIDs(hypo, &indices.front());
2014-06-03 11:13:32 +04:00
m_lm->set_state(&indices.front(), (last-&indices.front()), state);
2013-12-16 18:17:56 +04:00
}
2015-01-14 14:07:42 +03:00
score = TransformLMScore(score);
2013-12-16 18:17:56 +04:00
if (OOVFeatureEnabled()) {
std::vector<float> scores(2);
scores[0] = score;
scores[1] = 0.0;
out->PlusEquals(this, scores);
} else {
out->PlusEquals(this, score);
}
2015-01-14 14:07:42 +03:00
2013-12-16 18:17:56 +04:00
return dalm_state;
}
2015-01-14 14:07:42 +03:00
FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const
{
2013-12-16 18:17:56 +04:00
// initialize language model context state
2015-01-14 14:07:42 +03:00
DALMChartState *newState = new DALMChartState();
DALM::State &state = newState->GetRightContext();
2013-12-16 18:17:56 +04:00
2015-01-14 14:07:42 +03:00
DALM::Fragment *prefixFragments = newState->GetPrefixFragments();
unsigned char &prefixLength = newState->GetPrefixLength();
size_t &hypoSizeAll = newState->GetHypoSize();
2014-02-17 19:00:28 +04:00
2013-12-16 18:17:56 +04:00
// initial language model scores
2015-01-12 07:05:27 +03:00
float hypoScore = 0.0; // diffs of scores.
2014-02-17 19:00:28 +04:00
2015-01-14 14:07:42 +03:00
const TargetPhrase &targetPhrase = hypo.GetCurrTargetPhrase();
size_t hypoSize = targetPhrase.GetSize();
hypoSizeAll = hypoSize;
2013-12-16 18:17:56 +04:00
// get index map for underlying hypotheses
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
2014-02-17 19:00:28 +04:00
targetPhrase.GetAlignNonTerm().GetNonTermIndexMap();
2015-01-14 14:07:42 +03:00
size_t phrasePos = 0;
// begginig of sentence.
if(hypoSize > 0) {
const Word &word = targetPhrase.GetWord(0);
if(word.GetFactor(m_factorType) == m_beginSentenceFactor) {
m_lm->init_state(state);
// state is finalized.
newState->SetAsLarge();
phrasePos++;
} else if(word.IsNonTerminal()) {
2014-02-17 19:00:28 +04:00
// special case: rule starts with non-terminal -> copy everything
2014-06-03 11:13:32 +04:00
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[0]);
2014-02-17 19:00:28 +04:00
const DALMChartState* prevState =
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
2015-01-14 14:07:42 +03:00
// copy chart state
(*newState) = (*prevState);
hypoSizeAll = hypoSize+prevState->GetHypoSize()-1;
2014-02-28 08:38:35 +04:00
2015-01-14 14:07:42 +03:00
phrasePos++;
}
2014-02-17 19:00:28 +04:00
}
2013-12-16 18:17:56 +04:00
// loop over rule
2014-02-17 19:00:28 +04:00
for (; phrasePos < hypoSize; phrasePos++) {
2014-06-03 11:13:32 +04:00
2013-12-16 18:17:56 +04:00
// consult rule for either word or non-terminal
2014-02-17 19:00:28 +04:00
const Word &word = targetPhrase.GetWord(phrasePos);
2013-12-16 18:17:56 +04:00
// regular word
if (!word.IsNonTerminal()) {
2015-01-14 14:07:42 +03:00
EvaluateTerminal(
word, hypoScore,
newState, state,
prefixFragments, prefixLength
);
2013-12-16 18:17:56 +04:00
}
// non-terminal, add phrase from underlying hypothesis
2014-02-17 19:00:28 +04:00
// internal non-terminal
2013-12-16 18:17:56 +04:00
else {
// look up underlying hypothesis
2015-01-14 14:07:42 +03:00
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
const DALMChartState* prevState =
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
size_t prevTargetPhraseLength = prevHypo->GetCurrTargetPhrase().GetSize();
hypoSizeAll += prevState->GetHypoSize()-1;
EvaluateNonTerminal(
word, hypoScore,
newState, state,
prefixFragments, prefixLength,
prevState, prevTargetPhraseLength
);
2013-12-16 18:17:56 +04:00
}
}
// assign combined score to score breakdown
2015-01-12 07:20:38 +03:00
out->PlusEquals(this, TransformLMScore(hypoScore));
2013-11-05 18:37:56 +04:00
2014-06-03 11:13:32 +04:00
return newState;
2013-12-16 18:17:56 +04:00
}
bool LanguageModelDALM::IsUseable(const FactorMask &mask) const
{
2014-02-17 19:00:28 +04:00
return mask[m_factorType];
2013-11-05 18:37:56 +04:00
}
2013-11-18 17:54:40 +04:00
void LanguageModelDALM::CreateVocabMapping(const std::string &wordstxt)
{
InputFileStream vocabStrm(wordstxt);
2015-01-14 14:07:42 +03:00
std::vector< std::pair<std::size_t, DALM::VocabId> > vlist;
2013-11-18 17:54:40 +04:00
string line;
2015-01-14 14:07:42 +03:00
std::size_t max_fid = 0;
2013-11-18 17:54:40 +04:00
while(getline(vocabStrm, line)) {
2015-01-14 14:07:42 +03:00
const Factor *factor = FactorCollection::Instance().AddFactor(line);
std::size_t fid = factor->GetId();
DALM::VocabId wid = m_vocab->lookup(line.c_str());
2013-11-18 17:54:40 +04:00
2015-01-14 14:07:42 +03:00
vlist.push_back(std::pair<std::size_t, DALM::VocabId>(fid, wid));
if(max_fid < fid) max_fid = fid;
2013-11-18 17:54:40 +04:00
}
2015-01-14 14:07:42 +03:00
for(std::size_t i = 0; i < m_vocabMap.size(); i++) {
m_vocabMap[i] = m_vocab->unk();
}
2014-02-14 15:22:53 +04:00
2015-01-14 14:07:42 +03:00
m_vocabMap.resize(max_fid+1, m_vocab->unk());
std::vector< std::pair<std::size_t, DALM::VocabId> >::iterator it = vlist.begin();
while(it != vlist.end()) {
std::pair<std::size_t, DALM::VocabId> &entry = *it;
m_vocabMap[entry.first] = entry.second;
2014-02-14 15:22:53 +04:00
2015-01-14 14:07:42 +03:00
++it;
}
2013-11-18 17:54:40 +04:00
}
2013-11-11 23:49:00 +04:00
DALM::VocabId LanguageModelDALM::GetVocabId(const Factor *factor) const
{
2015-01-14 14:07:42 +03:00
std::size_t fid = factor->GetId();
return (m_vocabMap.size() > fid)? m_vocabMap[fid] : m_vocab->unk();
2013-11-11 23:49:00 +04:00
}
2013-12-16 18:17:56 +04:00
void LanguageModelDALM::SetParameter(const std::string& key, const std::string& value)
{
if (key == "factor") {
m_factorType = Scan<FactorType>(value);
2015-01-14 14:07:42 +03:00
} else if (key == "order") {
m_nGramOrder = Scan<size_t>(value);
} else if (key == "path") {
m_filePath = value;
2013-12-16 18:17:56 +04:00
} else {
LanguageModel::SetParameter(key, value);
2014-06-03 11:13:32 +04:00
}
2015-01-14 14:07:42 +03:00
m_ContextSize = m_nGramOrder-1;
2014-06-03 11:13:32 +04:00
}
void LanguageModelDALM::EvaluateTerminal(
2015-01-14 14:07:42 +03:00
const Word &word,
float &hypoScore,
DALMChartState *newState,
DALM::State &state,
DALM::Fragment *prefixFragments,
unsigned char &prefixLength) const
{
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
if (newState->LargeEnough()) {
float score = m_lm->query(wid, state);
hypoScore += score;
} else {
float score = m_lm->query(wid, state, prefixFragments[prefixLength]);
if(score > 0) {
hypoScore -= score;
newState->SetAsLarge();
} else if(state.get_count()<=prefixLength) {
hypoScore += score;
prefixLength++;
newState->SetAsLarge();
} else {
hypoScore += score;
prefixLength++;
if(prefixLength >= m_ContextSize) newState->SetAsLarge();
}
}
2014-06-03 11:13:32 +04:00
}
void LanguageModelDALM::EvaluateNonTerminal(
const Word &word,
float &hypoScore,
DALMChartState *newState,
DALM::State &state,
DALM::Fragment *prefixFragments,
unsigned char &prefixLength,
2015-01-14 14:07:42 +03:00
const DALMChartState *prevState,
size_t prevTargetPhraseLength
) const
{
2014-06-03 11:13:32 +04:00
const unsigned char prevPrefixLength = prevState->GetPrefixLength();
2015-01-14 14:07:42 +03:00
const DALM::Fragment *prevPrefixFragments = prevState->GetPrefixFragments();
if(prevPrefixLength == 0) {
newState->SetAsLarge();
hypoScore += state.sum_bows(0, state.get_count());
state = prevState->GetRightContext();
return;
}
if(!state.has_context()) {
newState->SetAsLarge();
state = prevState->GetRightContext();
return;
}
DALM::Gap gap(state);
2014-06-05 09:24:10 +04:00
2014-06-03 11:13:32 +04:00
// score its prefix
for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
2015-01-14 14:07:42 +03:00
const DALM::Fragment &f = prevPrefixFragments[prefixPos];
if (newState->LargeEnough()) {
float score = m_lm->query(f, state, gap);
hypoScore += score;
if(!gap.is_extended()) {
state = prevState->GetRightContext();
return;
} else if(state.get_count() <= prefixPos+1) {
state = prevState->GetRightContext();
return;
}
} else {
DALM::Fragment &fnew = prefixFragments[prefixLength];
float score = m_lm->query(f, state, gap, fnew);
hypoScore += score;
if(!gap.is_extended()) {
newState->SetAsLarge();
state = prevState->GetRightContext();
return;
} else if(state.get_count() <= prefixPos+1) {
if(!gap.is_finalized()) prefixLength++;
newState->SetAsLarge();
state = prevState->GetRightContext();
return;
} else if(gap.is_finalized()) {
newState->SetAsLarge();
} else {
prefixLength++;
if(prefixLength >= m_ContextSize) newState->SetAsLarge();
}
}
gap.succ();
2014-06-03 11:13:32 +04:00
}
// check if we are dealing with a large sub-phrase
if (prevState->LargeEnough()) {
newState->SetAsLarge();
2015-01-14 14:07:42 +03:00
if(prevPrefixLength < prevState->GetHypoSize()) {
hypoScore += state.sum_bows(prevPrefixLength, state.get_count());
}
// copy language model state
state = prevState->GetRightContext();
2014-06-03 11:13:32 +04:00
} else {
m_lm->set_state(state, prevState->GetRightContext(), prevPrefixFragments, gap);
2013-12-16 18:17:56 +04:00
}
2013-11-05 18:37:56 +04:00
}
2013-12-16 18:17:56 +04:00
}