Implementation of Lattice sampling (Chatterjee and Cancedda, emnlp 2010)

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4295 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
bhaddow 2011-10-04 15:46:24 +00:00
parent 23d9a9b55e
commit 84d73700af
11 changed files with 268 additions and 33 deletions

View File

@ -237,6 +237,7 @@ void pruneLatticeFB(Lattice & connectedHyp, map < const Hypothesis*, set <const
const ArcList *arcList = succHyp->GetArcList();
if (arcList != NULL) {
ArcList::const_iterator iterArcList;
//QUESTION: What happens if there's more than one loserPrevHypo?
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
const Hypothesis *loserHypo = *iterArcList;
const Hypothesis* loserPrevHypo = loserHypo->GetPrevHypo();

View File

@ -19,17 +19,7 @@
using namespace Moses;
template<class T>
T log_sum (T log_a, T log_b)
{
T v;
if (log_a < log_b) {
v = log_b+log ( 1 + exp ( log_a-log_b ));
} else {
v = log_a+log ( 1 + exp ( log_b-log_a ));
}
return ( v );
}
class Edge;

View File

@ -72,11 +72,13 @@ public:
TranslationTask(size_t lineNumber,
InputType* source, OutputCollector* outputCollector, OutputCollector* nbestCollector,
OutputCollector* latticeSamplesCollector,
OutputCollector* wordGraphCollector, OutputCollector* searchGraphCollector,
OutputCollector* detailedTranslationCollector,
OutputCollector* alignmentInfoCollector ) :
m_source(source), m_lineNumber(lineNumber),
m_outputCollector(outputCollector), m_nbestCollector(nbestCollector),
m_latticeSamplesCollector(latticeSamplesCollector),
m_wordGraphCollector(wordGraphCollector), m_searchGraphCollector(searchGraphCollector),
m_detailedTranslationCollector(detailedTranslationCollector),
m_alignmentInfoCollector(alignmentInfoCollector) {}
@ -240,6 +242,15 @@ public:
m_nbestCollector->Write(m_lineNumber, out.str());
}
//lattice samples
if (m_latticeSamplesCollector) {
TrellisPathList latticeSamples;
ostringstream out;
manager.CalcLatticeSamples(staticData.GetLatticeSamplesSize(), latticeSamples);
OutputNBest(out,latticeSamples, staticData.GetOutputFactorOrder(), manager.GetTranslationSystem(), m_lineNumber);
m_latticeSamplesCollector->Write(m_lineNumber, out.str());
}
// detailed translation reporting
if (m_detailedTranslationCollector) {
ostringstream out;
@ -264,6 +275,7 @@ private:
size_t m_lineNumber;
OutputCollector* m_outputCollector;
OutputCollector* m_nbestCollector;
OutputCollector* m_latticeSamplesCollector;
OutputCollector* m_wordGraphCollector;
OutputCollector* m_searchGraphCollector;
OutputCollector* m_detailedTranslationCollector;
@ -352,6 +364,9 @@ int main(int argc, char** argv)
const StaticData& staticData = StaticData::Instance();
//initialise random numbers
srand(time(NULL));
// set up read/writing class
IOWrapper* ioWrapper = GetIODevice(staticData);
if (!ioWrapper) {
@ -380,21 +395,43 @@ int main(int argc, char** argv)
// because multithreading may return sentences in shuffled order
auto_ptr<OutputCollector> outputCollector; // for translations
auto_ptr<OutputCollector> nbestCollector; // for n-best lists
auto_ptr<OutputCollector> latticeSamplesCollector; //for lattice samples
auto_ptr<ofstream> nbestOut;
auto_ptr<ofstream> latticeSamplesOut;
size_t nbestSize = staticData.GetNBestSize();
string nbestFile = staticData.GetNBestFilePath();
bool output1best = true;
if (nbestSize) {
if (nbestFile == "-" || nbestFile == "/dev/stdout") {
// nbest to stdout, no 1-best
nbestCollector.reset(new OutputCollector());
output1best = false;
} else {
// nbest to file, 1-best to stdout
nbestOut.reset(new ofstream(nbestFile.c_str()));
assert(nbestOut->good());
if (!nbestOut->good()) {
TRACE_ERR("ERROR: Failed to open " << nbestFile << " for nbest lists" << endl);
exit(1);
}
nbestCollector.reset(new OutputCollector(nbestOut.get()));
outputCollector.reset(new OutputCollector());
}
} else {
}
size_t latticeSamplesSize = staticData.GetLatticeSamplesSize();
string latticeSamplesFile = staticData.GetLatticeSamplesFilePath();
if (latticeSamplesSize) {
if (latticeSamplesFile == "-" || latticeSamplesFile == "/dev/stdout") {
latticeSamplesCollector.reset(new OutputCollector());
output1best = false;
} else {
latticeSamplesOut.reset(new ofstream(latticeSamplesFile.c_str()));
if (!latticeSamplesOut->good()) {
TRACE_ERR("ERROR: Failed to open " << latticeSamplesFile << " for lattice samples" << endl);
exit(1);
}
latticeSamplesCollector.reset(new OutputCollector(latticeSamplesOut.get()));
}
}
if (output1best) {
outputCollector.reset(new OutputCollector());
}
@ -437,7 +474,9 @@ int main(int argc, char** argv)
// set up task of translating one sentence
TranslationTask* task =
new TranslationTask(lineCount,source, outputCollector.get(),
nbestCollector.get(), wordGraphCollector.get(),
nbestCollector.get(),
latticeSamplesCollector.get(),
wordGraphCollector.get(),
searchGraphCollector.get(),
detailedTranslationCollector.get(),
alignmentInfoCollector.get() );

View File

@ -25,6 +25,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <ext/hash_set>
#endif
#include <algorithm>
#include <limits>
#include <cmath>
#include "Manager.h"
@ -257,6 +258,152 @@ void Manager::CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct) co
}
}
struct SGNReverseCompare {
bool operator() (const SearchGraphNode& s1, const SearchGraphNode& s2) const {
return s1.hypo->GetId() > s2.hypo->GetId();
}
};
/**
* Implements lattice sampling, as in Chatterjee & Cancedda, emnlp 2010
**/
void Manager::CalcLatticeSamples(size_t count, TrellisPathList &ret) const {
vector<SearchGraphNode> searchGraph;
GetSearchGraph(searchGraph);
//Calculation of the sigmas of each hypothesis and edge. In C&C notation this is
//the "log of the cumulative unnormalized probability of all the paths in the
// lattice for the hypothesis to a final node"
typedef pair<int, int> Edge;
map<const Hypothesis*, float> sigmas;
map<Edge, float> edgeScores;
map<const Hypothesis*, set<const Hypothesis*> > outgoingHyps;
map<int,const Hypothesis*> idToHyp;
map<int,float> fscores;
//Iterating through the hypos in reverse order of id gives a reverse
//topological order. We rely on the fact that hypo ids are given out
//sequentially, as the search proceeds.
//NB: Could just sort by stack.
sort(searchGraph.begin(), searchGraph.end(), SGNReverseCompare());
//first task is to fill in the outgoing hypos and edge scores.
for (vector<SearchGraphNode>::const_iterator i = searchGraph.begin();
i != searchGraph.end(); ++i) {
const Hypothesis* hypo = i->hypo;
idToHyp[hypo->GetId()] = hypo;
fscores[hypo->GetId()] = i->fscore;
if (hypo->GetId()) {
//back to current
const Hypothesis* prevHypo = i->hypo->GetPrevHypo();
outgoingHyps[prevHypo].insert(hypo);
edgeScores[Edge(prevHypo->GetId(),hypo->GetId())] =
hypo->GetScore() - prevHypo->GetScore();
}
//forward from current
if (i->forward >= 0) {
map<int,const Hypothesis*>::const_iterator idToHypIter = idToHyp.find(i->forward);
assert(idToHypIter != idToHyp.end());
const Hypothesis* nextHypo = idToHypIter->second;
outgoingHyps[hypo].insert(nextHypo);
map<int,float>::const_iterator fscoreIter = fscores.find(nextHypo->GetId());
assert(fscoreIter != fscores.end());
edgeScores[Edge(hypo->GetId(),nextHypo->GetId())] =
i->fscore - fscoreIter->second;
}
}
//then run through again to calculate sigmas
for (vector<SearchGraphNode>::const_iterator i = searchGraph.begin();
i != searchGraph.end(); ++i) {
if (i->forward == -1) {
sigmas[i->hypo] = 0;
} else {
map<const Hypothesis*, set<const Hypothesis*> >::const_iterator outIter =
outgoingHyps.find(i->hypo);
assert(outIter != outgoingHyps.end());
float sigma = 0;
for (set<const Hypothesis*>::const_iterator j = outIter->second.begin();
j != outIter->second.end(); ++j) {
map<const Hypothesis*, float>::const_iterator succIter = sigmas.find(*j);
assert(succIter != sigmas.end());
map<Edge,float>::const_iterator edgeScoreIter =
edgeScores.find(Edge(i->hypo->GetId(),(*j)->GetId()));
assert(edgeScoreIter != edgeScores.end());
float term = edgeScoreIter->second + succIter->second; // Add sigma(*j)
if (sigma == 0) {
sigma = term;
} else {
sigma = log_sum(sigma,term);
}
}
sigmas[i->hypo] = sigma;
}
}
//The actual sampling!
const Hypothesis* startHypo = searchGraph.back().hypo;
assert(startHypo->GetId() == 0);
for (size_t i = 0; i < count; ++i) {
vector<const Hypothesis*> path;
path.push_back(startHypo);
while(1) {
map<const Hypothesis*, set<const Hypothesis*> >::const_iterator outIter =
outgoingHyps.find(path.back());
if (outIter == outgoingHyps.end() || !outIter->second.size()) {
//end of the path
break;
}
//score the possibles
vector<const Hypothesis*> candidates;
vector<float> candidateScores;
float scoreTotal = 0;
for (set<const Hypothesis*>::const_iterator j = outIter->second.begin();
j != outIter->second.end(); ++j) {
candidates.push_back(*j);
assert(sigmas.find(*j) != sigmas.end());
Edge edge(path.back()->GetId(),(*j)->GetId());
assert(edgeScores.find(edge) != edgeScores.end());
candidateScores.push_back(sigmas[*j] + edgeScores[edge]);
if (scoreTotal == 0) {
scoreTotal = candidateScores.back();
} else {
scoreTotal = log_sum(candidateScores.back(), scoreTotal);
}
}
//normalise
transform(candidateScores.begin(), candidateScores.end(), candidateScores.begin(), bind2nd(minus<float>(),scoreTotal));
//copy(candidateScores.begin(),candidateScores.end(),ostream_iterator<float>(cerr," "));
//cerr << endl;
//draw the sample
float random = log((float)rand()/RAND_MAX);
size_t position = 1;
float sum = candidateScores[0];
for (; position < candidateScores.size() && sum < random; ++position) {
sum = log_sum(sum,candidateScores[position]);
}
//cerr << "Random: " << random << " Chose " << position-1 << endl;
const Hypothesis* chosen = candidates[position-1];
path.push_back(chosen);
}
//cerr << "Path: " << endl;
//for (size_t j = 0; j < path.size(); ++j) {
// cerr << path[j]->GetId() << " " << path[j]->GetScoreBreakdown() << endl;
//}
//cerr << endl;
//Convert the hypos to TrellisPath
ret.Add(new TrellisPath(path));
//cerr << ret.at(ret.GetSize()-1).GetScoreBreakdown() << endl;
}
}

