add Decode to API framework

This commit is contained in:
Hieu Hoang 2014-12-05 17:59:53 +00:00
parent 0d8e20980e
commit 23ca29a2ea
13 changed files with 71 additions and 58 deletions

View File

@ -283,7 +283,7 @@ public:
stringstream in(source + "\n");
tinput.Read(in,inputFactorOrder);
ChartManager manager(tinput);
manager.ProcessSentence();
manager.Decode();
const ChartHypothesis *hypo = manager.GetBestHypothesis();
outputChartHypo(out,hypo);
if (addGraphInfo) {
@ -302,7 +302,7 @@ public:
stringstream in(source + "\n");
sentence.Read(in,inputFactorOrder);
Manager manager(sentence, staticData.GetSearchAlgorithm());
manager.ProcessSentence();
manager.Decode();
const Hypothesis* hypo = manager.GetBestHypothesis();
vector<xmlrpc_c::value> alignInfo;

View File

@ -144,7 +144,7 @@ vector< vector<const Word*> > MosesDecoder::runDecoder(const std::string& source
{
// run the decoder
m_manager = new Moses::Manager(*m_sentence, search);
m_manager->ProcessSentence();
m_manager->Decode();
TrellisPathList nBestList;
m_manager->CalcNBest(nBestSize, nBestList, distinct);
@ -221,7 +221,7 @@ vector< vector<const Word*> > MosesDecoder::runChartDecoder(const std::string& s
{
// run the decoder
m_chartManager = new ChartManager(*m_sentence);
m_chartManager->ProcessSentence();
m_chartManager->Decode();
ChartKBestExtractor::KBestVec nBestList;
m_chartManager->CalcNBest(nBestSize, nBestList, distinct);

View File

@ -182,7 +182,7 @@ int main(int argc, char* argv[])
source->SetTranslationId(lineCount);
Manager manager(*source, staticData.GetSearchAlgorithm());
manager.ProcessSentence();
manager.Decode();
TrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,true);
//grid search

View File

@ -46,6 +46,7 @@ protected:
}
public:
virtual void Decode() = 0;
// outputs
virtual void OutputNBest(OutputCollector *collector) const = 0;
virtual void OutputLatticeSamples(OutputCollector *collector) const = 0;

View File

@ -64,7 +64,7 @@ ChartManager::~ChartManager()
}
//! decode the sentence. This contains the main laps. Basically, the CKY++ algorithm
void ChartManager::ProcessSentence()
void ChartManager::Decode()
{
VERBOSE(1,"Translating: " << m_source << endl);
@ -597,6 +597,18 @@ void ChartManager::OutputDetailedTranslationReport(
OutputTranslationOptions(out, applicationContext, hypo, sentence, translationId);
collector->Write(translationId, out.str());
//DIMw
const StaticData &staticData = StaticData::Instance();
if (staticData.IsDetailedAllTranslationReportingEnabled()) {
const Sentence &sentence = dynamic_cast<const Sentence &>(m_source);
size_t nBestSize = staticData.GetNBestSize();
std::vector<boost::shared_ptr<ChartKBestExtractor::Derivation> > nBestList;
CalcNBest(nBestSize, nBestList, staticData.GetDistinctNBest());
OutputDetailedAllTranslationReport(collector, nBestList, sentence, translationId);
}
}
void ChartManager::OutputTranslationOptions(std::ostream &out,
@ -753,4 +765,36 @@ void ChartManager::OutputSearchGraph(OutputCollector *collector) const
}
}
//DIMw
void ChartManager::OutputDetailedAllTranslationReport(
OutputCollector *collector,
const std::vector<boost::shared_ptr<Moses::ChartKBestExtractor::Derivation> > &nBestList,
const Sentence &sentence,
long translationId) const
{
std::ostringstream out;
ApplicationContext applicationContext;
const ChartCellCollection& cells = GetChartCellCollection();
size_t size = GetSource().GetSize();
for (size_t width = 1; width <= size; ++width) {
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
size_t endPos = startPos + width - 1;
WordsRange range(startPos, endPos);
const ChartCell& cell = cells.Get(range);
const HypoList* hyps = cell.GetAllSortedHypotheses();
out << "Chart Cell [" << startPos << ".." << endPos << "]" << endl;
HypoList::const_iterator iter;
size_t c = 1;
for (iter = hyps->begin(); iter != hyps->end(); ++iter) {
out << "----------------Item " << c++ << " ---------------------"
<< endl;
OutputTranslationOptions(out, applicationContext, *iter,
sentence, translationId);
}
}
}
collector->Write(translationId, out.str());
}
} // namespace Moses

View File

@ -96,11 +96,16 @@ private:
const ChartHypothesis *hypo,
const Sentence &sentence,
long translationId) const;
void OutputDetailedAllTranslationReport(
OutputCollector *collector,
const std::vector<boost::shared_ptr<Moses::ChartKBestExtractor::Derivation> > &nBestList,
const Sentence &sentence,
long translationId) const;
public:
ChartManager(InputType const& source);
~ChartManager();
void ProcessSentence();
void Decode();
void AddXmlChartOptions();
const ChartHypothesis *GetBestHypothesis() const;
void CalcNBest(size_t n, std::vector<boost::shared_ptr<ChartKBestExtractor::Derivation> > &nBestList, bool onlyDistinct=false) const;

View File

