Add Moses::Syntax::Manager class

Sits between Moses::BaseManager and S2T::Manager, F2S::Manager, etc.
This commit is contained in:
Phil Williams 2014-12-09 15:47:55 +00:00
parent a0b6b6a341
commit 030ea19e6c
6 changed files with 280 additions and 234 deletions

View File

@ -5,6 +5,7 @@
#include <queue>
#include <vector>
#include <boost/unordered_map.hpp>
#include <boost/unordered_set.hpp>
#include <boost/weak_ptr.hpp>

206
moses/Syntax/Manager.cpp Normal file
View File

@ -0,0 +1,206 @@
#include "Manager.h"
#include <sstream>
#include "moses/OutputCollector.h"
#include "moses/StaticData.h"
#include "PVertex.h"
namespace Moses
{
namespace Syntax
{
Manager::Manager(const InputType &source)
: Moses::BaseManager(source)
{
}
void Manager::OutputNBest(OutputCollector *collector) const
{
if (collector) {
const StaticData &staticData = StaticData::Instance();
long translationId = m_source.GetTranslationId();
KBestExtractor::KBestVec nBestList;
ExtractKBest(staticData.GetNBestSize(), nBestList,
staticData.GetDistinctNBest());
OutputNBestList(collector, nBestList, translationId);
}
}
void Manager::OutputUnknowns(OutputCollector *collector) const
{
if (collector) {
long translationId = m_source.GetTranslationId();
std::ostringstream out;
for (std::set<Moses::Word>::const_iterator p = m_oovs.begin();
p != m_oovs.end(); ++p) {
out << *p;
}
out << std::endl;
collector->Write(translationId, out.str());
}
}
void Manager::OutputNBestList(OutputCollector *collector,
const KBestExtractor::KBestVec &nBestList,
long translationId) const
{
const StaticData &staticData = StaticData::Instance();
const std::vector<FactorType> &outputFactorOrder =
staticData.GetOutputFactorOrder();
std::ostringstream out;
if (collector->OutputIsCout()) {
// Set precision only if we're writing the n-best list to cout. This is to
// preserve existing behaviour, but should probably be done either way.
FixPrecision(out);
}
bool includeWordAlignment = staticData.PrintAlignmentInfoInNbest();
bool PrintNBestTrees = staticData.PrintNBestTrees();
for (KBestExtractor::KBestVec::const_iterator p = nBestList.begin();
p != nBestList.end(); ++p) {
const KBestExtractor::Derivation &derivation = **p;
// get the derivation's target-side yield
Phrase outputPhrase = KBestExtractor::GetOutputPhrase(derivation);
// delete <s> and </s>
UTIL_THROW_IF2(outputPhrase.GetSize() < 2,
"Output phrase should have contained at least 2 words (beginning and end-of-sentence)");
outputPhrase.RemoveWord(0);
outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
// print the translation ID, surface factors, and scores
out << translationId << " ||| ";
OutputSurface(out, outputPhrase, outputFactorOrder, false);
out << " ||| ";
OutputAllFeatureScores(derivation.scoreBreakdown, out);
out << " ||| " << derivation.score;
// optionally, print word alignments
if (includeWordAlignment) {
out << " ||| ";
Alignments align;
OutputAlignmentNBest(align, derivation, 0);
for (Alignments::const_iterator q = align.begin(); q != align.end();
++q) {
out << q->first << "-" << q->second << " ";
}
}
// optionally, print tree
if (PrintNBestTrees) {
TreePointer tree = KBestExtractor::GetOutputTree(derivation);
out << " ||| " << tree->GetString();
}
out << std::endl;
}
assert(collector);
collector->Write(translationId, out.str());
}
std::size_t Manager::OutputAlignmentNBest(
Alignments &retAlign,
const KBestExtractor::Derivation &derivation,
std::size_t startTarget) const
{
const SHyperedge &shyperedge = derivation.edge->shyperedge;
std::size_t totalTargetSize = 0;
std::size_t startSource = shyperedge.head->pvertex->span.GetStartPos();
const TargetPhrase &tp = *(shyperedge.translation);
std::size_t thisSourceSize = CalcSourceSize(derivation);
// position of each terminal word in translation rule, irrespective of
// alignment if non-term, number is undefined
std::vector<std::size_t> sourceOffsets(thisSourceSize, 0);
std::vector<std::size_t> targetOffsets(tp.GetSize(), 0);
const AlignmentInfo &aiNonTerm = shyperedge.translation->GetAlignNonTerm();
std::vector<std::size_t> sourceInd2pos = aiNonTerm.GetSourceIndex2PosMap();
const AlignmentInfo::NonTermIndexMap &targetPos2SourceInd =
aiNonTerm.GetNonTermIndexMap();
UTIL_THROW_IF2(sourceInd2pos.size() != derivation.subderivations.size(),
"Error");
std::size_t targetInd = 0;
for (std::size_t targetPos = 0; targetPos < tp.GetSize(); ++targetPos) {
if (tp.GetWord(targetPos).IsNonTerminal()) {
UTIL_THROW_IF2(targetPos >= targetPos2SourceInd.size(), "Error");
std::size_t sourceInd = targetPos2SourceInd[targetPos];
std::size_t sourcePos = sourceInd2pos[sourceInd];
const KBestExtractor::Derivation &subderivation =
*derivation.subderivations[sourceInd];
// calc source size
std::size_t sourceSize =
subderivation.edge->head->svertex.pvertex->span.GetNumWordsCovered();
sourceOffsets[sourcePos] = sourceSize;
// calc target size.
// Recursively look thru child hypos
std::size_t currStartTarget = startTarget + totalTargetSize;
std::size_t targetSize = OutputAlignmentNBest(retAlign, subderivation,
currStartTarget);
targetOffsets[targetPos] = targetSize;
totalTargetSize += targetSize;
++targetInd;
} else {
++totalTargetSize;
}
}
// convert position within translation rule to absolute position within
// source sentence / output sentence
ShiftOffsets(sourceOffsets, startSource);
ShiftOffsets(targetOffsets, startTarget);
// get alignments from this hypo
const AlignmentInfo &aiTerm = shyperedge.translation->GetAlignTerm();
// add to output arg, offsetting by source & target
AlignmentInfo::const_iterator iter;
for (iter = aiTerm.begin(); iter != aiTerm.end(); ++iter) {
const std::pair<std::size_t, std::size_t> &align = *iter;
std::size_t relSource = align.first;
std::size_t relTarget = align.second;
std::size_t absSource = sourceOffsets[relSource];
std::size_t absTarget = targetOffsets[relTarget];
std::pair<std::size_t, std::size_t> alignPoint(absSource, absTarget);
std::pair<Alignments::iterator, bool> ret = retAlign.insert(alignPoint);
UTIL_THROW_IF2(!ret.second, "Error");
}
return totalTargetSize;
}
std::size_t Manager::CalcSourceSize(const KBestExtractor::Derivation &d) const
{
const SHyperedge &shyperedge = d.edge->shyperedge;
std::size_t ret = shyperedge.head->pvertex->span.GetNumWordsCovered();
for (std::size_t i = 0; i < shyperedge.tail.size(); ++i) {
std::size_t childSize =
shyperedge.tail[i]->pvertex->span.GetNumWordsCovered();
ret -= (childSize - 1);
}
return ret;
}
} // Syntax
} // Moses

