add OutputNBest() as abstract method to BaseManager. Tighten up framework

This commit is contained in:
Hieu Hoang 2014-12-02 19:09:10 +00:00
parent ba7afba9f6
commit 3da8415095
14 changed files with 140 additions and 125 deletions

View File

@ -90,8 +90,8 @@ int main(int argc, char** argv)
}
// set number of significant decimals in output
IOWrapper::FixPrecision(cout);
IOWrapper::FixPrecision(cerr);
FixPrecision(cout);
FixPrecision(cerr);
// load all the settings into the Parameter class
// (stores them as strings, or array of strings)

View File

@ -9,8 +9,8 @@ using namespace std;
namespace Moses
{
void BaseManager::OutputAllFeatureScores(const Moses::ScoreComponentCollection &features
, std::ostream &out)
void BaseManager::OutputAllFeatureScores(const Moses::ScoreComponentCollection &features,
std::ostream &out) const
{
std::string lastName = "";
const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
@ -30,10 +30,10 @@ void BaseManager::OutputAllFeatureScores(const Moses::ScoreComponentCollection &
}
}
void BaseManager::OutputFeatureScores( std::ostream& out
, const ScoreComponentCollection &features
, const FeatureFunction *ff
, std::string &lastName )
void BaseManager::OutputFeatureScores( std::ostream& out,
const ScoreComponentCollection &features,
const FeatureFunction *ff,
std::string &lastName ) const
{
const StaticData &staticData = StaticData::Instance();
bool labeledOutput = staticData.IsLabeledNBestList();
@ -57,6 +57,37 @@ void BaseManager::OutputFeatureScores( std::ostream& out
}
}
/***
* print surface factor only for the given phrase
*/
void BaseManager::OutputSurface(std::ostream &out, const Phrase &phrase,
const std::vector<FactorType> &outputFactorOrder,
bool reportAllFactors) const
{
UTIL_THROW_IF2(outputFactorOrder.size() == 0,
"Cannot be empty phrase");
if (reportAllFactors == true) {
out << phrase;
} else {
size_t size = phrase.GetSize();
for (size_t pos = 0 ; pos < size ; pos++) {
const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);
out << *factor;
UTIL_THROW_IF2(factor == NULL,
"Empty factor 0 at position " << pos);
for (size_t i = 1 ; i < outputFactorOrder.size() ; i++) {
const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[i]);
UTIL_THROW_IF2(factor == NULL,
"Empty factor " << i << " at position " << pos);
out << "|" << *factor;
}
out << " ";
}
}
}
} // namespace

View File

