Refactoring of lattice mbr code so that it can be used with mosesmt

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@2907 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
bhaddow 2010-02-17 17:25:56 +00:00
parent 7f45cd12f9
commit 11cb44ba5b
7 changed files with 76 additions and 57 deletions

View File

@ -227,27 +227,27 @@ void IOWrapper::Backtrack(const Hypothesis *hypo){
}
}
void IOWrapper::OutputBestHypo(const std::vector<const Factor*>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors)
void OutputBestHypo(const std::vector<const Factor*>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors, ostream& out)
{
for (size_t i = 0 ; i < mbrBestHypo.size() ; i++)
{
const Factor *factor = mbrBestHypo[i];
if (i>0) cout << " ";
cout << factor->GetString();
if (i>0) out << " ";
out << factor->GetString();
}
cout << endl;
out << endl;
}
void IOWrapper::OutputBestHypo(const std::vector<Word>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors)
void OutputBestHypo(const std::vector<Word>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors, ostream& out)
{
for (size_t i = 0 ; i < mbrBestHypo.size() ; i++)
{
const Factor *factor = mbrBestHypo[i].GetFactor(m_outputFactorOrder[0]);
if (i>0) cout << " ";
cout << *factor;
const Factor *factor = mbrBestHypo[i].GetFactor(StaticData::Instance().GetOutputFactorOrder()[0]);
if (i>0) out << " ";
out << *factor;
}
cout << endl;
out << endl;
}

View File

@ -35,7 +35,9 @@ POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <fstream>
#include <ostream>
#include <vector>
#include "TypeDef.h"
#include "Sentence.h"
#include "FactorTypeSet.h"
@ -83,9 +85,8 @@ public:
~IOWrapper();
Moses::InputType* GetInput(Moses::InputType *inputType);
void OutputBestHypo(const Moses::Hypothesis *hypo, long translationId, bool reportSegmentation, bool reportAllFactors);
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId, bool reportSegmentation, bool reportAllFactors);
void OutputBestHypo(const std::vector<Moses::Word>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors);
void OutputBestHypo(const Moses::Hypothesis *hypo, long translationId, bool reportSegmentation, bool reportAllFactors);
void OutputNBestList(const Moses::TrellisPathList &nBestList, long translationId);
void Backtrack(const Moses::Hypothesis *hypo);
@ -105,3 +106,7 @@ IOWrapper *GetIODevice(const Moses::StaticData &staticData);
bool ReadInput(IOWrapper &ioWrapper, Moses::InputTypeEnum inputType, Moses::InputType*& source);
void OutputSurface(std::ostream &out, const Moses::Hypothesis *hypo, const std::vector<Moses::FactorType> &outputFactorOrder ,bool reportSegmentation, bool reportAllFactors);
void OutputNBest(std::ostream& out, const Moses::TrellisPathList &nBestList, const std::vector<Moses::FactorType>&, long translationId);
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId,
bool reportSegmentation, bool reportAllFactors, std::ostream& out);
void OutputBestHypo(const std::vector<Moses::Word>& mbrBestHypo, long /*translationId*/,
bool reportSegmentation, bool reportAllFactors, std::ostream& out);

View File

@ -492,3 +492,39 @@ vector<Word> calcMBRSol(const TrellisPathList& nBestList, map<Phrase, float>& f
return argmaxTranslation;
}
vector<Word> doLatticeMBR(Manager& manager) {
const StaticData& staticData = StaticData::Instance();
std::map < int, bool > connected;
std::vector< const Hypothesis *> connectedList;
map<Phrase, float> ngramPosteriors;
std::map < const Hypothesis*, set <const Hypothesis*> > outgoingHyps;
map<const Hypothesis*, vector<Edge> > incomingEdges;
vector< float> estimatedScores;
manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores);
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores, staticData.GetLatticeMBRPruningFactor());
calcNgramPosteriors(connectedList, incomingEdges, staticData.GetMBRScale(), ngramPosteriors);
vector<Word> mbrBestHypo;
if (!staticData.UseLatticeHypSetForLatticeMBR()) {
size_t nBestSize = staticData.GetMBRSize();
if (nBestSize <= 0)
{
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
exit(1);
}
else
{
TrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,true);
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
mbrBestHypo = calcMBRSol(nBestList, ngramPosteriors, staticData.GetLatticeMBRThetas(),
staticData.GetLatticeMBRPrecision(), staticData.GetLatticeMBRPRatio());
}
}
else {
cerr << "Using Lattice for Hypothesis set not yet implemented" << endl;
exit(1);
}
return mbrBestHypo;
}

View File