@ -533,39 +533,6 @@ void IOWrapper::OutputSurface(std::ostream &out, const Phrase &phrase, const std
//DIMw
void IOWrapper::OutputDetailedAllTranslationReport(
const std::vector<boost::shared_ptr<Moses::ChartKBestExtractor::Derivation> > &nBestList,
const ChartManager &manager,
const Sentence &sentence,
long translationId)
{
std::ostringstream out;
ApplicationContext applicationContext;
const ChartCellCollection& cells = manager.GetChartCellCollection();
size_t size = manager.GetSource().GetSize();
for (size_t width = 1; width <= size; ++width) {
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
size_t endPos = startPos + width - 1;
WordsRange range(startPos, endPos);
const ChartCell& cell = cells.Get(range);
const HypoList* hyps = cell.GetAllSortedHypotheses();
out << "Chart Cell [" << startPos << ".." << endPos << "]" << endl;
HypoList::const_iterator iter;
size_t c = 1;
for (iter = hyps->begin(); iter != hyps->end(); ++iter) {
out << "----------------Item " << c++ << " ---------------------"
<< endl;
OutputTranslationOptions(out, applicationContext, *iter,
sentence, translationId);
}
}
}
UTIL_THROW_IF2(m_detailedTranslationCollector == NULL,
"No output file for details specified");
m_detailedTranslationCollector->Write(translationId, out.str());
}
//////////////////////////////////////////////////////////////////////////
/***

View File

@ -177,8 +177,6 @@ public:
void OutputBestNone(long translationId);
void OutputDetailedAllTranslationReport(const std::vector<boost::shared_ptr<Moses::ChartKBestExtractor::Derivation> > &nBestList, const Moses::ChartManager &manager, const Moses::Sentence &sentence, long translationId);
// phrase-based
void OutputBestSurface(std::ostream &out, const Moses::Hypothesis *hypo, const std::vector<Moses::FactorType> &outputFactorOrder, char reportSegmentation, bool reportAllFactors);
void OutputLatticeMBRNBest(std::ostream& out, const std::vector<LatticeMBRSolution>& solutions,long translationId);

View File

@ -273,9 +273,13 @@ template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::Qu
template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
const std::vector<search::Applied> &Manager::ProcessSentence()
void Manager::Decode()
{
LanguageModel::GetFirstLM().IncrementalCallback(*this);
}
const std::vector<search::Applied> &Manager::GetNBest() const
{
return *completed_nbest_;
}

View File

@ -30,7 +30,9 @@ public:
template <class Model> void LMCallback(const Model &model, const std::vector<lm::WordIndex> &words);
const std::vector<search::Applied> &ProcessSentence();
void Decode();
const std::vector<search::Applied> &GetNBest() const;
// Call to get the same value as ProcessSentence returned.
const std::vector<search::Applied> &Completed() const {

View File

@ -79,7 +79,7 @@ Manager::~Manager()
* Main decoder loop that translates a sentence by expanding
* hypotheses stack by stack, until the end of the sentence.
*/
void Manager::ProcessSentence()
void Manager::Decode()
{
// initialize statistics
ResetSentenceStats(m_source);

View File

@ -151,7 +151,7 @@ public:
~Manager();
const TranslationOptionCollection* getSntTranslationOptions();
void ProcessSentence();
void Decode();
const Hypothesis *GetBestHypothesis() const;
const Hypothesis *GetActualBestHypothesis() const;
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;

View File

@ -83,7 +83,7 @@ void TranslationTask::RunPb()
initTime.start();
Manager manager(*m_source,staticData.GetSearchAlgorithm());
VERBOSE(1, "Line " << m_source->GetTranslationId() << ": Initialize search took " << initTime << " seconds total" << endl);
manager.ProcessSentence();
manager.Decode();
// we are done with search, let's look what we got
Timer additionalReportingTime;
@ -301,7 +301,8 @@ void TranslationTask::RunChart()
if (staticData.GetSearchAlgorithm() == ChartIncremental) {
Incremental::Manager manager(*m_source);
const std::vector<search::Applied> &nbest = manager.ProcessSentence();
manager.Decode();
const std::vector<search::Applied> &nbest = manager.GetNBest();
if (!nbest.empty()) {
m_ioWrapper.OutputBestHypo(nbest[0], translationId);
@ -318,7 +319,7 @@ void TranslationTask::RunChart()
}
ChartManager manager(*m_source);
manager.ProcessSentence();
manager.Decode();
UTIL_THROW_IF2(staticData.UseMBR(), "Cannot use MBR");
@ -340,15 +341,6 @@ void TranslationTask::RunChart()
manager.OutputDetailedTreeFragmentsTranslationReport(m_ioWrapper.GetDetailTreeFragmentsOutputCollector());
manager.OutputUnknowns(m_ioWrapper.GetUnknownsCollector());
//DIMw
if (staticData.IsDetailedAllTranslationReportingEnabled()) {
const Sentence &sentence = dynamic_cast<const Sentence &>(*m_source);
size_t nBestSize = staticData.GetNBestSize();
std::vector<boost::shared_ptr<ChartKBestExtractor::Derivation> > nBestList;
manager.CalcNBest(nBestSize, nBestList, staticData.GetDistinctNBest());
m_ioWrapper.OutputDetailedAllTranslationReport(nBestList, manager, sentence, translationId);
}
// n-best
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());