move n-best code for phrase-based from IOWrapper to ChartManager

This commit is contained in:
Hieu Hoang 2014-12-02 17:40:53 +00:00
parent 08c57bce87
commit ba7afba9f6
8 changed files with 308 additions and 68 deletions

View File

@ -1,7 +1,61 @@
#include <vector>
#include "StaticData.h"
#include "BaseManager.h"
#include "moses/FF/StatelessFeatureFunction.h"
#include "moses/FF/StatefulFeatureFunction.h"
using namespace std;
namespace Moses
{
void BaseManager::OutputAllFeatureScores(const Moses::ScoreComponentCollection &features
, std::ostream &out)
{
std::string lastName = "";
const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
for( size_t i=0; i<sff.size(); i++ ) {
const StatefulFeatureFunction *ff = sff[i];
if (ff->GetScoreProducerDescription() != "BleuScoreFeature"
&& ff->IsTuneable()) {
OutputFeatureScores( out, features, ff, lastName );
}
}
const vector<const StatelessFeatureFunction*>& slf = StatelessFeatureFunction::GetStatelessFeatureFunctions();
for( size_t i=0; i<slf.size(); i++ ) {
const StatelessFeatureFunction *ff = slf[i];
if (ff->IsTuneable()) {
OutputFeatureScores( out, features, ff, lastName );
}
}
}
void BaseManager::OutputFeatureScores( std::ostream& out
, const ScoreComponentCollection &features
, const FeatureFunction *ff
, std::string &lastName )
{
const StaticData &staticData = StaticData::Instance();
bool labeledOutput = staticData.IsLabeledNBestList();
// regular features (not sparse)
if (ff->GetNumScoreComponents() != 0) {
if( labeledOutput && lastName != ff->GetScoreProducerDescription() ) {
lastName = ff->GetScoreProducerDescription();
out << " " << lastName << "=";
}
vector<float> scores = features.GetScoresForProducer( ff );
for (size_t j = 0; j<scores.size(); ++j) {
out << " " << scores[j];
}
}
// sparse features
const FVector scores = features.GetVectorForProducer( ff );
for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
out << " " << i->first << "= " << i->second;
}
}
}

View File

