mosesdecoder/moses/Incremental.cpp

336 lines
11 KiB
C++
Raw Normal View History

2013-05-27 18:54:50 +04:00
#include <stdexcept>
#include "moses/Incremental.h"
2012-11-12 23:56:18 +04:00
#include "moses/ChartCell.h"
2012-11-15 22:04:07 +04:00
#include "moses/ChartParserCallback.h"
#include "moses/FeatureVector.h"
2012-11-12 23:56:18 +04:00
#include "moses/StaticData.h"
#include "moses/Util.h"
2013-05-27 18:54:50 +04:00
#include "moses/LM/Base.h"
2012-11-15 22:04:07 +04:00
#include "lm/model.hh"
#include "search/applied.hh"
#include "search/config.hh"
2012-11-15 22:04:07 +04:00
#include "search/context.hh"
#include "search/edge_generator.hh"
#include "search/rule.hh"
#include "search/vertex_generator.hh"
2012-10-12 17:38:07 +04:00
#include <boost/lexical_cast.hpp>
2013-05-29 21:16:15 +04:00
namespace Moses
{
namespace Incremental
{
namespace
{
2012-11-15 22:04:07 +04:00
// This is called by EdgeGenerator. Route hypotheses to separate vertices for
2013-05-29 21:16:15 +04:00
// each left hand side label, populating ChartCellLabelSet out.
template <class Best> class HypothesisCallback
{
private:
typedef search::VertexGenerator<Best> Gen;
public:
HypothesisCallback(search::ContextBase &context, Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool)
: context_(context), best_(best), out_(out), vertex_pool_(vertex_pool) {}
void NewHypothesis(search::PartialEdge partial) {
// Get the LHS, look it up in the output ChartCellLabel, and upcast it.
// It's not part of the union because it would have been ugly to expose template types in ChartCellLabel.
ChartCellLabel::Stack &stack = out_.FindOrInsert(static_cast<const TargetPhrase *>(partial.GetNote().vp)->GetTargetLHS());
Gen *entry = static_cast<Gen*>(stack.incr_generator);
if (!entry) {
entry = generator_pool_.construct(boost::ref(context_), boost::ref(*vertex_pool_.construct()), boost::ref(best_));
2013-05-29 21:16:15 +04:00
stack.incr_generator = entry;
2012-11-15 22:04:07 +04:00
}
2013-05-29 21:16:15 +04:00
entry->NewHypothesis(partial);
}
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
void FinishedSearch() {
for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
if ((*i) == NULL) {
continue;
}
ChartCellLabel::Stack &stack = (*i)->MutableStack();
2013-05-29 21:16:15 +04:00
Gen *gen = static_cast<Gen*>(stack.incr_generator);
gen->FinishedSearch();
stack.incr = &gen->Generating();
2012-11-15 22:04:07 +04:00
}
2013-05-29 21:16:15 +04:00
}
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
private:
search::ContextBase &context_;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
Best &best_;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
ChartCellLabelSet &out_;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
boost::object_pool<search::Vertex> &vertex_pool_;
boost::object_pool<Gen> generator_pool_;
2012-11-15 22:04:07 +04:00
};
// This is called by the moses parser to collect hypotheses. It converts to my
2013-05-29 21:16:15 +04:00
// edges (search::PartialEdge).
template <class Model> class Fill : public ChartParserCallback
{
public:
Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping, search::Score oov_weight)
: context_(context), vocab_mapping_(vocab_mapping), oov_weight_(oov_weight) {}
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored);
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
2012-11-15 22:04:07 +04:00
2014-03-26 15:23:23 +04:00
float GetBestScore(const ChartCellLabel *chartCell) const;
2013-05-29 21:16:15 +04:00
bool Empty() const {
return edges_.Empty();
}
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
template <class Best> void Search(Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool) {
HypothesisCallback<Best> callback(context_, best, out, vertex_pool);
edges_.Search(context_, callback);
}
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
// Root: everything into one vertex.
template <class Best> search::History RootSearch(Best &best) {
search::Vertex vertex;
search::RootVertexGenerator<Best> gen(vertex, best);
edges_.Search(context_, gen);
return vertex.BestChild();
}
2013-08-16 00:14:04 +04:00
void Evaluate(const InputType &input, const InputPath &inputPath) {
// TODO for input lattice
}
2013-05-29 21:16:15 +04:00
private:
lm::WordIndex Convert(const Word &word) const;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
search::Context<Model> &context_;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
const std::vector<lm::WordIndex> &vocab_mapping_;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
search::EdgeGenerator edges_;
2012-11-15 22:04:07 +04:00
2013-05-29 21:16:15 +04:00
const search::Score oov_weight_;
2012-11-15 22:04:07 +04:00
};
template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &range)
2013-05-29 21:16:15 +04:00
{
2012-11-15 22:04:07 +04:00
std::vector<search::PartialVertex> vertices;
vertices.reserve(nts.size());
float below_score = 0.0;
for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
vertices.push_back((*i)->GetStack().incr->RootAlternate());
2014-03-26 15:23:23 +04:00
below_score += (*i)->GetBestScore(this);
2012-11-15 22:04:07 +04:00
}
std::vector<lm::WordIndex> words;
for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) {
words.clear();
const TargetPhrase &phrase = **p;
const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap();
search::PartialEdge edge(edges_.AllocateEdge(nts.size()));
search::PartialVertex *nt = edge.NT();
for (size_t i = 0; i < phrase.GetSize(); ++i) {
const Word &word = phrase.GetWord(i);
if (word.IsNonTerminal()) {
*(nt++) = vertices[align[i]];
words.push_back(search::kNonTerminal);
} else {
words.push_back(Convert(word));
}
}
edge.SetScore(phrase.GetFutureScore() + below_score);
2013-05-29 21:16:15 +04:00
// prob and oov were already accounted for.
2012-11-15 22:04:07 +04:00
search::ScoreRule(context_.LanguageModel(), words, edge.Between());
search::Note note;
note.vp = &phrase;
edge.SetNote(note);
edge.SetRange(range);
2012-11-15 22:04:07 +04:00
edges_.AddEdge(edge);
}
}
template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &, const WordsRange &range)
2013-05-29 21:16:15 +04:00
{
2012-11-15 22:04:07 +04:00
std::vector<lm::WordIndex> words;
2013-11-23 00:27:46 +04:00
UTIL_THROW_IF2(phrase.GetSize() > 1,
"OOV target phrase should be 0 or 1 word in length");
2012-11-15 22:04:07 +04:00
if (phrase.GetSize())
words.push_back(Convert(phrase.GetWord(0)));
search::PartialEdge edge(edges_.AllocateEdge(0));
2013-05-29 21:16:15 +04:00
// Appears to be a bug that FutureScore does not already include language model.
2012-11-15 22:04:07 +04:00
search::ScoreRuleRet scored(search::ScoreRule(context_.LanguageModel(), words, edge.Between()));
edge.SetScore(phrase.GetFutureScore() + scored.prob * context_.LMWeight() + static_cast<search::Score>(scored.oov) * oov_weight_);
search::Note note;
note.vp = &phrase;
edge.SetNote(note);
edge.SetRange(range);
2012-11-15 22:04:07 +04:00
edges_.AddEdge(edge);
}
2014-03-26 15:23:23 +04:00
// for pruning
template <class Model> float Fill<Model>::GetBestScore(const ChartCellLabel *chartCell) const
{
2014-03-26 15:23:23 +04:00
search::PartialVertex vertex = chartCell->GetStack().incr->RootAlternate();
UTIL_THROW_IF2(vertex.Empty(), "hypothesis with empty stack");
return vertex.Bound();
}
2013-05-29 21:16:15 +04:00
// TODO: factors (but chart doesn't seem to support factors anyway).
template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const
{
2012-11-15 22:04:07 +04:00
std::size_t factor = word.GetFactor(0)->GetId();
return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]);
}
struct ChartCellBaseFactory {
ChartCellBase *operator()(size_t startPos, size_t endPos) const {
return new ChartCellBase(startPos, endPos);
}
};
2012-11-15 22:04:07 +04:00
} // namespace
2013-05-11 17:13:26 +04:00
Manager::Manager(const InputType &source) :
source_(source),
cells_(source, ChartCellBaseFactory()),
2013-05-11 17:13:26 +04:00
parser_(source, cells_),
n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize())) {}
2013-05-29 21:16:15 +04:00
Manager::~Manager()
{
}
2013-05-29 21:16:15 +04:00
template <class Model, class Best> search::History Manager::PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out)
{
const LanguageModel &abstract = LanguageModel::GetFirstLM();
2012-11-15 22:04:07 +04:00
const float oov_weight = abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0;
const StaticData &data = StaticData::Instance();
search::Config config(abstract.GetWeight() * M_LN10, data.GetCubePruningPopLimit(), search::NBestConfig(data.GetNBestSize()));
search::Context<Model> context(config, model);
2012-11-15 22:04:07 +04:00
size_t size = source_.GetSize();
2012-10-15 14:43:43 +04:00
boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
2013-05-29 21:16:15 +04:00
for (int startPos = size-1; startPos >= 0; --startPos) {
for (size_t width = 1; width <= size-startPos; ++width) {
// full range uses RootSearch
if (startPos == 0 && startPos + width == size) {
break;
}
WordsRange range(startPos, startPos + width - 1);
2012-11-15 22:04:07 +04:00
Fill<Model> filler(context, words, oov_weight);
2012-10-12 16:53:08 +04:00
parser_.Create(range, filler);
filler.Search(out, cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool);
}
}
WordsRange range(0, size - 1);
Fill<Model> filler(context, words, oov_weight);
parser_.Create(range, filler);
return filler.RootSearch(out);
}
2013-05-29 21:16:15 +04:00
template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words)
{
std::size_t nbest = StaticData::Instance().GetNBestSize();
if (nbest <= 1) {
search::History ret = PopulateBest(model, words, single_best_);
if (ret) {
backing_for_single_.resize(1);
backing_for_single_[0] = search::Applied(ret);
} else {
backing_for_single_.clear();
}
completed_nbest_ = &backing_for_single_;
} else {
search::History ret = PopulateBest(model, words, n_best_);
if (ret) {
completed_nbest_ = &n_best_.Extract(ret);
} else {
backing_for_single_.clear();
completed_nbest_ = &backing_for_single_;
}
}
}
2012-10-12 16:53:08 +04:00
template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::RestProbingModel>(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::TrieModel>(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
2013-05-29 21:16:15 +04:00
const std::vector<search::Applied> &Manager::ProcessSentence()
{
LanguageModel::GetFirstLM().IncrementalCallback(*this);
return *completed_nbest_;
}
2013-05-29 21:16:15 +04:00
namespace
{
struct NoOp {
void operator()(const TargetPhrase &) const {}
};
struct AccumScore {
AccumScore(ScoreComponentCollection &out) : out_(&out) {}
void operator()(const TargetPhrase &phrase) {
out_->PlusEquals(phrase.GetScoreBreakdown());
}
ScoreComponentCollection *out_;
};
2013-05-29 21:16:15 +04:00
template <class Action> void AppendToPhrase(const search::Applied final, Phrase &out, Action action)
{
assert(final.Valid());
const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(final.GetNote().vp);
action(phrase);
const search::Applied *child = final.Children();
for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
const Word &word = phrase.GetWord(i);
if (word.IsNonTerminal()) {
AppendToPhrase(*child++, out, action);
} else {
out.AddWord(word);
}
}
}
} // namespace
2013-05-29 21:16:15 +04:00
void ToPhrase(const search::Applied final, Phrase &out)
{
out.Clear();
AppendToPhrase(final, out, NoOp());
}
2013-05-29 21:16:15 +04:00
void PhraseAndFeatures(const search::Applied final, Phrase &phrase, ScoreComponentCollection &features)
{
phrase.Clear();
features.ZeroAll();
AppendToPhrase(final, phrase, AccumScore(features));
2013-05-29 21:16:15 +04:00
// If we made it this far, there is only one language model.
float full, ignored_ngram;
std::size_t ignored_oov;
const LanguageModel &model = LanguageModel::GetFirstLM();
model.CalcScore(phrase, full, ignored_ngram, ignored_oov);
// CalcScore transforms, but EvaluateWhenApplied doesn't.
2013-04-26 22:39:29 +04:00
features.Assign(&model, full);
}
} // namespace Incremental
} // namespace Moses