It works!

This commit is contained in:
Kenneth Heafield 2012-10-12 14:25:39 +01:00
parent e75b51bb73
commit e33e0ffd61
2 changed files with 42 additions and 34 deletions

View File

@ -91,7 +91,7 @@ public:
Incremental::Manager manager(*m_source, system);
manager.ProcessSentence();
if (m_ioWrapper.ExposeSingleBest()) {
m_ioWrapper.ExposeSingleBest()->Write(lineNumber, manager.String());
m_ioWrapper.ExposeSingleBest()->Write(lineNumber, manager.String() + '\n');
}
return;
}

View File

@ -35,6 +35,46 @@ Manager::~Manager() {
system_.CleanUpAfterSentenceProcessing(source_);
}
namespace {
void ConstructString(const search::Final &final, std::ostringstream &stream) {
const TargetPhrase &phrase = static_cast<const Edge&>(final.From()).GetMoses();
size_t child = 0;
for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
const Word &word = phrase.GetWord(i);
if (word.IsNonTerminal()) {
ConstructString(*final.Children()[child++], stream);
} else {
stream << word[0]->GetString() << ' ';
}
}
}
void BestString(const ChartCellLabelSet &labels, std::string &out) {
const search::Final *best = NULL;
for (ChartCellLabelSet::const_iterator i = labels.begin(); i != labels.end(); ++i) {
const search::Final *child = i->second.GetStack().incr->BestChild();
if (child && (!best || (child->Bound() > best->Bound()))) {
best = child;
}
}
if (!best) {
out.clear();
return;
}
std::ostringstream stream;
ConstructString(*best, stream);
out = stream.str();
CHECK(out.size() > 9);
// <s>
out.erase(0, 4);
// </s>
out.erase(out.size() - 5);
}
} // namespace
template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words) {
const LanguageModel &abstract = **system_.GetLanguageModels().begin();
search::Weights weights(
@ -54,6 +94,7 @@ template <class Model> void Manager::LMCallback(const Model &model, const std::v
filler.Search(cells_.MutableBase(range).MutableTargetLabelSet());
}
}
BestString(cells_.GetBase(WordsRange(0, source_.GetSize() - 1)).GetTargetLabelSet(), output_);
}
template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
@ -63,43 +104,10 @@ template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::Qu
template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
namespace {
void ConstructString(const search::Final &final, std::ostringstream &stream) {
const TargetPhrase &phrase = static_cast<const Edge&>(final.From()).GetMoses();
size_t child = 0;
for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
const Word &word = phrase.GetWord(i);
if (word.IsNonTerminal()) {
ConstructString(*final.Children()[child++], stream);
} else {
stream << word[0]->GetString() << ' ';
}
}
}
} // namespace
void Manager::ProcessSentence() {
const LMList &lms = system_.GetLanguageModels();
UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model.");
(*lms.begin())->IncrementalCallback(*this);
const ChartCellLabelSet &labels = cells_.GetBase(WordsRange(0, source_.GetSize() - 1)).GetTargetLabelSet();
const search::Final *best = NULL;
for (ChartCellLabelSet::const_iterator i = labels.begin(); i != labels.end(); ++i) {
const search::Final *child = i->second.GetStack().incr->BestChild();
if (child && (!best || (child->Bound() > best->Bound()))) {
best = child;
}
}
if (!best) {
output_.clear();
return;
}
std::ostringstream stream;
ConstructString(*best, stream);
output_ = stream.str();
}
} // namespace Incremental