View File

@ -130,6 +130,7 @@ public:
const Hypothesis *GetBestHypothesis() const;
const Hypothesis *GetActualBestHypothesis() const;
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;
void CalcLatticeSamples(size_t count, TrellisPathList &ret) const;
void PrintAllDerivations(long translationId, std::ostream& outputStream) const;
void printDivergentHypothesis(long translationId, const Hypothesis* hypo, const std::vector <const TargetPhrase*> & remainingPhrases, float remainingScore , std::ostream& outputStream) const;
void printThisHypothesis(long translationId, const Hypothesis* hypo, const std::vector <const TargetPhrase* > & remainingPhrases, float remainingScore , std::ostream& outputStream) const;

View File

@ -62,6 +62,7 @@ Parameter::Parameter()
AddParam("max-trans-opt-per-coverage", "maximum number of translation options per input span (after applying mapping steps)");
AddParam("max-phrase-length", "maximum phrase length (default 20)");
AddParam("n-best-list", "file and size of n-best-list to be generated; specify - as the file in order to write to STDOUT");
AddParam("lattice-samples", "generate samples from lattice, in same format as nbest list. Uses the file and size arguments, as in n-best-list");
AddParam("n-best-factor", "factor to compute the maximum number of contenders (=factor*nbest-size). value 0 means infinity, i.e. no threshold. default is 0");
AddParam("print-all-derivations", "to print all derivations in search graph");
AddParam("output-factors", "list of factors in the output");

