mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 21:42:19 +03:00
Refactoring of lattice mbr code so that it can be used with mosesmt
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@2907 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
7f45cd12f9
commit
11cb44ba5b
@ -227,27 +227,27 @@ void IOWrapper::Backtrack(const Hypothesis *hypo){
|
||||
}
|
||||
}
|
||||
|
||||
void IOWrapper::OutputBestHypo(const std::vector<const Factor*>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors)
|
||||
void OutputBestHypo(const std::vector<const Factor*>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors, ostream& out)
|
||||
{
|
||||
for (size_t i = 0 ; i < mbrBestHypo.size() ; i++)
|
||||
{
|
||||
const Factor *factor = mbrBestHypo[i];
|
||||
if (i>0) cout << " ";
|
||||
cout << factor->GetString();
|
||||
if (i>0) out << " ";
|
||||
out << factor->GetString();
|
||||
}
|
||||
cout << endl;
|
||||
out << endl;
|
||||
}
|
||||
|
||||
void IOWrapper::OutputBestHypo(const std::vector<Word>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors)
|
||||
void OutputBestHypo(const std::vector<Word>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors, ostream& out)
|
||||
{
|
||||
|
||||
for (size_t i = 0 ; i < mbrBestHypo.size() ; i++)
|
||||
{
|
||||
const Factor *factor = mbrBestHypo[i].GetFactor(m_outputFactorOrder[0]);
|
||||
if (i>0) cout << " ";
|
||||
cout << *factor;
|
||||
const Factor *factor = mbrBestHypo[i].GetFactor(StaticData::Instance().GetOutputFactorOrder()[0]);
|
||||
if (i>0) out << " ";
|
||||
out << *factor;
|
||||
}
|
||||
cout << endl;
|
||||
out << endl;
|
||||
}
|
||||
|
||||
|
||||
|
@ -35,7 +35,9 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include "TypeDef.h"
|
||||
#include "Sentence.h"
|
||||
#include "FactorTypeSet.h"
|
||||
@ -83,9 +85,8 @@ public:
|
||||
~IOWrapper();
|
||||
|
||||
Moses::InputType* GetInput(Moses::InputType *inputType);
|
||||
void OutputBestHypo(const Moses::Hypothesis *hypo, long translationId, bool reportSegmentation, bool reportAllFactors);
|
||||
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId, bool reportSegmentation, bool reportAllFactors);
|
||||
void OutputBestHypo(const std::vector<Moses::Word>& mbrBestHypo, long /*translationId*/, bool reportSegmentation, bool reportAllFactors);
|
||||
|
||||
void OutputBestHypo(const Moses::Hypothesis *hypo, long translationId, bool reportSegmentation, bool reportAllFactors);
|
||||
void OutputNBestList(const Moses::TrellisPathList &nBestList, long translationId);
|
||||
void Backtrack(const Moses::Hypothesis *hypo);
|
||||
|
||||
@ -105,3 +106,7 @@ IOWrapper *GetIODevice(const Moses::StaticData &staticData);
|
||||
bool ReadInput(IOWrapper &ioWrapper, Moses::InputTypeEnum inputType, Moses::InputType*& source);
|
||||
void OutputSurface(std::ostream &out, const Moses::Hypothesis *hypo, const std::vector<Moses::FactorType> &outputFactorOrder ,bool reportSegmentation, bool reportAllFactors);
|
||||
void OutputNBest(std::ostream& out, const Moses::TrellisPathList &nBestList, const std::vector<Moses::FactorType>&, long translationId);
|
||||
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId,
|
||||
bool reportSegmentation, bool reportAllFactors, std::ostream& out);
|
||||
void OutputBestHypo(const std::vector<Moses::Word>& mbrBestHypo, long /*translationId*/,
|
||||
bool reportSegmentation, bool reportAllFactors, std::ostream& out);
|
||||
|
@ -492,3 +492,39 @@ vector<Word> calcMBRSol(const TrellisPathList& nBestList, map<Phrase, float>& f
|
||||
return argmaxTranslation;
|
||||
}
|
||||
|
||||
vector<Word> doLatticeMBR(Manager& manager) {
|
||||
const StaticData& staticData = StaticData::Instance();
|
||||
std::map < int, bool > connected;
|
||||
std::vector< const Hypothesis *> connectedList;
|
||||
map<Phrase, float> ngramPosteriors;
|
||||
std::map < const Hypothesis*, set <const Hypothesis*> > outgoingHyps;
|
||||
map<const Hypothesis*, vector<Edge> > incomingEdges;
|
||||
vector< float> estimatedScores;
|
||||
manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores);
|
||||
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores, staticData.GetLatticeMBRPruningFactor());
|
||||
calcNgramPosteriors(connectedList, incomingEdges, staticData.GetMBRScale(), ngramPosteriors);
|
||||
vector<Word> mbrBestHypo;
|
||||
if (!staticData.UseLatticeHypSetForLatticeMBR()) {
|
||||
size_t nBestSize = staticData.GetMBRSize();
|
||||
if (nBestSize <= 0)
|
||||
{
|
||||
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
|
||||
exit(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
TrellisPathList nBestList;
|
||||
manager.CalcNBest(nBestSize, nBestList,true);
|
||||
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
|
||||
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
|
||||
mbrBestHypo = calcMBRSol(nBestList, ngramPosteriors, staticData.GetLatticeMBRThetas(),
|
||||
staticData.GetLatticeMBRPrecision(), staticData.GetLatticeMBRPRatio());
|
||||
}
|
||||
}
|
||||
else {
|
||||
cerr << "Using Lattice for Hypothesis set not yet implemented" << endl;
|
||||
exit(1);
|
||||
}
|
||||
return mbrBestHypo;
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "Hypothesis.h"
|
||||
#include "Manager.h"
|
||||
#include "TrellisPathList.h"
|
||||
|
||||
using namespace Moses;
|
||||
@ -112,3 +113,4 @@ void calcNgramPosteriors(Lattice & connectedHyp, map<const Hypothesis*, vector<E
|
||||
void GetOutputFactors(const TrellisPath &path, vector <Word> &translation);
|
||||
void extract_ngrams(const vector<Word >& sentence, map < Phrase, int > & allngrams);
|
||||
bool ascendingCoverageCmp(const Hypothesis* a, const Hypothesis* b);
|
||||
vector<Word> doLatticeMBR(Manager& manager);
|
||||
|
@ -168,40 +168,9 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
else if (staticData.UseLatticeMBR()) {
|
||||
std::map < int, bool > connected;
|
||||
std::vector< const Hypothesis *> connectedList;
|
||||
map<Phrase, float> ngramPosteriors;
|
||||
std::map < const Hypothesis*, set <const Hypothesis*> > outgoingHyps;
|
||||
map<const Hypothesis*, vector<Edge> > incomingEdges;
|
||||
vector< float> estimatedScores;
|
||||
manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores);
|
||||
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores, staticData.GetLatticeMBRPruningFactor());
|
||||
calcNgramPosteriors(connectedList, incomingEdges, staticData.GetMBRScale(), ngramPosteriors);
|
||||
vector<Word> mbrBestHypo;
|
||||
if (!staticData.UseLatticeHypSetForLatticeMBR()) {
|
||||
size_t nBestSize = staticData.GetMBRSize();
|
||||
if (nBestSize <= 0)
|
||||
{
|
||||
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
else
|
||||
{
|
||||
TrellisPathList nBestList;
|
||||
manager.CalcNBest(nBestSize, nBestList,true);
|
||||
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
|
||||
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
|
||||
mbrBestHypo = calcMBRSol(nBestList, ngramPosteriors, staticData.GetLatticeMBRThetas(), staticData.GetLatticeMBRPrecision(), staticData.GetLatticeMBRPRatio());
|
||||
}
|
||||
}
|
||||
else {
|
||||
cerr << "Using Lattice for Hypothesis set not yet implemented" << endl;
|
||||
return EXIT_FAILURE;
|
||||
//mbrBestHypo = calcMBRSol(connectedList, ngramPosteriors, staticData.GetLatticeMBRThetas(), staticData.GetLatticeMBRPrecision(), staticData.GetLatticeMBRPRatio());
|
||||
}
|
||||
|
||||
ioWrapper->OutputBestHypo(mbrBestHypo, source->GetTranslationId(), staticData.GetReportSegmentation(),
|
||||
staticData.GetReportAllFactors());
|
||||
vector<Word> mbrBestHypo = doLatticeMBR(manager);
|
||||
OutputBestHypo(mbrBestHypo, source->GetTranslationId(), staticData.GetReportSegmentation(),
|
||||
staticData.GetReportAllFactors(),cout);
|
||||
IFVERBOSE(2) { PrintUserTime("finished Lattice MBR decoding"); }
|
||||
}
|
||||
// consider top candidate translations to find minimum Bayes risk translation
|
||||
@ -220,9 +189,9 @@ int main(int argc, char* argv[])
|
||||
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
|
||||
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
|
||||
std::vector<const Factor*> mbrBestHypo = doMBR(nBestList);
|
||||
ioWrapper->OutputBestHypo(mbrBestHypo, source->GetTranslationId(),
|
||||
OutputBestHypo(mbrBestHypo, source->GetTranslationId(),
|
||||
staticData.GetReportSegmentation(),
|
||||
staticData.GetReportAllFactors());
|
||||
staticData.GetReportAllFactors(),cout);
|
||||
IFVERBOSE(2) { PrintUserTime("finished MBR decoding"); }
|
||||
if (!staticData.GetNBestFilePath().empty()){
|
||||
//print the all nbest used for MBR (and not the amount passed through the parameter
|
||||
|
@ -35,6 +35,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
#include "Hypothesis.h"
|
||||
#include "IOWrapper.h"
|
||||
#include "LatticeMBR.h"
|
||||
#include "Manager.h"
|
||||
#include "StaticData.h"
|
||||
#include "ThreadPool.h"
|
||||
@ -127,29 +128,35 @@ class TranslationTask : public Task {
|
||||
staticData.GetReportSegmentation(),
|
||||
staticData.GetReportAllFactors());
|
||||
IFVERBOSE(1) {
|
||||
debug << "BEST TRANSLATION: " << *bestHypo << endl;
|
||||
debug << "BEST TRANSLATION: " << *bestHypo << endl;
|
||||
}
|
||||
}
|
||||
out << endl;
|
||||
} else {
|
||||
//MBR decoding
|
||||
size_t nBestSize = staticData.GetMBRSize();
|
||||
if (nBestSize <= 0) {
|
||||
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (staticData.UseLatticeMBR()) {
|
||||
//Lattice MBR decoding
|
||||
vector<Word> mbrBestHypo = doLatticeMBR(manager);
|
||||
OutputBestHypo(mbrBestHypo, m_lineNumber, staticData.GetReportSegmentation(),
|
||||
staticData.GetReportAllFactors(),out);
|
||||
IFVERBOSE(2) { PrintUserTime("finished Lattice MBR decoding"); }
|
||||
} else {
|
||||
//MBR decoding
|
||||
TrellisPathList nBestList;
|
||||
manager.CalcNBest(nBestSize, nBestList,true);
|
||||
VERBOSE(2,"size of n-best: " << nBestList.GetSize() << " (" << nBestSize << ")" << endl);
|
||||
IFVERBOSE(2) { PrintUserTime("calculated n-best list for MBR decoding"); }
|
||||
std::vector<const Factor*> mbrBestHypo = doMBR(nBestList);
|
||||
for (size_t i = 0 ; i < mbrBestHypo.size() ; i++) {
|
||||
const Factor *factor = mbrBestHypo[i];
|
||||
if (i>0) out << " ";
|
||||
out << factor->GetString();
|
||||
}
|
||||
out << endl;
|
||||
OutputBestHypo(mbrBestHypo, m_lineNumber,
|
||||
staticData.GetReportSegmentation(),
|
||||
staticData.GetReportAllFactors(),out);
|
||||
IFVERBOSE(2) { PrintUserTime("finished MBR decoding"); }
|
||||
|
||||
}
|
||||
}
|
||||
m_outputCollector->Write(m_lineNumber,out.str(),debug.str());
|
||||
|
@ -8,7 +8,7 @@ moses_SOURCES = Main.cpp mbr.cpp IOWrapper.cpp TranslationAnalysis.cpp LatticeMB
|
||||
moses_LDADD = -L$(top_srcdir)/moses/src -lmoses $(BOOST_LDFLAGS) $(BOOST_THREAD_LIB)
|
||||
moses_DEPENDENCIES = $(top_srcdir)/moses/src/libmoses.a
|
||||
|
||||
mosesmt_SOURCES = MainMT.cpp mbr.cpp IOWrapper.cpp TranslationAnalysis.cpp ThreadPool.cpp
|
||||
mosesmt_SOURCES = MainMT.cpp mbr.cpp IOWrapper.cpp TranslationAnalysis.cpp ThreadPool.cpp LatticeMBR.cpp
|
||||
mosesmt_LDADD = -L$(top_srcdir)/moses/src $(BOOST_LDFLAGS) -lmoses $(BOOST_THREAD_LIB)
|
||||
mosesmt_DEPENDENCIES = $(top_srcdir)/moses/src/libmoses.a
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user