@ -8,16 +8,25 @@ namespace Moses
{
class ScoreComponentCollection;
class FeatureFunction;
class OutputCollector;
class BaseManager
{
protected:
void OutputAllFeatureScores(const Moses::ScoreComponentCollection &features
, std::ostream &out);
void OutputFeatureScores( std::ostream& out
, const ScoreComponentCollection &features
, const FeatureFunction *ff
, std::string &lastName );
void OutputAllFeatureScores(const Moses::ScoreComponentCollection &features,
std::ostream &out) const;
void OutputFeatureScores( std::ostream& out,
const ScoreComponentCollection &features,
const FeatureFunction *ff,
std::string &lastName ) const;
void OutputSurface(std::ostream &out,
const Phrase &phrase,
const std::vector<FactorType> &outputFactorOrder,
bool reportAllFactors) const;
public:
// outputs
virtual void OutputNBest(OutputCollector *collector) const = 0;
};

View File

@ -298,7 +298,7 @@ void ChartManager::OutputSearchGraphMoses(std::ostream &outputSearchGraphStream)
WriteSearchGraph(writer);
}
void ChartManager::OutputNBest(OutputCollector *collector)
void ChartManager::OutputNBest(OutputCollector *collector) const
{
const StaticData &staticData = StaticData::Instance();
size_t nBestSize = staticData.GetNBestSize();
@ -316,15 +316,9 @@ void ChartManager::OutputNBest(OutputCollector *collector)
}
void FixPrecision(std::ostream &stream, size_t size = 3)
{
stream.setf(std::ios::fixed);
stream.precision(size);
}
void ChartManager::OutputNBestList(OutputCollector *collector,
const ChartKBestExtractor::KBestVec &nBestList,
long translationId)
long translationId) const
{
const StaticData &staticData = StaticData::Instance();
const std::vector<Moses::FactorType> &outputFactorOrder = staticData.GetOutputFactorOrder();
@ -386,36 +380,7 @@ void ChartManager::OutputNBestList(OutputCollector *collector,
collector->Write(translationId, out.str());
}
/***
* print surface factor only for the given phrase
*/
void ChartManager::OutputSurface(std::ostream &out, const Phrase &phrase, const std::vector<FactorType> &outputFactorOrder, bool reportAllFactors)
{
UTIL_THROW_IF2(outputFactorOrder.size() == 0,
"Cannot be empty phrase");
if (reportAllFactors == true) {
out << phrase;
} else {
size_t size = phrase.GetSize();
for (size_t pos = 0 ; pos < size ; pos++) {
const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);
out << *factor;
UTIL_THROW_IF2(factor == NULL,
"Empty factor 0 at position " << pos);
for (size_t i = 1 ; i < outputFactorOrder.size() ; i++) {
const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[i]);
UTIL_THROW_IF2(factor == NULL,
"Empty factor " << i << " at position " << pos);
out << "|" << *factor;
}
out << " ";
}
}
}
size_t ChartManager::CalcSourceSize(const Moses::ChartHypothesis *hypo)
size_t ChartManager::CalcSourceSize(const Moses::ChartHypothesis *hypo) const
{
size_t ret = hypo->GetCurrSourceRange().GetNumWordsCovered();
const std::vector<const ChartHypothesis*> &prevHypos = hypo->GetPrevHypos();
@ -429,7 +394,7 @@ size_t ChartManager::CalcSourceSize(const Moses::ChartHypothesis *hypo)
size_t ChartManager::OutputAlignmentNBest(
Alignments &retAlign,
const Moses::ChartKBestExtractor::Derivation &derivation,
size_t startTarget)
size_t startTarget) const
{
const ChartHypothesis &hypo = derivation.edge.head->hypothesis;

View File

@ -41,7 +41,6 @@ namespace Moses
class ChartHypothesis;
class ChartSearchGraphWriter;
class OutputCollector;
/** Holds everything you need to decode 1 sentence with the hierachical/syntax decoder
*/
@ -68,15 +67,14 @@ private:
void OutputNBestList(OutputCollector *collector,
const ChartKBestExtractor::KBestVec &nBestList,
long translationId);
void OutputSurface(std::ostream &out, const Phrase &phrase, const std::vector<FactorType> &outputFactorOrder, bool reportAllFactors);
size_t CalcSourceSize(const Moses::ChartHypothesis *hypo);
long translationId) const;
size_t CalcSourceSize(const Moses::ChartHypothesis *hypo) const;
size_t OutputAlignmentNBest(Alignments &retAlign,
const Moses::ChartKBestExtractor::Derivation &derivation,
size_t startTarget);
size_t startTarget) const;
template <class T>
void ShiftOffsets(std::vector<T> &offsets, T shift)
void ShiftOffsets(std::vector<T> &offsets, T shift) const
{
T currPos = shift;
for (size_t i = 0; i < offsets.size(); ++i) {
@ -138,7 +136,7 @@ public:
const ChartParser &GetParser() const { return m_parser; }
// outputs
void OutputNBest(OutputCollector *collector);
void OutputNBest(OutputCollector *collector) const;
};
}

View File

