add OutputLatticeSamples() to BaseManager. Move OutputAlignment() to ChartManager

This commit is contained in:
Hieu Hoang 2014-12-03 12:05:35 +00:00
parent 5bbd30ec12
commit fb25616bdd
7 changed files with 128 additions and 3 deletions

View File

@ -27,6 +27,8 @@ protected:
public:
// outputs
virtual void OutputNBest(OutputCollector *collector) const = 0;
virtual void OutputLatticeSamples(OutputCollector *collector) const = 0;
};

View File

@ -470,4 +470,106 @@ size_t ChartManager::OutputAlignmentNBest(
return totalTargetSize;
}
void ChartManager::OutputAlignment(OutputCollector *collector) const
{
if (collector == NULL) {
return;
}
ostringstream out;
const ChartHypothesis *hypo = GetBestHypothesis();
if (hypo) {
Alignments retAlign;
OutputAlignment(retAlign, hypo, 0);
// output alignments
Alignments::const_iterator iter;
for (iter = retAlign.begin(); iter != retAlign.end(); ++iter) {
const pair<size_t, size_t> &alignPoint = *iter;
out << alignPoint.first << "-" << alignPoint.second << " ";
}
}
out << endl;
collector->Write(m_source.GetTranslationId(), out.str());
}
size_t ChartManager::OutputAlignment(Alignments &retAlign,
const Moses::ChartHypothesis *hypo,
size_t startTarget) const
{
size_t totalTargetSize = 0;
size_t startSource = hypo->GetCurrSourceRange().GetStartPos();
const TargetPhrase &tp = hypo->GetCurrTargetPhrase();
size_t thisSourceSize = CalcSourceSize(hypo);
// position of each terminal word in translation rule, irrespective of alignment
// if non-term, number is undefined
vector<size_t> sourceOffsets(thisSourceSize, 0);
vector<size_t> targetOffsets(tp.GetSize(), 0);
const vector<const ChartHypothesis*> &prevHypos = hypo->GetPrevHypos();
const AlignmentInfo &aiNonTerm = hypo->GetCurrTargetPhrase().GetAlignNonTerm();
vector<size_t> sourceInd2pos = aiNonTerm.GetSourceIndex2PosMap();
const AlignmentInfo::NonTermIndexMap &targetPos2SourceInd = aiNonTerm.GetNonTermIndexMap();
UTIL_THROW_IF2(sourceInd2pos.size() != prevHypos.size(), "Error");
size_t targetInd = 0;
for (size_t targetPos = 0; targetPos < tp.GetSize(); ++targetPos) {
if (tp.GetWord(targetPos).IsNonTerminal()) {
UTIL_THROW_IF2(targetPos >= targetPos2SourceInd.size(), "Error");
size_t sourceInd = targetPos2SourceInd[targetPos];
size_t sourcePos = sourceInd2pos[sourceInd];
const ChartHypothesis *prevHypo = prevHypos[sourceInd];
// calc source size
size_t sourceSize = prevHypo->GetCurrSourceRange().GetNumWordsCovered();
sourceOffsets[sourcePos] = sourceSize;
// calc target size.
// Recursively look thru child hypos
size_t currStartTarget = startTarget + totalTargetSize;
size_t targetSize = OutputAlignment(retAlign, prevHypo, currStartTarget);
targetOffsets[targetPos] = targetSize;
totalTargetSize += targetSize;
++targetInd;
} else {
++totalTargetSize;
}
}
// convert position within translation rule to absolute position within
// source sentence / output sentence
ShiftOffsets(sourceOffsets, startSource);
ShiftOffsets(targetOffsets, startTarget);
// get alignments from this hypo
const AlignmentInfo &aiTerm = hypo->GetCurrTargetPhrase().GetAlignTerm();
// add to output arg, offsetting by source & target
AlignmentInfo::const_iterator iter;
for (iter = aiTerm.begin(); iter != aiTerm.end(); ++iter) {
const std::pair<size_t,size_t> &align = *iter;
size_t relSource = align.first;
size_t relTarget = align.second;
size_t absSource = sourceOffsets[relSource];
size_t absTarget = targetOffsets[relTarget];
pair<size_t, size_t> alignPoint(absSource, absTarget);
pair<Alignments::iterator, bool> ret = retAlign.insert(alignPoint);
UTIL_THROW_IF2(!ret.second, "Error");
}
return totalTargetSize;
}
} // namespace Moses

View File

@ -72,6 +72,9 @@ private:
size_t OutputAlignmentNBest(Alignments &retAlign,
const Moses::ChartKBestExtractor::Derivation &derivation,
size_t startTarget) const;
size_t OutputAlignment(Alignments &retAlign,
const Moses::ChartHypothesis *hypo,
size_t startTarget) const;
template <class T>
void ShiftOffsets(std::vector<T> &offsets, T shift) const
@ -137,6 +140,10 @@ public:
// outputs
void OutputNBest(OutputCollector *collector) const;
void OutputLatticeSamples(OutputCollector *collector) const
{}
void OutputAlignment(OutputCollector *collector) const;
};
}

View File

@ -60,6 +60,9 @@ private:
// outputs
void OutputNBestList(OutputCollector *collector, const std::vector<search::Applied> &nbest, long translationId) const;
void OutputLatticeSamples(OutputCollector *collector) const
{}
};
// Just get the phrase.

View File

@ -1678,4 +1678,12 @@ void Manager::OutputLatticeSamples(OutputCollector *collector) const
}
void Manager::OutputAlignment(OutputCollector *collector) const
{
}
void Manager::OutputAlignment(OutputCollector* collector, size_t lineNo , const Hypothesis *hypo) const
{
}
}

View File

@ -140,6 +140,7 @@ protected:
void OutputInput(std::ostream& os, const Hypothesis* hypo) const;
void OutputInput(std::vector<const Phrase*>& map, const Hypothesis* hypo) const;
std::map<size_t, const Factor*> GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor) const;
void OutputAlignment(OutputCollector* collector, size_t lineNo , const Hypothesis *hypo) const;
public:
InputType const& m_source; /**< source sentence to be translated */
@ -186,7 +187,9 @@ public:
// outputs
void OutputNBest(OutputCollector *collector) const;
void OutputAlignment(OutputCollector *collector) const;
void OutputLatticeSamples(OutputCollector *collector) const;
};
}

View File

@ -194,6 +194,8 @@ void TranslationTask::RunPb()
}
m_ioWrapper.OutputAlignment(m_ioWrapper.GetAlignmentInfoCollector(), m_source->GetTranslationId(), bestHypo);
manager.OutputAlignment(m_ioWrapper.GetAlignmentInfoCollector());
IFVERBOSE(1) {
debug << "BEST TRANSLATION: " << *bestHypo << endl;
}
@ -372,9 +374,7 @@ void TranslationTask::RunChart()
PrintUserTime("Best Hypothesis Generation Time:");
}
if (!staticData.GetAlignmentOutputFile().empty()) {
m_ioWrapper.OutputAlignment(translationId, bestHypo);
}
manager.OutputAlignment(m_ioWrapper.GetAlignmentInfoCollector());
if (staticData.IsDetailedTranslationReportingEnabled()) {
const Sentence &sentence = dynamic_cast<const Sentence &>(*m_source);