58
moses/Syntax/Manager.h Normal file
View File

@ -0,0 +1,58 @@
#pragma once
#include "moses/InputType.h"
#include "moses/BaseManager.h"
#include "KBestExtractor.h"
namespace Moses
{
namespace Syntax
{
// Common base class for Moses::Syntax managers.
class Manager : public BaseManager
{
public:
Manager(const InputType &);
// Virtual functions from Moses::BaseManager that are implemented the same
// way for all Syntax managers.
void OutputNBest(OutputCollector *collector) const;
void OutputUnknowns(OutputCollector *collector) const;
// Virtual functions from Moses::BaseManager that are no-ops for all Syntax
// managers.
void OutputLatticeSamples(OutputCollector *collector) const {}
void OutputAlignment(OutputCollector *collector) const {}
void OutputDetailedTreeFragmentsTranslationReport(
OutputCollector *collector) const {}
void OutputWordGraph(OutputCollector *collector) const {}
void OutputSearchGraph(OutputCollector *collector) const {}
void OutputSearchGraphSLF() const {}
void OutputSearchGraphHypergraph() const {}
// Syntax-specific virtual functions that derived classes must implement.
virtual void ExtractKBest(
std::size_t k,
std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList,
bool onlyDistinct=false) const = 0;
protected:
std::set<Word> m_oovs;
private:
// Syntax-specific helper functions used to implement OutputNBest.
void OutputNBestList(OutputCollector *collector,
const KBestExtractor::KBestVec &nBestList,
long translationId) const;
std::size_t OutputAlignmentNBest(Alignments &retAlign,
const KBestExtractor::Derivation &d,
std::size_t startTarget) const;
std::size_t CalcSourceSize(const KBestExtractor::Derivation &d) const;
};
} // Syntax
} // Moses