View File

@ -81,8 +81,8 @@ StaticData::StaticData()
,m_numInputScores(0)
,m_detailedTranslationReportingFilePath()
,m_onlyDistinctNBest(false)
,m_lmEnableOOVFeature(false)
,m_factorDelimiter("|") // default delimiter between factors
,m_lmEnableOOVFeature(false)
,m_isAlwaysCreateDirectTranslationOption(false)
{
m_maxFactorIdx[0] = 0; // source side
@ -169,7 +169,7 @@ bool StaticData::LoadData(Parameter *parameter)
m_nBestSize = Scan<size_t>( m_parameter->GetParam("n-best-list")[1] );
m_onlyDistinctNBest=(m_parameter->GetParam("n-best-list").size()>2 && m_parameter->GetParam("n-best-list")[2]=="distinct");
} else if (m_parameter->GetParam("n-best-list").size() == 1) {
UserMessage::Add(string("ERROR: wrong format for switch -n-best-list file size"));
UserMessage::Add(string("wrong format for switch -n-best-list file size"));
return false;
} else {
m_nBestSize = 0;
@ -180,6 +180,17 @@ bool StaticData::LoadData(Parameter *parameter)
m_nBestFactor = 20;
}
//lattice samples
if (m_parameter->GetParam("lattice-samples").size() ==2 ) {
m_latticeSamplesFilePath = m_parameter->GetParam("lattice-samples")[0];
m_latticeSamplesSize = Scan<size_t>(m_parameter->GetParam("lattice-samples")[1]);
} else if (m_parameter->GetParam("lattice-samples").size() != 0 ) {
UserMessage::Add(string("wrong format for switch -lattice-samples file size"));
return false;
} else {
m_latticeSamplesSize = 0;
}
// word graph
if (m_parameter->GetParam("output-word-graph").size() == 2)
m_outputWordGraph = true;