@ -258,12 +258,6 @@ GetInput(InputType* inputType)
}
}
void IOWrapper::FixPrecision(std::ostream &stream, size_t size)
{
stream.setf(std::ios::fixed);
stream.precision(size);
}
std::map<size_t, const Factor*> IOWrapper::GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor)
{
const InputPath &inputPath = hypo.GetTranslationOption().GetInputPath();
@ -628,34 +622,6 @@ void IOWrapper::OutputTreeFragmentsTranslationOptions(std::ostream &out, Applica
}
}
void IOWrapper::OutputNBestList(const std::vector<search::Applied> &nbest, long translationId)
{
std::ostringstream out;
// wtf? copied from the original OutputNBestList
if (m_nBestOutputCollector->OutputIsCout()) {
FixPrecision(out);
}
Phrase outputPhrase;
ScoreComponentCollection features;
for (std::vector<search::Applied>::const_iterator i = nbest.begin(); i != nbest.end(); ++i) {
Incremental::PhraseAndFeatures(*i, outputPhrase, features);
// <s> and </s>
UTIL_THROW_IF2(outputPhrase.GetSize() < 2,
"Output phrase should have contained at least 2 words (beginning and end-of-sentence)");
outputPhrase.RemoveWord(0);
outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
out << translationId << " ||| ";
OutputSurface(out, outputPhrase, *m_outputFactorOrder, false);
out << " ||| ";
OutputAllFeatureScores(features, out);
out << " ||| " << i->GetScore() << '\n';
}
out << std::flush;
assert(m_nBestOutputCollector);
m_nBestOutputCollector->Write(translationId, out.str());
}
/***
* print surface factor only for the given phrase
*/
@ -1353,7 +1319,7 @@ void IOWrapper::OutputBestHypo(const Syntax::SHyperedge *best,
return;
}
std::ostringstream out;
IOWrapper::FixPrecision(out);
FixPrecision(out);
if (best == NULL) {
VERBOSE(1, "NO BEST TRANSLATION" << std::endl);
if (StaticData::Instance().GetOutputHypoScore()) {
@ -1383,7 +1349,7 @@ void IOWrapper::OutputNBestList(
if (m_nBestOutputCollector->OutputIsCout()) {
// Set precision only if we're writing the n-best list to cout. This is to
// preserve existing behaviour, but should probably be done either way.
IOWrapper::FixPrecision(out);
FixPrecision(out);
}
bool includeWordAlignment =

View File

@ -157,8 +157,6 @@ protected:
}
public:
static void FixPrecision(std::ostream &, size_t size=3);
IOWrapper();
~IOWrapper();
@ -209,7 +207,6 @@ public:
void OutputBestNone(long translationId);
void OutputNBestList(const std::vector<boost::shared_ptr<Moses::ChartKBestExtractor::Derivation> > &nBestList, long translationId);
void OutputNBestList(const std::vector<search::Applied> &nbest, long translationId);
void OutputNBestList(const Moses::Syntax::KBestExtractor::KBestVec &nBestList, long translationId);
void OutputDetailedTranslationReport(const Moses::ChartHypothesis *hypo, const Moses::Sentence &sentence, long translationId);

View File

@ -8,6 +8,7 @@
#include "moses/StaticData.h"
#include "moses/Util.h"
#include "moses/LM/Base.h"
#include "moses/OutputCollector.h"
#include "lm/model.hh"
#include "search/applied.hh"
@ -278,6 +279,47 @@ const std::vector<search::Applied> &Manager::ProcessSentence()
return *completed_nbest_;
}
void Manager::OutputNBest(OutputCollector *collector) const
{
if (collector == NULL) {
return;
}
OutputNBestList(collector, *completed_nbest_, source_.GetTranslationId());
}
void Manager::OutputNBestList(OutputCollector *collector, const std::vector<search::Applied> &nbest, long translationId) const
{
const StaticData &staticData = StaticData::Instance();
const std::vector<Moses::FactorType> &outputFactorOrder = staticData.GetOutputFactorOrder();
std::ostringstream out;
// wtf? copied from the original OutputNBestList
if (collector->OutputIsCout()) {
FixPrecision(out);
}
Phrase outputPhrase;
ScoreComponentCollection features;
for (std::vector<search::Applied>::const_iterator i = nbest.begin(); i != nbest.end(); ++i) {
Incremental::PhraseAndFeatures(*i, outputPhrase, features);
// <s> and </s>
UTIL_THROW_IF2(outputPhrase.GetSize() < 2,
"Output phrase should have contained at least 2 words (beginning and end-of-sentence)");
outputPhrase.RemoveWord(0);
outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
out << translationId << " ||| ";
OutputSurface(out, outputPhrase, outputFactorOrder, false);
out << " ||| ";
OutputAllFeatureScores(features, out);
out << " ||| " << i->GetScore() << '\n';
}
out << std::flush;
assert(collector);
collector->Write(translationId, out.str());
}
namespace
{

View File

@ -37,6 +37,10 @@ public:
return *completed_nbest_;
}
// output
void OutputNBest(OutputCollector *collector) const;
private:
template <class Model, class Best> search::History PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out);
@ -53,6 +57,9 @@ private:
search::NBest n_best_;
const std::vector<search::Applied> *completed_nbest_;
// outputs
void OutputNBestList(OutputCollector *collector, const std::vector<search::Applied> &nbest, long translationId) const;
};
// Just get the phrase.

