mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-29 23:12:41 +03:00
test chart passed.
This commit is contained in:
parent
b0bc978325
commit
690a948e9e
@ -3,9 +3,7 @@
|
||||
#include <algorithm>
|
||||
#include "moses/FF/FFState.h"
|
||||
#include "DALMWrapper.h"
|
||||
#include "logger.h"
|
||||
#include "dalm.h"
|
||||
#include "vocabulary.h"
|
||||
#include "moses/FactorTypeSet.h"
|
||||
#include "moses/FactorCollection.h"
|
||||
#include "moses/InputFileStream.h"
|
||||
@ -47,7 +45,7 @@ public:
|
||||
Murmur(std::size_t seed=0): seed(seed){
|
||||
}
|
||||
virtual std::size_t operator()(const DALM::VocabId *words, std::size_t size) const{
|
||||
return util::MurmurHashNative(words, sizeof(DALM::VocabId) * size, seed);
|
||||
return util::MurmurHashNative(words, sizeof(DALM::VocabId) * size, seed);
|
||||
}
|
||||
private:
|
||||
std::size_t seed;
|
||||
@ -84,6 +82,7 @@ public:
|
||||
// imitate KenLM
|
||||
return state.hash(Murmur());
|
||||
}
|
||||
|
||||
virtual bool operator==(const FFState& other) const {
|
||||
const DALMState &o = static_cast<const DALMState &>(other);
|
||||
return state.compare(o.state) == 0;
|
||||
@ -105,7 +104,6 @@ private:
|
||||
unsigned char prefixLength;
|
||||
DALM::State rightContext;
|
||||
bool isLarge;
|
||||
size_t hypoSize;
|
||||
|
||||
public:
|
||||
DALMChartState()
|
||||
@ -113,38 +111,9 @@ public:
|
||||
isLarge(false) {
|
||||
}
|
||||
|
||||
/*
|
||||
DALMChartState(const DALMChartState &other)
|
||||
: prefixLength(other.prefixLength),
|
||||
rightContext(other.rightContext),
|
||||
isLarge(other.isLarge)
|
||||
{
|
||||
std::copy(
|
||||
other.prefixFragments,
|
||||
other.prefixFragments+other.prefixLength,
|
||||
prefixFragments
|
||||
);
|
||||
}
|
||||
*/
|
||||
|
||||
virtual ~DALMChartState() {
|
||||
}
|
||||
|
||||
/*
|
||||
DALMChartState &operator=(const DALMChartState &other){
|
||||
prefixLength = other.prefixLength;
|
||||
std::copy(
|
||||
other.prefixFragments,
|
||||
other.prefixFragments+other.prefixLength,
|
||||
prefixFragments
|
||||
);
|
||||
rightContext = other.rightContext;
|
||||
isLarge=other.isLarge;
|
||||
|
||||
return *this;
|
||||
}
|
||||
*/
|
||||
|
||||
inline unsigned char GetPrefixLength() const {
|
||||
return prefixLength;
|
||||
}
|
||||
@ -177,13 +146,6 @@ public:
|
||||
isLarge=true;
|
||||
}
|
||||
|
||||
inline size_t &GetHypoSize() {
|
||||
return hypoSize;
|
||||
}
|
||||
inline size_t GetHypoSize() const {
|
||||
return hypoSize;
|
||||
}
|
||||
|
||||
virtual int Compare(const FFState& other) const {
|
||||
const DALMChartState &o = static_cast<const DALMChartState &>(other);
|
||||
if(prefixLength < o.prefixLength) return -1;
|
||||
@ -217,7 +179,7 @@ public:
|
||||
const DALM::Fragment &f = prefixFragments[prefixLength-1];
|
||||
const DALM::Fragment &of = o.prefixFragments[prefixLength-1];
|
||||
if(DALM::compare_fragments(f, of) != 0) return false;
|
||||
|
||||
|
||||
// check right state.
|
||||
if(rightContext.get_count() != o.rightContext.get_count()) return false;
|
||||
return rightContext.compare(o.rightContext) == 0;
|
||||
@ -295,45 +257,83 @@ const FFState *LanguageModelDALM::EmptyHypothesisState(const InputType &/*input*
|
||||
|
||||
void LanguageModelDALM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
|
||||
{
|
||||
fullScore = 0;
|
||||
ngramScore = 0;
|
||||
|
||||
oovCount = 0;
|
||||
fullScore = 0.0f;
|
||||
ngramScore = 0.0f;
|
||||
|
||||
size_t phraseSize = phrase.GetSize();
|
||||
if (!phraseSize) return;
|
||||
|
||||
size_t currPos = 0;
|
||||
size_t hist_count = 0;
|
||||
//size_t hist_count = 0;
|
||||
DALM::State state;
|
||||
|
||||
if(phrase.GetWord(0).GetFactor(m_factorType) == m_beginSentenceFactor) {
|
||||
m_lm->init_state(state);
|
||||
currPos++;
|
||||
hist_count++;
|
||||
//hist_count++;
|
||||
}
|
||||
|
||||
float score;
|
||||
float prefixScore=0.0f;
|
||||
float partScore=0.0f;
|
||||
//std::cerr << std::setprecision(8);
|
||||
//std::cerr << "# ";
|
||||
while (currPos < phraseSize) {
|
||||
const Word &word = phrase.GetWord(currPos);
|
||||
hist_count++;
|
||||
//hist_count++;
|
||||
|
||||
if (word.IsNonTerminal()) {
|
||||
//std::cerr << "X ";
|
||||
state.refresh();
|
||||
hist_count = 0;
|
||||
fullScore += partScore;
|
||||
partScore = 0.0f;
|
||||
//std::cerr << fullScore << " ";
|
||||
//hist_count = 0;
|
||||
} else {
|
||||
//std::cerr << word.GetString(m_factorType).as_string() << " ";
|
||||
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
||||
score = m_lm->query(wid, state);
|
||||
fullScore += score;
|
||||
if (hist_count >= m_nGramOrder) ngramScore += score;
|
||||
partScore += score;
|
||||
//std::cerr << partScore << " ";
|
||||
//if (hist_count >= m_nGramOrder) ngramScore += score;
|
||||
if (wid==m_vocab->unk()) ++oovCount;
|
||||
}
|
||||
|
||||
currPos++;
|
||||
if (currPos >= m_ContextSize){
|
||||
break;
|
||||
}
|
||||
}
|
||||
prefixScore = fullScore + partScore;
|
||||
//std::cerr << prefixScore << " ";
|
||||
|
||||
while (currPos < phraseSize) {
|
||||
const Word &word = phrase.GetWord(currPos);
|
||||
//hist_count++;
|
||||
|
||||
if (word.IsNonTerminal()) {
|
||||
//std::cerr << "X ";
|
||||
fullScore += partScore;
|
||||
partScore = 0.0f;
|
||||
//std::cerr << fullScore << " ";
|
||||
state.refresh();
|
||||
//hist_count = 0;
|
||||
} else {
|
||||
//std::cerr << word.GetString(m_factorType).as_string() << " ";
|
||||
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
||||
score = m_lm->query(wid, state);
|
||||
partScore += score;
|
||||
//std::cerr << partScore << " ";
|
||||
if (wid==m_vocab->unk()) ++oovCount;
|
||||
}
|
||||
|
||||
currPos++;
|
||||
}
|
||||
fullScore += partScore;
|
||||
|
||||
ngramScore = TransformLMScore(fullScore - prefixScore);
|
||||
fullScore = TransformLMScore(fullScore);
|
||||
ngramScore = TransformLMScore(ngramScore);
|
||||
}
|
||||
|
||||
FFState *LanguageModelDALM::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
|
||||
@ -396,14 +396,12 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
|
||||
|
||||
DALM::Fragment *prefixFragments = newState->GetPrefixFragments();
|
||||
unsigned char &prefixLength = newState->GetPrefixLength();
|
||||
size_t &hypoSizeAll = newState->GetHypoSize();
|
||||
|
||||
// initial language model scores
|
||||
float hypoScore = 0.0; // diffs of scores.
|
||||
float hypoScore = 0.0; // total hypothesis score.
|
||||
|
||||
const TargetPhrase &targetPhrase = hypo.GetCurrTargetPhrase();
|
||||
size_t hypoSize = targetPhrase.GetSize();
|
||||
hypoSizeAll = hypoSize;
|
||||
|
||||
// get index map for underlying hypotheses
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
@ -415,6 +413,7 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
|
||||
if(hypoSize > 0) {
|
||||
const Word &word = targetPhrase.GetWord(0);
|
||||
if(word.GetFactor(m_factorType) == m_beginSentenceFactor) {
|
||||
//std::cerr << word.GetString(m_factorType).as_string() << " ";
|
||||
m_lm->init_state(state);
|
||||
// state is finalized.
|
||||
newState->SetAsLarge();
|
||||
@ -426,9 +425,9 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
|
||||
const DALMChartState* prevState =
|
||||
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
|
||||
|
||||
|
||||
// copy chart state
|
||||
(*newState) = (*prevState);
|
||||
hypoSizeAll = hypoSize+prevState->GetHypoSize()-1;
|
||||
|
||||
phrasePos++;
|
||||
}
|
||||
@ -456,8 +455,8 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
|
||||
const DALMChartState* prevState =
|
||||
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
|
||||
|
||||
size_t prevTargetPhraseLength = prevHypo->GetCurrTargetPhrase().GetSize();
|
||||
hypoSizeAll += prevState->GetHypoSize()-1;
|
||||
|
||||
EvaluateNonTerminal(
|
||||
word, hypoScore,
|
||||
@ -467,9 +466,18 @@ FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int
|
||||
);
|
||||
}
|
||||
}
|
||||
hypoScore = TransformLMScore(hypoScore);
|
||||
hypoScore -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
|
||||
|
||||
// assign combined score to score breakdown
|
||||
out->PlusEquals(this, TransformLMScore(hypoScore));
|
||||
if (OOVFeatureEnabled()) {
|
||||
std::vector<float> scores(2);
|
||||
scores[0] = hypoScore;
|
||||
scores[1] = 0.0;
|
||||
out->PlusEquals(this, scores);
|
||||
} else {
|
||||
out->PlusEquals(this, hypoScore);
|
||||
}
|
||||
|
||||
return newState;
|
||||
}
|
||||
@ -543,6 +551,7 @@ void LanguageModelDALM::EvaluateTerminal(
|
||||
float score = m_lm->query(wid, state);
|
||||
hypoScore += score;
|
||||
} else {
|
||||
unsigned char prevLen = state.get_count();
|
||||
float score = m_lm->query(wid, state, prefixFragments[prefixLength]);
|
||||
|
||||
if(score > 0) {
|
||||
@ -555,6 +564,9 @@ void LanguageModelDALM::EvaluateTerminal(
|
||||
} else {
|
||||
hypoScore += score;
|
||||
prefixLength++;
|
||||
if(state.get_count() < std::min(prevLen+1, (int)m_ContextSize)){
|
||||
newState->SetAsLarge();
|
||||
}
|
||||
if(prefixLength >= m_ContextSize) newState->SetAsLarge();
|
||||
}
|
||||
}
|
||||
@ -577,7 +589,7 @@ void LanguageModelDALM::EvaluateNonTerminal(
|
||||
|
||||
if(prevPrefixLength == 0) {
|
||||
newState->SetAsLarge();
|
||||
hypoScore += state.sum_bows(0, state.get_count());
|
||||
hypoScore += m_lm->sum_bows(state, 0, state.get_count());
|
||||
state = prevState->GetRightContext();
|
||||
return;
|
||||
}
|
||||
@ -587,10 +599,12 @@ void LanguageModelDALM::EvaluateNonTerminal(
|
||||
return;
|
||||
}
|
||||
DALM::Gap gap(state);
|
||||
unsigned char prevLen = state.get_count();
|
||||
|
||||
// score its prefix
|
||||
for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
|
||||
const DALM::Fragment &f = prevPrefixFragments[prefixPos];
|
||||
|
||||
if (newState->LargeEnough()) {
|
||||
float score = m_lm->query(f, state, gap);
|
||||
hypoScore += score;
|
||||
@ -612,7 +626,9 @@ void LanguageModelDALM::EvaluateNonTerminal(
|
||||
state = prevState->GetRightContext();
|
||||
return;
|
||||
} else if(state.get_count() <= prefixPos+1) {
|
||||
if(!gap.is_finalized()) prefixLength++;
|
||||
if(state.get_count() == prefixPos+1 && !gap.is_finalized()){
|
||||
prefixLength++;
|
||||
}
|
||||
newState->SetAsLarge();
|
||||
state = prevState->GetRightContext();
|
||||
return;
|
||||
@ -620,18 +636,23 @@ void LanguageModelDALM::EvaluateNonTerminal(
|
||||
newState->SetAsLarge();
|
||||
} else {
|
||||
prefixLength++;
|
||||
if(state.get_count() < std::min(prevLen+1, (int)m_ContextSize)){
|
||||
newState->SetAsLarge();
|
||||
}
|
||||
|
||||
if(prefixLength >= m_ContextSize) newState->SetAsLarge();
|
||||
}
|
||||
}
|
||||
gap.succ();
|
||||
prevLen = state.get_count();
|
||||
}
|
||||
|
||||
// check if we are dealing with a large sub-phrase
|
||||
if (prevState->LargeEnough()) {
|
||||
newState->SetAsLarge();
|
||||
if(prevPrefixLength < prevState->GetHypoSize()) {
|
||||
hypoScore += state.sum_bows(prevPrefixLength, state.get_count());
|
||||
}
|
||||
//if(prevPrefixLength < prevState->GetHypoSize()) {
|
||||
hypoScore += m_lm->sum_bows(state, prevPrefixLength, state.get_count());
|
||||
//}
|
||||
// copy language model state
|
||||
state = prevState->GetRightContext();
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user