View File

@ -112,6 +112,7 @@ protected:
m_maxHypoStackSize //! hypothesis-stack size that triggers pruning
, m_minHypoStackDiversity //! minimum number of hypothesis in stack for each source word coverage
, m_nBestSize
, m_latticeSamplesSize
, m_nBestFactor
, m_maxNoTransOptPerCoverage
, m_maxNoPartTransOpt
@ -121,7 +122,7 @@ protected:
std::string
m_constraintFileName;
std::string m_nBestFilePath;
std::string m_nBestFilePath, m_latticeSamplesFilePath;
bool m_fLMsLoaded, m_labeledNBestList,m_nBestIncludesAlignment;
bool m_dropUnknown; //! false = treat unknown words as unknowns, and translate them as themselves; true = drop (ignore) them
bool m_wordDeletionEnabled;
@ -407,12 +408,20 @@ public:
return m_nBestFilePath;
}
bool IsNBestEnabled() const {
return (!m_nBestFilePath.empty()) || m_mbr || m_useLatticeMBR || m_outputSearchGraph || m_useConsensusDecoding
return (!m_nBestFilePath.empty()) || m_mbr || m_useLatticeMBR || m_outputSearchGraph || m_useConsensusDecoding || !m_latticeSamplesFilePath.empty()
#ifdef HAVE_PROTOBUF
|| m_outputSearchGraphPB
#endif
;
}
size_t GetLatticeSamplesSize() const {
return m_latticeSamplesSize;
}
const std::string& GetLatticeSamplesFilePath() const {
return m_latticeSamplesFilePath;
}
size_t GetNBestFactor() const {
return m_nBestFactor;
}

