mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 05:14:36 +03:00
improve DALMwrapper.
This commit is contained in:
parent
4488d97629
commit
ba63f1eb97
@ -9,8 +9,8 @@
|
||||
#include "moses/FactorCollection.h"
|
||||
#include "moses/InputFileStream.h"
|
||||
#include "util/exception.hh"
|
||||
#include "ChartState.h"
|
||||
#include "util/exception.hh"
|
||||
#include "moses/ChartHypothesis.h"
|
||||
#include "moses/ChartManager.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -58,6 +58,11 @@ public:
|
||||
delete state;
|
||||
}
|
||||
|
||||
void reset(const DALMState &from){
|
||||
delete state;
|
||||
state = new DALM::State(*from.state);
|
||||
}
|
||||
|
||||
virtual int Compare(const FFState& other) const{
|
||||
const DALMState &o = static_cast<const DALMState &>(other);
|
||||
if(state->get_count() < o.state->get_count()) return -1;
|
||||
@ -74,6 +79,67 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class DALMChartState : public FFState
|
||||
{
|
||||
private:
|
||||
const ChartHypothesis &m_hypo;
|
||||
DALM::VocabId *prefixIDs;
|
||||
size_t prefixLength;
|
||||
float prefixScore;
|
||||
DALMState *rightContext;
|
||||
bool isLarge;
|
||||
|
||||
public:
|
||||
DALMChartState(const ChartHypothesis &hypo, DALM::VocabId *prefixIDs, size_t prefixLength, float prefixScore, DALMState *rightContext, bool isLarge)
|
||||
: m_hypo(hypo), prefixIDs(prefixIDs), prefixLength(prefixLength), prefixScore(prefixScore), rightContext(rightContext), isLarge(isLarge)
|
||||
{}
|
||||
|
||||
virtual ~DALMChartState(){
|
||||
if(prefixIDs != NULL) delete [] prefixIDs;
|
||||
if(rightContext != NULL) delete rightContext;
|
||||
}
|
||||
|
||||
size_t GetPrefixLength() const{
|
||||
return prefixLength;
|
||||
}
|
||||
|
||||
const DALM::VocabId *GetPrefixIDs() const{
|
||||
return prefixIDs;
|
||||
}
|
||||
|
||||
float GetPrefixScore() const{
|
||||
return prefixScore;
|
||||
}
|
||||
|
||||
const DALMState *GetRightContext() const{
|
||||
return rightContext;
|
||||
}
|
||||
|
||||
bool LargeEnough() const{
|
||||
return isLarge;
|
||||
}
|
||||
|
||||
virtual int Compare(const FFState& other) const{
|
||||
const DALMChartState &o = static_cast<const DALMChartState &>(other);
|
||||
// prefix
|
||||
if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) { // not for "<s> ..."
|
||||
if(prefixLength != o.prefixLength){
|
||||
return (prefixLength < o.prefixLength)?-1:1;
|
||||
}else{
|
||||
int ret = memcmp(prefixIDs, o.prefixIDs, prefixLength);
|
||||
if (ret != 0) return ret;
|
||||
}
|
||||
}
|
||||
// suffix
|
||||
size_t inputSize = m_hypo.GetManager().GetSource().GetSize();
|
||||
if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) { // not for "... </s>"
|
||||
int ret = o.rightContext->Compare(*rightContext);
|
||||
if (ret != 0) return ret;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
LanguageModelDALM::LanguageModelDALM(const std::string &line)
|
||||
:LanguageModel(line)
|
||||
{
|
||||
@ -96,7 +162,7 @@ void LanguageModelDALM::Load()
|
||||
/////////////////////
|
||||
// READING INIFILE //
|
||||
/////////////////////
|
||||
string inifile= m_filePath + "/dalm.ini";
|
||||
string inifile= m_filePath + "/dalm.ini";
|
||||
|
||||
string model; // Path to the double-array file.
|
||||
string words; // Path to the vocabulary file.
|
||||
@ -104,8 +170,8 @@ void LanguageModelDALM::Load()
|
||||
read_ini(inifile.c_str(), model, words, wordstxt);
|
||||
|
||||
model = m_filePath + "/" + model;
|
||||
words = m_filePath + "/" + words;
|
||||
wordstxt = m_filePath + "/" + wordstxt;
|
||||
words = m_filePath + "/" + words;
|
||||
wordstxt = m_filePath + "/" + wordstxt;
|
||||
|
||||
UTIL_THROW_IF(model.empty() || words.empty() || wordstxt.empty(),
|
||||
util::FileOpenException,
|
||||
@ -237,110 +303,140 @@ FFState *LanguageModelDALM::Evaluate(const Hypothesis &hypo, const FFState *ps,
|
||||
}
|
||||
|
||||
FFState *LanguageModelDALM::EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const{
|
||||
LanguageModelChartState *ret = new LanguageModelChartState(hypo, featureID, m_nGramOrder);
|
||||
// initialize language model context state
|
||||
DALMState *dalm_state = new DALMState(m_nGramOrder);
|
||||
DALM::State *state = dalm_state->get_state();
|
||||
|
||||
size_t contextSize = m_nGramOrder-1;
|
||||
DALM::VocabId *prefixIDs = new DALM::VocabId[contextSize];
|
||||
size_t prefixLength = 0;
|
||||
bool isLarge = false;
|
||||
|
||||
// initial language model scores
|
||||
float prefixScore = 0.0; // not yet final for initial words (lack context)
|
||||
float finalizedScore = 0.0; // finalized, has sufficient context
|
||||
float prevScore = 0.0; // previous hypothesis
|
||||
|
||||
const TargetPhrase &targetPhrase = hypo.GetCurrTargetPhrase();
|
||||
size_t hypoSize = targetPhrase.GetSize();
|
||||
|
||||
// get index map for underlying hypotheses
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap();
|
||||
targetPhrase.GetAlignNonTerm().GetNonTermIndexMap();
|
||||
|
||||
size_t phrasePos = 0;
|
||||
|
||||
// begginig of sentence.
|
||||
if(hypoSize > 0){
|
||||
const Word &word = targetPhrase.GetWord(0);
|
||||
if(!word.IsNonTerminal()){
|
||||
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
||||
if(word.GetFactor(m_factorType) == m_beginSentenceFactor){
|
||||
m_lm->init_state(*state);
|
||||
if (prefixLength < contextSize){
|
||||
prefixIDs[prefixLength] = wid;
|
||||
prefixLength++;
|
||||
}else{
|
||||
isLarge = true;
|
||||
}
|
||||
}else{
|
||||
float score = m_lm->query(wid, *state);
|
||||
if (prefixLength < contextSize){
|
||||
prefixScore += score;
|
||||
prefixIDs[prefixLength] = wid;
|
||||
prefixLength++;
|
||||
}else{ finalizedScore += score; }
|
||||
}
|
||||
}else{
|
||||
// special case: rule starts with non-terminal -> copy everything
|
||||
size_t nonTermIndex = nonTermIndexMap[0];
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndex);
|
||||
|
||||
const DALMChartState* prevState =
|
||||
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
|
||||
|
||||
// get prefixScore and finalizedScore
|
||||
prefixScore = prevState->GetPrefixScore();
|
||||
finalizedScore = -prefixScore;
|
||||
prevScore = prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0];
|
||||
|
||||
// get language model state
|
||||
dalm_state->reset(*prevState->GetRightContext());
|
||||
state = dalm_state->get_state();
|
||||
prefixLength = prevState->GetPrefixLength();
|
||||
std::memcpy(prefixIDs, prevState->GetPrefixIDs(), sizeof(DALM::VocabId)*prefixLength);
|
||||
}
|
||||
phrasePos++;
|
||||
}
|
||||
|
||||
// loop over rule
|
||||
for (size_t phrasePos = 0, wordPos = 0;
|
||||
phrasePos < hypo.GetCurrTargetPhrase().GetSize();
|
||||
phrasePos++) {
|
||||
for (; phrasePos < hypoSize; phrasePos++) {
|
||||
// consult rule for either word or non-terminal
|
||||
const Word &word = hypo.GetCurrTargetPhrase().GetWord(phrasePos);
|
||||
const Word &word = targetPhrase.GetWord(phrasePos);
|
||||
|
||||
// regular word
|
||||
if (!word.IsNonTerminal()) {
|
||||
// beginning of sentence symbol <s>? -> just update state
|
||||
if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
|
||||
UTIL_THROW_IF2(phrasePos != 0,
|
||||
"Sentence start symbol must be at the beginning of sentence");
|
||||
m_lm->init_state(*state);
|
||||
}
|
||||
// score a regular word added by the rule
|
||||
else {
|
||||
float score = m_lm->query(GetVocabId(word.GetFactor(m_factorType)), *state);
|
||||
|
||||
updateChartScore( &prefixScore, &finalizedScore, score, ++wordPos );
|
||||
}
|
||||
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
||||
float score = m_lm->query(wid, *state);
|
||||
if (prefixLength < contextSize){
|
||||
prefixScore += score;
|
||||
prefixIDs[prefixLength] = wid;
|
||||
prefixLength++;
|
||||
}else{
|
||||
finalizedScore += score;
|
||||
isLarge = true;
|
||||
}
|
||||
}
|
||||
|
||||
// non-terminal, add phrase from underlying hypothesis
|
||||
// internal non-terminal
|
||||
else {
|
||||
// look up underlying hypothesis
|
||||
size_t nonTermIndex = nonTermIndexMap[phrasePos];
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndex);
|
||||
|
||||
const LanguageModelChartState* prevState =
|
||||
static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID));
|
||||
const DALMChartState* prevState =
|
||||
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
|
||||
|
||||
size_t subPhraseLength = prevState->GetNumTargetTerminals();
|
||||
// special case: rule starts with non-terminal -> copy everything
|
||||
if (phrasePos == 0) {
|
||||
size_t prevPrefixLength = prevState->GetPrefixLength();
|
||||
const DALM::VocabId *prevPrefixIDs = prevState->GetPrefixIDs();
|
||||
|
||||
// get prefixScore and finalizedScore
|
||||
prefixScore = UntransformLMScore(prevState->GetPrefixScore());
|
||||
finalizedScore = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]) - prefixScore;
|
||||
|
||||
// get language model state
|
||||
delete dalm_state;
|
||||
dalm_state = new DALMState( *static_cast<DALMState*>(prevState->GetRightContext()) );
|
||||
state = dalm_state->get_state();
|
||||
wordPos += subPhraseLength;
|
||||
// score its prefix
|
||||
for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
|
||||
DALM::VocabId wid = prevPrefixIDs[prefixPos];
|
||||
float score = m_lm->query(wid, *state);
|
||||
if (prefixLength < contextSize){
|
||||
prefixScore += score;
|
||||
prefixIDs[prefixLength] = wid;
|
||||
prefixLength++;
|
||||
} else {
|
||||
finalizedScore += score;
|
||||
isLarge = true;
|
||||
}
|
||||
}
|
||||
|
||||
// internal non-terminal
|
||||
else {
|
||||
// score its prefix
|
||||
size_t wpos = wordPos;
|
||||
for(size_t prefixPos = 0;
|
||||
prefixPos < m_nGramOrder-1 // up to LM order window
|
||||
&& prefixPos < subPhraseLength; // up to length
|
||||
prefixPos++) {
|
||||
const Word &word = prevState->GetPrefix().GetWord(prefixPos);
|
||||
float score = m_lm->query(GetVocabId(word.GetFactor(m_factorType)), *state);
|
||||
updateChartScore( &prefixScore, &finalizedScore, score, ++wpos );
|
||||
}
|
||||
wordPos += subPhraseLength;
|
||||
// check if we are dealing with a large sub-phrase
|
||||
if (prevState->LargeEnough()) {
|
||||
// add its finalized language model score
|
||||
prevScore += prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0];
|
||||
finalizedScore -= prevState->GetPrefixScore();
|
||||
|
||||
// check if we are dealing with a large sub-phrase
|
||||
if (subPhraseLength > m_nGramOrder - 1) {
|
||||
// add its finalized language model score
|
||||
finalizedScore += UntransformLMScore(
|
||||
prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0] // full score
|
||||
- prevState->GetPrefixScore()); // - prefix score
|
||||
|
||||
// copy language model state
|
||||
delete dalm_state;
|
||||
dalm_state = new DALMState( *static_cast<DALMState*>(prevState->GetRightContext()) );
|
||||
state = dalm_state->get_state();
|
||||
}
|
||||
// copy language model state
|
||||
dalm_state->reset(*prevState->GetRightContext());
|
||||
state = dalm_state->get_state();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefixScore = TransformLMScore(prefixScore);
|
||||
finalizedScore = TransformLMScore(finalizedScore);
|
||||
|
||||
// assign combined score to score breakdown
|
||||
out->Assign(this, prefixScore + finalizedScore);
|
||||
out->Assign(this, prevScore + TransformLMScore(prefixScore + finalizedScore));
|
||||
|
||||
ret->Set(prefixScore, dalm_state);
|
||||
return ret;
|
||||
return new DALMChartState(hypo, prefixIDs, prefixLength, prefixScore, dalm_state, isLarge);
|
||||
}
|
||||
|
||||
bool LanguageModelDALM::IsUseable(const FactorMask &mask) const
|
||||
{
|
||||
bool ret = mask[m_factorType];
|
||||
return ret;
|
||||
return mask[m_factorType];
|
||||
}
|
||||
|
||||
void LanguageModelDALM::CreateVocabMapping(const std::string &wordstxt)
|
||||
@ -392,13 +488,4 @@ void LanguageModelDALM::SetParameter(const std::string& key, const std::string&
|
||||
}
|
||||
}
|
||||
|
||||
void LanguageModelDALM::updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const
|
||||
{
|
||||
if (wordPos < m_nGramOrder) {
|
||||
*prefixScore += score;
|
||||
} else {
|
||||
*finalizedScore += score;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -59,8 +59,6 @@ protected:
|
||||
DALM::VocabId GetVocabId(const Factor *factor) const;
|
||||
|
||||
private:
|
||||
void updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const;
|
||||
|
||||
// Convert last words of hypothesis into vocab ids, returning an end pointer.
|
||||
DALM::VocabId *LastIDs(const Hypothesis &hypo, DALM::VocabId *indices) const {
|
||||
DALM::VocabId *index = indices;
|
||||
|
Loading…
Reference in New Issue
Block a user