mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 22:14:57 +03:00
Merge branch 'master' of github.com:moses-smt/mosesdecoder
This commit is contained in:
commit
90b251a50f
@ -40,6 +40,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/DummyScoreProducers.h"
|
||||
#include "moses/InputFileStream.h"
|
||||
#include "moses/Incremental.h"
|
||||
#include "moses/PhraseDictionary.h"
|
||||
#include "moses/ChartTrellisPathList.h"
|
||||
#include "moses/ChartTrellisPath.h"
|
||||
@ -377,8 +378,136 @@ void IOWrapper::OutputBestHypo(const ChartHypothesis *hypo, long translationId)
|
||||
m_singleBestOutputCollector->Write(translationId, out.str());
|
||||
}
|
||||
|
||||
void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const ChartHypothesis *bestHypo, const TranslationSystem* system, long translationId)
|
||||
{
|
||||
void IOWrapper::OutputBestHypo(search::Applied applied, long translationId) {
|
||||
if (!m_singleBestOutputCollector) return;
|
||||
std::ostringstream out;
|
||||
IOWrapper::FixPrecision(out);
|
||||
if (StaticData::Instance().GetOutputHypoScore()) {
|
||||
out << applied.GetScore() << ' ';
|
||||
}
|
||||
Phrase outPhrase;
|
||||
Incremental::ToPhrase(applied, outPhrase);
|
||||
// delete 1st & last
|
||||
CHECK(outPhrase.GetSize() >= 2);
|
||||
outPhrase.RemoveWord(0);
|
||||
outPhrase.RemoveWord(outPhrase.GetSize() - 1);
|
||||
out << outPhrase.GetStringRep(StaticData::Instance().GetOutputFactorOrder());
|
||||
out << '\n';
|
||||
m_singleBestOutputCollector->Write(translationId, out.str());
|
||||
}
|
||||
|
||||
void IOWrapper::OutputBestNone(long translationId) {
|
||||
if (!m_singleBestOutputCollector) return;
|
||||
if (StaticData::Instance().GetOutputHypoScore()) {
|
||||
m_singleBestOutputCollector->Write(translationId, "0 \n");
|
||||
} else {
|
||||
m_singleBestOutputCollector->Write(translationId, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void OutputSparseFeatureScores(std::ostream& out, const ScoreComponentCollection &features, const FeatureFunction *ff, std::string &lastName) {
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
bool labeledOutput = staticData.IsLabeledNBestList();
|
||||
const FVector scores = features.GetVectorForProducer( ff );
|
||||
|
||||
// report weighted aggregate
|
||||
if (! ff->GetSparseFeatureReporting()) {
|
||||
const FVector &weights = staticData.GetAllWeights().GetScoresVector();
|
||||
if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":"))
|
||||
out << " " << ff->GetScoreProducerWeightShortName() << ":";
|
||||
out << " " << scores.inner_product(weights);
|
||||
}
|
||||
|
||||
// report each feature
|
||||
else {
|
||||
for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
|
||||
if (i->second != 0) { // do not report zero-valued features
|
||||
if (labeledOutput)
|
||||
out << " " << i->first << ":";
|
||||
out << " " << i->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void WriteFeatures(const TranslationSystem &system, const ScoreComponentCollection &features, std::ostream &out) {
|
||||
bool labeledOutput = StaticData::Instance().IsLabeledNBestList();
|
||||
// lm
|
||||
const LMList& lml = system.GetLanguageModels();
|
||||
if (lml.size() > 0) {
|
||||
if (labeledOutput)
|
||||
out << "lm:";
|
||||
LMList::const_iterator lmi = lml.begin();
|
||||
for (; lmi != lml.end(); ++lmi) {
|
||||
out << " " << features.GetScoreForProducer(*lmi);
|
||||
}
|
||||
}
|
||||
|
||||
std::string lastName = "";
|
||||
|
||||
// output stateful sparse features
|
||||
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
|
||||
for( size_t i=0; i<sff.size(); i++ )
|
||||
if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited)
|
||||
OutputSparseFeatureScores(out, features, sff[i], lastName);
|
||||
|
||||
// translation components
|
||||
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
|
||||
if (pds.size() > 0) {
|
||||
for( size_t i=0; i<pds.size(); i++ ) {
|
||||
size_t pd_numinputscore = pds[i]->GetNumInputScores();
|
||||
vector<float> scores = features.GetScoresForProducer( pds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = pds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// word penalty
|
||||
if (labeledOutput)
|
||||
out << " w:";
|
||||
out << " " << features.GetScoreForProducer(system.GetWordPenaltyProducer());
|
||||
|
||||
// generation
|
||||
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
|
||||
if (gds.size() > 0) {
|
||||
for( size_t i=0; i<gds.size(); i++ ) {
|
||||
size_t pd_numinputscore = gds[i]->GetNumInputScores();
|
||||
vector<float> scores = features.GetScoresForProducer( gds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = gds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output stateless sparse features
|
||||
lastName = "";
|
||||
|
||||
const vector<const StatelessFeatureFunction*>& slf = system.GetStatelessFeatureFunctions();
|
||||
for( size_t i=0; i<slf.size(); i++ ) {
|
||||
if (slf[i]->GetNumScoreComponents() == ScoreProducer::unlimited) {
|
||||
OutputSparseFeatureScores(out, features, slf[i], lastName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const TranslationSystem* system, long translationId) {
|
||||
std::ostringstream out;
|
||||
|
||||
// Check if we're writing to std::cout.
|
||||
@ -387,17 +516,10 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
// preserve existing behaviour, but should probably be done either way.
|
||||
IOWrapper::FixPrecision(out);
|
||||
|
||||
// The output from -output-hypo-score is always written to std::cout.
|
||||
if (StaticData::Instance().GetOutputHypoScore()) {
|
||||
if (bestHypo != NULL) {
|
||||
out << bestHypo->GetTotalScore() << " ";
|
||||
} else {
|
||||
out << "0 ";
|
||||
}
|
||||
}
|
||||
// Used to check StaticData's GetOutputHypoScore(), but it makes no sense with nbest output.
|
||||
}
|
||||
|
||||
bool labeledOutput = StaticData::Instance().IsLabeledNBestList();
|
||||
//bool includeAlignment = StaticData::Instance().NBestIncludesAlignment();
|
||||
bool includeWordAlignment = StaticData::Instance().PrintAlignmentInfoInNbest();
|
||||
|
||||
ChartTrellisPathList::const_iterator iter;
|
||||
@ -421,75 +543,7 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
// before each model type, the corresponding command-line-like name must be emitted
|
||||
// MERT script relies on this
|
||||
|
||||
// lm
|
||||
const LMList& lml = system->GetLanguageModels();
|
||||
if (lml.size() > 0) {
|
||||
if (labeledOutput)
|
||||
out << "lm:";
|
||||
LMList::const_iterator lmi = lml.begin();
|
||||
for (; lmi != lml.end(); ++lmi) {
|
||||
out << " " << path.GetScoreBreakdown().GetScoreForProducer(*lmi);
|
||||
}
|
||||
}
|
||||
|
||||
std::string lastName = "";
|
||||
|
||||
// output stateful sparse features
|
||||
const vector<const StatefulFeatureFunction*>& sff = system->GetStatefulFeatureFunctions();
|
||||
for( size_t i=0; i<sff.size(); i++ )
|
||||
if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited)
|
||||
OutputSparseFeatureScores( out, path, sff[i], lastName );
|
||||
|
||||
// translation components
|
||||
const vector<PhraseDictionaryFeature*>& pds = system->GetPhraseDictionaries();
|
||||
if (pds.size() > 0) {
|
||||
for( size_t i=0; i<pds.size(); i++ ) {
|
||||
size_t pd_numinputscore = pds[i]->GetNumInputScores();
|
||||
vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( pds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = pds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// word penalty
|
||||
if (labeledOutput)
|
||||
out << " w:";
|
||||
out << " " << path.GetScoreBreakdown().GetScoreForProducer(system->GetWordPenaltyProducer());
|
||||
|
||||
// generation
|
||||
const vector<GenerationDictionary*>& gds = system->GetGenerationDictionaries();
|
||||
if (gds.size() > 0) {
|
||||
for( size_t i=0; i<gds.size(); i++ ) {
|
||||
size_t pd_numinputscore = gds[i]->GetNumInputScores();
|
||||
vector<float> scores = path.GetScoreBreakdown().GetScoresForProducer( gds[i] );
|
||||
for (size_t j = 0; j<scores.size(); ++j){
|
||||
if (labeledOutput && (i == 0) ){
|
||||
if ((j == 0) || (j == pd_numinputscore)){
|
||||
lastName = gds[i]->GetScoreProducerWeightShortName(j);
|
||||
out << " " << lastName << ":";
|
||||
}
|
||||
}
|
||||
out << " " << scores[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output stateless sparse features
|
||||
lastName = "";
|
||||
|
||||
const vector<const StatelessFeatureFunction*>& slf = system->GetStatelessFeatureFunctions();
|
||||
for( size_t i=0; i<slf.size(); i++ ) {
|
||||
if (slf[i]->GetNumScoreComponents() == ScoreProducer::unlimited) {
|
||||
OutputSparseFeatureScores( out, path, slf[i], lastName );
|
||||
}
|
||||
}
|
||||
WriteFeatures(*system, path.GetScoreBreakdown(), out);
|
||||
|
||||
// total
|
||||
out << " ||| " << path.GetTotalScore();
|
||||
@ -528,34 +582,33 @@ void IOWrapper::OutputNBestList(const ChartTrellisPathList &nBestList, const Cha
|
||||
|
||||
out <<std::flush;
|
||||
|
||||
CHECK(m_nBestOutputCollector);
|
||||
assert(m_nBestOutputCollector);
|
||||
m_nBestOutputCollector->Write(translationId, out.str());
|
||||
}
|
||||
|
||||
void IOWrapper::OutputSparseFeatureScores( std::ostream& out, const ChartTrellisPath &path, const FeatureFunction *ff, std::string &lastName )
|
||||
{
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
bool labeledOutput = staticData.IsLabeledNBestList();
|
||||
const FVector scores = path.GetScoreBreakdown().GetVectorForProducer( ff );
|
||||
|
||||
// report weighted aggregate
|
||||
if (! ff->GetSparseFeatureReporting()) {
|
||||
const FVector &weights = staticData.GetAllWeights().GetScoresVector();
|
||||
if (labeledOutput && !boost::contains(ff->GetScoreProducerDescription(), ":"))
|
||||
out << " " << ff->GetScoreProducerWeightShortName() << ":";
|
||||
out << " " << scores.inner_product(weights);
|
||||
void IOWrapper::OutputNBestList(const std::vector<search::Applied> &nbest, const TranslationSystem &system, long translationId) {
|
||||
std::ostringstream out;
|
||||
// wtf? copied from the original OutputNBestList
|
||||
if (m_nBestOutputCollector->OutputIsCout()) {
|
||||
IOWrapper::FixPrecision(out);
|
||||
}
|
||||
|
||||
// report each feature
|
||||
else {
|
||||
for(FVector::FNVmap::const_iterator i = scores.cbegin(); i != scores.cend(); i++) {
|
||||
if (i->second != 0) { // do not report zero-valued features
|
||||
if (labeledOutput)
|
||||
out << " " << i->first << ":";
|
||||
out << " " << i->second;
|
||||
}
|
||||
}
|
||||
Phrase outputPhrase;
|
||||
ScoreComponentCollection features;
|
||||
for (std::vector<search::Applied>::const_iterator i = nbest.begin(); i != nbest.end(); ++i) {
|
||||
Incremental::PhraseAndFeatures(system, *i, outputPhrase, features);
|
||||
// <s> and </s>
|
||||
CHECK(outputPhrase.GetSize() >= 2);
|
||||
outputPhrase.RemoveWord(0);
|
||||
outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
|
||||
out << translationId << " ||| ";
|
||||
OutputSurface(out, outputPhrase, m_outputFactorOrder, false);
|
||||
out << " ||| ";
|
||||
WriteFeatures(system, features, out);
|
||||
out << " ||| " << i->GetScore() << '\n';
|
||||
}
|
||||
out << std::flush;
|
||||
assert(m_nBestOutputCollector);
|
||||
m_nBestOutputCollector->Write(translationId, out.str());
|
||||
}
|
||||
|
||||
void IOWrapper::FixPrecision(std::ostream &stream, size_t size)
|
||||
|
@ -45,6 +45,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
#include "moses/OutputCollector.h"
|
||||
#include "moses/ChartHypothesis.h"
|
||||
#include "moses/ChartTrellisPath.h"
|
||||
#include "search/applied.hh"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
@ -94,14 +95,14 @@ public:
|
||||
|
||||
Moses::InputType* GetInput(Moses::InputType *inputType);
|
||||
void OutputBestHypo(const Moses::ChartHypothesis *hypo, long translationId);
|
||||
void OutputBestHypo(search::Applied applied, long translationId);
|
||||
void OutputBestHypo(const std::vector<const Moses::Factor*>& mbrBestHypo, long translationId);
|
||||
void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::ChartHypothesis *bestHypo, const Moses::TranslationSystem* system, long translationId);
|
||||
void OutputSparseFeatureScores(std::ostream& out, const Moses::ChartTrellisPath &path, const Moses::FeatureFunction *ff, std::string &lastName);
|
||||
void OutputBestNone(long translationId);
|
||||
void OutputNBestList(const Moses::ChartTrellisPathList &nBestList, const Moses::TranslationSystem* system, long translationId);
|
||||
void OutputNBestList(const std::vector<search::Applied> &nbest, const Moses::TranslationSystem &system, long translationId);
|
||||
void OutputDetailedTranslationReport(const Moses::ChartHypothesis *hypo, const Moses::Sentence &sentence, long translationId);
|
||||
void Backtrack(const Moses::ChartHypothesis *hypo);
|
||||
|
||||
Moses::OutputCollector *ExposeSingleBest() { return m_singleBestOutputCollector; }
|
||||
|
||||
void ResetTranslationId();
|
||||
|
||||
Moses::OutputCollector *GetSearchGraphOutputCollector() {
|
||||
|
@ -59,7 +59,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
||||
#include "moses/ChartHypothesis.h"
|
||||
#include "moses/ChartTrellisPath.h"
|
||||
#include "moses/ChartTrellisPathList.h"
|
||||
#include "moses/Incremental/Manager.h"
|
||||
#include "moses/Incremental.h"
|
||||
|
||||
#include "util/usage.hh"
|
||||
|
||||
@ -91,10 +91,14 @@ public:
|
||||
|
||||
if (staticData.GetSearchAlgorithm() == ChartIncremental) {
|
||||
Incremental::Manager manager(*m_source, system);
|
||||
manager.ProcessSentence();
|
||||
if (m_ioWrapper.ExposeSingleBest()) {
|
||||
m_ioWrapper.ExposeSingleBest()->Write(translationId, manager.String() + '\n');
|
||||
const std::vector<search::Applied> &nbest = manager.ProcessSentence();
|
||||
if (!nbest.empty()) {
|
||||
m_ioWrapper.OutputBestHypo(nbest[0], translationId);
|
||||
} else {
|
||||
m_ioWrapper.OutputBestNone(translationId);
|
||||
}
|
||||
if (staticData.GetNBestSize() > 0)
|
||||
m_ioWrapper.OutputNBestList(nbest, system, translationId);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -125,7 +129,7 @@ public:
|
||||
VERBOSE(2,"WRITING " << nBestSize << " TRANSLATION ALTERNATIVES TO " << staticData.GetNBestFilePath() << endl);
|
||||
ChartTrellisPathList nBestList;
|
||||
manager.CalcNBest(nBestSize, nBestList,staticData.GetDistinctNBest());
|
||||
m_ioWrapper.OutputNBestList(nBestList, bestHypo, &system, translationId);
|
||||
m_ioWrapper.OutputNBestList(nBestList, &system, translationId);
|
||||
IFVERBOSE(2) {
|
||||
PrintUserTime("N-Best Hypotheses Generation Time:");
|
||||
}
|
||||
|
@ -23,7 +23,7 @@
|
||||
#include "Word.h"
|
||||
#include "WordsRange.h"
|
||||
|
||||
namespace search { class Vertex; class VertexGenerator; }
|
||||
namespace search { class Vertex; }
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
@ -41,7 +41,7 @@ class ChartCellLabel
|
||||
union Stack {
|
||||
const HypoList *cube; // cube pruning
|
||||
const search::Vertex *incr; // incremental search after filling.
|
||||
search::VertexGenerator *incr_generator; // incremental search during filling.
|
||||
void *incr_generator; // incremental search during filling.
|
||||
};
|
||||
|
||||
|
||||
|
@ -401,8 +401,8 @@ TargetPhraseVectorPtr PhraseDecoder::DecodeCollection(
|
||||
if(m_phraseDictionary.m_useAlignmentInfo)
|
||||
{
|
||||
// reconstruct the alignment data based on the alignment of the subphrase
|
||||
for(AlignmentInfo::const_iterator it = subTp.GetAlignNonTerm().begin();
|
||||
it != subTp.GetAlignNonTerm().end(); it++)
|
||||
for(AlignmentInfo::const_iterator it = subTp.GetAlignTerm().begin();
|
||||
it != subTp.GetAlignTerm().end(); it++)
|
||||
{
|
||||
alignment.insert(AlignPointSizeT(srcStart + it->first,
|
||||
targetPhrase->GetSize() + it->second));
|
||||
@ -455,8 +455,9 @@ TargetPhraseVectorPtr PhraseDecoder::DecodeCollection(
|
||||
|
||||
if(state == Add)
|
||||
{
|
||||
if(m_phraseDictionary.m_useAlignmentInfo)
|
||||
if(m_phraseDictionary.m_useAlignmentInfo) {
|
||||
targetPhrase->SetAlignTerm(alignment);
|
||||
}
|
||||
|
||||
if(m_coding == PREnc)
|
||||
{
|
||||
|
@ -51,7 +51,7 @@ bool PhraseDictionaryCompact::Load(const std::vector<FactorType> &input
|
||||
{
|
||||
m_input = &input;
|
||||
m_output = &output;
|
||||
m_weight = &weight;
|
||||
m_weight = new std::vector<float>(weight);
|
||||
m_tableLimit = tableLimit;
|
||||
m_languageModels = &languageModels;
|
||||
m_weightWP = weightWP;
|
||||
|
296
moses/Incremental.cpp
Normal file
296
moses/Incremental.cpp
Normal file
@ -0,0 +1,296 @@
|
||||
#include "moses/Incremental.h"
|
||||
|
||||
#include "moses/ChartCell.h"
|
||||
#include "moses/ChartParserCallback.h"
|
||||
#include "moses/FeatureVector.h"
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/TranslationSystem.h"
|
||||
#include "moses/Util.h"
|
||||
|
||||
#include "lm/model.hh"
|
||||
#include "search/applied.hh"
|
||||
#include "search/config.hh"
|
||||
#include "search/context.hh"
|
||||
#include "search/edge_generator.hh"
|
||||
#include "search/rule.hh"
|
||||
#include "search/vertex_generator.hh"
|
||||
|
||||
#include <boost/lexical_cast.hpp>
|
||||
|
||||
namespace Moses {
|
||||
namespace Incremental {
|
||||
namespace {
|
||||
|
||||
// This is called by EdgeGenerator. Route hypotheses to separate vertices for
|
||||
// each left hand side label, populating ChartCellLabelSet out.
|
||||
template <class Best> class HypothesisCallback {
|
||||
private:
|
||||
typedef search::VertexGenerator<Best> Gen;
|
||||
public:
|
||||
HypothesisCallback(search::ContextBase &context, Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool)
|
||||
: context_(context), best_(best), out_(out), vertex_pool_(vertex_pool) {}
|
||||
|
||||
void NewHypothesis(search::PartialEdge partial) {
|
||||
// Get the LHS, look it up in the output ChartCellLabel, and upcast it.
|
||||
// It's not part of the union because it would have been ugly to expose template types in ChartCellLabel.
|
||||
ChartCellLabel::Stack &stack = out_.FindOrInsert(static_cast<const TargetPhrase *>(partial.GetNote().vp)->GetTargetLHS());
|
||||
Gen *entry = static_cast<Gen*>(stack.incr_generator);
|
||||
if (!entry) {
|
||||
entry = generator_pool_.construct(context_, *vertex_pool_.construct(), best_);
|
||||
stack.incr_generator = entry;
|
||||
}
|
||||
entry->NewHypothesis(partial);
|
||||
}
|
||||
|
||||
void FinishedSearch() {
|
||||
for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
|
||||
ChartCellLabel::Stack &stack = i->second.MutableStack();
|
||||
Gen *gen = static_cast<Gen*>(stack.incr_generator);
|
||||
gen->FinishedSearch();
|
||||
stack.incr = &gen->Generating();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
search::ContextBase &context_;
|
||||
|
||||
Best &best_;
|
||||
|
||||
ChartCellLabelSet &out_;
|
||||
|
||||
boost::object_pool<search::Vertex> &vertex_pool_;
|
||||
boost::object_pool<Gen> generator_pool_;
|
||||
};
|
||||
|
||||
// This is called by the moses parser to collect hypotheses. It converts to my
|
||||
// edges (search::PartialEdge).
|
||||
template <class Model> class Fill : public ChartParserCallback {
|
||||
public:
|
||||
Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping, search::Score oov_weight)
|
||||
: context_(context), vocab_mapping_(vocab_mapping), oov_weight_(oov_weight) {}
|
||||
|
||||
void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored);
|
||||
|
||||
void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
|
||||
|
||||
bool Empty() const { return edges_.Empty(); }
|
||||
|
||||
template <class Best> void Search(Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool) {
|
||||
HypothesisCallback<Best> callback(context_, best, out, vertex_pool);
|
||||
edges_.Search(context_, callback);
|
||||
}
|
||||
|
||||
// Root: everything into one vertex.
|
||||
template <class Best> search::History RootSearch(Best &best) {
|
||||
search::Vertex vertex;
|
||||
search::RootVertexGenerator<Best> gen(vertex, best);
|
||||
edges_.Search(context_, gen);
|
||||
return vertex.BestChild();
|
||||
}
|
||||
|
||||
private:
|
||||
lm::WordIndex Convert(const Word &word) const;
|
||||
|
||||
search::Context<Model> &context_;
|
||||
|
||||
const std::vector<lm::WordIndex> &vocab_mapping_;
|
||||
|
||||
search::EdgeGenerator edges_;
|
||||
|
||||
const search::Score oov_weight_;
|
||||
};
|
||||
|
||||
template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &) {
|
||||
std::vector<search::PartialVertex> vertices;
|
||||
vertices.reserve(nts.size());
|
||||
float below_score = 0.0;
|
||||
for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
|
||||
vertices.push_back((*i)->GetStack().incr->RootPartial());
|
||||
if (vertices.back().Empty()) return;
|
||||
below_score += vertices.back().Bound();
|
||||
}
|
||||
|
||||
std::vector<lm::WordIndex> words;
|
||||
for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) {
|
||||
words.clear();
|
||||
const TargetPhrase &phrase = **p;
|
||||
const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap();
|
||||
search::PartialEdge edge(edges_.AllocateEdge(nts.size()));
|
||||
|
||||
search::PartialVertex *nt = edge.NT();
|
||||
for (size_t i = 0; i < phrase.GetSize(); ++i) {
|
||||
const Word &word = phrase.GetWord(i);
|
||||
if (word.IsNonTerminal()) {
|
||||
*(nt++) = vertices[align[i]];
|
||||
words.push_back(search::kNonTerminal);
|
||||
} else {
|
||||
words.push_back(Convert(word));
|
||||
}
|
||||
}
|
||||
|
||||
edge.SetScore(phrase.GetFutureScore() + below_score);
|
||||
// prob and oov were already accounted for.
|
||||
search::ScoreRule(context_.LanguageModel(), words, edge.Between());
|
||||
|
||||
search::Note note;
|
||||
note.vp = &phrase;
|
||||
edge.SetNote(note);
|
||||
|
||||
edges_.AddEdge(edge);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &, const WordsRange &) {
|
||||
std::vector<lm::WordIndex> words;
|
||||
CHECK(phrase.GetSize() <= 1);
|
||||
if (phrase.GetSize())
|
||||
words.push_back(Convert(phrase.GetWord(0)));
|
||||
|
||||
search::PartialEdge edge(edges_.AllocateEdge(0));
|
||||
// Appears to be a bug that FutureScore does not already include language model.
|
||||
search::ScoreRuleRet scored(search::ScoreRule(context_.LanguageModel(), words, edge.Between()));
|
||||
edge.SetScore(phrase.GetFutureScore() + scored.prob * context_.LMWeight() + static_cast<search::Score>(scored.oov) * oov_weight_);
|
||||
|
||||
search::Note note;
|
||||
note.vp = &phrase;
|
||||
edge.SetNote(note);
|
||||
|
||||
edges_.AddEdge(edge);
|
||||
}
|
||||
|
||||
// TODO: factors (but chart doesn't seem to support factors anyway).
|
||||
template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const {
|
||||
std::size_t factor = word.GetFactor(0)->GetId();
|
||||
return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]);
|
||||
}
|
||||
|
||||
struct ChartCellBaseFactory {
|
||||
ChartCellBase *operator()(size_t startPos, size_t endPos) const {
|
||||
return new ChartCellBase(startPos, endPos);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Manager::Manager(const InputType &source, const TranslationSystem &system) :
|
||||
source_(source),
|
||||
system_(system),
|
||||
cells_(source, ChartCellBaseFactory()),
|
||||
parser_(source, system, cells_),
|
||||
n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize())) {}
|
||||
|
||||
Manager::~Manager() {
|
||||
system_.CleanUpAfterSentenceProcessing(source_);
|
||||
}
|
||||
|
||||
template <class Model, class Best> search::History Manager::PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out) {
|
||||
const LanguageModel &abstract = **system_.GetLanguageModels().begin();
|
||||
const float oov_weight = abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0;
|
||||
const StaticData &data = StaticData::Instance();
|
||||
search::Config config(abstract.GetWeight(), data.GetCubePruningPopLimit(), search::NBestConfig(data.GetNBestSize()));
|
||||
search::Context<Model> context(config, model);
|
||||
|
||||
size_t size = source_.GetSize();
|
||||
boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
|
||||
|
||||
for (size_t width = 1; width < size; ++width) {
|
||||
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
|
||||
WordsRange range(startPos, startPos + width - 1);
|
||||
Fill<Model> filler(context, words, oov_weight);
|
||||
parser_.Create(range, filler);
|
||||
filler.Search(out, cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool);
|
||||
}
|
||||
}
|
||||
|
||||
WordsRange range(0, size - 1);
|
||||
Fill<Model> filler(context, words, oov_weight);
|
||||
parser_.Create(range, filler);
|
||||
return filler.RootSearch(out);
|
||||
}
|
||||
|
||||
template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words) {
|
||||
std::size_t nbest = StaticData::Instance().GetNBestSize();
|
||||
if (nbest <= 1) {
|
||||
search::History ret = PopulateBest(model, words, single_best_);
|
||||
if (ret) {
|
||||
backing_for_single_.resize(1);
|
||||
backing_for_single_[0] = search::Applied(ret);
|
||||
} else {
|
||||
backing_for_single_.clear();
|
||||
}
|
||||
completed_nbest_ = &backing_for_single_;
|
||||
} else {
|
||||
search::History ret = PopulateBest(model, words, n_best_);
|
||||
if (ret) {
|
||||
completed_nbest_ = &n_best_.Extract(ret);
|
||||
} else {
|
||||
backing_for_single_.clear();
|
||||
completed_nbest_ = &backing_for_single_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::RestProbingModel>(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::TrieModel>(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
|
||||
const std::vector<search::Applied> &Manager::ProcessSentence() {
|
||||
const LMList &lms = system_.GetLanguageModels();
|
||||
UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model.");
|
||||
(*lms.begin())->IncrementalCallback(*this);
|
||||
return *completed_nbest_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct NoOp {
|
||||
void operator()(const TargetPhrase &) const {}
|
||||
};
|
||||
struct AccumScore {
|
||||
AccumScore(ScoreComponentCollection &out) : out_(&out) {}
|
||||
void operator()(const TargetPhrase &phrase) {
|
||||
out_->PlusEquals(phrase.GetScoreBreakdown());
|
||||
}
|
||||
ScoreComponentCollection *out_;
|
||||
};
|
||||
template <class Action> void AppendToPhrase(const search::Applied final, Phrase &out, Action action) {
|
||||
assert(final.Valid());
|
||||
const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(final.GetNote().vp);
|
||||
action(phrase);
|
||||
const search::Applied *child = final.Children();
|
||||
for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
|
||||
const Word &word = phrase.GetWord(i);
|
||||
if (word.IsNonTerminal()) {
|
||||
AppendToPhrase(*child++, out, action);
|
||||
} else {
|
||||
out.AddWord(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ToPhrase(const search::Applied final, Phrase &out) {
|
||||
out.Clear();
|
||||
AppendToPhrase(final, out, NoOp());
|
||||
}
|
||||
|
||||
void PhraseAndFeatures(const TranslationSystem &system, const search::Applied final, Phrase &phrase, ScoreComponentCollection &features) {
|
||||
phrase.Clear();
|
||||
features.ZeroAll();
|
||||
AppendToPhrase(final, phrase, AccumScore(features));
|
||||
|
||||
// If we made it this far, there is only one language model.
|
||||
float full, ignored_ngram;
|
||||
std::size_t ignored_oov;
|
||||
const LanguageModel &model = **system.GetLanguageModels().begin();
|
||||
model.CalcScore(phrase, full, ignored_ngram, ignored_oov);
|
||||
// CalcScore transforms, but EvaluateChart doesn't.
|
||||
features.Assign(&model, UntransformLMScore(full));
|
||||
}
|
||||
|
||||
} // namespace Incremental
|
||||
} // namespace Moses
|
60
moses/Incremental.h
Normal file
60
moses/Incremental.h
Normal file
@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include "lm/word_index.hh"
|
||||
#include "search/applied.hh"
|
||||
#include "search/nbest.hh"
|
||||
|
||||
#include "moses/ChartCellCollection.h"
|
||||
#include "moses/ChartParser.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace Moses {
|
||||
class ScoreComponentCollection;
|
||||
class InputType;
|
||||
class TranslationSystem;
|
||||
namespace Incremental {
|
||||
|
||||
class Manager {
|
||||
public:
|
||||
Manager(const InputType &source, const TranslationSystem &system);
|
||||
|
||||
~Manager();
|
||||
|
||||
template <class Model> void LMCallback(const Model &model, const std::vector<lm::WordIndex> &words);
|
||||
|
||||
const std::vector<search::Applied> &ProcessSentence();
|
||||
|
||||
// Call to get the same value as ProcessSentence returned.
|
||||
const std::vector<search::Applied> &Completed() const {
|
||||
return *completed_nbest_;
|
||||
}
|
||||
|
||||
private:
|
||||
template <class Model, class Best> search::History PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out);
|
||||
|
||||
const InputType &source_;
|
||||
const TranslationSystem &system_;
|
||||
ChartCellCollectionBase cells_;
|
||||
ChartParser parser_;
|
||||
|
||||
// Only one of single_best_ or n_best_ will be used, but it was easier to do this than a template.
|
||||
search::SingleBest single_best_;
|
||||
// ProcessSentence returns a reference to a vector. ProcessSentence
|
||||
// doesn't have one, so this is populated and returned.
|
||||
std::vector<search::Applied> backing_for_single_;
|
||||
|
||||
search::NBest n_best_;
|
||||
|
||||
const std::vector<search::Applied> *completed_nbest_;
|
||||
};
|
||||
|
||||
// Just get the phrase.
|
||||
void ToPhrase(const search::Applied final, Phrase &out);
|
||||
// Get the phrase and the features.
|
||||
void PhraseAndFeatures(const TranslationSystem &system, const search::Applied final, Phrase &phrase, ScoreComponentCollection &features);
|
||||
|
||||
} // namespace Incremental
|
||||
} // namespace Moses
|
||||
|
@ -1,143 +0,0 @@
|
||||
#include "Fill.h"
|
||||
|
||||
#include "moses/ChartCellLabel.h"
|
||||
#include "moses/ChartCellLabelSet.h"
|
||||
#include "moses/TargetPhraseCollection.h"
|
||||
#include "moses/TargetPhrase.h"
|
||||
#include "moses/Word.h"
|
||||
|
||||
#include "lm/model.hh"
|
||||
#include "search/context.hh"
|
||||
#include "search/note.hh"
|
||||
#include "search/rule.hh"
|
||||
#include "search/vertex.hh"
|
||||
#include "search/vertex_generator.hh"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
namespace Moses {
|
||||
namespace Incremental {
|
||||
|
||||
template <class Model> Fill<Model>::Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping)
|
||||
: context_(context), vocab_mapping_(vocab_mapping) {}
|
||||
|
||||
template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &) {
|
||||
std::vector<search::PartialVertex> vertices;
|
||||
vertices.reserve(nts.size());
|
||||
float below_score = 0.0;
|
||||
for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
|
||||
vertices.push_back((*i)->GetStack().incr->RootPartial());
|
||||
if (vertices.back().Empty()) return;
|
||||
below_score += vertices.back().Bound();
|
||||
}
|
||||
|
||||
std::vector<lm::WordIndex> words;
|
||||
for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) {
|
||||
words.clear();
|
||||
const TargetPhrase &phrase = **p;
|
||||
const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap();
|
||||
search::PartialEdge edge(edges_.AllocateEdge(nts.size()));
|
||||
|
||||
size_t i = 0;
|
||||
bool bos = false;
|
||||
search::PartialVertex *nt = edge.NT();
|
||||
if (phrase.GetSize() && !phrase.GetWord(0).IsNonTerminal()) {
|
||||
lm::WordIndex index = Convert(phrase.GetWord(0));
|
||||
if (context_.LanguageModel().GetVocabulary().BeginSentence() == index) {
|
||||
bos = true;
|
||||
} else {
|
||||
words.push_back(index);
|
||||
}
|
||||
i = 1;
|
||||
}
|
||||
for (; i < phrase.GetSize(); ++i) {
|
||||
const Word &word = phrase.GetWord(i);
|
||||
if (word.IsNonTerminal()) {
|
||||
*(nt++) = vertices[align[i]];
|
||||
words.push_back(search::kNonTerminal);
|
||||
} else {
|
||||
words.push_back(Convert(word));
|
||||
}
|
||||
}
|
||||
|
||||
edge.SetScore(phrase.GetFutureScore() + below_score);
|
||||
search::ScoreRule(context_, words, bos, edge.Between());
|
||||
|
||||
search::Note note;
|
||||
note.vp = &phrase;
|
||||
edge.SetNote(note);
|
||||
|
||||
edges_.AddEdge(edge);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &, const WordsRange &) {
|
||||
std::vector<lm::WordIndex> words;
|
||||
CHECK(phrase.GetSize() <= 1);
|
||||
if (phrase.GetSize())
|
||||
words.push_back(Convert(phrase.GetWord(0)));
|
||||
|
||||
search::PartialEdge edge(edges_.AllocateEdge(0));
|
||||
// Appears to be a bug that FutureScore does not already include language model.
|
||||
edge.SetScore(phrase.GetFutureScore() + search::ScoreRule(context_, words, false, edge.Between()));
|
||||
|
||||
search::Note note;
|
||||
note.vp = &phrase;
|
||||
edge.SetNote(note);
|
||||
|
||||
edges_.AddEdge(edge);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Route hypotheses to separate vertices for each left hand side label, populating ChartCellLabelSet out.
|
||||
class HypothesisCallback {
|
||||
public:
|
||||
HypothesisCallback(search::ContextBase &context, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool)
|
||||
: context_(context), out_(out), vertex_pool_(vertex_pool) {}
|
||||
|
||||
void NewHypothesis(search::PartialEdge partial) {
|
||||
search::VertexGenerator *&entry = out_.FindOrInsert(static_cast<const TargetPhrase *>(partial.GetNote().vp)->GetTargetLHS()).incr_generator;
|
||||
if (!entry) {
|
||||
entry = generator_pool_.construct(context_, *vertex_pool_.construct());
|
||||
}
|
||||
entry->NewHypothesis(partial);
|
||||
}
|
||||
|
||||
void FinishedSearch() {
|
||||
for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
|
||||
ChartCellLabel::Stack &stack = i->second.MutableStack();
|
||||
stack.incr_generator->FinishedSearch();
|
||||
stack.incr = &stack.incr_generator->Generating();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
search::ContextBase &context_;
|
||||
|
||||
ChartCellLabelSet &out_;
|
||||
|
||||
boost::object_pool<search::Vertex> &vertex_pool_;
|
||||
boost::object_pool<search::VertexGenerator> generator_pool_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <class Model> void Fill<Model>::Search(ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool) {
|
||||
HypothesisCallback callback(context_, out, vertex_pool);
|
||||
edges_.Search(context_, callback);
|
||||
}
|
||||
|
||||
// TODO: factors (but chart doesn't seem to support factors anyway).
|
||||
template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const {
|
||||
std::size_t factor = word.GetFactor(0)->GetId();
|
||||
return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]);
|
||||
}
|
||||
|
||||
template class Fill<lm::ngram::ProbingModel>;
|
||||
template class Fill<lm::ngram::RestProbingModel>;
|
||||
template class Fill<lm::ngram::TrieModel>;
|
||||
template class Fill<lm::ngram::QuantTrieModel>;
|
||||
template class Fill<lm::ngram::ArrayTrieModel>;
|
||||
template class Fill<lm::ngram::QuantArrayTrieModel>;
|
||||
|
||||
} // namespace Incremental
|
||||
} // namespace Moses
|
@ -1,54 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "moses/ChartParserCallback.h"
|
||||
#include "moses/StackVec.h"
|
||||
|
||||
#include "lm/word_index.hh"
|
||||
#include "search/edge_generator.hh"
|
||||
|
||||
#include <boost/pool/object_pool.hpp>
|
||||
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
namespace search {
|
||||
template <class Model> class Context;
|
||||
class Vertex;
|
||||
} // namespace search
|
||||
|
||||
namespace Moses {
|
||||
class Word;
|
||||
class WordsRange;
|
||||
class TargetPhraseCollection;
|
||||
class WordsRange;
|
||||
class ChartCellLabelSet;
|
||||
class TargetPhrase;
|
||||
|
||||
namespace Incremental {
|
||||
|
||||
// Replacement for ChartTranslationOptionList
|
||||
// TODO: implement count and score thresholding.
|
||||
template <class Model> class Fill : public ChartParserCallback {
|
||||
public:
|
||||
Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping);
|
||||
|
||||
void Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &ignored);
|
||||
|
||||
void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
|
||||
|
||||
bool Empty() const { return edges_.Empty(); }
|
||||
|
||||
void Search(ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool);
|
||||
|
||||
private:
|
||||
lm::WordIndex Convert(const Word &word) const ;
|
||||
|
||||
search::Context<Model> &context_;
|
||||
|
||||
const std::vector<lm::WordIndex> &vocab_mapping_;
|
||||
|
||||
search::EdgeGenerator edges_;
|
||||
};
|
||||
|
||||
} // namespace Incremental
|
||||
} // namespace Moses
|
@ -1,122 +0,0 @@
|
||||
#include "Manager.h"
|
||||
|
||||
#include "Fill.h"
|
||||
|
||||
#include "moses/ChartCell.h"
|
||||
#include "moses/TranslationSystem.h"
|
||||
#include "moses/StaticData.h"
|
||||
|
||||
#include "search/context.hh"
|
||||
#include "search/config.hh"
|
||||
#include "search/weights.hh"
|
||||
|
||||
#include <boost/lexical_cast.hpp>
|
||||
|
||||
namespace Moses {
|
||||
namespace Incremental {
|
||||
|
||||
namespace {
|
||||
struct ChartCellBaseFactory {
|
||||
ChartCellBase *operator()(size_t startPos, size_t endPos) const {
|
||||
return new ChartCellBase(startPos, endPos);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Manager::Manager(const InputType &source, const TranslationSystem &system) :
|
||||
source_(source),
|
||||
system_(system),
|
||||
cells_(source, ChartCellBaseFactory()),
|
||||
parser_(source, system, cells_) {
|
||||
|
||||
}
|
||||
|
||||
Manager::~Manager() {
|
||||
system_.CleanUpAfterSentenceProcessing(source_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void ConstructString(const search::Final final, std::ostringstream &stream) {
|
||||
assert(final.Valid());
|
||||
const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(final.GetNote().vp);
|
||||
size_t child = 0;
|
||||
for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
|
||||
const Word &word = phrase.GetWord(i);
|
||||
if (word.IsNonTerminal()) {
|
||||
assert(child < final.GetArity());
|
||||
ConstructString(final.Children()[child++], stream);
|
||||
} else {
|
||||
stream << word[0]->GetString() << ' ';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BestString(const ChartCellLabelSet &labels, std::string &out) {
|
||||
search::Final best;
|
||||
for (ChartCellLabelSet::const_iterator i = labels.begin(); i != labels.end(); ++i) {
|
||||
const search::Final child(i->second.GetStack().incr->BestChild());
|
||||
if (child.Valid() && (!best.Valid() || (child.GetScore() > best.GetScore()))) {
|
||||
best = child;
|
||||
}
|
||||
}
|
||||
if (!best.Valid()) {
|
||||
out.clear();
|
||||
return;
|
||||
}
|
||||
std::ostringstream stream;
|
||||
ConstructString(best, stream);
|
||||
out = stream.str();
|
||||
CHECK(out.size() > 9);
|
||||
// <s>
|
||||
out.erase(0, 4);
|
||||
// </s>
|
||||
out.erase(out.size() - 5);
|
||||
// Hack: include model score
|
||||
out += " ||| ";
|
||||
out += boost::lexical_cast<std::string>(best.GetScore());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words) {
|
||||
const LanguageModel &abstract = **system_.GetLanguageModels().begin();
|
||||
search::Weights weights(
|
||||
abstract.GetWeight(),
|
||||
abstract.OOVFeatureEnabled() ? abstract.GetOOVWeight() : 0.0,
|
||||
system_.GetWeightWordPenalty());
|
||||
search::Config config(weights, StaticData::Instance().GetCubePruningPopLimit());
|
||||
search::Context<Model> context(config, model);
|
||||
|
||||
size_t size = source_.GetSize();
|
||||
|
||||
boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
|
||||
|
||||
for (size_t width = 1; width <= size; ++width) {
|
||||
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
|
||||
size_t endPos = startPos + width - 1;
|
||||
WordsRange range(startPos, endPos);
|
||||
Fill<Model> filler(context, words);
|
||||
parser_.Create(range, filler);
|
||||
filler.Search(cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool);
|
||||
}
|
||||
}
|
||||
BestString(cells_.GetBase(WordsRange(0, source_.GetSize() - 1)).GetTargetLabelSet(), output_);
|
||||
}
|
||||
|
||||
template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::RestProbingModel>(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::TrieModel>(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
|
||||
|
||||
void Manager::ProcessSentence() {
|
||||
const LMList &lms = system_.GetLanguageModels();
|
||||
UTIL_THROW_IF(lms.size() != 1, util::Exception, "Incremental search only supports one language model.");
|
||||
(*lms.begin())->IncrementalCallback(*this);
|
||||
}
|
||||
|
||||
} // namespace Incremental
|
||||
} // namespace Moses
|
@ -1,35 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "lm/word_index.hh"
|
||||
|
||||
#include "moses/ChartCellCollection.h"
|
||||
#include "moses/ChartParser.h"
|
||||
|
||||
namespace Moses {
|
||||
class InputType;
|
||||
class TranslationSystem;
|
||||
namespace Incremental {
|
||||
|
||||
class Manager {
|
||||
public:
|
||||
Manager(const InputType &source, const TranslationSystem &system);
|
||||
|
||||
~Manager();
|
||||
|
||||
template <class Model> void LMCallback(const Model &model, const std::vector<lm::WordIndex> &words);
|
||||
|
||||
void ProcessSentence();
|
||||
|
||||
const std::string &String() const { return output_; }
|
||||
|
||||
private:
|
||||
const InputType &source_;
|
||||
const TranslationSystem &system_;
|
||||
ChartCellCollectionBase cells_;
|
||||
ChartParser parser_;
|
||||
|
||||
std::string output_;
|
||||
};
|
||||
} // namespace Incremental
|
||||
} // namespace Moses
|
||||
|
@ -32,7 +32,6 @@ lib moses :
|
||||
CYKPlusParser/*.cpp
|
||||
RuleTable/*.cpp
|
||||
fuzzy-match/*.cpp
|
||||
Incremental/*.cpp
|
||||
: #exceptions
|
||||
ThreadPool.cpp
|
||||
SyntacticLanguageModel.cpp
|
||||
|
@ -38,7 +38,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
#include "moses/InputFileStream.h"
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/ChartHypothesis.h"
|
||||
#include "moses/Incremental/Manager.h"
|
||||
#include "moses/Incremental.h"
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
|
@ -119,7 +119,11 @@ while(<INI>) {
|
||||
print INI_OUT "2 $source_factor $t $w $new_name.bin$table_flag\n";
|
||||
}
|
||||
elsif ($binarizer && $phrase_table_impl == 0) {
|
||||
print INI_OUT "1 $source_factor $t $w $new_name$table_flag\n";
|
||||
if ($binarizer =~ /processPhraseTableMin/) {
|
||||
print INI_OUT "12 $source_factor $t $w $new_name$table_flag\n";
|
||||
} else {
|
||||
print INI_OUT "1 $source_factor $t $w $new_name$table_flag\n";
|
||||
}
|
||||
} else {
|
||||
$new_name .= ".gz" if $opt_gzip;
|
||||
print INI_OUT "$phrase_table_impl $source_factor $t $w $new_name$table_flag\n";
|
||||
@ -147,7 +151,7 @@ while(<INI>) {
|
||||
|
||||
$file =~ s/^.*\/+([^\/]+)/$1/g;
|
||||
my $new_name = "$dir/$file";
|
||||
$new_name =~ s/\.gz//;
|
||||
$new_name =~ s/\.gz//;
|
||||
print INI_OUT "$factors $t $w $new_name\n";
|
||||
push @TABLE_NEW_NAME,$new_name;
|
||||
|
||||
@ -275,11 +279,16 @@ for(my $i=0;$i<=$#TABLE;$i++) {
|
||||
# ... hierarchical translation model
|
||||
if ($opt_hierarchical) {
|
||||
my $cmd = "$binarizer $new_file $new_file.bin";
|
||||
print STDERR $cmd."\n";
|
||||
print STDERR `$cmd`;
|
||||
print STDERR $cmd."\n";
|
||||
print STDERR `$cmd`;
|
||||
}
|
||||
# ... phrase translation model
|
||||
else {
|
||||
elsif ($binarizer =~ /processPhraseTableMin/) {
|
||||
#compact phrase table
|
||||
my $cmd = "LC_ALL=C sort -T $dir $new_file > $new_file.sorted; $binarizer -in $new_file.sorted -out $new_file -nscores $TABLE_WEIGHTS[$i]; rm $new_file.sorted";
|
||||
print STDERR $cmd."\n";
|
||||
print STDERR `$cmd`;
|
||||
} else {
|
||||
my $cmd = "cat $new_file | LC_ALL=C sort -T $dir | $binarizer -ttable 0 0 - -nscores $TABLE_WEIGHTS[$i] -out $new_file";
|
||||
print STDERR $cmd."\n";
|
||||
print STDERR `$cmd`;
|
||||
@ -289,8 +298,13 @@ for(my $i=0;$i<=$#TABLE;$i++) {
|
||||
else {
|
||||
my $lexbin = $binarizer;
|
||||
$lexbin =~ s/PhraseTable/LexicalTable/;
|
||||
$lexbin =~ s/^\s*(\S+)\s.+/$1/; # no options
|
||||
my $cmd = "$lexbin -in $new_file -out $new_file";
|
||||
my $cmd;
|
||||
if ($lexbin =~ /processLexicalTableMin/) {
|
||||
$cmd = "LC_ALL=C sort -T $dir $new_file > $new_file.sorted; $lexbin -in $new_file.sorted -out $new_file; rm $new_file.sorted";
|
||||
} else {
|
||||
$lexbin =~ s/^\s*(\S+)\s.+/$1/; # no options
|
||||
$cmd = "$lexbin -in $new_file -out $new_file";
|
||||
}
|
||||
print STDERR $cmd."\n";
|
||||
print STDERR `$cmd`;
|
||||
}
|
||||
|
@ -1,5 +1 @@
|
||||
fakelib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
|
||||
|
||||
import testing ;
|
||||
|
||||
unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ;
|
||||
fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc vertex_generator.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
|
||||
|
86
search/applied.hh
Normal file
86
search/applied.hh
Normal file
@ -0,0 +1,86 @@
|
||||
#ifndef SEARCH_APPLIED__
|
||||
#define SEARCH_APPLIED__
|
||||
|
||||
#include "search/edge.hh"
|
||||
#include "search/header.hh"
|
||||
#include "util/pool.hh"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
namespace search {
|
||||
|
||||
// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.
|
||||
template <class Below> class GenericApplied : public Header {
|
||||
public:
|
||||
GenericApplied() {}
|
||||
|
||||
GenericApplied(void *location, PartialEdge partial)
|
||||
: Header(location) {
|
||||
memcpy(Base(), partial.Base(), kHeaderSize);
|
||||
Below *child_out = Children();
|
||||
const PartialVertex *part = partial.NT();
|
||||
const PartialVertex *const part_end_loop = part + partial.GetArity();
|
||||
for (; part != part_end_loop; ++part, ++child_out)
|
||||
*child_out = Below(part->End());
|
||||
}
|
||||
|
||||
GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) {
|
||||
SetScore(score);
|
||||
SetNote(note);
|
||||
}
|
||||
|
||||
explicit GenericApplied(History from) : Header(from) {}
|
||||
|
||||
|
||||
// These are arrays of length GetArity().
|
||||
Below *Children() {
|
||||
return reinterpret_cast<Below*>(After());
|
||||
}
|
||||
const Below *Children() const {
|
||||
return reinterpret_cast<const Below*>(After());
|
||||
}
|
||||
|
||||
static std::size_t Size(Arity arity) {
|
||||
return kHeaderSize + arity * sizeof(const Below);
|
||||
}
|
||||
};
|
||||
|
||||
// Applied rule that references itself.
|
||||
class Applied : public GenericApplied<Applied> {
|
||||
private:
|
||||
typedef GenericApplied<Applied> P;
|
||||
|
||||
public:
|
||||
Applied() {}
|
||||
Applied(void *location, PartialEdge partial) : P(location, partial) {}
|
||||
Applied(History from) : P(from) {}
|
||||
};
|
||||
|
||||
// How to build single-best hypotheses.
|
||||
class SingleBest {
|
||||
public:
|
||||
typedef PartialEdge Combine;
|
||||
|
||||
void Add(PartialEdge &existing, PartialEdge add) const {
|
||||
if (!existing.Valid() || existing.GetScore() < add.GetScore())
|
||||
existing = add;
|
||||
}
|
||||
|
||||
NBestComplete Complete(PartialEdge partial) {
|
||||
if (!partial.Valid())
|
||||
return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY);
|
||||
void *place_final = pool_.Allocate(Applied::Size(partial.GetArity()));
|
||||
Applied(place_final, partial);
|
||||
return NBestComplete(
|
||||
place_final,
|
||||
partial.CompletedState(),
|
||||
partial.GetScore());
|
||||
}
|
||||
|
||||
private:
|
||||
util::Pool pool_;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
||||
#endif // SEARCH_APPLIED__
|
@ -1,23 +1,36 @@
|
||||
#ifndef SEARCH_CONFIG__
|
||||
#define SEARCH_CONFIG__
|
||||
|
||||
#include "search/weights.hh"
|
||||
#include "util/string_piece.hh"
|
||||
#include "search/types.hh"
|
||||
|
||||
namespace search {
|
||||
|
||||
struct NBestConfig {
|
||||
explicit NBestConfig(unsigned int in_size) {
|
||||
keep = in_size;
|
||||
size = in_size;
|
||||
}
|
||||
|
||||
unsigned int keep, size;
|
||||
};
|
||||
|
||||
class Config {
|
||||
public:
|
||||
Config(const Weights &weights, unsigned int pop_limit) :
|
||||
weights_(weights), pop_limit_(pop_limit) {}
|
||||
Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) :
|
||||
lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {}
|
||||
|
||||
const Weights &GetWeights() const { return weights_; }
|
||||
Score LMWeight() const { return lm_weight_; }
|
||||
|
||||
unsigned int PopLimit() const { return pop_limit_; }
|
||||
|
||||
const NBestConfig &GetNBest() const { return nbest_; }
|
||||
|
||||
private:
|
||||
Weights weights_;
|
||||
Score lm_weight_;
|
||||
|
||||
unsigned int pop_limit_;
|
||||
|
||||
NBestConfig nbest_;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
@ -1,30 +1,16 @@
|
||||
#ifndef SEARCH_CONTEXT__
|
||||
#define SEARCH_CONTEXT__
|
||||
|
||||
#include "lm/model.hh"
|
||||
#include "search/config.hh"
|
||||
#include "search/final.hh"
|
||||
#include "search/types.hh"
|
||||
#include "search/vertex.hh"
|
||||
#include "util/exception.hh"
|
||||
#include "util/pool.hh"
|
||||
|
||||
#include <boost/pool/object_pool.hpp>
|
||||
#include <boost/ptr_container/ptr_vector.hpp>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace search {
|
||||
|
||||
class Weights;
|
||||
|
||||
class ContextBase {
|
||||
public:
|
||||
explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
|
||||
|
||||
util::Pool &FinalPool() {
|
||||
return final_pool_;
|
||||
}
|
||||
explicit ContextBase(const Config &config) : config_(config) {}
|
||||
|
||||
VertexNode *NewVertexNode() {
|
||||
VertexNode *ret = vertex_node_pool_.construct();
|
||||
@ -36,18 +22,16 @@ class ContextBase {
|
||||
vertex_node_pool_.destroy(node);
|
||||
}
|
||||
|
||||
unsigned int PopLimit() const { return pop_limit_; }
|
||||
unsigned int PopLimit() const { return config_.PopLimit(); }
|
||||
|
||||
const Weights &GetWeights() const { return weights_; }
|
||||
Score LMWeight() const { return config_.LMWeight(); }
|
||||
|
||||
const Config &GetConfig() const { return config_; }
|
||||
|
||||
private:
|
||||
util::Pool final_pool_;
|
||||
|
||||
boost::object_pool<VertexNode> vertex_node_pool_;
|
||||
|
||||
unsigned int pop_limit_;
|
||||
|
||||
const Weights &weights_;
|
||||
Config config_;
|
||||
};
|
||||
|
||||
template <class Model> class Context : public ContextBase {
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "search/edge_generator.hh"
|
||||
|
||||
#include "lm/left.hh"
|
||||
#include "lm/model.hh"
|
||||
#include "lm/partial.hh"
|
||||
#include "search/context.hh"
|
||||
#include "search/vertex.hh"
|
||||
@ -38,7 +39,7 @@ template <class Model> void FastScore(const Context<Model> &context, Arity victi
|
||||
*cover = *(cover + 1);
|
||||
}
|
||||
}
|
||||
update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM());
|
||||
update.SetScore(update.GetScore() + adjustment * context.LMWeight());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -2,7 +2,6 @@
|
||||
#define SEARCH_EDGE_GENERATOR__
|
||||
|
||||
#include "search/edge.hh"
|
||||
#include "search/note.hh"
|
||||
#include "search/types.hh"
|
||||
|
||||
#include <queue>
|
||||
|
@ -1,36 +0,0 @@
|
||||
#ifndef SEARCH_FINAL__
|
||||
#define SEARCH_FINAL__
|
||||
|
||||
#include "search/header.hh"
|
||||
#include "util/pool.hh"
|
||||
|
||||
namespace search {
|
||||
|
||||
// A full hypothesis with pointers to children.
|
||||
class Final : public Header {
|
||||
public:
|
||||
Final() {}
|
||||
|
||||
Final(util::Pool &pool, Score score, Arity arity, Note note)
|
||||
: Header(pool.Allocate(Size(arity)), arity) {
|
||||
SetScore(score);
|
||||
SetNote(note);
|
||||
}
|
||||
|
||||
// These are arrays of length GetArity().
|
||||
Final *Children() {
|
||||
return reinterpret_cast<Final*>(After());
|
||||
}
|
||||
const Final *Children() const {
|
||||
return reinterpret_cast<const Final*>(After());
|
||||
}
|
||||
|
||||
private:
|
||||
static std::size_t Size(Arity arity) {
|
||||
return kHeaderSize + arity * sizeof(const Final);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
||||
#endif // SEARCH_FINAL__
|
@ -3,7 +3,6 @@
|
||||
|
||||
// Header consisting of Score, Arity, and Note
|
||||
|
||||
#include "search/note.hh"
|
||||
#include "search/types.hh"
|
||||
|
||||
#include <stdint.h>
|
||||
@ -24,6 +23,9 @@ class Header {
|
||||
bool operator<(const Header &other) const {
|
||||
return GetScore() < other.GetScore();
|
||||
}
|
||||
bool operator>(const Header &other) const {
|
||||
return GetScore() > other.GetScore();
|
||||
}
|
||||
|
||||
Arity GetArity() const {
|
||||
return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
|
||||
@ -36,9 +38,14 @@ class Header {
|
||||
*reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
|
||||
}
|
||||
|
||||
uint8_t *Base() { return base_; }
|
||||
const uint8_t *Base() const { return base_; }
|
||||
|
||||
protected:
|
||||
Header() : base_(NULL) {}
|
||||
|
||||
explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {}
|
||||
|
||||
Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
|
||||
*reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
|
||||
}
|
||||
|
106
search/nbest.cc
Normal file
106
search/nbest.cc
Normal file
@ -0,0 +1,106 @@
|
||||
#include "search/nbest.hh"
|
||||
|
||||
#include "util/pool.hh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
|
||||
namespace search {
|
||||
|
||||
NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) {
|
||||
assert(!partials.empty());
|
||||
std::vector<PartialEdge>::iterator end;
|
||||
if (partials.size() > keep) {
|
||||
end = partials.begin() + keep;
|
||||
std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>());
|
||||
} else {
|
||||
end = partials.end();
|
||||
}
|
||||
for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) {
|
||||
queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i));
|
||||
}
|
||||
}
|
||||
|
||||
Score NBestList::TopAfterConstructor() const {
|
||||
assert(revealed_.empty());
|
||||
return queue_.top().GetScore();
|
||||
}
|
||||
|
||||
const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) {
|
||||
while (revealed_.size() < n && !queue_.empty()) {
|
||||
MoveTop(pool);
|
||||
}
|
||||
return revealed_;
|
||||
}
|
||||
|
||||
Score NBestList::Visit(util::Pool &pool, std::size_t index) {
|
||||
if (index + 1 < revealed_.size())
|
||||
return revealed_[index + 1].GetScore() - revealed_[index].GetScore();
|
||||
if (queue_.empty())
|
||||
return -INFINITY;
|
||||
if (index + 1 == revealed_.size())
|
||||
return queue_.top().GetScore() - revealed_[index].GetScore();
|
||||
assert(index == revealed_.size());
|
||||
|
||||
MoveTop(pool);
|
||||
|
||||
if (queue_.empty()) return -INFINITY;
|
||||
return queue_.top().GetScore() - revealed_[index].GetScore();
|
||||
}
|
||||
|
||||
Applied NBestList::Get(util::Pool &pool, std::size_t index) {
|
||||
assert(index <= revealed_.size());
|
||||
if (index == revealed_.size()) MoveTop(pool);
|
||||
return revealed_[index];
|
||||
}
|
||||
|
||||
void NBestList::MoveTop(util::Pool &pool) {
|
||||
assert(!queue_.empty());
|
||||
QueueEntry entry(queue_.top());
|
||||
queue_.pop();
|
||||
RevealedRef *const children_begin = entry.Children();
|
||||
RevealedRef *const children_end = children_begin + entry.GetArity();
|
||||
Score basis = entry.GetScore();
|
||||
for (RevealedRef *child = children_begin; child != children_end; ++child) {
|
||||
Score change = child->in_->Visit(pool, child->index_);
|
||||
if (change != -INFINITY) {
|
||||
assert(change < 0.001);
|
||||
QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote());
|
||||
std::copy(children_begin, child, new_entry.Children());
|
||||
RevealedRef *update = new_entry.Children() + (child - children_begin);
|
||||
update->in_ = child->in_;
|
||||
update->index_ = child->index_ + 1;
|
||||
std::copy(child + 1, children_end, update + 1);
|
||||
queue_.push(new_entry);
|
||||
}
|
||||
// Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010.
|
||||
if (child->index_) break;
|
||||
}
|
||||
|
||||
// Convert QueueEntry to Applied. This leaves some unused memory.
|
||||
void *overwrite = entry.Children();
|
||||
for (unsigned int i = 0; i < entry.GetArity(); ++i) {
|
||||
RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i));
|
||||
*(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_);
|
||||
}
|
||||
revealed_.push_back(Applied(entry.Base()));
|
||||
}
|
||||
|
||||
NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) {
|
||||
assert(!partials.empty());
|
||||
NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep);
|
||||
return NBestComplete(
|
||||
list,
|
||||
partials.front().CompletedState(), // All partials have the same state
|
||||
list->TopAfterConstructor());
|
||||
}
|
||||
|
||||
const std::vector<Applied> &NBest::Extract(History history) {
|
||||
return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size);
|
||||
}
|
||||
|
||||
} // namespace search
|
81
search/nbest.hh
Normal file
81
search/nbest.hh
Normal file
@ -0,0 +1,81 @@
|
||||
#ifndef SEARCH_NBEST__
|
||||
#define SEARCH_NBEST__
|
||||
|
||||
#include "search/applied.hh"
|
||||
#include "search/config.hh"
|
||||
#include "search/edge.hh"
|
||||
|
||||
#include <boost/pool/object_pool.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
namespace search {
|
||||
|
||||
class NBestList;
|
||||
|
||||
class NBestList {
|
||||
private:
|
||||
class RevealedRef {
|
||||
public:
|
||||
explicit RevealedRef(History history)
|
||||
: in_(static_cast<NBestList*>(history)), index_(0) {}
|
||||
|
||||
private:
|
||||
friend class NBestList;
|
||||
|
||||
NBestList *in_;
|
||||
std::size_t index_;
|
||||
};
|
||||
|
||||
typedef GenericApplied<RevealedRef> QueueEntry;
|
||||
|
||||
public:
|
||||
NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep);
|
||||
|
||||
Score TopAfterConstructor() const;
|
||||
|
||||
const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n);
|
||||
|
||||
private:
|
||||
Score Visit(util::Pool &pool, std::size_t index);
|
||||
|
||||
Applied Get(util::Pool &pool, std::size_t index);
|
||||
|
||||
void MoveTop(util::Pool &pool);
|
||||
|
||||
typedef std::vector<Applied> Revealed;
|
||||
Revealed revealed_;
|
||||
|
||||
typedef std::priority_queue<QueueEntry> Queue;
|
||||
Queue queue_;
|
||||
};
|
||||
|
||||
class NBest {
|
||||
public:
|
||||
typedef std::vector<PartialEdge> Combine;
|
||||
|
||||
explicit NBest(const NBestConfig &config) : config_(config) {}
|
||||
|
||||
void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const {
|
||||
existing.push_back(addition);
|
||||
}
|
||||
|
||||
NBestComplete Complete(std::vector<PartialEdge> &partials);
|
||||
|
||||
const std::vector<Applied> &Extract(History root);
|
||||
|
||||
private:
|
||||
const NBestConfig config_;
|
||||
|
||||
boost::object_pool<NBestList> list_pool_;
|
||||
|
||||
util::Pool entry_pool_;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
||||
#endif // SEARCH_NBEST__
|
@ -1,12 +0,0 @@
|
||||
#ifndef SEARCH_NOTE__
|
||||
#define SEARCH_NOTE__
|
||||
|
||||
namespace search {
|
||||
|
||||
union Note {
|
||||
const void *vp;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
||||
#endif // SEARCH_NOTE__
|
@ -1,7 +1,7 @@
|
||||
#include "search/rule.hh"
|
||||
|
||||
#include "lm/model.hh"
|
||||
#include "search/context.hh"
|
||||
#include "search/final.hh"
|
||||
|
||||
#include <ostream>
|
||||
|
||||
@ -9,35 +9,35 @@
|
||||
|
||||
namespace search {
|
||||
|
||||
template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
|
||||
unsigned int oov_count = 0;
|
||||
float prob = 0.0;
|
||||
const Model &model = context.LanguageModel();
|
||||
const lm::WordIndex oov = model.GetVocabulary().NotFound();
|
||||
for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
|
||||
lm::ngram::RuleScore<Model> scorer(model, *(writing++));
|
||||
// TODO: optimize
|
||||
if (prepend_bos && (word == words.begin())) {
|
||||
scorer.BeginSentence();
|
||||
}
|
||||
for (; ; ++word) {
|
||||
if (word == words.end()) {
|
||||
prob += scorer.Finish();
|
||||
return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
|
||||
}
|
||||
if (*word == kNonTerminal) break;
|
||||
if (*word == oov) ++oov_count;
|
||||
template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) {
|
||||
ScoreRuleRet ret;
|
||||
ret.prob = 0.0;
|
||||
ret.oov = 0;
|
||||
const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence();
|
||||
lm::ngram::RuleScore<Model> scorer(model, *(writing++));
|
||||
std::vector<lm::WordIndex>::const_iterator word = words.begin();
|
||||
if (word != words.end() && *word == bos) {
|
||||
scorer.BeginSentence();
|
||||
++word;
|
||||
}
|
||||
for (; word != words.end(); ++word) {
|
||||
if (*word == kNonTerminal) {
|
||||
ret.prob += scorer.Finish();
|
||||
scorer.Reset(*(writing++));
|
||||
} else {
|
||||
if (*word == oov) ++ret.oov;
|
||||
scorer.Terminal(*word);
|
||||
}
|
||||
prob += scorer.Finish();
|
||||
}
|
||||
ret.prob += scorer.Finish();
|
||||
return ret;
|
||||
}
|
||||
|
||||
template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
|
||||
template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
|
||||
template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
|
||||
template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
|
||||
template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
|
||||
template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
|
||||
template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
|
||||
template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
|
||||
template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
|
||||
template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
|
||||
template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
|
||||
template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
|
||||
|
||||
} // namespace search
|
||||
|
@ -9,11 +9,16 @@
|
||||
|
||||
namespace search {
|
||||
|
||||
template <class Model> class Context;
|
||||
|
||||
const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
|
||||
|
||||
template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);
|
||||
struct ScoreRuleRet {
|
||||
Score prob;
|
||||
unsigned int oov;
|
||||
};
|
||||
|
||||
// Pass <s> and </s> normally.
|
||||
// Indicate non-terminals with kNonTerminal.
|
||||
template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out);
|
||||
|
||||
} // namespace search
|
||||
|
||||
|
@ -3,12 +3,29 @@
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace lm { namespace ngram { class ChartState; } }
|
||||
|
||||
namespace search {
|
||||
|
||||
typedef float Score;
|
||||
|
||||
typedef uint32_t Arity;
|
||||
|
||||
union Note {
|
||||
const void *vp;
|
||||
};
|
||||
|
||||
typedef void *History;
|
||||
|
||||
struct NBestComplete {
|
||||
NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)
|
||||
: history(in_history), state(&in_state), score(in_score) {}
|
||||
|
||||
History history;
|
||||
const lm::ngram::ChartState *state;
|
||||
Score score;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
||||
#endif // SEARCH_TYPES__
|
||||
|
@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve
|
||||
|
||||
} // namespace
|
||||
|
||||
void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
|
||||
void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
|
||||
if (Complete()) {
|
||||
assert(end_.Valid());
|
||||
assert(end_);
|
||||
assert(extend_.empty());
|
||||
bound_ = end_.GetScore();
|
||||
return;
|
||||
}
|
||||
if (extend_.size() == 1 && parent_ptr) {
|
||||
*parent_ptr = extend_[0];
|
||||
extend_[0]->SortAndSet(context, parent_ptr);
|
||||
if (extend_.size() == 1) {
|
||||
parent_ptr = extend_[0];
|
||||
extend_[0]->RecursiveSortAndSet(context, parent_ptr);
|
||||
context.DeleteVertexNode(this);
|
||||
return;
|
||||
}
|
||||
for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
|
||||
(*i)->SortAndSet(context, &*i);
|
||||
(*i)->RecursiveSortAndSet(context, *i);
|
||||
}
|
||||
std::sort(extend_.begin(), extend_.end(), GreaterByBound());
|
||||
bound_ = extend_.front()->Bound();
|
||||
}
|
||||
|
||||
void VertexNode::SortAndSet(ContextBase &context) {
|
||||
// This is the root. The root might be empty.
|
||||
if (extend_.empty()) {
|
||||
bound_ = -INFINITY;
|
||||
return;
|
||||
}
|
||||
// The root cannot be replaced. There's always one transition.
|
||||
for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
|
||||
(*i)->RecursiveSortAndSet(context, *i);
|
||||
}
|
||||
std::sort(extend_.begin(), extend_.end(), GreaterByBound());
|
||||
bound_ = extend_.front()->Bound();
|
||||
|
@ -2,7 +2,6 @@
|
||||
#define SEARCH_VERTEX__
|
||||
|
||||
#include "lm/left.hh"
|
||||
#include "search/final.hh"
|
||||
#include "search/types.hh"
|
||||
|
||||
#include <boost/unordered_set.hpp>
|
||||
@ -10,6 +9,7 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace search {
|
||||
@ -18,7 +18,7 @@ class ContextBase;
|
||||
|
||||
class VertexNode {
|
||||
public:
|
||||
VertexNode() {}
|
||||
VertexNode() : end_() {}
|
||||
|
||||
void InitRoot() {
|
||||
extend_.clear();
|
||||
@ -26,7 +26,7 @@ class VertexNode {
|
||||
state_.left.length = 0;
|
||||
state_.right.length = 0;
|
||||
right_full_ = false;
|
||||
end_ = Final();
|
||||
end_ = History();
|
||||
}
|
||||
|
||||
lm::ngram::ChartState &MutableState() { return state_; }
|
||||
@ -36,20 +36,21 @@ class VertexNode {
|
||||
extend_.push_back(next);
|
||||
}
|
||||
|
||||
void SetEnd(Final end) {
|
||||
assert(!end_.Valid());
|
||||
void SetEnd(History end, Score score) {
|
||||
assert(!end_);
|
||||
end_ = end;
|
||||
bound_ = score;
|
||||
}
|
||||
|
||||
void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
|
||||
void SortAndSet(ContextBase &context);
|
||||
|
||||
// Should only happen to a root node when the entire vertex is empty.
|
||||
bool Empty() const {
|
||||
return !end_.Valid() && extend_.empty();
|
||||
return !end_ && extend_.empty();
|
||||
}
|
||||
|
||||
bool Complete() const {
|
||||
return end_.Valid();
|
||||
return end_;
|
||||
}
|
||||
|
||||
const lm::ngram::ChartState &State() const { return state_; }
|
||||
@ -64,7 +65,7 @@ class VertexNode {
|
||||
}
|
||||
|
||||
// Will be invalid unless this is a leaf.
|
||||
const Final End() const { return end_; }
|
||||
const History End() const { return end_; }
|
||||
|
||||
const VertexNode &operator[](size_t index) const {
|
||||
return *extend_[index];
|
||||
@ -75,13 +76,15 @@ class VertexNode {
|
||||
}
|
||||
|
||||
private:
|
||||
void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
|
||||
|
||||
std::vector<VertexNode*> extend_;
|
||||
|
||||
lm::ngram::ChartState state_;
|
||||
bool right_full_;
|
||||
|
||||
Score bound_;
|
||||
Final end_;
|
||||
History end_;
|
||||
};
|
||||
|
||||
class PartialVertex {
|
||||
@ -97,7 +100,7 @@ class PartialVertex {
|
||||
const lm::ngram::ChartState &State() const { return back_->State(); }
|
||||
bool RightFull() const { return back_->RightFull(); }
|
||||
|
||||
Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
|
||||
Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
|
||||
|
||||
unsigned char Length() const { return back_->Length(); }
|
||||
|
||||
@ -121,7 +124,7 @@ class PartialVertex {
|
||||
return ret;
|
||||
}
|
||||
|
||||
const Final End() const {
|
||||
const History End() const {
|
||||
return back_->End();
|
||||
}
|
||||
|
||||
@ -130,16 +133,18 @@ class PartialVertex {
|
||||
unsigned int index_;
|
||||
};
|
||||
|
||||
template <class Output> class VertexGenerator;
|
||||
|
||||
class Vertex {
|
||||
public:
|
||||
Vertex() {}
|
||||
|
||||
PartialVertex RootPartial() const { return PartialVertex(root_); }
|
||||
|
||||
const Final BestChild() const {
|
||||
const History BestChild() const {
|
||||
PartialVertex top(RootPartial());
|
||||
if (top.Empty()) {
|
||||
return Final();
|
||||
return History();
|
||||
} else {
|
||||
PartialVertex continuation;
|
||||
while (!top.Complete()) {
|
||||
@ -150,8 +155,8 @@ class Vertex {
|
||||
}
|
||||
|
||||
private:
|
||||
friend class VertexGenerator;
|
||||
|
||||
template <class Output> friend class VertexGenerator;
|
||||
template <class Output> friend class RootVertexGenerator;
|
||||
VertexNode root_;
|
||||
};
|
||||
|
||||
|
@ -11,23 +11,11 @@
|
||||
|
||||
namespace search {
|
||||
|
||||
VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
|
||||
gen.root_.InitRoot();
|
||||
}
|
||||
|
||||
#if BOOST_VERSION > 104200
|
||||
namespace {
|
||||
|
||||
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
|
||||
|
||||
// Parallel structure to VertexNode.
|
||||
struct Trie {
|
||||
Trie() : under(NULL) {}
|
||||
|
||||
VertexNode *under;
|
||||
boost::unordered_map<uint64_t, Trie> extend;
|
||||
};
|
||||
|
||||
Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
|
||||
Trie &next = node.extend[added];
|
||||
if (!next.under) {
|
||||
@ -43,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n
|
||||
return next;
|
||||
}
|
||||
|
||||
void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
|
||||
Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
|
||||
Final *child_out = final.Children();
|
||||
const PartialVertex *part = partial.NT();
|
||||
const PartialVertex *const part_end_loop = part + partial.GetArity();
|
||||
for (; part != part_end_loop; ++part, ++child_out)
|
||||
*child_out = part->End();
|
||||
} // namespace
|
||||
|
||||
starter.under->SetEnd(final);
|
||||
}
|
||||
|
||||
void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
|
||||
const lm::ngram::ChartState &state = partial.CompletedState();
|
||||
void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) {
|
||||
const lm::ngram::ChartState &state = *end.state;
|
||||
|
||||
unsigned char left = 0, right = 0;
|
||||
Trie *node = &root;
|
||||
@ -81,30 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
|
||||
}
|
||||
|
||||
node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
|
||||
CompleteTransition(context, *node, partial);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#else // BOOST_VERSION
|
||||
|
||||
struct Trie {
|
||||
VertexNode *under;
|
||||
};
|
||||
|
||||
void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
|
||||
UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
|
||||
node->under->SetEnd(end.history, end.score);
|
||||
}
|
||||
|
||||
#endif // BOOST_VERSION
|
||||
|
||||
void VertexGenerator::FinishedSearch() {
|
||||
Trie root;
|
||||
root.under = &gen_.root_;
|
||||
for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
|
||||
AddHypothesis(context_, root, i->second);
|
||||
}
|
||||
root.under->SortAndSet(context_, NULL);
|
||||
}
|
||||
|
||||
} // namespace search
|
||||
|
@ -2,9 +2,11 @@
|
||||
#define SEARCH_VERTEX_GENERATOR__
|
||||
|
||||
#include "search/edge.hh"
|
||||
#include "search/types.hh"
|
||||
#include "search/vertex.hh"
|
||||
|
||||
#include <boost/unordered_map.hpp>
|
||||
#include <boost/version.hpp>
|
||||
|
||||
namespace lm {
|
||||
namespace ngram {
|
||||
@ -15,21 +17,44 @@ class ChartState;
|
||||
namespace search {
|
||||
|
||||
class ContextBase;
|
||||
class Final;
|
||||
|
||||
class VertexGenerator {
|
||||
#if BOOST_VERSION > 104200
|
||||
// Parallel structure to VertexNode.
|
||||
struct Trie {
|
||||
Trie() : under(NULL) {}
|
||||
|
||||
VertexNode *under;
|
||||
boost::unordered_map<uint64_t, Trie> extend;
|
||||
};
|
||||
|
||||
void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
|
||||
|
||||
#endif // BOOST_VERSION
|
||||
|
||||
// Output makes the single-best or n-best list.
|
||||
template <class Output> class VertexGenerator {
|
||||
public:
|
||||
VertexGenerator(ContextBase &context, Vertex &gen);
|
||||
|
||||
void NewHypothesis(PartialEdge partial) {
|
||||
const lm::ngram::ChartState &state = partial.CompletedState();
|
||||
std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial)));
|
||||
if (!ret.second && ret.first->second < partial) {
|
||||
ret.first->second = partial;
|
||||
}
|
||||
VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
|
||||
gen.root_.InitRoot();
|
||||
}
|
||||
|
||||
void FinishedSearch();
|
||||
void NewHypothesis(PartialEdge partial) {
|
||||
nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
|
||||
}
|
||||
|
||||
void FinishedSearch() {
|
||||
#if BOOST_VERSION > 104200
|
||||
Trie root;
|
||||
root.under = &gen_.root_;
|
||||
for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
|
||||
AddHypothesis(context_, root, nbest_.Complete(i->second));
|
||||
}
|
||||
existing_.clear();
|
||||
root.under->SortAndSet(context_);
|
||||
#else
|
||||
UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
|
||||
#endif
|
||||
}
|
||||
|
||||
const Vertex &Generating() const { return gen_; }
|
||||
|
||||
@ -38,8 +63,35 @@ class VertexGenerator {
|
||||
|
||||
Vertex &gen_;
|
||||
|
||||
typedef boost::unordered_map<uint64_t, PartialEdge> Existing;
|
||||
typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing;
|
||||
Existing existing_;
|
||||
|
||||
Output &nbest_;
|
||||
};
|
||||
|
||||
// Special case for root vertex: everything should come together into the root
|
||||
// node. In theory, this should happen naturally due to state collapsing with
|
||||
// <s> and </s>. If that's the case, VertexGenerator is fine, though it will
|
||||
// make one connection.
|
||||
template <class Output> class RootVertexGenerator {
|
||||
public:
|
||||
RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {}
|
||||
|
||||
void NewHypothesis(PartialEdge partial) {
|
||||
out_.Add(combine_, partial);
|
||||
}
|
||||
|
||||
void FinishedSearch() {
|
||||
gen_.root_.InitRoot();
|
||||
NBestComplete completed(out_.Complete(combine_));
|
||||
gen_.root_.SetEnd(completed.history, completed.score);
|
||||
}
|
||||
|
||||
private:
|
||||
Vertex &gen_;
|
||||
|
||||
typename Output::Combine combine_;
|
||||
Output &out_;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
@ -1,71 +0,0 @@
|
||||
#include "search/weights.hh"
|
||||
#include "util/tokenize_piece.hh"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace search {
|
||||
|
||||
namespace {
|
||||
struct Insert {
|
||||
void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
|
||||
std::string copy(name.data(), name.size());
|
||||
map[copy] = score;
|
||||
}
|
||||
};
|
||||
|
||||
struct DotProduct {
|
||||
search::Score total;
|
||||
DotProduct() : total(0.0) {}
|
||||
|
||||
void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
|
||||
boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
|
||||
if (i != map.end())
|
||||
total += score * i->second;
|
||||
}
|
||||
};
|
||||
|
||||
template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
|
||||
for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
|
||||
util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
|
||||
UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
|
||||
StringPiece name(*equals);
|
||||
UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
|
||||
char *end;
|
||||
// Assumes proper termination.
|
||||
double value = std::strtod(equals->data(), &end);
|
||||
UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
|
||||
UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
|
||||
op(map, name, value);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Weights::Weights(StringPiece text) {
|
||||
Insert op;
|
||||
Parse<Map, Insert>(text, map_, op);
|
||||
lm_ = Steal("LanguageModel");
|
||||
oov_ = Steal("OOV");
|
||||
word_penalty_ = Steal("WordPenalty");
|
||||
}
|
||||
|
||||
Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {}
|
||||
|
||||
search::Score Weights::DotNoLM(StringPiece text) const {
|
||||
DotProduct dot;
|
||||
Parse<const Map, DotProduct>(text, map_, dot);
|
||||
return dot.total;
|
||||
}
|
||||
|
||||
float Weights::Steal(const std::string &str) {
|
||||
Map::iterator i(map_.find(str));
|
||||
if (i == map_.end()) {
|
||||
return 0.0;
|
||||
} else {
|
||||
float ret = i->second;
|
||||
map_.erase(i);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace search
|
@ -1,52 +0,0 @@
|
||||
// For now, the individual features are not kept.
|
||||
#ifndef SEARCH_WEIGHTS__
|
||||
#define SEARCH_WEIGHTS__
|
||||
|
||||
#include "search/types.hh"
|
||||
#include "util/exception.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
#include <boost/unordered_map.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace search {
|
||||
|
||||
class WeightParseException : public util::Exception {
|
||||
public:
|
||||
WeightParseException() {}
|
||||
~WeightParseException() throw() {}
|
||||
};
|
||||
|
||||
class Weights {
|
||||
public:
|
||||
// Parses weights, sets lm_weight_, removes it from map_.
|
||||
explicit Weights(StringPiece text);
|
||||
|
||||
// Just the three scores we care about adding.
|
||||
Weights(Score lm, Score oov, Score word_penalty);
|
||||
|
||||
Score DotNoLM(StringPiece text) const;
|
||||
|
||||
Score LM() const { return lm_; }
|
||||
|
||||
Score OOV() const { return oov_; }
|
||||
|
||||
Score WordPenalty() const { return word_penalty_; }
|
||||
|
||||
// Mostly for testing.
|
||||
const boost::unordered_map<std::string, Score> &GetMap() const { return map_; }
|
||||
|
||||
private:
|
||||
float Steal(const std::string &str);
|
||||
|
||||
typedef boost::unordered_map<std::string, Score> Map;
|
||||
|
||||
Map map_;
|
||||
|
||||
Score lm_, oov_, word_penalty_;
|
||||
};
|
||||
|
||||
} // namespace search
|
||||
|
||||
#endif // SEARCH_WEIGHTS__
|
@ -1,38 +0,0 @@
|
||||
#include "search/weights.hh"
|
||||
|
||||
#define BOOST_TEST_MODULE WeightTest
|
||||
#include <boost/test/unit_test.hpp>
|
||||
#include <boost/test/floating_point_comparison.hpp>
|
||||
|
||||
namespace search {
|
||||
namespace {
|
||||
|
||||
#define CHECK_WEIGHT(value, string) \
|
||||
i = parsed.find(string); \
|
||||
BOOST_REQUIRE(i != parsed.end()); \
|
||||
BOOST_CHECK_CLOSE((value), i->second, 0.001);
|
||||
|
||||
BOOST_AUTO_TEST_CASE(parse) {
|
||||
// These are not real feature weights.
|
||||
Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
|
||||
const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap();
|
||||
boost::unordered_map<std::string, search::Score>::const_iterator i;
|
||||
CHECK_WEIGHT(0.0, "rarity");
|
||||
CHECK_WEIGHT(0.0, "phrase-SGT");
|
||||
CHECK_WEIGHT(9.45117, "phrase-TGS");
|
||||
CHECK_WEIGHT(2.33833, "lexical-SGT");
|
||||
BOOST_CHECK(parsed.end() == parsed.find("lm"));
|
||||
BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001);
|
||||
CHECK_WEIGHT(-28.3317, "lexical-TGS");
|
||||
CHECK_WEIGHT(5.0, "glue?");
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(dot) {
|
||||
Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
|
||||
BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001);
|
||||
BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001);
|
||||
BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace search
|
Loading…
Reference in New Issue
Block a user