@ -14,6 +14,7 @@
#include <vector>
#include <set>
#include "Hypothesis.h"
#include "Manager.h"
#include "TrellisPathList.h"
using namespace Moses;
@ -112,3 +113,4 @@ void calcNgramPosteriors(Lattice & connectedHyp, map<const Hypothesis*, vector<E
void GetOutputFactors(const TrellisPath &path, vector <Word> &translation);
void extract_ngrams(const vector<Word >& sentence, map < Phrase, int > & allngrams);
bool ascendingCoverageCmp(const Hypothesis* a, const Hypothesis* b);
vector<Word> doLatticeMBR(Manager& manager);

View File

@ -168,40 +168,9 @@ int main(int argc, char* argv[])
}
}
else if (staticData.UseLatticeMBR()) {
std::map < int, bool > connected;
std::vector< const Hypothesis *> connectedList;
map<Phrase, float> ngramPosteriors;
std::map < const Hypothesis*, set <const Hypothesis*> > outgoingHyps;
map<const Hypothesis*, vector<Edge> > incomingEdges;
vector< float> estimatedScores;
manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores);
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores, staticData.GetLatticeMBRPruningFactor());
calcNgramPosteriors(connectedList, incomingEdges, staticData.GetMBRScale(), ngramPosteriors);
vector<Word> mbrBestHypo;
if (!staticData.UseLatticeHypSetForLatticeMBR()) {
size_t nBestSize = staticData.GetMBRSize();
if (nBestSize <= 0)
{
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
return EXIT_FAILURE;
}
else
{
TrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,true);
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
mbrBestHypo = calcMBRSol(nBestList, ngramPosteriors, staticData.GetLatticeMBRThetas(), staticData.GetLatticeMBRPrecision(), staticData.GetLatticeMBRPRatio());
}
}
else {
cerr << "Using Lattice for Hypothesis set not yet implemented" << endl;
return EXIT_FAILURE;
//mbrBestHypo = calcMBRSol(connectedList, ngramPosteriors, staticData.GetLatticeMBRThetas(), staticData.GetLatticeMBRPrecision(), staticData.GetLatticeMBRPRatio());
}
ioWrapper->OutputBestHypo(mbrBestHypo, source->GetTranslationId(), staticData.GetReportSegmentation(),
staticData.GetReportAllFactors());
vector<Word> mbrBestHypo = doLatticeMBR(manager);
OutputBestHypo(mbrBestHypo, source->GetTranslationId(), staticData.GetReportSegmentation(),
staticData.GetReportAllFactors(),cout);
IFVERBOSE(2) { PrintUserTime("finished Lattice MBR decoding"); }
}
// consider top candidate translations to find minimum Bayes risk translation
@ -220,9 +189,9 @@ int main(int argc, char* argv[])
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
std::vector<const Factor*> mbrBestHypo = doMBR(nBestList);
ioWrapper->OutputBestHypo(mbrBestHypo, source->GetTranslationId(),
OutputBestHypo(mbrBestHypo, source->GetTranslationId(),
staticData.GetReportSegmentation(),
staticData.GetReportAllFactors());
staticData.GetReportAllFactors(),cout);
IFVERBOSE(2) { PrintUserTime("finished MBR decoding"); }
if (!staticData.GetNBestFilePath().empty()){
//print the all nbest used for MBR (and not the amount passed through the parameter

View File

@ -35,6 +35,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "Hypothesis.h"
#include "IOWrapper.h"
#include "LatticeMBR.h"
#include "Manager.h"
#include "StaticData.h"
#include "ThreadPool.h"
@ -127,29 +128,35 @@ class TranslationTask : public Task {
staticData.GetReportSegmentation(),
staticData.GetReportAllFactors());
IFVERBOSE(1) {
debug << "BEST TRANSLATION: " << *bestHypo << endl;
debug << "BEST TRANSLATION: " << *bestHypo << endl;
}
}
out << endl;
} else {
//MBR decoding
size_t nBestSize = staticData.GetMBRSize();
if (nBestSize <= 0) {
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
exit(1);
}
if (staticData.UseLatticeMBR()) {
//Lattice MBR decoding
vector<Word> mbrBestHypo = doLatticeMBR(manager);
OutputBestHypo(mbrBestHypo, m_lineNumber, staticData.GetReportSegmentation(),
staticData.GetReportAllFactors(),out);
IFVERBOSE(2) { PrintUserTime("finished Lattice MBR decoding"); }
} else {
//MBR decoding
TrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,true);
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
std::vector<const Factor*> mbrBestHypo = doMBR(nBestList);
for (size_t i = 0 ; i < mbrBestHypo.size() ; i++) {
const Factor *factor = mbrBestHypo[i];
if (i>0) out << " ";
out << factor->GetString();
}
out << endl;
OutputBestHypo(mbrBestHypo, m_lineNumber,
staticData.GetReportSegmentation(),
staticData.GetReportAllFactors(),out);
IFVERBOSE(2) { PrintUserTime("finished MBR decoding"); }
}
}
m_outputCollector->Write(m_lineNumber,out.str(),debug.str());

View File

@ -8,7 +8,7 @@ moses_SOURCES = Main.cpp mbr.cpp IOWrapper.cpp TranslationAnalysis.cpp LatticeMB
moses_LDADD = -L$(top_srcdir)/moses/src -lmoses $(BOOST_LDFLAGS) $(BOOST_THREAD_LIB)
moses_DEPENDENCIES = $(top_srcdir)/moses/src/libmoses.a
mosesmt_SOURCES = MainMT.cpp mbr.cpp IOWrapper.cpp TranslationAnalysis.cpp ThreadPool.cpp
mosesmt_SOURCES = MainMT.cpp mbr.cpp IOWrapper.cpp TranslationAnalysis.cpp ThreadPool.cpp LatticeMBR.cpp
mosesmt_LDADD = -L$(top_srcdir)/moses/src $(BOOST_LDFLAGS) -lmoses $(BOOST_THREAD_LIB)
mosesmt_DEPENDENCIES = $(top_srcdir)/moses/src/libmoses.a