View File

@ -1448,7 +1448,7 @@ SentenceStats& Manager::GetSentenceStats() const
}
void Manager::OutputNBest(OutputCollector *collector)
void Manager::OutputNBest(OutputCollector *collector) const
{
const StaticData &staticData = StaticData::Instance();
@ -1467,7 +1467,7 @@ void Manager::OutputNBest(std::ostream& out
, const Moses::TrellisPathList &nBestList
, const std::vector<Moses::FactorType>& outputFactorOrder
, long translationId
, char reportSegmentation)
, char reportSegmentation) const
{
const StaticData &staticData = StaticData::Instance();
bool reportAllFactors = staticData.GetReportAllFactorsNBest();
@ -1542,7 +1542,7 @@ void Manager::OutputNBest(std::ostream& out
* print surface factor only for the given phrase
*/
void Manager::OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<FactorType> &outputFactorOrder,
char reportSegmentation, bool reportAllFactors)
char reportSegmentation, bool reportAllFactors) const
{
UTIL_THROW_IF2(outputFactorOrder.size() == 0,
"Must specific at least 1 output factor");
@ -1614,7 +1614,7 @@ void Manager::OutputSurface(std::ostream &out, const Hypothesis &edge, const std
}
}
void Manager::OutputAlignment(ostream &out, const AlignmentInfo &ai, size_t sourceOffset, size_t targetOffset)
void Manager::OutputAlignment(ostream &out, const AlignmentInfo &ai, size_t sourceOffset, size_t targetOffset) const
{
typedef std::vector< const std::pair<size_t,size_t>* > AlignVec;
AlignVec alignments = ai.GetSortedAlignments();
@ -1627,7 +1627,7 @@ void Manager::OutputAlignment(ostream &out, const AlignmentInfo &ai, size_t sour
}
void Manager::OutputInput(std::ostream& os, const Hypothesis* hypo)
void Manager::OutputInput(std::ostream& os, const Hypothesis* hypo) const
{
size_t len = hypo->GetInput().GetSize();
std::vector<const Phrase*> inp_phrases(len, 0);
@ -1636,7 +1636,7 @@ void Manager::OutputInput(std::ostream& os, const Hypothesis* hypo)
if (inp_phrases[i]) os << *inp_phrases[i];
}
void Manager::OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hypo)
void Manager::OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hypo) const
{
if (hypo->GetPrevHypo()) {
OutputInput(map, hypo->GetPrevHypo());
@ -1644,7 +1644,7 @@ void Manager::OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hyp
}
}
std::map<size_t, const Factor*> Manager::GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor)
std::map<size_t, const Factor*> Manager::GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor) const
{
const InputPath &inputPath = hypo.GetTranslationOption().GetInputPath();
const Phrase &inputPhrase = inputPath.GetPhrase();
@ -1664,7 +1664,7 @@ std::map<size_t, const Factor*> Manager::GetPlaceholders(const Hypothesis &hypo,
return ret;
}
void Manager::OutputLatticeSamples(OutputCollector *collector)
void Manager::OutputLatticeSamples(OutputCollector *collector) const
{
const StaticData &staticData = StaticData::Instance();
if (collector) {

View File

@ -42,7 +42,6 @@ namespace Moses
class SentenceStats;
class TrellisPath;
class TranslationOptionCollection;
class OutputCollector;
/** Used to output the search graph */
struct SearchGraphNode {
@ -134,13 +133,13 @@ protected:
, const Moses::TrellisPathList &nBestList
, const std::vector<Moses::FactorType>& outputFactorOrder
, long translationId
, char reportSegmentation);
, char reportSegmentation) const;
void OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<FactorType> &outputFactorOrder,
char reportSegmentation, bool reportAllFactors);
void OutputAlignment(std::ostream &out, const AlignmentInfo &ai, size_t sourceOffset, size_t targetOffset);
void OutputInput(std::ostream& os, const Hypothesis* hypo);
void OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hypo);
std::map<size_t, const Factor*> GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor);
char reportSegmentation, bool reportAllFactors) const;
void OutputAlignment(std::ostream &out, const AlignmentInfo &ai, size_t sourceOffset, size_t targetOffset) const;
void OutputInput(std::ostream& os, const Hypothesis* hypo) const;
void OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hypo) const;
std::map<size_t, const Factor*> GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor) const;
public:
InputType const& m_source; /**< source sentence to be translated */
@ -186,8 +185,8 @@ public:
std::vector< const Hypothesis* >* pConnectedList, std::map < const Hypothesis*, std::set < const Hypothesis* > >* pOutgoingHyps, std::vector< float>* pFwdBwdScores) const;
// outputs
void OutputNBest(OutputCollector *collector);
void OutputLatticeSamples(OutputCollector *collector);
void OutputNBest(OutputCollector *collector) const;
void OutputLatticeSamples(OutputCollector *collector) const;
};
}