View File

@ -2,6 +2,7 @@
#include <iostream>
#include <sstream>
#include "moses/DecodeGraph.h"
#include "moses/StaticData.h"
#include "moses/Syntax/BoundedPriorityContainer.h"
@ -14,8 +15,8 @@
#include "moses/Syntax/SVertexRecombinationOrderer.h"
#include "moses/Syntax/SymbolEqualityPred.h"
#include "moses/Syntax/SymbolHasher.h"
#include "DerivationWriter.h"
#include "DerivationWriter.h"
#include "OovHandler.h"
#include "PChart.h"
#include "RuleTrie.h"
@ -30,7 +31,7 @@ namespace S2T
template<typename Parser>
Manager<Parser>::Manager(const InputType &source)
: BaseManager(source)
: Syntax::Manager(source)
, m_pchart(source.GetSize(), Parser::RequiresCompressedChart())
, m_schart(source.GetSize())
{
@ -44,7 +45,7 @@ void Manager<Parser>::InitializeCharts()
const Word &terminal = m_source.GetWord(i);
// PVertex
PVertex tmp(WordsRange(i,i), m_source.GetWord(i));
PVertex tmp(WordsRange(i,i), terminal);
PVertex &pvertex = m_pchart.AddVertex(tmp);
// SVertex
@ -262,6 +263,7 @@ const SHyperedge *Manager<Parser>::GetBestSHyperedge() const
}
assert(stacks.Size() == 1);
const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second;
// TODO Throw exception if stack is empty? Or return 0?
return stack[0]->best;
}
@ -284,6 +286,7 @@ void Manager<Parser>::ExtractKBest(
}
assert(stacks.Size() == 1);
const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second;
// TODO Throw exception if stack is empty? Or return 0?
KBestExtractor extractor;
@ -386,212 +389,17 @@ void Manager<Parser>::RecombineAndSort(const std::vector<SHyperedge*> &buffer,
}
template<typename Parser>
void Manager<Parser>::OutputNBest(OutputCollector *collector) const
{
if (collector) {
const StaticData &staticData = StaticData::Instance();
long translationId = m_source.GetTranslationId();
Syntax::KBestExtractor::KBestVec nBestList;
ExtractKBest(staticData.GetNBestSize(), nBestList,
staticData.GetDistinctNBest());
OutputNBestList(collector, nBestList, translationId);
}
}
template<typename Parser>
void Manager<Parser>::OutputDetailedTranslationReport(OutputCollector *collector) const
void Manager<Parser>::OutputDetailedTranslationReport(
OutputCollector *collector) const
{
const SHyperedge *best = GetBestSHyperedge();
if (best == NULL || collector == NULL) {
return;
return;
}
long translationId = m_source.GetTranslationId();
std::ostringstream out;
Syntax::S2T::DerivationWriter::Write(*best, translationId, out);
DerivationWriter::Write(*best, translationId, out);
collector->Write(translationId, out.str());
}
template<typename Parser>
void Manager<Parser>::OutputUnknowns(OutputCollector *collector) const
{
if (collector) {
long translationId = m_source.GetTranslationId();
std::ostringstream out;
for (std::set<Moses::Word>::const_iterator p = m_oovs.begin();
p != m_oovs.end(); ++p) {
out << *p;
}
out << std::endl;
collector->Write(translationId, out.str());
}
}
template<typename Parser>
void Manager<Parser>::OutputNBestList(OutputCollector *collector,
const Syntax::KBestExtractor::KBestVec &nBestList,
long translationId) const
{
const StaticData &staticData = StaticData::Instance();
const std::vector<Moses::FactorType> &outputFactorOrder = staticData.GetOutputFactorOrder();
std::ostringstream out;
if (collector->OutputIsCout()) {
// Set precision only if we're writing the n-best list to cout. This is to
// preserve existing behaviour, but should probably be done either way.
FixPrecision(out);
}
bool includeWordAlignment =
staticData.PrintAlignmentInfoInNbest();
bool PrintNBestTrees = StaticData::Instance().PrintNBestTrees();
for (Syntax::KBestExtractor::KBestVec::const_iterator p = nBestList.begin();
p != nBestList.end(); ++p) {
const Syntax::KBestExtractor::Derivation &derivation = **p;
// get the derivation's target-side yield
Phrase outputPhrase = Syntax::KBestExtractor::GetOutputPhrase(derivation);
// delete <s> and </s>
UTIL_THROW_IF2(outputPhrase.GetSize() < 2,
"Output phrase should have contained at least 2 words (beginning and end-of-sentence)");
outputPhrase.RemoveWord(0);
outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
// print the translation ID, surface factors, and scores
out << translationId << " ||| ";
OutputSurface(out, outputPhrase, outputFactorOrder, false);
out << " ||| ";
OutputAllFeatureScores(derivation.scoreBreakdown, out);
out << " ||| " << derivation.score;
// optionally, print word alignments
if (includeWordAlignment) {
out << " ||| ";
Alignments align;
OutputAlignmentNBest(align, derivation, 0);
for (Alignments::const_iterator q = align.begin(); q != align.end();
++q) {
out << q->first << "-" << q->second << " ";
}
}
// optionally, print tree
if (PrintNBestTrees) {
TreePointer tree = Syntax::KBestExtractor::GetOutputTree(derivation);
out << " ||| " << tree->GetString();
}
out << std::endl;
}
assert(collector);
collector->Write(translationId, out.str());
}
template<typename Parser>
size_t Manager<Parser>::OutputAlignmentNBest(
Alignments &retAlign,
const Syntax::KBestExtractor::Derivation &derivation,
size_t startTarget) const
{
const Syntax::SHyperedge &shyperedge = derivation.edge->shyperedge;
size_t totalTargetSize = 0;
size_t startSource = shyperedge.head->pvertex->span.GetStartPos();
const TargetPhrase &tp = *(shyperedge.translation);
size_t thisSourceSize = CalcSourceSize(derivation);
// position of each terminal word in translation rule, irrespective of alignment
// if non-term, number is undefined
std::vector<size_t> sourceOffsets(thisSourceSize, 0);
std::vector<size_t> targetOffsets(tp.GetSize(), 0);
const AlignmentInfo &aiNonTerm = shyperedge.translation->GetAlignNonTerm();
std::vector<size_t> sourceInd2pos = aiNonTerm.GetSourceIndex2PosMap();
const AlignmentInfo::NonTermIndexMap &targetPos2SourceInd = aiNonTerm.GetNonTermIndexMap();
UTIL_THROW_IF2(sourceInd2pos.size() != derivation.subderivations.size(),
"Error");
size_t targetInd = 0;
for (size_t targetPos = 0; targetPos < tp.GetSize(); ++targetPos) {
if (tp.GetWord(targetPos).IsNonTerminal()) {
UTIL_THROW_IF2(targetPos >= targetPos2SourceInd.size(), "Error");
size_t sourceInd = targetPos2SourceInd[targetPos];
size_t sourcePos = sourceInd2pos[sourceInd];
const Moses::Syntax::KBestExtractor::Derivation &subderivation =
*derivation.subderivations[sourceInd];
// calc source size
size_t sourceSize =
subderivation.edge->head->svertex.pvertex->span.GetNumWordsCovered();
sourceOffsets[sourcePos] = sourceSize;
// calc target size.
// Recursively look thru child hypos
size_t currStartTarget = startTarget + totalTargetSize;
size_t targetSize = OutputAlignmentNBest(retAlign, subderivation,
currStartTarget);
targetOffsets[targetPos] = targetSize;
totalTargetSize += targetSize;
++targetInd;
} else {
++totalTargetSize;
}
}
// convert position within translation rule to absolute position within
// source sentence / output sentence
ShiftOffsets(sourceOffsets, startSource);
ShiftOffsets(targetOffsets, startTarget);
// get alignments from this hypo
const AlignmentInfo &aiTerm = shyperedge.translation->GetAlignTerm();
// add to output arg, offsetting by source & target
AlignmentInfo::const_iterator iter;
for (iter = aiTerm.begin(); iter != aiTerm.end(); ++iter) {
const std::pair<size_t,size_t> &align = *iter;
size_t relSource = align.first;
size_t relTarget = align.second;
size_t absSource = sourceOffsets[relSource];
size_t absTarget = targetOffsets[relTarget];
std::pair<size_t, size_t> alignPoint(absSource, absTarget);
std::pair<Alignments::iterator, bool> ret = retAlign.insert(alignPoint);
UTIL_THROW_IF2(!ret.second, "Error");
}
return totalTargetSize;
}
template<typename Parser>
size_t Manager<Parser>::CalcSourceSize(const Syntax::KBestExtractor::Derivation &d) const
{
using namespace Moses::Syntax;
const Syntax::SHyperedge &shyperedge = d.edge->shyperedge;
size_t ret = shyperedge.head->pvertex->span.GetNumWordsCovered();
for (size_t i = 0; i < shyperedge.tail.size(); ++i) {
size_t childSize = shyperedge.tail[i]->pvertex->span.GetNumWordsCovered();
ret -= (childSize - 1);
}
return ret;
}
} // S2T

