mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 05:14:36 +03:00
implement EvaluateChart() for TargetNgramFeature, add sparse features to chart decoder nbest list
This commit is contained in:
parent
6f0cc81aa4
commit
1dcc3c3c6f
@ -47,6 +47,9 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
#include "ChartHypothesis.h"
|
||||
#include "DotChart.h"
|
||||
|
||||
#include <boost/algorithm/string.hpp>
|
||||
#include "FeatureVector.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
using namespace Moses;
|
||||
@ -345,7 +348,7 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
// print the surface factor of the translation
|
||||
out << translationId << " ||| ";
|
||||
OutputSurface(out, outputPhrase, m_outputFactorOrder, false);
|
||||
out << " |||";
|
||||
out << " ||| ";
|
||||
|
||||
// print the scores in a hardwired order
|
||||
// before each model type, the corresponding command-line-like name must be emitted
|
||||
@ -362,26 +365,23 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::string lastName = "";
|
||||
|
||||
// translation components
|
||||
const vector<PhraseDictionaryFeature*>& pds = system->GetPhraseDictionaries();
|
||||
if (pds.size() > 0) {
|
||||
|
||||
for( size_t i=0; i<pds.size(); i++ ) {
|
||||
size_t pd_numinputscore = pds[i]->GetNumInputScores();
|
||||
vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( pds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = pds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
size_t pd_numinputscore = pds[i]->GetNumInputScores();
|
||||
vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( pds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = pds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,26 +393,36 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
// generation
|
||||
const vector<GenerationDictionary*>& gds = system->GetGenerationDictionaries();
|
||||
if (gds.size() > 0) {
|
||||
|
||||
for( size_t i=0; i<gds.size(); i++ ) {
|
||||
size_t pd_numinputscore = gds[i]->GetNumInputScores();
|
||||
vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( gds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = gds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
size_t pd_numinputscore = gds[i]->GetNumInputScores();
|
||||
vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( gds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = gds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output sparse features
|
||||
lastName = "";
|
||||
const vector<const StatefulFeatureFunction*>& sff = system->GetStatefulFeatureFunctions();
|
||||
for( size_t i=0; i<sff.size(); i++ )
|
||||
if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited)
|
||||
OutputSparseFeatureScores( out, path, sff[i], lastName );
|
||||
|
||||
const vector<const StatelessFeatureFunction*>& slf = system->GetStatelessFeatureFunctions();
|
||||
for( size_t i=0; i<slf.size(); i++ )
|
||||
if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited)
|
||||
OutputSparseFeatureScores( out, path, slf[i], lastName );
|
||||
|
||||
|
||||
// total
|
||||
out << " |||" << path.GetTotalScore();
|
||||
out << " ||| " << path.GetTotalScore();
|
||||
|
||||
/*
|
||||
if (includeAlignment) {
|
||||
@ -443,6 +453,32 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
m_nBestOutputCollector->Write(translationId, out.str());
|
||||
}
|
||||
|
||||
void IOWrapper::OutputSparseFeatureScores( std::ostream& out, const ChartTrellisPath &path, const FeatureFunction *ff, std::string &lastName )
|
||||
{
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
bool labeledOutput = staticData.IsLabeledNBestList();
|
||||
const FVector scores = path.GetScoreBreakdown().GetVectorForProducer( ff );
|
||||
|
||||
// report weighted aggregate
|
||||
if (! ff->GetSparseFeatureReporting()) {
|
||||
const FVector &weights = staticData.GetAllWeights().GetScoresVector();
|
||||
if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":"))
|
||||
out << " " << ff->GetScoreProducerWeightShortName() << ":";
|
||||
out << " " << scores.inner_product(weights);
|
||||
}
|
||||
|
||||
// report each feature
|
||||
else {
|
||||
for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
|
||||
if (i->second != 0) { // do not report zero-valued features
|
||||
if (labeledOutput)
|
||||
out << " " << i->first << ":";
|
||||
out << " " << i->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void IOWrapper::FixPrecision(std::ostream &stream, size_t size)
|
||||
{
|
||||
stream.setf(std::ios::fixed);
|
||||
|
@ -44,6 +44,8 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
#include "OutputCollector.h"
|
||||
#include "ChartHypothesis.h"
|
||||
|
||||
#include "ChartTrellisPath.h"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
class FactorCollection;
|
||||
@ -82,6 +84,7 @@ public:
|
||||
void OutputBestHypo(const Moses::ChartHypothesis *hypo, long translationId, bool reportSegmentation, bool reportAllFactors);
|
||||
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId, bool reportSegmentation, bool reportAllFactors);
|
||||
void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::ChartHypothesis *bestHypo, const Moses::TranslationSystem* system, long translationId);
|
||||
void OutputSparseFeatureScores(std::ostream& out, const Moses::ChartTrellisPath &path, const Moses::FeatureFunction *ff, std::string &lastName);
|
||||
void OutputDetailedTranslationReport(const Moses::ChartHypothesis *hypo, long translationId);
|
||||
void Backtrack(const Moses::ChartHypothesis *hypo);
|
||||
|
||||
|
@ -290,6 +290,10 @@ namespace Moses {
|
||||
return (m_fv->m_features[m_name] += lhs);
|
||||
}
|
||||
|
||||
FValue operator -=(FValue lhs) {
|
||||
return (m_fv->m_features[m_name] -= lhs);
|
||||
}
|
||||
|
||||
private:
|
||||
FValue m_tmp;
|
||||
|
||||
|
@ -149,6 +149,13 @@ public:
|
||||
m_scores -= rhs.m_scores;
|
||||
}
|
||||
|
||||
//For features which have an unbounded number of components
|
||||
void MinusEquals(const ScoreProducer*sp, const std::string& name, float score)
|
||||
{
|
||||
assert(sp->GetNumScoreComponents() == ScoreProducer::unlimited);
|
||||
FName fname(sp->GetScoreProducerDescription(),name);
|
||||
m_scores[fname] -= score;
|
||||
}
|
||||
|
||||
//! Add scores from a single ScoreProducer only
|
||||
//! The length of scores must be equal to the number of score components
|
||||
|
@ -1459,14 +1459,12 @@ bool StaticData::LoadReferences()
|
||||
|
||||
bool StaticData::LoadDiscrimLMFeature()
|
||||
{
|
||||
cerr << "Loading discriminative language models.. ";
|
||||
|
||||
// only load if specified
|
||||
// only load if specified
|
||||
const vector<string> &wordFile = m_parameter->GetParam("discrim-lmodel-file");
|
||||
if (wordFile.empty()) {
|
||||
return true;
|
||||
}
|
||||
cerr << wordFile.size() << " models" << endl;
|
||||
cerr << "Loading " << wordFile.size() << " discriminative language model(s).." << endl;
|
||||
|
||||
// if this weight is specified, the sparse DLM weights will be scaled with an additional weight
|
||||
vector<string> dlmWeightStr = m_parameter->GetParam("weight-dlm");
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "TargetPhrase.h"
|
||||
#include "Hypothesis.h"
|
||||
#include "ScoreComponentCollection.h"
|
||||
#include "ChartHypothesis.h"
|
||||
|
||||
namespace Moses {
|
||||
|
||||
@ -12,25 +13,25 @@ int TargetNgramState::Compare(const FFState& other) const {
|
||||
const TargetNgramState& rhs = dynamic_cast<const TargetNgramState&>(other);
|
||||
int result;
|
||||
if (m_words.size() == rhs.m_words.size()) {
|
||||
for (size_t i = 0; i < m_words.size(); ++i) {
|
||||
result = Word::Compare(m_words[i],rhs.m_words[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
for (size_t i = 0; i < m_words.size(); ++i) {
|
||||
result = Word::Compare(m_words[i],rhs.m_words[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
else if (m_words.size() < rhs.m_words.size()) {
|
||||
for (size_t i = 0; i < m_words.size(); ++i) {
|
||||
result = Word::Compare(m_words[i],rhs.m_words[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
return -1;
|
||||
for (size_t i = 0; i < m_words.size(); ++i) {
|
||||
result = Word::Compare(m_words[i],rhs.m_words[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
else {
|
||||
for (size_t i = 0; i < rhs.m_words.size(); ++i) {
|
||||
result = Word::Compare(m_words[i],rhs.m_words[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
return 1;
|
||||
for (size_t i = 0; i < rhs.m_words.size(); ++i) {
|
||||
result = Word::Compare(m_words[i],rhs.m_words[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,7 +46,7 @@ bool TargetNgramFeature::Load(const std::string &filePath)
|
||||
|
||||
std::string line;
|
||||
m_vocab.insert(BOS_);
|
||||
m_vocab.insert(BOS_);
|
||||
m_vocab.insert(EOS_);
|
||||
while (getline(inFile, line)) {
|
||||
m_vocab.insert(line);
|
||||
}
|
||||
@ -54,7 +55,6 @@ bool TargetNgramFeature::Load(const std::string &filePath)
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
string TargetNgramFeature::GetScoreProducerWeightShortName(unsigned) const
|
||||
{
|
||||
return "dlmn";
|
||||
@ -65,7 +65,6 @@ size_t TargetNgramFeature::GetNumInputScores() const
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
const FFState* TargetNgramFeature::EmptyHypothesisState(const InputType &/*input*/) const
|
||||
{
|
||||
vector<Word> bos(1,m_bos);
|
||||
@ -177,5 +176,254 @@ void TargetNgramFeature::appendNgram(const Word& word, bool& skip, string& ngram
|
||||
ngram.append(":");
|
||||
}
|
||||
}
|
||||
|
||||
FFState* TargetNgramFeature::EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator) const
|
||||
{
|
||||
TargetNgramChartState *ret = new TargetNgramChartState(cur_hypo, featureID, GetNGramOrder());
|
||||
// data structure for factored context phrase (history and predicted word)
|
||||
vector<const Word*> contextFactor;
|
||||
contextFactor.reserve(GetNGramOrder());
|
||||
|
||||
// initialize language model context state
|
||||
FFState *lmState = NewState( GetNullContextState() );
|
||||
|
||||
// get index map for underlying hypotheses
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
cur_hypo.GetCurrTargetPhrase().GetAlignmentInfo().GetNonTermIndexMap();
|
||||
|
||||
// loop over rule
|
||||
bool makePrefix = false;
|
||||
bool makeSuffix = false;
|
||||
bool beforeSubphrase = true;
|
||||
size_t terminalsBeforeSubphrase = 0;
|
||||
size_t terminalsAfterSubphrase = 0;
|
||||
for (size_t phrasePos = 0, wordPos = 0;
|
||||
phrasePos < cur_hypo.GetCurrTargetPhrase().GetSize();
|
||||
phrasePos++)
|
||||
{
|
||||
// consult rule for either word or non-terminal
|
||||
const Word &word = cur_hypo.GetCurrTargetPhrase().GetWord(phrasePos);
|
||||
// cerr << "word: " << word << endl;
|
||||
|
||||
// regular word
|
||||
if (!word.IsNonTerminal())
|
||||
{
|
||||
if (phrasePos==0)
|
||||
makePrefix = true;
|
||||
|
||||
contextFactor.push_back(&word);
|
||||
|
||||
// beginning of sentence symbol <s>?
|
||||
if (word.GetString(GetFactorType(), false).compare("<s>") == 0)
|
||||
{
|
||||
assert(phrasePos == 0);
|
||||
delete lmState;
|
||||
lmState = NewState( GetBeginSentenceState() );
|
||||
|
||||
terminalsBeforeSubphrase++;
|
||||
}
|
||||
// end of sentence symbol </s>?
|
||||
else if (word.GetString(GetFactorType(), false).compare("</s>") == 0) {
|
||||
terminalsAfterSubphrase++;
|
||||
}
|
||||
// everything else
|
||||
else {
|
||||
string curr_ngram = word.GetString(GetFactorType(), false);
|
||||
// cerr << "ngram: " << curr_ngram << endl;
|
||||
accumulator->PlusEquals(this,curr_ngram,1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// non-terminal, add phrase from underlying hypothesis
|
||||
else
|
||||
{
|
||||
// look up underlying hypothesis
|
||||
size_t nonTermIndex = nonTermIndexMap[phrasePos];
|
||||
const ChartHypothesis *prevHypo = cur_hypo.GetPrevHypo(nonTermIndex);
|
||||
|
||||
const TargetNgramChartState* prevState =
|
||||
static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureID));
|
||||
|
||||
size_t subPhraseLength = prevState->GetNumTargetTerminals();
|
||||
if (subPhraseLength==1) {
|
||||
if (beforeSubphrase)
|
||||
terminalsBeforeSubphrase++;
|
||||
else
|
||||
terminalsAfterSubphrase++;
|
||||
}
|
||||
else {
|
||||
beforeSubphrase = false;
|
||||
}
|
||||
|
||||
// special case: rule starts with non-terminal -> copy everything
|
||||
if (phrasePos == 0) {
|
||||
if (subPhraseLength == 1)
|
||||
makePrefix = true;
|
||||
|
||||
// get language model state
|
||||
delete lmState;
|
||||
lmState = NewState( prevState->GetRightContext() );
|
||||
|
||||
// push suffix
|
||||
// cerr << "suffix of NT in the beginning" << endl;
|
||||
int suffixPos = prevState->GetSuffix().GetSize() - (GetNGramOrder()-1);
|
||||
if (suffixPos < 0) suffixPos = 0; // push all words if less than order
|
||||
for(;(size_t)suffixPos < prevState->GetSuffix().GetSize(); suffixPos++)
|
||||
{
|
||||
const Word &word = prevState->GetSuffix().GetWord(suffixPos);
|
||||
// cerr << "NT0 --> : " << word << endl;
|
||||
contextFactor.push_back(&word);
|
||||
wordPos++;
|
||||
}
|
||||
}
|
||||
|
||||
// internal non-terminal
|
||||
else
|
||||
{
|
||||
if (subPhraseLength == 1 && phrasePos == cur_hypo.GetCurrTargetPhrase().GetSize()-1)
|
||||
makeSuffix = true;
|
||||
|
||||
// cerr << "prefix of subphrase for left context" << endl;
|
||||
// score its prefix
|
||||
for(size_t prefixPos = 0;
|
||||
prefixPos < GetNGramOrder()-1 // up to LM order window
|
||||
&& prefixPos < subPhraseLength; // up to length
|
||||
prefixPos++)
|
||||
{
|
||||
const Word &word = prevState->GetPrefix().GetWord(prefixPos);
|
||||
// cerr << "NT --> " << word << endl;
|
||||
contextFactor.push_back(&word);
|
||||
}
|
||||
|
||||
bool next = false;
|
||||
if (phrasePos < cur_hypo.GetCurrTargetPhrase().GetSize() - 1)
|
||||
next = true;
|
||||
|
||||
// check if we are dealing with a large sub-phrase
|
||||
if (next && subPhraseLength > GetNGramOrder() - 1) // TODO: CHECK??
|
||||
{
|
||||
// clear up pending ngrams
|
||||
MakePrefixNgrams(contextFactor, accumulator, terminalsBeforeSubphrase);
|
||||
contextFactor.clear();
|
||||
makePrefix = false;
|
||||
makeSuffix = true;
|
||||
// cerr << "suffix of subphrase for right context (only if something is following)" << endl;
|
||||
|
||||
// copy language model state
|
||||
delete lmState;
|
||||
lmState = NewState( prevState->GetRightContext() );
|
||||
|
||||
// push its suffix
|
||||
size_t remainingWords = subPhraseLength - (GetNGramOrder()-1);
|
||||
if (remainingWords > GetNGramOrder()-1) {
|
||||
// only what is needed for the history window
|
||||
remainingWords = GetNGramOrder()-1;
|
||||
}
|
||||
for(size_t suffixPos = 0; suffixPos < prevState->GetSuffix().GetSize(); suffixPos++) {
|
||||
const Word &word = prevState->GetSuffix().GetWord(suffixPos);
|
||||
// cerr << "NT --> : " << word << endl;
|
||||
contextFactor.push_back(&word);
|
||||
}
|
||||
wordPos += subPhraseLength;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (makePrefix) {
|
||||
size_t terminals = beforeSubphrase? 1 : terminalsBeforeSubphrase;
|
||||
MakePrefixNgrams(contextFactor, accumulator, terminals);
|
||||
}
|
||||
if (makeSuffix) {
|
||||
size_t terminals = beforeSubphrase? 1 : terminalsAfterSubphrase;
|
||||
MakeSuffixNgrams(contextFactor, accumulator, terminals);
|
||||
}
|
||||
|
||||
// remove duplicates
|
||||
if (makePrefix && makeSuffix && (contextFactor.size() <= GetNGramOrder())) {
|
||||
string curr_ngram;
|
||||
for (size_t i = 0; i < contextFactor.size(); ++i) {
|
||||
curr_ngram.append((*contextFactor[i]).GetString(GetFactorType(), false));
|
||||
if (i < contextFactor.size()-1)
|
||||
curr_ngram.append(":");
|
||||
}
|
||||
accumulator->MinusEquals(this,curr_ngram,1);
|
||||
}
|
||||
|
||||
ret->Set(lmState);
|
||||
// cerr << endl;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TargetNgramFeature::ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const
|
||||
{
|
||||
if (contextFactor.size() < GetNGramOrder()) {
|
||||
contextFactor.push_back(&word);
|
||||
} else {
|
||||
// shift
|
||||
for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++) {
|
||||
contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1];
|
||||
}
|
||||
contextFactor[GetNGramOrder() - 1] = &word;
|
||||
}
|
||||
}
|
||||
|
||||
void TargetNgramFeature::MakePrefixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator, size_t numberOfStartPos) const {
|
||||
string curr_ngram;
|
||||
size_t size = contextFactor.size();
|
||||
for (size_t k = 0; k < numberOfStartPos; ++k) {
|
||||
size_t max_length = (size < GetNGramOrder())? size: GetNGramOrder();
|
||||
for (size_t end = 1+k; end < max_length+k; ++end) {
|
||||
for (size_t i=k; i <= end; ++i) {
|
||||
if (i > k)
|
||||
curr_ngram.append(":");
|
||||
curr_ngram.append((*contextFactor[i]).GetString(GetFactorType(), false));
|
||||
}
|
||||
if (curr_ngram != "<s>" && curr_ngram != "</s>") {
|
||||
// cerr << "p-ngram: " << curr_ngram << endl;
|
||||
accumulator->PlusEquals(this,curr_ngram,1);
|
||||
}
|
||||
curr_ngram.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TargetNgramFeature::MakeSuffixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator, size_t numberOfEndPos) const {
|
||||
string curr_ngram;
|
||||
size_t size = contextFactor.size();
|
||||
for (size_t k = 0; k < numberOfEndPos; ++k) {
|
||||
size_t min_start = (size > GetNGramOrder())? (size - GetNGramOrder()): 0;
|
||||
size_t end = size-1;
|
||||
for (size_t start=min_start-k; start < end-k; ++start) {
|
||||
for (size_t j=start; j < size-k; ++j){
|
||||
curr_ngram.append((*contextFactor[j]).GetString(GetFactorType(), false));
|
||||
if (j < size-k-1)
|
||||
curr_ngram.append(":");
|
||||
}
|
||||
if (curr_ngram != "<s>" && curr_ngram != "</s>") {
|
||||
// cerr << "s-ngram: " << curr_ngram << endl;
|
||||
accumulator->PlusEquals(this,curr_ngram,1);
|
||||
}
|
||||
curr_ngram.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TargetNgramFeature::Load(const std::string &filePath, FactorType factorType, size_t nGramOrder) {
|
||||
// dummy
|
||||
cerr << "This method has not been implemented.." << endl;
|
||||
assert(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
LMResult TargetNgramFeature::GetValue(const std::vector<const Word*> &contextFactor, State* finalState) const {
|
||||
// dummy
|
||||
LMResult* result = new LMResult();
|
||||
cerr << "This method has not been implemented.." << endl;
|
||||
assert(false);
|
||||
return *result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,10 @@
|
||||
#include "FFState.h"
|
||||
#include "Word.h"
|
||||
|
||||
#include "LM/SingleFactor.h"
|
||||
#include "ChartHypothesis.h"
|
||||
#include "ChartManager.h"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
@ -22,9 +26,168 @@ class TargetNgramState : public FFState {
|
||||
std::vector<Word> m_words;
|
||||
};
|
||||
|
||||
class TargetNgramChartState : public FFState
|
||||
{
|
||||
private:
|
||||
FFState* m_lmRightContext;
|
||||
|
||||
Phrase m_contextPrefix, m_contextSuffix;
|
||||
|
||||
size_t m_numTargetTerminals; // This isn't really correct except for the surviving hypothesis
|
||||
|
||||
const ChartHypothesis &m_hypo;
|
||||
|
||||
/** Construct the prefix string of up to specified size
|
||||
* \param ret prefix string
|
||||
* \param size maximum size (typically max lm context window)
|
||||
*/
|
||||
size_t CalcPrefix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const
|
||||
{
|
||||
const TargetPhrase &target = hypo.GetCurrTargetPhrase();
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
target.GetAlignmentInfo().GetNonTermIndexMap();
|
||||
|
||||
// loop over the rule that is being applied
|
||||
for (size_t pos = 0; pos < target.GetSize(); ++pos) {
|
||||
const Word &word = target.GetWord(pos);
|
||||
|
||||
// for non-terminals, retrieve it from underlying hypothesis
|
||||
if (word.IsNonTerminal()) {
|
||||
size_t nonTermInd = nonTermIndexMap[pos];
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
|
||||
size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureID))->CalcPrefix(*prevHypo, featureID, ret, size);
|
||||
}
|
||||
// for words, add word
|
||||
else {
|
||||
ret.AddWord(target.GetWord(pos));
|
||||
size--;
|
||||
}
|
||||
|
||||
// finish when maximum length reached
|
||||
if (size==0)
|
||||
break;
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
/** Construct the suffix phrase of up to specified size
|
||||
* will always be called after the construction of prefix phrase
|
||||
* \param ret suffix phrase
|
||||
* \param size maximum size of suffix
|
||||
*/
|
||||
size_t CalcSuffix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const
|
||||
{
|
||||
assert(m_contextPrefix.GetSize() <= m_numTargetTerminals);
|
||||
|
||||
// special handling for small hypotheses
|
||||
// does the prefix match the entire hypothesis string? -> just copy prefix
|
||||
if (m_contextPrefix.GetSize() == m_numTargetTerminals) {
|
||||
size_t maxCount = std::min(m_contextPrefix.GetSize(), size);
|
||||
size_t pos= m_contextPrefix.GetSize() - 1;
|
||||
|
||||
for (size_t ind = 0; ind < maxCount; ++ind) {
|
||||
const Word &word = m_contextPrefix.GetWord(pos);
|
||||
ret.PrependWord(word);
|
||||
--pos;
|
||||
}
|
||||
|
||||
size -= maxCount;
|
||||
return size;
|
||||
}
|
||||
// construct suffix analogous to prefix
|
||||
else {
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
hypo.GetCurrTargetPhrase().GetAlignmentInfo().GetNonTermIndexMap();
|
||||
for (int pos = (int) hypo.GetCurrTargetPhrase().GetSize() - 1; pos >= 0 ; --pos) {
|
||||
const Word &word = hypo.GetCurrTargetPhrase().GetWord(pos);
|
||||
|
||||
if (word.IsNonTerminal()) {
|
||||
size_t nonTermInd = nonTermIndexMap[pos];
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
|
||||
size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureID))->CalcSuffix(*prevHypo, featureID, ret, size);
|
||||
}
|
||||
else {
|
||||
ret.PrependWord(hypo.GetCurrTargetPhrase().GetWord(pos));
|
||||
size--;
|
||||
}
|
||||
|
||||
if (size==0)
|
||||
break;
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
TargetNgramChartState(const ChartHypothesis &hypo, int featureID, size_t order)
|
||||
:m_lmRightContext(NULL)
|
||||
,m_contextPrefix(Output, order - 1)
|
||||
,m_contextSuffix(Output, order - 1)
|
||||
,m_hypo(hypo)
|
||||
{
|
||||
m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals();
|
||||
|
||||
for (std::vector<const ChartHypothesis*>::const_iterator i = hypo.GetPrevHypos().begin(); i != hypo.GetPrevHypos().end(); ++i) {
|
||||
// keep count of words (= length of generated string)
|
||||
m_numTargetTerminals += static_cast<const TargetNgramChartState*>((*i)->GetFFState(featureID))->GetNumTargetTerminals();
|
||||
}
|
||||
|
||||
CalcPrefix(hypo, featureID, m_contextPrefix, order - 1);
|
||||
CalcSuffix(hypo, featureID, m_contextSuffix, order - 1);
|
||||
}
|
||||
|
||||
~TargetNgramChartState() {
|
||||
delete m_lmRightContext;
|
||||
}
|
||||
|
||||
void Set(FFState *rightState) {
|
||||
m_lmRightContext = rightState;
|
||||
}
|
||||
|
||||
FFState* GetRightContext() const {
|
||||
return m_lmRightContext;
|
||||
}
|
||||
|
||||
size_t GetNumTargetTerminals() const {
|
||||
return m_numTargetTerminals;
|
||||
}
|
||||
|
||||
const Phrase &GetPrefix() const {
|
||||
return m_contextPrefix;
|
||||
}
|
||||
const Phrase &GetSuffix() const {
|
||||
return m_contextSuffix;
|
||||
}
|
||||
|
||||
int Compare(const FFState& o) const {
|
||||
const TargetNgramChartState &other =
|
||||
dynamic_cast<const TargetNgramChartState &>( o );
|
||||
|
||||
// prefix
|
||||
if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) // not for "<s> ..."
|
||||
{
|
||||
int ret = GetPrefix().Compare(other.GetPrefix());
|
||||
if (ret != 0)
|
||||
return ret;
|
||||
}
|
||||
|
||||
// suffix
|
||||
size_t inputSize = m_hypo.GetManager().GetSource().GetSize();
|
||||
if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1)// not for "... </s>"
|
||||
{
|
||||
int ret = other.GetRightContext()->Compare(*m_lmRightContext);
|
||||
if (ret != 0)
|
||||
return ret;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/** Sets the features of observed ngrams.
|
||||
*/
|
||||
class TargetNgramFeature : public StatefulFeatureFunction {
|
||||
class TargetNgramFeature : public StatefulFeatureFunction, public LanguageModelPointerState {
|
||||
public:
|
||||
TargetNgramFeature(FactorType factorType = 0, size_t n = 3, bool lower_ngrams = true):
|
||||
StatefulFeatureFunction("dlmn", ScoreProducer::unlimited),
|
||||
@ -39,8 +202,8 @@ public:
|
||||
m_bos.SetFactor(m_factorType,bosFactor);
|
||||
}
|
||||
|
||||
|
||||
bool Load(const std::string &filePath);
|
||||
bool Load(const std::string&, Moses::FactorType, size_t);
|
||||
|
||||
std::string GetScoreProducerWeightShortName(unsigned) const;
|
||||
size_t GetNumInputScores() const;
|
||||
@ -53,13 +216,15 @@ public:
|
||||
virtual FFState* Evaluate(const Hypothesis& cur_hypo, const FFState* prev_state,
|
||||
ScoreComponentCollection* accumulator) const;
|
||||
|
||||
virtual FFState* EvaluateChart( const ChartHypothesis& /* cur_hypo */,
|
||||
int /* featureID */,
|
||||
ScoreComponentCollection* ) const
|
||||
{
|
||||
/* Not implemented */
|
||||
assert(0);
|
||||
}
|
||||
virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureID,
|
||||
ScoreComponentCollection* accumulator) const;
|
||||
|
||||
LMResult GetValue(const std::vector<const Word*> &contextFactor, State* finalState = NULL) const;
|
||||
|
||||
size_t GetNGramOrder() const {
|
||||
return m_n;
|
||||
}
|
||||
|
||||
private:
|
||||
FactorType m_factorType;
|
||||
Word m_bos;
|
||||
@ -71,6 +236,17 @@ private:
|
||||
float m_sparseProducerWeight;
|
||||
|
||||
void appendNgram(const Word& word, bool& skip, std::string& ngram) const;
|
||||
void ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const;
|
||||
void MakePrefixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
|
||||
size_t numberOfStartPos = 1) const;
|
||||
void MakeSuffixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
|
||||
size_t numberOfEndPos = 1) const;
|
||||
|
||||
std::vector<FactorType> GetFactorType() const {
|
||||
std::vector<FactorType> factorType;
|
||||
factorType.push_back(m_factorType);
|
||||
return factorType;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user