@ -1,9 +1,23 @@
#pragma once
#include <iostream>
#include <string>
#include "ScoreComponentCollection.h"
namespace Moses
{
class ScoreComponentCollection;
class FeatureFunction;
class BaseManager
{
protected:
void OutputAllFeatureScores(const Moses::ScoreComponentCollection &features
, std::ostream &out);
void OutputFeatureScores( std::ostream& out
, const ScoreComponentCollection &features
, const FeatureFunction *ff
, std::string &lastName );
};

View File

@ -30,9 +30,10 @@
#include "DecodeStep.h"
#include "TreeInput.h"
#include "moses/FF/WordPenaltyProducer.h"
#include "moses/OutputCollector.h"
#include "moses/ChartKBestExtractor.h"
using namespace std;
using namespace Moses;
namespace Moses
{
@ -297,4 +298,211 @@ void ChartManager::OutputSearchGraphMoses(std::ostream &outputSearchGraphStream)
WriteSearchGraph(writer);
}
void ChartManager::OutputNBest(OutputCollector *collector)
{
const StaticData &staticData = StaticData::Instance();
size_t nBestSize = staticData.GetNBestSize();
if (nBestSize > 0) {
const size_t translationId = m_source.GetTranslationId();
VERBOSE(2,"WRITING " << nBestSize << " TRANSLATION ALTERNATIVES TO " << staticData.GetNBestFilePath() << endl);
std::vector<boost::shared_ptr<ChartKBestExtractor::Derivation> > nBestList;
CalcNBest(nBestSize, nBestList,staticData.GetDistinctNBest());
OutputNBestList(collector, nBestList, translationId);
IFVERBOSE(2) {
PrintUserTime("N-Best Hypotheses Generation Time:");
}
}
}
void FixPrecision(std::ostream &stream, size_t size = 3)
{
stream.setf(std::ios::fixed);
stream.precision(size);
}
void ChartManager::OutputNBestList(OutputCollector *collector,
const ChartKBestExtractor::KBestVec &nBestList,
long translationId)
{
const StaticData &staticData = StaticData::Instance();
const std::vector<Moses::FactorType> &outputFactorOrder = staticData.GetOutputFactorOrder();
std::ostringstream out;
if (collector->OutputIsCout()) {
// Set precision only if we're writing the n-best list to cout. This is to
// preserve existing behaviour, but should probably be done either way.
FixPrecision(out);
}
bool includeWordAlignment =
StaticData::Instance().PrintAlignmentInfoInNbest();
bool PrintNBestTrees = StaticData::Instance().PrintNBestTrees();
for (ChartKBestExtractor::KBestVec::const_iterator p = nBestList.begin();
p != nBestList.end(); ++p) {
const ChartKBestExtractor::Derivation &derivation = **p;
// get the derivation's target-side yield
Phrase outputPhrase = ChartKBestExtractor::GetOutputPhrase(derivation);
// delete <s> and </s>
UTIL_THROW_IF2(outputPhrase.GetSize() < 2,
"Output phrase should have contained at least 2 words (beginning and end-of-sentence)");
outputPhrase.RemoveWord(0);
outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
// print the translation ID, surface factors, and scores
out << translationId << " ||| ";
OutputSurface(out, outputPhrase, outputFactorOrder, false);
out << " ||| ";
OutputAllFeatureScores(derivation.scoreBreakdown, out);
out << " ||| " << derivation.score;
// optionally, print word alignments
if (includeWordAlignment) {
out << " ||| ";
Alignments align;
OutputAlignmentNBest(align, derivation, 0);
for (Alignments::const_iterator q = align.begin(); q != align.end();
++q) {
out << q->first << "-" << q->second << " ";
}
}
// optionally, print tree
if (PrintNBestTrees) {
TreePointer tree = ChartKBestExtractor::GetOutputTree(derivation);
out << " ||| " << tree->GetString();
}
out << std::endl;
}
assert(collector);
collector->Write(translationId, out.str());
}
/***
* print surface factor only for the given phrase
*/
void ChartManager::OutputSurface(std::ostream &out, const Phrase &phrase, const std::vector<FactorType> &outputFactorOrder, bool reportAllFactors)
{
UTIL_THROW_IF2(outputFactorOrder.size() == 0,
"Cannot be empty phrase");
if (reportAllFactors == true) {
out << phrase;
} else {
size_t size = phrase.GetSize();
for (size_t pos = 0 ; pos < size ; pos++) {
const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);
out << *factor;
UTIL_THROW_IF2(factor == NULL,
"Empty factor 0 at position " << pos);
for (size_t i = 1 ; i < outputFactorOrder.size() ; i++) {
const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[i]);
UTIL_THROW_IF2(factor == NULL,
"Empty factor " << i << " at position " << pos);
out << "|" << *factor;
}
out << " ";
}
}
}
size_t ChartManager::CalcSourceSize(const Moses::ChartHypothesis *hypo)
{
size_t ret = hypo->GetCurrSourceRange().GetNumWordsCovered();
const std::vector<const ChartHypothesis*> &prevHypos = hypo->GetPrevHypos();
for (size_t i = 0; i < prevHypos.size(); ++i) {
size_t childSize = prevHypos[i]->GetCurrSourceRange().GetNumWordsCovered();
ret -= (childSize - 1);
}
return ret;
}
size_t ChartManager::OutputAlignmentNBest(
Alignments &retAlign,
const Moses::ChartKBestExtractor::Derivation &derivation,
size_t startTarget)
{
const ChartHypothesis &hypo = derivation.edge.head->hypothesis;
size_t totalTargetSize = 0;
size_t startSource = hypo.GetCurrSourceRange().GetStartPos();
const TargetPhrase &tp = hypo.GetCurrTargetPhrase();
size_t thisSourceSize = CalcSourceSize(&hypo);
// position of each terminal word in translation rule, irrespective of alignment
// if non-term, number is undefined
vector<size_t> sourceOffsets(thisSourceSize, 0);
vector<size_t> targetOffsets(tp.GetSize(), 0);
const AlignmentInfo &aiNonTerm = hypo.GetCurrTargetPhrase().GetAlignNonTerm();
vector<size_t> sourceInd2pos = aiNonTerm.GetSourceIndex2PosMap();
const AlignmentInfo::NonTermIndexMap &targetPos2SourceInd = aiNonTerm.GetNonTermIndexMap();
UTIL_THROW_IF2(sourceInd2pos.size() != derivation.subderivations.size(),
"Error");
size_t targetInd = 0;
for (size_t targetPos = 0; targetPos < tp.GetSize(); ++targetPos) {
if (tp.GetWord(targetPos).IsNonTerminal()) {
UTIL_THROW_IF2(targetPos >= targetPos2SourceInd.size(), "Error");
size_t sourceInd = targetPos2SourceInd[targetPos];
size_t sourcePos = sourceInd2pos[sourceInd];
const Moses::ChartKBestExtractor::Derivation &subderivation =
*derivation.subderivations[sourceInd];
// calc source size
size_t sourceSize = subderivation.edge.head->hypothesis.GetCurrSourceRange().GetNumWordsCovered();
sourceOffsets[sourcePos] = sourceSize;
// calc target size.
// Recursively look thru child hypos
size_t currStartTarget = startTarget + totalTargetSize;
size_t targetSize = OutputAlignmentNBest(retAlign, subderivation,
currStartTarget);
targetOffsets[targetPos] = targetSize;
totalTargetSize += targetSize;
++targetInd;
} else {
++totalTargetSize;
}
}
// convert position within translation rule to absolute position within
// source sentence / output sentence
ShiftOffsets(sourceOffsets, startSource);
ShiftOffsets(targetOffsets, startTarget);
// get alignments from this hypo
const AlignmentInfo &aiTerm = hypo.GetCurrTargetPhrase().GetAlignTerm();
// add to output arg, offsetting by source & target
AlignmentInfo::const_iterator iter;
for (iter = aiTerm.begin(); iter != aiTerm.end(); ++iter) {
const std::pair<size_t,size_t> &align = *iter;
size_t relSource = align.first;
size_t relTarget = align.second;
size_t absSource = sourceOffsets[relSource];
size_t absTarget = targetOffsets[relTarget];
pair<size_t, size_t> alignPoint(absSource, absTarget);
pair<Alignments::iterator, bool> ret = retAlign.insert(alignPoint);
UTIL_THROW_IF2(!ret.second, "Error");
}
return totalTargetSize;
}
} // namespace Moses