View File

@ -41,6 +41,25 @@ TrellisPath::TrellisPath(const Hypothesis *hypo)
}
}
void TrellisPath::InitScore() {
m_totalScore = m_path[0]->GetWinningHypo()->GetTotalScore();
m_scoreBreakdown= m_path[0]->GetWinningHypo()->GetScoreBreakdown();
//calc score
size_t sizePath = m_path.size();
for (size_t pos = 0 ; pos < sizePath ; pos++) {
const Hypothesis *hypo = m_path[pos];
const Hypothesis *winningHypo = hypo->GetWinningHypo();
if (hypo != winningHypo) {
m_totalScore = m_totalScore - winningHypo->GetTotalScore() + hypo->GetTotalScore();
m_scoreBreakdown.MinusEquals(winningHypo->GetScoreBreakdown());
m_scoreBreakdown.PlusEquals(hypo->GetScoreBreakdown());
}
}
}
TrellisPath::TrellisPath(const TrellisPath &copy, size_t edgeIndex, const Hypothesis *arc)
:m_prevEdgeChanged(edgeIndex)
{
@ -60,22 +79,20 @@ TrellisPath::TrellisPath(const TrellisPath &copy, size_t edgeIndex, const Hypoth
prevHypo = prevHypo->GetPrevHypo();
}
// Calc score
m_totalScore = m_path[0]->GetWinningHypo()->GetTotalScore();
m_scoreBreakdown= m_path[0]->GetWinningHypo()->GetScoreBreakdown();
size_t sizePath = m_path.size();
for (size_t pos = 0 ; pos < sizePath ; pos++) {
const Hypothesis *hypo = m_path[pos];
const Hypothesis *winningHypo = hypo->GetWinningHypo();
if (hypo != winningHypo) {
m_totalScore = m_totalScore - winningHypo->GetTotalScore() + hypo->GetTotalScore();
m_scoreBreakdown.MinusEquals(winningHypo->GetScoreBreakdown());
m_scoreBreakdown.PlusEquals(hypo->GetScoreBreakdown());
}
}
InitScore();
}
TrellisPath::TrellisPath(const vector<const Hypothesis*> edges)
:m_prevEdgeChanged(NOT_FOUND)
{
m_path.resize(edges.size());
copy(edges.rbegin(),edges.rend(),m_path.begin());
InitScore();
}
void TrellisPath::CreateDeviantPaths(TrellisPathCollection &pathColl) const
{
const size_t sizePath = m_path.size();

View File

@ -41,6 +41,7 @@ class TrellisPathList;
class TrellisPath
{
friend std::ostream& operator<<(std::ostream&, const TrellisPath&);
friend class Manager;
protected:
std::vector<const Hypothesis *> m_path; //< list of hypotheses/arcs
@ -51,7 +52,13 @@ protected:
ScoreComponentCollection m_scoreBreakdown;
float m_totalScore;
//Used by Manager::LatticeSample()
TrellisPath(const std::vector<const Hypothesis*> edges);
void InitScore();
public:
TrellisPath(); // not implemented
//! create path OF pure hypo

View File

@ -346,6 +346,18 @@ double GetUserTime();
// dump SGML parser for <seg> tags
std::map<std::string, std::string> ProcessAndStripSGML(std::string &line);
template<class T>
T log_sum (T log_a, T log_b)
{
T v;
if (log_a < log_b) {
v = log_b+log ( 1 + exp ( log_a-log_b ));
} else {
v = log_a+log ( 1 + exp ( log_b-log_a ));
}
return ( v );
}
}
#endif