View File

@ -92,7 +92,7 @@ void TranslationTask::RunPb()
// output word graph
if (m_ioWrapper.GetWordGraphCollector()) {
ostringstream out;
fix(out,PRECISION);
FixPrecision(out,PRECISION);
manager.GetWordGraph(m_source->GetTranslationId(), out);
m_ioWrapper.GetWordGraphCollector()->Write(m_source->GetTranslationId(), out.str());
}
@ -100,7 +100,7 @@ void TranslationTask::RunPb()
// output search graph
if (m_ioWrapper.GetSearchGraphOutputCollector()) {
ostringstream out;
fix(out,PRECISION);
FixPrecision(out,PRECISION);
manager.OutputSearchGraph(m_source->GetTranslationId(), out);
m_ioWrapper.GetSearchGraphOutputCollector()->Write(m_source->GetTranslationId(), out.str());
@ -128,7 +128,7 @@ void TranslationTask::RunPb()
file->open(fileName.str().c_str());
if (file->is_open() && file->good()) {
ostringstream out;
fix(out,PRECISION);
FixPrecision(out,PRECISION);
manager.OutputSearchGraphAsSLF(m_source->GetTranslationId(), out);
*file << out.str();
file -> flush();
@ -149,7 +149,7 @@ void TranslationTask::RunPb()
if (m_ioWrapper.GetSingleBestOutputCollector()) {
ostringstream out;
ostringstream debug;
fix(debug,PRECISION);
FixPrecision(debug,PRECISION);
// all derivations - send them to debug stream
if (staticData.PrintAllDerivations()) {
@ -283,7 +283,7 @@ void TranslationTask::RunPb()
// detailed translation reporting
if (m_ioWrapper.GetDetailedTranslationCollector()) {
ostringstream out;
fix(out,PRECISION);
FixPrecision(out,PRECISION);
TranslationAnalysis::PrintTranslationAnalysis(out, manager.GetBestHypothesis());
m_ioWrapper.GetDetailedTranslationCollector()->Write(m_source->GetTranslationId(),out.str());
}
@ -348,8 +348,9 @@ void TranslationTask::RunChart()
} else {
m_ioWrapper.OutputBestNone(translationId);
}
if (staticData.GetNBestSize() > 0)
m_ioWrapper.OutputNBestList(nbest, translationId);
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
return;
}

View File

@ -220,7 +220,7 @@ void PrintFeatureWeight(const FeatureFunction* ff)
void ShowWeights()
{
fix(cout,6);
FixPrecision(cout,6);
const vector<const StatelessFeatureFunction*>& slf = StatelessFeatureFunction::GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();

View File

@ -478,7 +478,7 @@ T log_sum (T log_a, T log_b)
}
/** Enforce rounding */
inline void fix(std::ostream& stream, size_t size)
inline void FixPrecision(std::ostream& stream, size_t size = 3)
{
stream.setf(std::ios::fixed);
stream.precision(size);