test chart passed.

This commit is contained in:
Jun-ya Norimatsu 2015-10-31 19:23:59 +09:00
parent b0bc978325
commit 690a948e9e

View File

@ -3,9 +3,7 @@
#include <algorithm>
#include "moses/FF/FFState.h"
#include "DALMWrapper.h"
#include "logger.h"
#include "dalm.h"
#include "vocabulary.h"
#include "moses/FactorTypeSet.h"
#include "moses/FactorCollection.h"
#include "moses/InputFileStream.h"
@ -47,7 +45,7 @@ public:
Murmur(std::size_t seed=0): seed(seed){
}
virtual std::size_t operator()(const DALM::VocabId *words, std::size_t size) const{
return util::MurmurHashNative(words, sizeof(DALM::VocabId) * size, seed);
return util::MurmurHashNative(words, sizeof(DALM::VocabId) * size, seed);
}
private:
std::size_t seed;
@ -84,6 +82,7 @@ public:
// imitate KenLM
return state.hash(Murmur());
}
virtual bool operator==(const FFState& other) const {
const DALMState &o = static_cast<const DALMState &>(other);
return state.compare(o.state) == 0;
@ -105,7 +104,6 @@ private:
unsigned char prefixLength;
DALM::State rightContext;
bool isLarge;
size_t hypoSize;
public:
DALMChartState()
@ -113,38 +111,9 @@ public:
isLarge(false) {
}
/*
DALMChartState(const DALMChartState &other)
: prefixLength(other.prefixLength),
rightContext(other.rightContext),
isLarge(other.isLarge)
{
std::copy(
other.prefixFragments,
other.prefixFragments+other.prefixLength,
prefixFragments
);
}
*/
virtual ~DALMChartState() {
}
/*
DALMChartState &operator=(const DALMChartState &other){
prefixLength = other.prefixLength;
std::copy(
other.prefixFragments,
other.prefixFragments+other.prefixLength,
prefixFragments
);
rightContext = other.rightContext;
isLarge=other.isLarge;
return *this;
}
*/
inline unsigned char GetPrefixLength() const {
return prefixLength;
}
@ -177,13 +146,6 @@ public:
isLarge=true;
}
inline size_t &GetHypoSize() {
return hypoSize;
}
inline size_t GetHypoSize() const {
return hypoSize;
}
virtual int Compare(const FFState& other) const {
const DALMChartState &o = static_cast<const DALMChartState &>(other);
if(prefixLength < o.prefixLength) return -1;
@ -217,7 +179,7 @@ public:
const DALM::Fragment &f = prefixFragments[prefixLength-1];
const DALM::Fragment &of = o.prefixFragments[prefixLength-1];
if(DALM::compare_fragments(f, of) != 0) return false;
// check right state.
if(rightContext.get_count() != o.rightContext.get_count()) return false;
return rightContext.compare(o.rightContext) == 0;
@ -295,45 +257,83 @@ const FFState *LanguageModelDALM::EmptyHypothesisState(const InputType &/*input*
void LanguageModelDALM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
{
fullScore = 0;
ngramScore = 0;
oovCount = 0;
fullScore = 0.0f;
ngramScore = 0.0f;
size_t phraseSize = phrase.GetSize();
if (!phraseSize) return;
size_t currPos = 0;
size_t hist_count = 0;
//size_t hist_count = 0;
DALM::State state;
if(phrase.GetWord(0).GetFactor(m_factorType) == m_beginSentenceFactor) {
m_lm->init_state(state);
currPos++;
hist_count++;
//hist_count++;
}
float score;
float prefixScore=0.0f;
float partScore=0.0f;
//std::cerr << std::setprecision(8);
//std::cerr << "# ";
while (currPos < phraseSize) {
const Word &word = phrase.GetWord(currPos);
hist_count++;
//hist_count++;
if (word.IsNonTerminal()) {
//std::cerr << "X ";
state.refresh();
hist_count = 0;
fullScore += partScore;
partScore = 0.0f;
//std::cerr << fullScore << " ";
//hist_count = 0;
} else {
//std::cerr << word.GetString(m_factorType).as_string() << " ";
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
score = m_lm->query(wid, state);
fullScore += score;
if (hist_count >= m_nGramOrder) ngramScore += score;
partScore += score;
//std::cerr << partScore << " ";
//if (hist_count >= m_nGramOrder) ngramScore += score;
if (wid==m_vocab->unk()) ++oovCount;
}
currPos++;
if (currPos >= m_ContextSize){
break;
}
}
prefixScore = fullScore + partScore;
//std::cerr << prefixScore << " ";
while (currPos < phraseSize) {
const Word &word = phrase.GetWord(currPos);
//hist_count++;
if (word.IsNonTerminal()) {
//std::cerr << "X ";
fullScore += partScore;
partScore = 0.0f;
//std::cerr << fullScore << " ";
state.refresh();
//hist_count = 0;
} else {
//std::cerr << word.GetString(m_factorType).as_string() << " ";
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
score = m_lm->query(wid, state);
partScore += score;
//std::cerr << partScore << " ";
if (wid==m_vocab->unk()) ++oovCount;
}
currPos++;
}
fullScore += partScore;
ngramScore = TransformLMScore(fullScore - prefixScore);
fullScore = TransformLMScore(fullScore);
ngramScore = TransformLMScore(ngramScore);
}
FFState *LanguageModelDALM::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
@ -396,14 +396,12 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
DALM::Fragment *prefixFragments = newState->GetPrefixFragments();
unsigned char &prefixLength = newState->GetPrefixLength();
size_t &hypoSizeAll = newState->GetHypoSize();
// initial language model scores
float hypoScore = 0.0; // diffs of scores.
float hypoScore = 0.0; // total hypothesis score.
const TargetPhrase &targetPhrase = hypo.GetCurrTargetPhrase();
size_t hypoSize = targetPhrase.GetSize();
hypoSizeAll = hypoSize;
// get index map for underlying hypotheses
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
@ -415,6 +413,7 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
if(hypoSize > 0) {
const Word &word = targetPhrase.GetWord(0);
if(word.GetFactor(m_factorType) == m_beginSentenceFactor) {
//std::cerr << word.GetString(m_factorType).as_string() << " ";
m_lm->init_state(state);
// state is finalized.
newState->SetAsLarge();
@ -426,9 +425,9 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
const DALMChartState* prevState =
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
// copy chart state
(*newState) = (*prevState);
hypoSizeAll = hypoSize+prevState->GetHypoSize()-1;
phrasePos++;
}
@ -456,8 +455,8 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
const DALMChartState* prevState =
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
size_t prevTargetPhraseLength = prevHypo->GetCurrTargetPhrase().GetSize();
hypoSizeAll += prevState->GetHypoSize()-1;
EvaluateNonTerminal(
word, hypoScore,
@ -467,9 +466,18 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
);
}
}
hypoScore = TransformLMScore(hypoScore);
hypoScore -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
// assign combined score to score breakdown
out->PlusEquals(this, TransformLMScore(hypoScore));
if (OOVFeatureEnabled()) {
std::vector<float> scores(2);
scores[0] = hypoScore;
scores[1] = 0.0;
out->PlusEquals(this, scores);
} else {
out->PlusEquals(this, hypoScore);
}
return newState;
}
@ -543,6 +551,7 @@ void LanguageModelDALM::EvaluateTerminal(
float score = m_lm->query(wid, state);
hypoScore += score;
} else {
unsigned char prevLen = state.get_count();
float score = m_lm->query(wid, state, prefixFragments[prefixLength]);
if(score > 0) {
@ -555,6 +564,9 @@ void LanguageModelDALM::EvaluateTerminal(
} else {
hypoScore += score;
prefixLength++;
if(state.get_count() < std::min(prevLen+1, (int)m_ContextSize)){
newState->SetAsLarge();
}
if(prefixLength >= m_ContextSize) newState->SetAsLarge();
}
}
@ -577,7 +589,7 @@ void LanguageModelDALM::EvaluateNonTerminal(
if(prevPrefixLength == 0) {
newState->SetAsLarge();
hypoScore += state.sum_bows(0, state.get_count());
hypoScore += m_lm->sum_bows(state, 0, state.get_count());
state = prevState->GetRightContext();
return;
}
@ -587,10 +599,12 @@ void LanguageModelDALM::EvaluateNonTerminal(
return;
}
DALM::Gap gap(state);
unsigned char prevLen = state.get_count();
// score its prefix
for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
const DALM::Fragment &f = prevPrefixFragments[prefixPos];
if (newState->LargeEnough()) {
float score = m_lm->query(f, state, gap);
hypoScore += score;
@ -612,7 +626,9 @@ void LanguageModelDALM::EvaluateNonTerminal(
state = prevState->GetRightContext();
return;
} else if(state.get_count() <= prefixPos+1) {
if(!gap.is_finalized()) prefixLength++;
if(state.get_count() == prefixPos+1 && !gap.is_finalized()){
prefixLength++;
}
newState->SetAsLarge();
state = prevState->GetRightContext();
return;
@ -620,18 +636,23 @@ void LanguageModelDALM::EvaluateNonTerminal(
newState->SetAsLarge();
} else {
prefixLength++;
if(state.get_count() < std::min(prevLen+1, (int)m_ContextSize)){
newState->SetAsLarge();
}
if(prefixLength >= m_ContextSize) newState->SetAsLarge();
}
}
gap.succ();
prevLen = state.get_count();
}
// check if we are dealing with a large sub-phrase
if (prevState->LargeEnough()) {
newState->SetAsLarge();
if(prevPrefixLength < prevState->GetHypoSize()) {
hypoScore += state.sum_bows(prevPrefixLength, state.get_count());
}
//if(prevPrefixLength < prevState->GetHypoSize()) {
hypoScore += m_lm->sum_bows(state, prevPrefixLength, state.get_count());
//}
// copy language model state
state = prevState->GetRightContext();
} else {