View File

@ -32,6 +32,7 @@
#include "ChartParser.h"
#include "ChartKBestExtractor.h"
#include "BaseManager.h"
#include "moses/Syntax/KBestExtractor.h"
#include <boost/shared_ptr.hpp>
@ -40,6 +41,7 @@ namespace Moses
class ChartHypothesis;
class ChartSearchGraphWriter;
class OutputCollector;
/** Holds everything you need to decode 1 sentence with the hierachical/syntax decoder
*/
@ -61,6 +63,32 @@ private:
const ChartHypothesis *hypo, std::map<unsigned,bool> &reachable , size_t* winners, size_t* losers) const;
void WriteSearchGraph(const ChartSearchGraphWriter& writer) const;
// output
typedef std::set< std::pair<size_t, size_t> > Alignments;
void OutputNBestList(OutputCollector *collector,
const ChartKBestExtractor::KBestVec &nBestList,
long translationId);
void OutputSurface(std::ostream &out, const Phrase &phrase, const std::vector<FactorType> &outputFactorOrder, bool reportAllFactors);
size_t CalcSourceSize(const Moses::ChartHypothesis *hypo);
size_t OutputAlignmentNBest(Alignments &retAlign,
const Moses::ChartKBestExtractor::Derivation &derivation,
size_t startTarget);
template <class T>
void ShiftOffsets(std::vector<T> &offsets, T shift)
{
T currPos = shift;
for (size_t i = 0; i < offsets.size(); ++i) {
if (offsets[i] == 0) {
offsets[i] = currPos;
++currPos;
} else {
currPos += offsets[i];
}
}
}
public:
ChartManager(InputType const& source);
~ChartManager();
@ -109,6 +137,8 @@ public:
const ChartParser &GetParser() const { return m_parser; }
// outputs
void OutputNBest(OutputCollector *collector);
};
}

