diff --git a/moses-chart-cmd/IOWrapper.cpp b/moses-chart-cmd/IOWrapper.cpp index 401636fa5..571f7e32e 100644 --- a/moses-chart-cmd/IOWrapper.cpp +++ b/moses-chart-cmd/IOWrapper.cpp @@ -40,6 +40,7 @@ POSSIBILITY OF SUCH DAMAGE. #include "moses/StaticData.h" #include "moses/DummyScoreProducers.h" #include "moses/InputFileStream.h" +#include "moses/Incremental.h" #include "moses/PhraseDictionary.h" #include "moses/ChartTrellisPathList.h" #include "moses/ChartTrellisPath.h" @@ -377,8 +378,136 @@ void IOWrapper::OutputBestHypo(const ChartHypothesis *hypo, long translationId) m_singleBestOutputCollector->Write(translationId, out.str()); } -void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const ChartHypothesis *bestHypo, const TranslationSystem* system, long translationId) -{ +void IOWrapper::OutputBestHypo(search::Applied applied, long translationId) { + if (!m_singleBestOutputCollector) return; + std::ostringstream out; + IOWrapper::FixPrecision(out); + if (StaticData::Instance().GetOutputHypoScore()) { + out << applied.GetScore() << ' '; + } + Phrase outPhrase; + Incremental::ToPhrase(applied, outPhrase); + // delete 1st & last + CHECK(outPhrase.GetSize() >= 2); + outPhrase.RemoveWord(0); + outPhrase.RemoveWord(outPhrase.GetSize() - 1); + out << outPhrase.GetStringRep(StaticData::Instance().GetOutputFactorOrder()); + out << '\n'; + m_singleBestOutputCollector->Write(translationId, out.str()); +} + +void IOWrapper::OutputBestNone(long translationId) { + if (!m_singleBestOutputCollector) return; + if (StaticData::Instance().GetOutputHypoScore()) { + m_singleBestOutputCollector->Write(translationId, "0 \n"); + } else { + m_singleBestOutputCollector->Write(translationId, "\n"); + } +} + +namespace { + +void OutputSparseFeatureScores(std::ostream& out, const ScoreComponentCollection &features, const FeatureFunction *ff, std::string &lastName) { + const StaticData &staticData = StaticData::Instance(); + bool labeledOutput = staticData.IsLabeledNBestList(); + const FVector scores = features.GetVectorForProducer( ff ); + + // report weighted aggregate + if (! ff->GetSparseFeatureReporting()) { + const FVector &weights = staticData.GetAllWeights().GetScoresVector(); + if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":")) + out << " " << ff->GetScoreProducerWeightShortName() << ":"; + out << " " << scores.inner_product(weights); + } + + // report each feature + else { + for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) { + if (i->second != 0) { // do not report zero-valued features + if (labeledOutput) + out << " " << i->first << ":"; + out << " " << i->second; + } + } + } +} + +void WriteFeatures(const TranslationSystem &system, const ScoreComponentCollection &features, std::ostream &out) { + bool labeledOutput = StaticData::Instance().IsLabeledNBestList(); + // lm + const LMList& lml = system.GetLanguageModels(); + if (lml.size() > 0) { + if (labeledOutput) + out << "lm:"; + LMList::const_iterator lmi = lml.begin(); + for (; lmi != lml.end(); ++lmi) { + out << " " << features.GetScoreForProducer(*lmi); + } + } + + std::string lastName = ""; + + // output stateful sparse features + const vector& sff = system.GetStatefulFeatureFunctions(); + for( size_t i=0; iGetNumScoreComponents() == ScoreProducer::unlimited) + OutputSparseFeatureScores(out, features, sff[i], lastName); + + // translation components + const vector& pds = system.GetPhraseDictionaries(); + if (pds.size() > 0) { + for( size_t i=0; iGetNumInputScores(); + vector scores = features.GetScoresForProducer( pds[i] ); + for (size_t j = 0; jGetScoreProducerWeightShortName(j); + out << " " << lastName << ":"; + } + } + out << " " << scores[j]; + } + } + } + + // word penalty + if (labeledOutput) + out << " w:"; + out << " " << features.GetScoreForProducer(system.GetWordPenaltyProducer()); + + // generation + const vector& gds = system.GetGenerationDictionaries(); + if (gds.size() > 0) { + for( size_t i=0; iGetNumInputScores(); + vector scores = features.GetScoresForProducer( gds[i] ); + for (size_t j = 0; jGetScoreProducerWeightShortName(j); + out << " " << lastName << ":"; + } + } + out << " " << scores[j]; + } + } + } + + // output stateless sparse features + lastName = ""; + + const vector& slf = system.GetStatelessFeatureFunctions(); + for( size_t i=0; iGetNumScoreComponents() == ScoreProducer::unlimited) { + OutputSparseFeatureScores(out, features, slf[i], lastName); + } + } +} + +} // namespace + +void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const TranslationSystem* system, long translationId) { std::ostringstream out; // Check if we're writing to std::cout. @@ -387,17 +516,10 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha // preserve existing behaviour, but should probably be done either way. IOWrapper::FixPrecision(out); - // The output from -output-hypo-score is always written to std::cout. - if (StaticData::Instance().GetOutputHypoScore()) { - if (bestHypo != NULL) { - out << bestHypo->GetTotalScore() << " "; - } else { - out << "0 "; - } - } + // Used to check StaticData's GetOutputHypoScore(), but it makes no sense with nbest output. } - bool labeledOutput = StaticData::Instance().IsLabeledNBestList(); + //bool includeAlignment = StaticData::Instance().NBestIncludesAlignment(); bool includeWordAlignment = StaticData::Instance().PrintAlignmentInfoInNbest(); ChartTrellisPathList::const_iterator iter; @@ -421,75 +543,7 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha // before each model type, the corresponding command-line-like name must be emitted // MERT script relies on this - // lm - const LMList& lml = system->GetLanguageModels(); - if (lml.size() > 0) { - if (labeledOutput) - out << "lm:"; - LMList::const_iterator lmi = lml.begin(); - for (; lmi != lml.end(); ++lmi) { - out << " " << path.GetScoreBreakdown().GetScoreForProducer(*lmi); - } - } - - std::string lastName = ""; - - // output stateful sparse features - const vector& sff = system->GetStatefulFeatureFunctions(); - for( size_t i=0; iGetNumScoreComponents() == ScoreProducer::unlimited) - OutputSparseFeatureScores( out, path, sff[i], lastName ); - - // translation components - const vector& pds = system->GetPhraseDictionaries(); - if (pds.size() > 0) { - for( size_t i=0; iGetNumInputScores(); - vector scores = path.GetScoreBreakdown().GetScoresForProducer( pds[i] ); - for (size_t j = 0; jGetScoreProducerWeightShortName(j); - out << " " << lastName << ":"; - } - } - out << " " << scores[j]; - } - } - } - - // word penalty - if (labeledOutput) - out << " w:"; - out << " " << path.GetScoreBreakdown().GetScoreForProducer(system->GetWordPenaltyProducer()); - - // generation - const vector& gds = system->GetGenerationDictionaries(); - if (gds.size() > 0) { - for( size_t i=0; iGetNumInputScores(); - vector scores = path.GetScoreBreakdown().GetScoresForProducer( gds[i] ); - for (size_t j = 0; jGetScoreProducerWeightShortName(j); - out << " " << lastName << ":"; - } - } - out << " " << scores[j]; - } - } - } - - // output stateless sparse features - lastName = ""; - - const vector& slf = system->GetStatelessFeatureFunctions(); - for( size_t i=0; iGetNumScoreComponents() == ScoreProducer::unlimited) { - OutputSparseFeatureScores( out, path, slf[i], lastName ); - } - } + WriteFeatures(*system, path.GetScoreBreakdown(), out); // total out << " ||| " << path.GetTotalScore(); @@ -524,34 +578,33 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha out <Write(translationId, out.str()); } -void IOWrapper::OutputSparseFeatureScores( std::ostream& out, const ChartTrellisPath &path, const FeatureFunction *ff, std::string &lastName ) -{ - const StaticData &staticData = StaticData::Instance(); - bool labeledOutput = staticData.IsLabeledNBestList(); - const FVector scores = path.GetScoreBreakdown().GetVectorForProducer( ff ); - - // report weighted aggregate - if (! ff->GetSparseFeatureReporting()) { - const FVector &weights = staticData.GetAllWeights().GetScoresVector(); - if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":")) - out << " " << ff->GetScoreProducerWeightShortName() << ":"; - out << " " << scores.inner_product(weights); +void IOWrapper::OutputNBestList(const std::vector &nbest, const TranslationSystem &system, long translationId) { + std::ostringstream out; + // wtf? copied from the original OutputNBestList + if (m_nBestOutputCollector->OutputIsCout()) { + IOWrapper::FixPrecision(out); } - - // report each feature - else { - for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) { - if (i->second != 0) { // do not report zero-valued features - if (labeledOutput) - out << " " << i->first << ":"; - out << " " << i->second; - } - } + Phrase outputPhrase; + ScoreComponentCollection features; + for (std::vector::const_iterator i = nbest.begin(); i != nbest.end(); ++i) { + Incremental::PhraseAndFeatures(system, *i, outputPhrase, features); + // and + CHECK(outputPhrase.GetSize() >= 2); + outputPhrase.RemoveWord(0); + outputPhrase.RemoveWord(outputPhrase.GetSize() - 1); + out << translationId << " ||| "; + OutputSurface(out, outputPhrase, m_outputFactorOrder, false); + out << " ||| "; + WriteFeatures(system, features, out); + out << " ||| " << i->GetScore() << '\n'; } + out << std::flush; + assert(m_nBestOutputCollector); + m_nBestOutputCollector->Write(translationId, out.str()); } void IOWrapper::FixPrecision(std::ostream &stream, size_t size) diff --git a/moses-chart-cmd/IOWrapper.h b/moses-chart-cmd/IOWrapper.h index dea7355d0..5686d5728 100644 --- a/moses-chart-cmd/IOWrapper.h +++ b/moses-chart-cmd/IOWrapper.h @@ -44,6 +44,7 @@ POSSIBILITY OF SUCH DAMAGE. #include "moses/OutputCollector.h" #include "moses/ChartHypothesis.h" #include "moses/ChartTrellisPath.h" +#include "search/applied.hh" namespace Moses { @@ -92,14 +93,14 @@ public: Moses::InputType* GetInput(Moses::InputType *inputType); void OutputBestHypo(const Moses::ChartHypothesis *hypo, long translationId); + void OutputBestHypo(search::Applied applied, long translationId); void OutputBestHypo(const std::vector& mbrBestHypo, long translationId); - void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::ChartHypothesis *bestHypo, const Moses::TranslationSystem* system, long translationId); - void OutputSparseFeatureScores(std::ostream& out, const Moses::ChartTrellisPath &path, const Moses::FeatureFunction *ff, std::string &lastName); + void OutputBestNone(long translationId); + void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::TranslationSystem* system, long translationId); + void OutputNBestList(const std::vector &nbest, const Moses::TranslationSystem &system, long translationId); void OutputDetailedTranslationReport(const Moses::ChartHypothesis *hypo, const Moses::Sentence &sentence, long translationId); void Backtrack(const Moses::ChartHypothesis *hypo); - Moses::OutputCollector *ExposeSingleBest() { return m_singleBestOutputCollector; } - void ResetTranslationId(); Moses::OutputCollector *GetSearchGraphOutputCollector() { diff --git a/moses-chart-cmd/Main.cpp b/moses-chart-cmd/Main.cpp index ee8099e3f..278783926 100644 --- a/moses-chart-cmd/Main.cpp +++ b/moses-chart-cmd/Main.cpp @@ -59,7 +59,7 @@ POSSIBILITY OF SUCH DAMAGE. #include "moses/ChartHypothesis.h" #include "moses/ChartTrellisPath.h" #include "moses/ChartTrellisPathList.h" -#include "moses/Incremental/Manager.h" +#include "moses/Incremental.h" #include "util/usage.hh" @@ -91,10 +91,14 @@ public: if (staticData.GetSearchAlgorithm() == ChartIncremental) { Incremental::Manager manager(*m_source, system); - manager.ProcessSentence(); - if (m_ioWrapper.ExposeSingleBest()) { - m_ioWrapper.ExposeSingleBest()->Write(translationId, manager.String() + '\n'); + const std::vector &nbest = manager.ProcessSentence(); + if (!nbest.empty()) { + m_ioWrapper.OutputBestHypo(nbest[0], translationId); + } else { + m_ioWrapper.OutputBestNone(translationId); } + if (staticData.GetNBestSize() > 0) + m_ioWrapper.OutputNBestList(nbest, system, translationId); return; } @@ -125,7 +129,7 @@ public: VERBOSE(2,"WRITING " << nBestSize << " TRANSLATION ALTERNATIVES TO " << staticData.GetNBestFilePath() << endl); ChartTrellisPathList nBestList; manager.CalcNBest(nBestSize, nBestList,staticData.GetDistinctNBest()); - m_ioWrapper.OutputNBestList(nBestList, bestHypo, &system, translationId); + m_ioWrapper.OutputNBestList(nBestList, &system, translationId); IFVERBOSE(2) { PrintUserTime("N-Best Hypotheses Generation Time:"); } diff --git a/moses/ChartCellLabel.h b/moses/ChartCellLabel.h index c44462fcc..9fccf71e9 100644 --- a/moses/ChartCellLabel.h +++ b/moses/ChartCellLabel.h @@ -23,7 +23,7 @@ #include "Word.h" #include "WordsRange.h" -namespace search { class Vertex; class VertexGenerator; } +namespace search { class Vertex; } namespace Moses { @@ -41,7 +41,7 @@ class ChartCellLabel union Stack { const HypoList *cube; // cube pruning const search::Vertex *incr; // incremental search after filling. - search::VertexGenerator *incr_generator; // incremental search during filling. + void *incr_generator; // incremental search during filling. }; diff --git a/moses/Incremental.cpp b/moses/Incremental.cpp new file mode 100644 index 000000000..770b0d67e --- /dev/null +++ b/moses/Incremental.cpp @@ -0,0 +1,296 @@ +#include "moses/Incremental.h" + +#include "moses/ChartCell.h" +#include "moses/ChartParserCallback.h" +#include "moses/FeatureVector.h" +#include "moses/StaticData.h" +#include "moses/TranslationSystem.h" +#include "moses/Util.h" + +#include "lm/model.hh" +#include "search/applied.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "search/edge_generator.hh" +#include "search/rule.hh" +#include "search/vertex_generator.hh" + +#include + +namespace Moses { +namespace Incremental { +namespace { + +// This is called by EdgeGenerator. Route hypotheses to separate vertices for +// each left hand side label, populating ChartCellLabelSet out. +template class HypothesisCallback { + private: + typedef search::VertexGenerator Gen; + public: + HypothesisCallback(search::ContextBase &context, Best &best, ChartCellLabelSet &out, boost::object_pool &vertex_pool) + : context_(context), best_(best), out_(out), vertex_pool_(vertex_pool) {} + + void NewHypothesis(search::PartialEdge partial) { + // Get the LHS, look it up in the output ChartCellLabel, and upcast it. + // It's not part of the union because it would have been ugly to expose template types in ChartCellLabel. + ChartCellLabel::Stack &stack = out_.FindOrInsert(static_cast(partial.GetNote().vp)->GetTargetLHS()); + Gen *entry = static_cast(stack.incr_generator); + if (!entry) { + entry = generator_pool_.construct(context_, *vertex_pool_.construct(), best_); + stack.incr_generator = entry; + } + entry->NewHypothesis(partial); + } + + void FinishedSearch() { + for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) { + ChartCellLabel::Stack &stack = i->second.MutableStack(); + Gen *gen = static_cast(stack.incr_generator); + gen->FinishedSearch(); + stack.incr = &gen->Generating(); + } + } + + private: + search::ContextBase &context_; + + Best &best_; + + ChartCellLabelSet &out_; + + boost::object_pool &vertex_pool_; + boost::object_pool generator_pool_; +}; + +// This is called by the moses parser to collect hypotheses. It converts to my +// edges (search::PartialEdge). +template class Fill : public ChartParserCallback { + public: + Fill(search::Context &context, const std::vector &vocab_mapping, search::Score oov_weight) + : context_(context), vocab_mapping_(vocab_mapping), oov_weight_(oov_weight) {} + + void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored); + + void AddPhraseOOV(TargetPhrase &phrase, std::list &waste_memory, const WordsRange &range); + + bool Empty() const { return edges_.Empty(); } + + template void Search(Best &best, ChartCellLabelSet &out, boost::object_pool &vertex_pool) { + HypothesisCallback callback(context_, best, out, vertex_pool); + edges_.Search(context_, callback); + } + + // Root: everything into one vertex. + template search::History RootSearch(Best &best) { + search::Vertex vertex; + search::RootVertexGenerator gen(vertex, best); + edges_.Search(context_, gen); + return vertex.BestChild(); + } + + private: + lm::WordIndex Convert(const Word &word) const; + + search::Context &context_; + + const std::vector &vocab_mapping_; + + search::EdgeGenerator edges_; + + const search::Score oov_weight_; +}; + +template void Fill::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &) { + std::vector vertices; + vertices.reserve(nts.size()); + float below_score = 0.0; + for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) { + vertices.push_back((*i)->GetStack().incr->RootPartial()); + if (vertices.back().Empty()) return; + below_score += vertices.back().Bound(); + } + + std::vector words; + for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) { + words.clear(); + const TargetPhrase &phrase = **p; + const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap(); + search::PartialEdge edge(edges_.AllocateEdge(nts.size())); + + search::PartialVertex *nt = edge.NT(); + for (size_t i = 0; i < phrase.GetSize(); ++i) { + const Word &word = phrase.GetWord(i); + if (word.IsNonTerminal()) { + *(nt++) = vertices[align[i]]; + words.push_back(search::kNonTerminal); + } else { + words.push_back(Convert(word)); + } + } + + edge.SetScore(phrase.GetFutureScore() + below_score); + // prob and oov were already accounted for. + search::ScoreRule(context_.LanguageModel(), words, edge.Between()); + + search::Note note; + note.vp = &phrase; + edge.SetNote(note); + + edges_.AddEdge(edge); + } +} + +template void Fill::AddPhraseOOV(TargetPhrase &phrase, std::list &, const WordsRange &) { + std::vector words; + CHECK(phrase.GetSize() <= 1); + if (phrase.GetSize()) + words.push_back(Convert(phrase.GetWord(0))); + + search::PartialEdge edge(edges_.AllocateEdge(0)); + // Appears to be a bug that FutureScore does not already include language model. + search::ScoreRuleRet scored(search::ScoreRule(context_.LanguageModel(), words, edge.Between())); + edge.SetScore(phrase.GetFutureScore() + scored.prob * context_.LMWeight() + static_cast(scored.oov) * oov_weight_); + + search::Note note; + note.vp = &phrase; + edge.SetNote(note); + + edges_.AddEdge(edge); +} + +// TODO: factors (but chart doesn't seem to support factors anyway). +template lm::WordIndex Fill::Convert(const Word &word) const { + std::size_t factor = word.GetFactor(0)->GetId(); + return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]); +} + +struct ChartCellBaseFactory { + ChartCellBase *operator()(size_t startPos, size_t endPos) const { + return new ChartCellBase(startPos, endPos); + } +}; + +} // namespace + +Manager::Manager(const InputType &source, const TranslationSystem &system) : + source_(source), + system_(system), + cells_(source, ChartCellBaseFactory()), + parser_(source, system, cells_), + n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize())) {} + +Manager::~Manager() { + system_.CleanUpAfterSentenceProcessing(source_); +} + +template search::History Manager::PopulateBest(const Model &model, const std::vector &words, Best &out) { + const LanguageModel &abstract = **system_.GetLanguageModels().begin(); + const float oov_weight = abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0; + const StaticData &data = StaticData::Instance(); + search::Config config(abstract.GetWeight(), data.GetCubePruningPopLimit(), search::NBestConfig(data.GetNBestSize())); + search::Context context(config, model); + + size_t size = source_.GetSize(); + boost::object_pool vertex_pool(std::max(size * size / 2, 32)); + + for (size_t width = 1; width < size; ++width) { + for (size_t startPos = 0; startPos <= size-width; ++startPos) { + WordsRange range(startPos, startPos + width - 1); + Fill filler(context, words, oov_weight); + parser_.Create(range, filler); + filler.Search(out, cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool); + } + } + + WordsRange range(0, size - 1); + Fill filler(context, words, oov_weight); + parser_.Create(range, filler); + return filler.RootSearch(out); +} + +template void Manager::LMCallback(const Model &model, const std::vector &words) { + std::size_t nbest = StaticData::Instance().GetNBestSize(); + if (nbest <= 1) { + search::History ret = PopulateBest(model, words, single_best_); + if (ret) { + backing_for_single_.resize(1); + backing_for_single_[0] = search::Applied(ret); + } else { + backing_for_single_.clear(); + } + completed_nbest_ = &backing_for_single_; + } else { + search::History ret = PopulateBest(model, words, n_best_); + if (ret) { + completed_nbest_ = &n_best_.Extract(ret); + } else { + backing_for_single_.clear(); + completed_nbest_ = &backing_for_single_; + } + } +} + +template void Manager::LMCallback(const lm::ngram::ProbingModel &model, const std::vector &words); +template void Manager::LMCallback(const lm::ngram::RestProbingModel &model, const std::vector &words); +template void Manager::LMCallback(const lm::ngram::TrieModel &model, const std::vector &words); +template void Manager::LMCallback(const lm::ngram::QuantTrieModel &model, const std::vector &words); +template void Manager::LMCallback(const lm::ngram::ArrayTrieModel &model, const std::vector &words); +template void Manager::LMCallback(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words); + +const std::vector &Manager::ProcessSentence() { + const LMList &lms = system_.GetLanguageModels(); + UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model."); + (*lms.begin())->IncrementalCallback(*this); + return *completed_nbest_; +} + +namespace { + +struct NoOp { + void operator()(const TargetPhrase &) const {} +}; +struct AccumScore { + AccumScore(ScoreComponentCollection &out) : out_(&out) {} + void operator()(const TargetPhrase &phrase) { + out_->PlusEquals(phrase.GetScoreBreakdown()); + } + ScoreComponentCollection *out_; +}; +template void AppendToPhrase(const search::Applied final, Phrase &out, Action action) { + assert(final.Valid()); + const TargetPhrase &phrase = *static_cast(final.GetNote().vp); + action(phrase); + const search::Applied *child = final.Children(); + for (std::size_t i = 0; i < phrase.GetSize(); ++i) { + const Word &word = phrase.GetWord(i); + if (word.IsNonTerminal()) { + AppendToPhrase(*child++, out, action); + } else { + out.AddWord(word); + } + } +} + +} // namespace + +void ToPhrase(const search::Applied final, Phrase &out) { + out.Clear(); + AppendToPhrase(final, out, NoOp()); +} + +void PhraseAndFeatures(const TranslationSystem &system, const search::Applied final, Phrase &phrase, ScoreComponentCollection &features) { + phrase.Clear(); + features.ZeroAll(); + AppendToPhrase(final, phrase, AccumScore(features)); + + // If we made it this far, there is only one language model. + float full, ignored_ngram; + std::size_t ignored_oov; + const LanguageModel &model = **system.GetLanguageModels().begin(); + model.CalcScore(phrase, full, ignored_ngram, ignored_oov); + // CalcScore transforms, but EvaluateChart doesn't. + features.Assign(&model, UntransformLMScore(full)); +} + +} // namespace Incremental +} // namespace Moses diff --git a/moses/Incremental.h b/moses/Incremental.h new file mode 100644 index 000000000..4bfc2dae3 --- /dev/null +++ b/moses/Incremental.h @@ -0,0 +1,60 @@ +#pragma once + +#include "lm/word_index.hh" +#include "search/applied.hh" +#include "search/nbest.hh" + +#include "moses/ChartCellCollection.h" +#include "moses/ChartParser.h" + +#include +#include + +namespace Moses { +class ScoreComponentCollection; +class InputType; +class TranslationSystem; +namespace Incremental { + +class Manager { + public: + Manager(const InputType &source, const TranslationSystem &system); + + ~Manager(); + + template void LMCallback(const Model &model, const std::vector &words); + + const std::vector &ProcessSentence(); + + // Call to get the same value as ProcessSentence returned. + const std::vector &Completed() const { + return *completed_nbest_; + } + + private: + template search::History PopulateBest(const Model &model, const std::vector &words, Best &out); + + const InputType &source_; + const TranslationSystem &system_; + ChartCellCollectionBase cells_; + ChartParser parser_; + + // Only one of single_best_ or n_best_ will be used, but it was easier to do this than a template. + search::SingleBest single_best_; + // ProcessSentence returns a reference to a vector. ProcessSentence + // doesn't have one, so this is populated and returned. + std::vector backing_for_single_; + + search::NBest n_best_; + + const std::vector *completed_nbest_; +}; + +// Just get the phrase. +void ToPhrase(const search::Applied final, Phrase &out); +// Get the phrase and the features. +void PhraseAndFeatures(const TranslationSystem &system, const search::Applied final, Phrase &phrase, ScoreComponentCollection &features); + +} // namespace Incremental +} // namespace Moses + diff --git a/moses/Incremental/Fill.cpp b/moses/Incremental/Fill.cpp deleted file mode 100644 index 6f0baba92..000000000 --- a/moses/Incremental/Fill.cpp +++ /dev/null @@ -1,143 +0,0 @@ -#include "Fill.h" - -#include "moses/ChartCellLabel.h" -#include "moses/ChartCellLabelSet.h" -#include "moses/TargetPhraseCollection.h" -#include "moses/TargetPhrase.h" -#include "moses/Word.h" - -#include "lm/model.hh" -#include "search/context.hh" -#include "search/note.hh" -#include "search/rule.hh" -#include "search/vertex.hh" -#include "search/vertex_generator.hh" - -#include - -namespace Moses { -namespace Incremental { - -template Fill::Fill(search::Context &context, const std::vector &vocab_mapping) - : context_(context), vocab_mapping_(vocab_mapping) {} - -template void Fill::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &) { - std::vector vertices; - vertices.reserve(nts.size()); - float below_score = 0.0; - for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) { - vertices.push_back((*i)->GetStack().incr->RootPartial()); - if (vertices.back().Empty()) return; - below_score += vertices.back().Bound(); - } - - std::vector words; - for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) { - words.clear(); - const TargetPhrase &phrase = **p; - const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap(); - search::PartialEdge edge(edges_.AllocateEdge(nts.size())); - - size_t i = 0; - bool bos = false; - search::PartialVertex *nt = edge.NT(); - if (phrase.GetSize() && !phrase.GetWord(0).IsNonTerminal()) { - lm::WordIndex index = Convert(phrase.GetWord(0)); - if (context_.LanguageModel().GetVocabulary().BeginSentence() == index) { - bos = true; - } else { - words.push_back(index); - } - i = 1; - } - for (; i < phrase.GetSize(); ++i) { - const Word &word = phrase.GetWord(i); - if (word.IsNonTerminal()) { - *(nt++) = vertices[align[i]]; - words.push_back(search::kNonTerminal); - } else { - words.push_back(Convert(word)); - } - } - - edge.SetScore(phrase.GetFutureScore() + below_score); - search::ScoreRule(context_, words, bos, edge.Between()); - - search::Note note; - note.vp = &phrase; - edge.SetNote(note); - - edges_.AddEdge(edge); - } -} - -template void Fill::AddPhraseOOV(TargetPhrase &phrase, std::list &, const WordsRange &) { - std::vector words; - CHECK(phrase.GetSize() <= 1); - if (phrase.GetSize()) - words.push_back(Convert(phrase.GetWord(0))); - - search::PartialEdge edge(edges_.AllocateEdge(0)); - // Appears to be a bug that FutureScore does not already include language model. - edge.SetScore(phrase.GetFutureScore() + search::ScoreRule(context_, words, false, edge.Between())); - - search::Note note; - note.vp = &phrase; - edge.SetNote(note); - - edges_.AddEdge(edge); -} - -namespace { -// Route hypotheses to separate vertices for each left hand side label, populating ChartCellLabelSet out. -class HypothesisCallback { - public: - HypothesisCallback(search::ContextBase &context, ChartCellLabelSet &out, boost::object_pool &vertex_pool) - : context_(context), out_(out), vertex_pool_(vertex_pool) {} - - void NewHypothesis(search::PartialEdge partial) { - search::VertexGenerator *&entry = out_.FindOrInsert(static_cast(partial.GetNote().vp)->GetTargetLHS()).incr_generator; - if (!entry) { - entry = generator_pool_.construct(context_, *vertex_pool_.construct()); - } - entry->NewHypothesis(partial); - } - - void FinishedSearch() { - for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) { - ChartCellLabel::Stack &stack = i->second.MutableStack(); - stack.incr_generator->FinishedSearch(); - stack.incr = &stack.incr_generator->Generating(); - } - } - - private: - search::ContextBase &context_; - - ChartCellLabelSet &out_; - - boost::object_pool &vertex_pool_; - boost::object_pool generator_pool_; -}; -} // namespace - -template void Fill::Search(ChartCellLabelSet &out, boost::object_pool &vertex_pool) { - HypothesisCallback callback(context_, out, vertex_pool); - edges_.Search(context_, callback); -} - -// TODO: factors (but chart doesn't seem to support factors anyway). -template lm::WordIndex Fill::Convert(const Word &word) const { - std::size_t factor = word.GetFactor(0)->GetId(); - return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]); -} - -template class Fill; -template class Fill; -template class Fill; -template class Fill; -template class Fill; -template class Fill; - -} // namespace Incremental -} // namespace Moses diff --git a/moses/Incremental/Fill.h b/moses/Incremental/Fill.h deleted file mode 100644 index 0f4059d09..000000000 --- a/moses/Incremental/Fill.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include "moses/ChartParserCallback.h" -#include "moses/StackVec.h" - -#include "lm/word_index.hh" -#include "search/edge_generator.hh" - -#include - -#include -#include - -namespace search { -template class Context; -class Vertex; -} // namespace search - -namespace Moses { -class Word; -class WordsRange; -class TargetPhraseCollection; -class WordsRange; -class ChartCellLabelSet; -class TargetPhrase; - -namespace Incremental { - -// Replacement for ChartTranslationOptionList -// TODO: implement count and score thresholding. -template class Fill : public ChartParserCallback { - public: - Fill(search::Context &context, const std::vector &vocab_mapping); - - void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored); - - void AddPhraseOOV(TargetPhrase &phrase, std::list &waste_memory, const WordsRange &range); - - bool Empty() const { return edges_.Empty(); } - - void Search(ChartCellLabelSet &out, boost::object_pool &vertex_pool); - - private: - lm::WordIndex Convert(const Word &word) const ; - - search::Context &context_; - - const std::vector &vocab_mapping_; - - search::EdgeGenerator edges_; -}; - -} // namespace Incremental -} // namespace Moses diff --git a/moses/Incremental/Manager.cpp b/moses/Incremental/Manager.cpp deleted file mode 100644 index 7d684540c..000000000 --- a/moses/Incremental/Manager.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include "Manager.h" - -#include "Fill.h" - -#include "moses/ChartCell.h" -#include "moses/TranslationSystem.h" -#include "moses/StaticData.h" - -#include "search/context.hh" -#include "search/config.hh" -#include "search/weights.hh" - -#include - -namespace Moses { -namespace Incremental { - -namespace { -struct ChartCellBaseFactory { - ChartCellBase *operator()(size_t startPos, size_t endPos) const { - return new ChartCellBase(startPos, endPos); - } -}; -} // namespace - -Manager::Manager(const InputType &source, const TranslationSystem &system) : - source_(source), - system_(system), - cells_(source, ChartCellBaseFactory()), - parser_(source, system, cells_) { - -} - -Manager::~Manager() { - system_.CleanUpAfterSentenceProcessing(source_); -} - -namespace { - -void ConstructString(const search::Final final, std::ostringstream &stream) { - assert(final.Valid()); - const TargetPhrase &phrase = *static_cast(final.GetNote().vp); - size_t child = 0; - for (std::size_t i = 0; i < phrase.GetSize(); ++i) { - const Word &word = phrase.GetWord(i); - if (word.IsNonTerminal()) { - assert(child < final.GetArity()); - ConstructString(final.Children()[child++], stream); - } else { - stream << word[0]->GetString() << ' '; - } - } -} - -void BestString(const ChartCellLabelSet &labels, std::string &out) { - search::Final best; - for (ChartCellLabelSet::const_iterator i = labels.begin(); i != labels.end(); ++i) { - const search::Final child(i->second.GetStack().incr->BestChild()); - if (child.Valid() && (!best.Valid() || (child.GetScore() > best.GetScore()))) { - best = child; - } - } - if (!best.Valid()) { - out.clear(); - return; - } - std::ostringstream stream; - ConstructString(best, stream); - out = stream.str(); - CHECK(out.size() > 9); - // - out.erase(0, 4); - // - out.erase(out.size() - 5); - // Hack: include model score - out += " ||| "; - out += boost::lexical_cast(best.GetScore()); -} - -} // namespace - - -template void Manager::LMCallback(const Model &model, const std::vector &words) { - const LanguageModel &abstract = **system_.GetLanguageModels().begin(); - search::Weights weights( - abstract.GetWeight(), - abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0, - system_.GetWeightWordPenalty()); - search::Config config(weights, StaticData::Instance().GetCubePruningPopLimit()); - search::Context context(config, model); - - size_t size = source_.GetSize(); - - boost::object_pool vertex_pool(std::max(size * size / 2, 32)); - - for (size_t width = 1; width <= size; ++width) { - for (size_t startPos = 0; startPos <= size-width; ++startPos) { - size_t endPos = startPos + width - 1; - WordsRange range(startPos, endPos); - Fill filler(context, words); - parser_.Create(range, filler); - filler.Search(cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool); - } - } - BestString(cells_.GetBase(WordsRange(0, source_.GetSize() - 1)).GetTargetLabelSet(), output_); -} - -template void Manager::LMCallback(const lm::ngram::ProbingModel &model, const std::vector &words); -template void Manager::LMCallback(const lm::ngram::RestProbingModel &model, const std::vector &words); -template void Manager::LMCallback(const lm::ngram::TrieModel &model, const std::vector &words); -template void Manager::LMCallback(const lm::ngram::QuantTrieModel &model, const std::vector &words); -template void Manager::LMCallback(const lm::ngram::ArrayTrieModel &model, const std::vector &words); -template void Manager::LMCallback(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words); - -void Manager::ProcessSentence() { - const LMList &lms = system_.GetLanguageModels(); - UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model."); - (*lms.begin())->IncrementalCallback(*this); -} - -} // namespace Incremental -} // namespace Moses diff --git a/moses/Incremental/Manager.h b/moses/Incremental/Manager.h deleted file mode 100644 index ac8d76a81..000000000 --- a/moses/Incremental/Manager.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include "lm/word_index.hh" - -#include "moses/ChartCellCollection.h" -#include "moses/ChartParser.h" - -namespace Moses { -class InputType; -class TranslationSystem; -namespace Incremental { - -class Manager { - public: - Manager(const InputType &source, const TranslationSystem &system); - - ~Manager(); - - template void LMCallback(const Model &model, const std::vector &words); - - void ProcessSentence(); - - const std::string &String() const { return output_; } - - private: - const InputType &source_; - const TranslationSystem &system_; - ChartCellCollectionBase cells_; - ChartParser parser_; - - std::string output_; -}; -} // namespace Incremental -} // namespace Moses - diff --git a/moses/Jamfile b/moses/Jamfile index c05a9c6ab..9caa4e788 100644 --- a/moses/Jamfile +++ b/moses/Jamfile @@ -32,7 +32,6 @@ lib moses : CYKPlusParser/*.cpp RuleTable/*.cpp fuzzy-match/*.cpp - Incremental/*.cpp : #exceptions ThreadPool.cpp SyntacticLanguageModel.cpp diff --git a/moses/LM/Ken.cpp b/moses/LM/Ken.cpp index 25e5a00d3..42e517f17 100644 --- a/moses/LM/Ken.cpp +++ b/moses/LM/Ken.cpp @@ -38,7 +38,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include "moses/InputFileStream.h" #include "moses/StaticData.h" #include "moses/ChartHypothesis.h" -#include "moses/Incremental/Manager.h" +#include "moses/Incremental.h" #include diff --git a/scripts/training/filter-model-given-input.pl b/scripts/training/filter-model-given-input.pl index df9c528e0..d994fbcef 100755 --- a/scripts/training/filter-model-given-input.pl +++ b/scripts/training/filter-model-given-input.pl @@ -119,7 +119,11 @@ while() { print INI_OUT "2 $source_factor $t $w $new_name.bin$table_flag\n"; } elsif ($binarizer && $phrase_table_impl == 0) { - print INI_OUT "1 $source_factor $t $w $new_name$table_flag\n"; + if ($binarizer =~ /processPhraseTableMin/) { + print INI_OUT "12 $source_factor $t $w $new_name$table_flag\n"; + } else { + print INI_OUT "1 $source_factor $t $w $new_name$table_flag\n"; + } } else { $new_name .= ".gz" if $opt_gzip; print INI_OUT "$phrase_table_impl $source_factor $t $w $new_name$table_flag\n"; @@ -147,7 +151,7 @@ while() { $file =~ s/^.*\/+([^\/]+)/$1/g; my $new_name = "$dir/$file"; - $new_name =~ s/\.gz//; + $new_name =~ s/\.gz//; print INI_OUT "$factors $t $w $new_name\n"; push @TABLE_NEW_NAME,$new_name; @@ -275,11 +279,16 @@ for(my $i=0;$i<=$#TABLE;$i++) { # ... hierarchical translation model if ($opt_hierarchical) { my $cmd = "$binarizer $new_file $new_file.bin"; - print STDERR $cmd."\n"; - print STDERR `$cmd`; + print STDERR $cmd."\n"; + print STDERR `$cmd`; } # ... phrase translation model - else { + elsif ($binarizer =~ /processPhraseTableMin/) { + #compact phrase table + my $cmd = "LC_ALL=C sort -T $dir $new_file > $new_file.sorted; $binarizer -in $new_file.sorted -out $new_file -nscores $TABLE_WEIGHTS[$i]; rm $new_file.sorted"; + print STDERR $cmd."\n"; + print STDERR `$cmd`; + } else { my $cmd = "cat $new_file | LC_ALL=C sort -T $dir | $binarizer -ttable 0 0 - -nscores $TABLE_WEIGHTS[$i] -out $new_file"; print STDERR $cmd."\n"; print STDERR `$cmd`; @@ -289,8 +298,13 @@ for(my $i=0;$i<=$#TABLE;$i++) { else { my $lexbin = $binarizer; $lexbin =~ s/PhraseTable/LexicalTable/; - $lexbin =~ s/^\s*(\S+)\s.+/$1/; # no options - my $cmd = "$lexbin -in $new_file -out $new_file"; + my $cmd; + if ($lexbin =~ /processLexicalTableMin/) { + $cmd = "LC_ALL=C sort -T $dir $new_file > $new_file.sorted; $lexbin -in $new_file.sorted -out $new_file; rm $new_file.sorted"; + } else { + $lexbin =~ s/^\s*(\S+)\s.+/$1/; # no options + $cmd = "$lexbin -in $new_file -out $new_file"; + } print STDERR $cmd."\n"; print STDERR `$cmd`; } diff --git a/search/Jamfile b/search/Jamfile index c00d23828..f6433e0e3 100644 --- a/search/Jamfile +++ b/search/Jamfile @@ -1,5 +1 @@ -fakelib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; - -import testing ; - -unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; +fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc vertex_generator.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; diff --git a/search/applied.hh b/search/applied.hh new file mode 100644 index 000000000..bd659e5c0 --- /dev/null +++ b/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted. +template class GenericApplied : public Header { + public: + GenericApplied() {} + + GenericApplied(void *location, PartialEdge partial) + : Header(location) { + memcpy(Base(), partial.Base(), kHeaderSize); + Below *child_out = Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = Below(part->End()); + } + + GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { + SetScore(score); + SetNote(note); + } + + explicit GenericApplied(History from) : Header(from) {} + + + // These are arrays of length GetArity(). + Below *Children() { + return reinterpret_cast(After()); + } + const Below *Children() const { + return reinterpret_cast(After()); + } + + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Below); + } +}; + +// Applied rule that references itself. +class Applied : public GenericApplied { + private: + typedef GenericApplied P; + + public: + Applied() {} + Applied(void *location, PartialEdge partial) : P(location, partial) {} + Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses. +class SingleBest { + public: + typedef PartialEdge Combine; + + void Add(PartialEdge &existing, PartialEdge add) const { + if (!existing.Valid() || existing.GetScore() < add.GetScore()) + existing = add; + } + + NBestComplete Complete(PartialEdge partial) { + if (!partial.Valid()) + return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); + void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); + Applied(place_final, partial); + return NBestComplete( + place_final, + partial.CompletedState(), + partial.GetScore()); + } + + private: + util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/search/config.hh b/search/config.hh index ef8e2354a..ba18c09e9 100644 --- a/search/config.hh +++ b/search/config.hh @@ -1,23 +1,36 @@ #ifndef SEARCH_CONFIG__ #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh" namespace search { +struct NBestConfig { + explicit NBestConfig(unsigned int in_size) { + keep = in_size; + size = in_size; + } + + unsigned int keep, size; +}; + class Config { public: - Config(const Weights &weights, unsigned int pop_limit) : - weights_(weights), pop_limit_(pop_limit) {} + Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : + lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return lm_weight_; } unsigned int PopLimit() const { return pop_limit_; } + const NBestConfig &GetNBest() const { return nbest_; } + private: - Weights weights_; + Score lm_weight_; + unsigned int pop_limit_; + + NBestConfig nbest_; }; } // namespace search diff --git a/search/context.hh b/search/context.hh index 62163144f..08f21bbf0 100644 --- a/search/context.hh +++ b/search/context.hh @@ -1,30 +1,16 @@ #ifndef SEARCH_CONTEXT__ #define SEARCH_CONTEXT__ -#include "lm/model.hh" #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh" #include -#include - -#include namespace search { -class Weights; - class ContextBase { public: - explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - - util::Pool &FinalPool() { - return final_pool_; - } + explicit ContextBase(const Config &config) : config_(config) {} VertexNode *NewVertexNode() { VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase { vertex_node_pool_.destroy(node); } - unsigned int PopLimit() const { return pop_limit_; } + unsigned int PopLimit() const { return config_.PopLimit(); } - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return config_.LMWeight(); } + + const Config &GetConfig() const { return config_; } private: - util::Pool final_pool_; - boost::object_pool vertex_node_pool_; - unsigned int pop_limit_; - - const Weights &weights_; + Config config_; }; template class Context : public ContextBase { diff --git a/search/edge_generator.cc b/search/edge_generator.cc index 260159b1f..eacf5de5c 100644 --- a/search/edge_generator.cc +++ b/search/edge_generator.cc @@ -1,6 +1,7 @@ #include "search/edge_generator.hh" #include "lm/left.hh" +#include "lm/model.hh" #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" @@ -38,7 +39,7 @@ template void FastScore(const Context &context, Arity victi *cover = *(cover + 1); } } - update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); + update.SetScore(update.GetScore() + adjustment * context.LMWeight()); } } // namespace diff --git a/search/edge_generator.hh b/search/edge_generator.hh index 582c78b7b..203942c6f 100644 --- a/search/edge_generator.hh +++ b/search/edge_generator.hh @@ -2,7 +2,6 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" -#include "search/note.hh" #include "search/types.hh" #include diff --git a/search/final.hh b/search/final.hh deleted file mode 100644 index 50e62cf2e..000000000 --- a/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { - public: - Final() {} - - Final(util::Pool &pool, Score score, Arity arity, Note note) - : Header(pool.Allocate(Size(arity)), arity) { - SetScore(score); - SetNote(note); - } - - // These are arrays of length GetArity(). - Final *Children() { - return reinterpret_cast(After()); - } - const Final *Children() const { - return reinterpret_cast(After()); - } - - private: - static std::size_t Size(Arity arity) { - return kHeaderSize + arity * sizeof(const Final); - } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/search/header.hh b/search/header.hh index 25550dbed..69f0eed04 100644 --- a/search/header.hh +++ b/search/header.hh @@ -3,7 +3,6 @@ // Header consisting of Score, Arity, and Note -#include "search/note.hh" #include "search/types.hh" #include @@ -24,6 +23,9 @@ class Header { bool operator<(const Header &other) const { return GetScore() < other.GetScore(); } + bool operator>(const Header &other) const { + return GetScore() > other.GetScore(); + } Arity GetArity() const { return *reinterpret_cast(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header { *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; } + uint8_t *Base() { return base_; } + const uint8_t *Base() const { return base_; } + protected: Header() : base_(NULL) {} + explicit Header(void *base) : base_(static_cast(base)) {} + Header(void *base, Arity arity) : base_(static_cast(base)) { *reinterpret_cast(base_ + sizeof(Score)) = arity; } diff --git a/search/nbest.cc b/search/nbest.cc new file mode 100644 index 000000000..ec3322c97 --- /dev/null +++ b/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include +#include +#include + +#include +#include + +namespace search { + +NBestList::NBestList(std::vector &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater()); + } else { + end = partials.end(); + } + for (std::vector::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast(overwrite) + i)); + *(static_cast(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector &NBest::Extract(History history) { + return static_cast(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/search/nbest.hh b/search/nbest.hh new file mode 100644 index 000000000..cb7651bc2 --- /dev/null +++ b/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include + +#include +#include +#include + +#include + +namespace search { + +class NBestList; + +class NBestList { + private: + class RevealedRef { + public: + explicit RevealedRef(History history) + : in_(static_cast(history)), index_(0) {} + + private: + friend class NBestList; + + NBestList *in_; + std::size_t index_; + }; + + typedef GenericApplied QueueEntry; + + public: + NBestList(std::vector &existing, util::Pool &entry_pool, std::size_t keep); + + Score TopAfterConstructor() const; + + const std::vector &Extract(util::Pool &pool, std::size_t n); + + private: + Score Visit(util::Pool &pool, std::size_t index); + + Applied Get(util::Pool &pool, std::size_t index); + + void MoveTop(util::Pool &pool); + + typedef std::vector Revealed; + Revealed revealed_; + + typedef std::priority_queue Queue; + Queue queue_; +}; + +class NBest { + public: + typedef std::vector Combine; + + explicit NBest(const NBestConfig &config) : config_(config) {} + + void Add(std::vector &existing, PartialEdge addition) const { + existing.push_back(addition); + } + + NBestComplete Complete(std::vector &partials); + + const std::vector &Extract(History root); + + private: + const NBestConfig config_; + + boost::object_pool list_pool_; + + util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/search/note.hh b/search/note.hh deleted file mode 100644 index 50bed06ec..000000000 --- a/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { - const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/search/rule.cc b/search/rule.cc index 5b00207ef..0244a09f7 100644 --- a/search/rule.cc +++ b/search/rule.cc @@ -1,7 +1,7 @@ #include "search/rule.hh" +#include "lm/model.hh" #include "search/context.hh" -#include "search/final.hh" #include @@ -9,35 +9,35 @@ namespace search { -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { - unsigned int oov_count = 0; - float prob = 0.0; - const Model &model = context.LanguageModel(); - const lm::WordIndex oov = model.GetVocabulary().NotFound(); - for (std::vector::const_iterator word = words.begin(); ; ++word) { - lm::ngram::RuleScore scorer(model, *(writing++)); - // TODO: optimize - if (prepend_bos && (word == words.begin())) { - scorer.BeginSentence(); - } - for (; ; ++word) { - if (word == words.end()) { - prob += scorer.Finish(); - return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); - } - if (*word == kNonTerminal) break; - if (*word == oov) ++oov_count; +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore scorer(model, *(writing++)); + std::vector::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; scorer.Terminal(*word); } - prob += scorer.Finish(); } + ret.prob += scorer.Finish(); + return ret; } -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); } // namespace search diff --git a/search/rule.hh b/search/rule.hh index 0ce2794db..43ca61625 100644 --- a/search/rule.hh +++ b/search/rule.hh @@ -9,11 +9,16 @@ namespace search { -template class Context; - const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { + Score prob; + unsigned int oov; +}; + +// Pass and normally. +// Indicate non-terminals with kNonTerminal. +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *state_out); } // namespace search diff --git a/search/types.hh b/search/types.hh index 06eb5bfa2..f9c849b3f 100644 --- a/search/types.hh +++ b/search/types.hh @@ -3,12 +3,29 @@ #include +namespace lm { namespace ngram { class ChartState; } } + namespace search { typedef float Score; typedef uint32_t Arity; +union Note { + const void *vp; +}; + +typedef void *History; + +struct NBestComplete { + NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score) + : history(in_history), state(&in_state), score(in_score) {} + + History history; + const lm::ngram::ChartState *state; + Score score; +}; + } // namespace search #endif // SEARCH_TYPES__ diff --git a/search/vertex.cc b/search/vertex.cc index 11f4631fa..45842982c 100644 --- a/search/vertex.cc +++ b/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_functionSortAndSet(context, parent_ptr); + if (extend_.size() == 1) { + parent_ptr = extend_[0]; + extend_[0]->RecursiveSortAndSet(context, parent_ptr); context.DeleteVertexNode(this); return; } for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->SortAndSet(context, &*i); + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { + // This is the root. The root might be empty. + if (extend_.empty()) { + bound_ = -INFINITY; + return; + } + // The root cannot be replaced. There's always one transition. + for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); } std::sort(extend_.begin(), extend_.end(), GreaterByBound()); bound_ = extend_.front()->Bound(); diff --git a/search/vertex.hh b/search/vertex.hh index 52bc1dfe7..10b3339b9 100644 --- a/search/vertex.hh +++ b/search/vertex.hh @@ -2,7 +2,6 @@ #define SEARCH_VERTEX__ #include "lm/left.hh" -#include "search/final.hh" #include "search/types.hh" #include @@ -10,6 +9,7 @@ #include #include +#include #include namespace search { @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() {} + VertexNode() : end_() {} void InitRoot() { extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - end_ = Final(); + end_ = History(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final end) { - assert(!end_.Valid()); + void SetEnd(History end, Score score) { + assert(!end_); end_ = end; + bound_ = score; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + void SortAndSet(ContextBase &context); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_.Valid() && extend_.empty(); + return !end_ && extend_.empty(); } bool Complete() const { - return end_.Valid(); + return end_; } const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode { } // Will be invalid unless this is a leaf. - const Final End() const { return end_; } + const History End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode { } private: + void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + std::vector extend_; lm::ngram::ChartState state_; bool right_full_; Score bound_; - Final end_; + History end_; }; class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex { return ret; } - const Final End() const { + const History End() const { return back_->End(); } @@ -130,16 +133,18 @@ class PartialVertex { unsigned int index_; }; +template class VertexGenerator; + class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final BestChild() const { + const History BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return Final(); + return History(); } else { PartialVertex continuation; while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex { } private: - friend class VertexGenerator; - + template friend class VertexGenerator; + template friend class RootVertexGenerator; VertexNode root_; }; diff --git a/search/vertex_generator.cc b/search/vertex_generator.cc index e18010c38..73139ffc5 100644 --- a/search/vertex_generator.cc +++ b/search/vertex_generator.cc @@ -11,23 +11,11 @@ namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { - gen.root_.InitRoot(); -} - #if BOOST_VERSION > 104200 namespace { const uint64_t kCompleteAdd = static_cast(-1); -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; -}; - Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { Trie &next = node.extend[added]; if (!next.under) { @@ -43,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n return next; } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { - Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); - Final *child_out = final.Children(); - const PartialVertex *part = partial.NT(); - const PartialVertex *const part_end_loop = part + partial.GetArity(); - for (; part != part_end_loop; ++part, ++child_out) - *child_out = part->End(); +} // namespace - starter.under->SetEnd(final); -} - -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { + const lm::ngram::ChartState &state = *end.state; unsigned char left = 0, right = 0; Trie *node = &root; @@ -81,30 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { } node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - CompleteTransition(context, *node, partial); -} - -} // namespace - -#else // BOOST_VERSION - -struct Trie { - VertexNode *under; -}; - -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { - UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); + node->under->SetEnd(end.history, end.score); } #endif // BOOST_VERSION -void VertexGenerator::FinishedSearch() { - Trie root; - root.under = &gen_.root_; - for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, i->second); - } - root.under->SortAndSet(context_, NULL); -} - } // namespace search diff --git a/search/vertex_generator.hh b/search/vertex_generator.hh index 60e86112a..da563c2df 100644 --- a/search/vertex_generator.hh +++ b/search/vertex_generator.hh @@ -2,9 +2,11 @@ #define SEARCH_VERTEX_GENERATOR__ #include "search/edge.hh" +#include "search/types.hh" #include "search/vertex.hh" #include +#include namespace lm { namespace ngram { @@ -15,21 +17,44 @@ class ChartState; namespace search { class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list. +template class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen); - - void NewHypothesis(PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); - if (!ret.second && ret.first->second < partial) { - ret.first->second = partial; - } + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { + gen.root_.InitRoot(); } - void FinishedSearch(); + void NewHypothesis(PartialEdge partial) { + nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); + } + + void FinishedSearch() { +#if BOOST_VERSION > 104200 + Trie root; + root.under = &gen_.root_; + for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, nbest_.Complete(i->second)); + } + existing_.clear(); + root.under->SortAndSet(context_); +#else + UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif + } const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator { Vertex &gen_; - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; + + Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node. In theory, this should happen naturally due to state collapsing with +// and . If that's the case, VertexGenerator is fine, though it will +// make one connection. +template class RootVertexGenerator { + public: + RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + + void NewHypothesis(PartialEdge partial) { + out_.Add(combine_, partial); + } + + void FinishedSearch() { + gen_.root_.InitRoot(); + NBestComplete completed(out_.Complete(combine_)); + gen_.root_.SetEnd(completed.history, completed.score); + } + + private: + Vertex &gen_; + + typename Output::Combine combine_; + Output &out_; }; } // namespace search diff --git a/search/weights.cc b/search/weights.cc deleted file mode 100644 index d65471ad7..000000000 --- a/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include - -namespace search { - -namespace { -struct Insert { - void operator()(boost::unordered_map &map, StringPiece name, search::Score score) const { - std::string copy(name.data(), name.size()); - map[copy] = score; - } -}; - -struct DotProduct { - search::Score total; - DotProduct() : total(0.0) {} - - void operator()(const boost::unordered_map &map, StringPiece name, search::Score score) { - boost::unordered_map::const_iterator i(FindStringPiece(map, name)); - if (i != map.end()) - total += score * i->second; - } -}; - -template void Parse(StringPiece text, Map &map, Op &op) { - for (util::TokenIter spaces(text, ' '); spaces; ++spaces) { - util::TokenIter equals(*spaces, '='); - UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); - StringPiece name(*equals); - UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); - char *end; - // Assumes proper termination. - double value = std::strtod(equals->data(), &end); - UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); - UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); - op(map, name, value); - } -} - -} // namespace - -Weights::Weights(StringPiece text) { - Insert op; - Parse(text, map_, op); - lm_ = Steal("LanguageModel"); - oov_ = Steal("OOV"); - word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { - DotProduct dot; - Parse(text, map_, dot); - return dot.total; -} - -float Weights::Steal(const std::string &str) { - Map::iterator i(map_.find(str)); - if (i == map_.end()) { - return 0.0; - } else { - float ret = i->second; - map_.erase(i); - return ret; - } -} - -} // namespace search diff --git a/search/weights.hh b/search/weights.hh deleted file mode 100644 index df1c419f0..000000000 --- a/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept. -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include - -#include - -namespace search { - -class WeightParseException : public util::Exception { - public: - WeightParseException() {} - ~WeightParseException() throw() {} -}; - -class Weights { - public: - // Parses weights, sets lm_weight_, removes it from map_. - explicit Weights(StringPiece text); - - // Just the three scores we care about adding. - Weights(Score lm, Score oov, Score word_penalty); - - Score DotNoLM(StringPiece text) const; - - Score LM() const { return lm_; } - - Score OOV() const { return oov_; } - - Score WordPenalty() const { return word_penalty_; } - - // Mostly for testing. - const boost::unordered_map &GetMap() const { return map_; } - - private: - float Steal(const std::string &str); - - typedef boost::unordered_map Map; - - Map map_; - - Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/search/weights_test.cc b/search/weights_test.cc deleted file mode 100644 index 4811ff060..000000000 --- a/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include -#include - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ - i = parsed.find(string); \ - BOOST_REQUIRE(i != parsed.end()); \ - BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { - // These are not real feature weights. - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - const boost::unordered_map &parsed = w.GetMap(); - boost::unordered_map::const_iterator i; - CHECK_WEIGHT(0.0, "rarity"); - CHECK_WEIGHT(0.0, "phrase-SGT"); - CHECK_WEIGHT(9.45117, "phrase-TGS"); - CHECK_WEIGHT(2.33833, "lexical-SGT"); - BOOST_CHECK(parsed.end() == parsed.find("lm")); - BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); - CHECK_WEIGHT(-28.3317, "lexical-TGS"); - CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search