move all feature functions out of StaticData

This commit is contained in:
Hieu Hoang 2013-12-07 00:21:06 +00:00
parent a5d0296699
commit ba209202ef
13 changed files with 52 additions and 49 deletions

View File

@ -48,7 +48,7 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
{
// unknown word, add as trans opt
const StaticData &staticData = StaticData::Instance();
const UnknownWordPenaltyProducer *unknownWordPenaltyProducer = staticData.GetUnknownWordPenaltyProducer();
const UnknownWordPenaltyProducer &unknownWordPenaltyProducer = UnknownWordPenaltyProducer::Instance();
size_t isDigit = 0;
if (staticData.GetDropUnknown()) {
@ -93,7 +93,7 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
// scores
float unknownScore = FloorScore(TransformScore(prob));
targetPhrase->GetScoreBreakdown().Assign(unknownWordPenaltyProducer, unknownScore);
targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
targetPhrase->Evaluate(*unksrc);
targetPhrase->SetTargetLHS(targetLHS);
@ -121,7 +121,7 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
targetLHS->CreateFromString(Output, staticData.GetOutputFactorOrder(), targetLHSStr, true);
UTIL_THROW_IF2(targetLHS->GetFactor(0) == NULL, "Null factor for target LHS");
targetPhrase->GetScoreBreakdown().Assign(unknownWordPenaltyProducer, unknownScore);
targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
targetPhrase->Evaluate(*unksrc);
targetPhrase->SetTargetLHS(targetLHS);

View File

@ -69,8 +69,7 @@ ConfusionNet::ConfusionNet()
if (staticData.IsChart()) {
m_defaultLabelSet.insert(StaticData::Instance().GetInputDefaultNonTerminal());
}
UTIL_THROW_IF2(StaticData::Instance().GetInputFeature() == NULL,
"Input feature must be specified");
UTIL_THROW_IF2(&InputFeature::Instance() == NULL, "Input feature must be specified");
}
ConfusionNet::~ConfusionNet()
{
@ -131,9 +130,9 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
Clear();
const StaticData &staticData = StaticData::Instance();
const InputFeature *inputFeature = staticData.GetInputFeature();
size_t numInputScores = inputFeature->GetNumInputScores();
size_t numRealWordCount = inputFeature->GetNumRealWordsInInput();
const InputFeature &inputFeature = InputFeature::Instance();
size_t numInputScores = inputFeature.GetNumInputScores();
size_t numRealWordCount = inputFeature.GetNumRealWordsInInput();
size_t totalCount = numInputScores + numRealWordCount;
bool addRealWordCount = (numRealWordCount > 0);

View File

@ -10,10 +10,15 @@ using namespace std;
namespace Moses
{
InputFeature *InputFeature::s_instance = NULL;
InputFeature::InputFeature(const std::string &line)
:StatelessFeatureFunction(line)
{
ReadParameters();
UTIL_THROW_IF2(s_instance, "Can only have 1 input feature");
s_instance = this;
}
void InputFeature::Load()

View File

@ -10,11 +10,20 @@ namespace Moses
class InputFeature : public StatelessFeatureFunction
{
protected:
static InputFeature *s_instance;
size_t m_numInputScores;
size_t m_numRealWordCount;
bool m_legacy;
public:
static const InputFeature& Instance() {
return *s_instance;
}
static InputFeature& InstanceNonConst() {
return *s_instance;
}
InputFeature(const std::string &line);
void Load();

View File

@ -1,16 +1,22 @@
#include "UnknownWordPenaltyProducer.h"
#include <vector>
#include <string>
#include "UnknownWordPenaltyProducer.h"
#include "util/exception.hh"
using namespace std;
namespace Moses
{
UnknownWordPenaltyProducer *UnknownWordPenaltyProducer::s_instance = NULL;
UnknownWordPenaltyProducer::UnknownWordPenaltyProducer(const std::string &line)
: StatelessFeatureFunction(1, line)
{
m_tuneable = false;
ReadParameters();
UTIL_THROW_IF2(s_instance, "Can only have 1 unknown word penalty feature");
s_instance = this;
}
std::vector<float> UnknownWordPenaltyProducer::DefaultWeights() const

View File

@ -13,7 +13,17 @@ class WordsRange;
/** unknown word penalty */
class UnknownWordPenaltyProducer : public StatelessFeatureFunction
{
protected:
static UnknownWordPenaltyProducer *s_instance;
public:
static const UnknownWordPenaltyProducer& Instance() {
return *s_instance;
}
static UnknownWordPenaltyProducer& InstanceNonConst() {
return *s_instance;
}
UnknownWordPenaltyProducer(const std::string &line);
bool IsUseable(const FactorMask &mask) const {

View File

@ -45,7 +45,7 @@ protected:
distinctE(0) {
m_numInputScores = 0;
const StaticData &staticData = StaticData::Instance();
m_inputFeature = staticData.GetInputFeature();
m_inputFeature = &InputFeature::Instance();
if (m_inputFeature) {
const PhraseDictionary *firstPt = PhraseDictionary::GetColl()[0];

View File

@ -57,8 +57,6 @@ StaticData StaticData::s_instance;
StaticData::StaticData()
:m_sourceStartPosMattersForRecombination(false)
,m_inputType(SentenceInput)
,m_unknownWordPenaltyProducer(NULL)
,m_inputFeature(NULL)
,m_detailedTranslationReportingFilePath()
,m_detailedTreeFragmentsTranslationReportingFilePath()
,m_onlyDistinctNBest(false)
@ -876,7 +874,7 @@ float StaticData::GetWeightWordPenalty() const
float StaticData::GetWeightUnknownWordPenalty() const
{
return GetWeight(m_unknownWordPenaltyProducer);
return GetWeight(&UnknownWordPenaltyProducer::Instance());
}
void StaticData::InitializeForInput(const InputType& source) const
@ -908,16 +906,6 @@ void StaticData::LoadFeatureFunctions()
if (PhraseDictionary *ffCast = dynamic_cast<PhraseDictionary*>(ff)) {
doLoad = false;
} else if (const GenerationDictionary *ffCast
= dynamic_cast<const GenerationDictionary*>(ff)) {
// do nothing
} else if (UnknownWordPenaltyProducer *ffCast
= dynamic_cast<UnknownWordPenaltyProducer*>(ff)) {
UTIL_THROW_IF2(m_unknownWordPenaltyProducer, "Only 1 unknown word penalty allowed"); // max 1 feature;
m_unknownWordPenaltyProducer = ffCast;
} else if (const InputFeature *ffCast = dynamic_cast<const InputFeature*>(ff)) {
UTIL_THROW_IF2(m_inputFeature, "Only 1 input feature allowed"); // max 1 input feature;
m_inputFeature = ffCast;
}
if (doLoad) {

View File

@ -49,8 +49,6 @@ namespace Moses
class InputType;
class DecodeGraph;
class DecodeStep;
class UnknownWordPenaltyProducer;
class InputFeature;
typedef std::pair<std::string, float> UnknownLHSEntry;
typedef std::vector<UnknownLHSEntry> UnknownLHSList;
@ -113,8 +111,6 @@ protected:
InputTypeEnum m_inputType;
mutable size_t m_verboseLevel;
UnknownWordPenaltyProducer *m_unknownWordPenaltyProducer;
const InputFeature *m_inputFeature;
bool m_reportSegmentation;
bool m_reportSegmentationEnriched;
@ -418,14 +414,6 @@ public:
return m_searchAlgorithm == ChartDecoding || m_searchAlgorithm == ChartIncremental;
}
const UnknownWordPenaltyProducer *GetUnknownWordPenaltyProducer() const {
return m_unknownWordPenaltyProducer;
}
const InputFeature *GetInputFeature() const {
return m_inputFeature;
}
const ScoreComponentCollection& GetAllWeights() const {
return m_allWeights;
}

View File

@ -208,7 +208,7 @@ void TranslationOptionCollection::ProcessOneUnknownWord(const InputPath &inputPa
const ScorePair *inputScores)
{
const StaticData &staticData = StaticData::Instance();
const UnknownWordPenaltyProducer *unknownWordPenaltyProducer = staticData.GetUnknownWordPenaltyProducer();
const UnknownWordPenaltyProducer &unknownWordPenaltyProducer = UnknownWordPenaltyProducer::Instance();
float unknownScore = FloorScore(TransformScore(0));
const Word &sourceWord = inputPath.GetPhrase().GetWord(0);
@ -259,7 +259,7 @@ void TranslationOptionCollection::ProcessOneUnknownWord(const InputPath &inputPa
}
targetPhrase.GetScoreBreakdown().Assign(unknownWordPenaltyProducer, unknownScore);
targetPhrase.GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
// source phrase
const Phrase &sourcePhrase = inputPath.GetPhrase();
@ -523,14 +523,14 @@ void TranslationOptionCollection::SetInputScore(const InputPath &inputPath, Part
return;
}
const InputFeature *inputFeature = StaticData::Instance().GetInputFeature();
const InputFeature &inputFeature = InputFeature::Instance();
const std::vector<TranslationOption*> &transOpts = oldPtoc.GetList();
for (size_t i = 0; i < transOpts.size(); ++i) {
TranslationOption &transOpt = *transOpts[i];
ScoreComponentCollection &scores = transOpt.GetScoreBreakdown();
scores.PlusEquals(inputFeature, *inputScore);
scores.PlusEquals(&inputFeature, *inputScore);
}
}

View File

@ -22,9 +22,8 @@ TranslationOptionCollectionConfusionNet::TranslationOptionCollectionConfusionNet
, size_t maxNoTransOptPerCoverage, float translationOptionThreshold)
: TranslationOptionCollection(input, maxNoTransOptPerCoverage, translationOptionThreshold)
{
const InputFeature *inputFeature = StaticData::Instance().GetInputFeature();
UTIL_THROW_IF2(inputFeature == NULL,
"Input feature must be specified");
const InputFeature &inputFeature = InputFeature::Instance();
UTIL_THROW_IF2(&inputFeature == NULL, "Input feature must be specified");
size_t inputSize = input.GetSize();
m_inputPathMatrix.resize(inputSize);

View File

@ -26,9 +26,8 @@ TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
UTIL_THROW_IF2(StaticData::Instance().GetUseLegacyPT(),
"Not for models using the legqacy binary phrase table");
const InputFeature *inputFeature = StaticData::Instance().GetInputFeature();
UTIL_THROW_IF2(inputFeature == NULL,
"Input feature must be specified");
const InputFeature &inputFeature = InputFeature::Instance();
UTIL_THROW_IF2(&inputFeature == NULL, "Input feature must be specified");
size_t maxPhraseLength = StaticData::Instance().GetMaxPhraseLength();
size_t size = input.GetSize();

View File

@ -12,7 +12,7 @@ namespace Moses
{
WordLattice::WordLattice()
{
UTIL_THROW_IF2(StaticData::Instance().GetInputFeature() == NULL,
UTIL_THROW_IF2(&InputFeature::Instance() == NULL,
"Input feature must be specified");
}
@ -52,9 +52,9 @@ void WordLattice::Print(std::ostream& out) const
int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<FactorType>& factorOrder, const std::string& debug_line)
{
const StaticData &staticData = StaticData::Instance();
const InputFeature *inputFeature = staticData.GetInputFeature();
size_t numInputScores = inputFeature->GetNumInputScores();
size_t numRealWordCount = inputFeature->GetNumRealWordsInInput();
const InputFeature &inputFeature = InputFeature::Instance();
size_t numInputScores = inputFeature.GetNumInputScores();
size_t numRealWordCount = inputFeature.GetNumRealWordsInInput();
size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();