View File

@ -1614,54 +1614,6 @@ void Manager::OutputSurface(std::ostream &out, const Hypothesis &edge, const std
}
}
void Manager::OutputAllFeatureScores(const Moses::ScoreComponentCollection &features
, std::ostream &out)
{
std::string lastName = "";
const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
for( size_t i=0; i<sff.size(); i++ ) {
const StatefulFeatureFunction *ff = sff[i];
if (ff->GetScoreProducerDescription() != "BleuScoreFeature"
&& ff->IsTuneable()) {
OutputFeatureScores( out, features, ff, lastName );
}
}
const vector<const StatelessFeatureFunction*>& slf = StatelessFeatureFunction::GetStatelessFeatureFunctions();
for( size_t i=0; i<slf.size(); i++ ) {
const StatelessFeatureFunction *ff = slf[i];
if (ff->IsTuneable()) {
OutputFeatureScores( out, features, ff, lastName );
}
}
}
void Manager::OutputFeatureScores( std::ostream& out
, const ScoreComponentCollection &features
, const FeatureFunction *ff
, std::string &lastName )
{
const StaticData &staticData = StaticData::Instance();
bool labeledOutput = staticData.IsLabeledNBestList();
// regular features (not sparse)
if (ff->GetNumScoreComponents() != 0) {
if( labeledOutput && lastName != ff->GetScoreProducerDescription() ) {
lastName = ff->GetScoreProducerDescription();
out << " " << lastName << "=";
}
vector<float> scores = features.GetScoresForProducer( ff );
for (size_t j = 0; j<scores.size(); ++j) {
out << " " << scores[j];
}
}
// sparse features
const FVector scores = features.GetVectorForProducer( ff );
for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
out << " " << i->first << "= " << i->second;
}
}
void Manager::OutputAlignment(ostream &out, const AlignmentInfo &ai, size_t sourceOffset, size_t targetOffset)
{
typedef std::vector< const std::pair<size_t,size_t>* > AlignVec;

View File

@ -137,12 +137,6 @@ protected:
, char reportSegmentation);
void OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<FactorType> &outputFactorOrder,
char reportSegmentation, bool reportAllFactors);
void OutputAllFeatureScores(const Moses::ScoreComponentCollection &features
, std::ostream &out);
void OutputFeatureScores( std::ostream& out
, const ScoreComponentCollection &features
, const FeatureFunction *ff
, std::string &lastName );
void OutputAlignment(std::ostream &out, const AlignmentInfo &ai, size_t sourceOffset, size_t targetOffset);
void OutputInput(std::ostream& os, const Hypothesis* hypo);
void OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hypo);

View File

@ -398,16 +398,7 @@ void TranslationTask::RunChart()
}
// n-best
size_t nBestSize = staticData.GetNBestSize();
if (nBestSize > 0) {
VERBOSE(2,"WRITING " << nBestSize << " TRANSLATION ALTERNATIVES TO " << staticData.GetNBestFilePath() << endl);
std::vector<boost::shared_ptr<ChartKBestExtractor::Derivation> > nBestList;
manager.CalcNBest(nBestSize, nBestList,staticData.GetDistinctNBest());
m_ioWrapper.OutputNBestList(nBestList, translationId);
IFVERBOSE(2) {
PrintUserTime("N-Best Hypotheses Generation Time:");
}
}
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
if (staticData.GetOutputSearchGraph()) {
std::ostringstream out;

View File

@ -18,8 +18,6 @@
***********************************************************************/
#pragma once
#ifndef EXTRACT_GHKM_ALIGNMENT_H_
#define EXTRACT_GHKM_ALIGNMENT_H_
#include <string>
#include <utility>
@ -39,4 +37,3 @@ void FlipAlignment(Alignment &);
} // namespace GHKM
} // namespace Moses
#endif