View File

@ -1,13 +1,15 @@
#pragma once
#include <set>
#include <vector>
#include <boost/shared_ptr.hpp>
#include "moses/InputType.h"
#include "moses/BaseManager.h"
#include "moses/Syntax/KBestExtractor.h"
#include "moses/Syntax/Manager.h"
#include "moses/Syntax/SVertexStack.h"
#include "moses/Word.h"
#include "OovHandler.h"
#include "ParserCallback.h"
@ -19,14 +21,13 @@ namespace Moses
namespace Syntax
{
class SDerivation;
struct SHyperedge;
namespace S2T
{
template<typename Parser>
class Manager : public BaseManager
class Manager : public Syntax::Manager
{
public:
Manager(const InputType &);
@ -41,25 +42,7 @@ class Manager : public BaseManager
std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList,
bool onlyDistinct=false) const;
const std::set<Word> &GetUnknownWords() const { return m_oovs; }
void OutputNBest(OutputCollector *collector) const;
void OutputLatticeSamples(OutputCollector *collector) const
{}
void OutputAlignment(OutputCollector *collector) const
{}
void OutputDetailedTranslationReport(OutputCollector *collector) const;
void OutputUnknowns(OutputCollector *collector) const;
void OutputDetailedTreeFragmentsTranslationReport(OutputCollector *collector) const
{}
void OutputWordGraph(OutputCollector *collector) const
{}
void OutputSearchGraph(OutputCollector *collector) const
{}
void OutputSearchGraphSLF() const
{}
void OutputSearchGraphHypergraph() const
{}
private:
void FindOovs(const PChart &, std::set<Word> &, std::size_t);
@ -74,19 +57,8 @@ class Manager : public BaseManager
PChart m_pchart;
SChart m_schart;
std::set<Word> m_oovs;
boost::shared_ptr<typename Parser::RuleTrie> m_oovRuleTrie;
std::vector<boost::shared_ptr<Parser> > m_parsers;
// output
void OutputNBestList(OutputCollector *collector,
const Moses::Syntax::KBestExtractor::KBestVec &nBestList,
long translationId) const;
std::size_t OutputAlignmentNBest(Alignments &retAlign,
const Moses::Syntax::KBestExtractor::Derivation &derivation,
std::size_t startTarget) const;
size_t CalcSourceSize(const Syntax::KBestExtractor::Derivation &d) const;
};
} // S2T

View File

@ -4,6 +4,7 @@
#include <boost/shared_ptr.hpp>
#include "moses/Phrase.h"
#include "moses/Syntax/RuleTableFF.h"
#include "moses/TargetPhrase.h"
#include "moses/Word.h"