diff --git a/moses/BleuScoreFeature.cpp b/moses/BleuScoreFeature.cpp index 4f8fe9c9c..e8e1d9b4d 100644 --- a/moses/BleuScoreFeature.cpp +++ b/moses/BleuScoreFeature.cpp @@ -70,6 +70,75 @@ void BleuScoreState::AddNgramCountAndMatches(std::vector< size_t >& counts, } } + +BleuScoreFeature::BleuScoreFeature(const std::string &line) +:StatefulFeatureFunction("BleuScoreFeature",1), +m_enabled(true), +m_sentence_bleu(true), +m_simple_history_bleu(false), +m_count_history(BleuScoreState::bleu_order), +m_match_history(BleuScoreState::bleu_order), +m_source_length_history(0), +m_target_length_history(0), +m_ref_length_history(0), +m_scale_by_input_length(true), +m_scale_by_avg_input_length(false), +m_scale_by_inverse_length(false), +m_scale_by_avg_inverse_length(false), +m_scale_by_x(1), +m_historySmoothing(0.9), +m_smoothing_scheme(PLUS_POINT_ONE) {} +{ + vector referenceFiles = m_parameter->GetParam("references"); + if ((!referenceFiles.size() && bleuWeightStr.size()) || (referenceFiles.size() && !bleuWeightStr.size())) { + UserMessage::Add("You cannot use the bleu feature without references, and vice-versa"); + return false; + } + if (!referenceFiles.size()) { + return true; + } + if (bleuWeightStr.size() > 1) { + UserMessage::Add("Can only specify one weight for the bleu feature"); + return false; + } + + float bleuWeight = Scan(bleuWeightStr[0]); + BleuScoreFeature *bleuScoreFeature = new BleuScoreFeature(); + SetWeight(bleuScoreFeature, bleuWeight); + + cerr << "Loading reference file " << referenceFiles[0] << endl; + vector > references(referenceFiles.size()); + for (size_t i =0; i < referenceFiles.size(); ++i) { + ifstream in(referenceFiles[i].c_str()); + if (!in) { + stringstream strme; + strme << "Unable to load references from " << referenceFiles[i]; + UserMessage::Add(strme.str()); + return false; + } + string line; + while (getline(in,line)) { +/* if (GetSearchAlgorithm() == ChartDecoding) { + stringstream tmp; + tmp << " " << line << " "; + line = tmp.str(); + }*/ + references[i].push_back(line); + } + if (i > 0) { + if (references[i].size() != references[i-1].size()) { + UserMessage::Add("Reference files are of different lengths"); + return false; + } + } + in.close(); + } + //Set the references in the bleu feature + bleuScoreFeature->LoadReferences(references); + + +} + void BleuScoreFeature::PrintHistory(std::ostream& out) const { out << "source length history=" << m_source_length_history << endl; out << "target length history=" << m_target_length_history << endl; diff --git a/moses/BleuScoreFeature.h b/moses/BleuScoreFeature.h index 697759e4b..fb6fd944d 100644 --- a/moses/BleuScoreFeature.h +++ b/moses/BleuScoreFeature.h @@ -62,23 +62,7 @@ public: typedef boost::unordered_map RefCounts; typedef boost::unordered_map Matches; - BleuScoreFeature(): - StatefulFeatureFunction("BleuScoreFeature",1), - m_enabled(true), - m_sentence_bleu(true), - m_simple_history_bleu(false), - m_count_history(BleuScoreState::bleu_order), - m_match_history(BleuScoreState::bleu_order), - m_source_length_history(0), - m_target_length_history(0), - m_ref_length_history(0), - m_scale_by_input_length(true), - m_scale_by_avg_input_length(false), - m_scale_by_inverse_length(false), - m_scale_by_avg_inverse_length(false), - m_scale_by_x(1), - m_historySmoothing(0.9), - m_smoothing_scheme(PLUS_POINT_ONE) {} + BleuScoreFeature(const std::string &line); void PrintHistory(std::ostream& out) const; void LoadReferences(const std::vector< std::vector< std::string > > &); diff --git a/moses/StaticData.cpp b/moses/StaticData.cpp index 24231b569..54585ef8c 100644 --- a/moses/StaticData.cpp +++ b/moses/StaticData.cpp @@ -628,6 +628,11 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord); const vector &weights = m_parameter->GetWeights(feature, featureIndex); SetWeights(model, weights); } + else if (feature == "BleuScoreFeature") { + BleuScoreFeature *model = new BleuScoreFeature(line); + const vector &weights = m_parameter->GetWeights(feature, featureIndex); + SetWeights(model, weights); + } #ifdef HAVE_SYNLM else if (feature == "SyntacticLanguageModel") { SyntacticLanguageModel *model = new SyntacticLanguageModel(line); @@ -647,7 +652,6 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord); if (!LoadPhraseTables()) return false; if (!LoadDecodeGraphs()) return false; - if (!LoadReferences()) return false; // report individual sparse features in n-best list if (m_parameter->GetParam("report-sparse-features").size() > 0) { @@ -1035,61 +1039,6 @@ bool StaticData::LoadDecodeGraphs() return true; } -bool StaticData::LoadReferences() -{ - vector bleuWeightStr = m_parameter->GetParam("weight-bl"); - vector referenceFiles = m_parameter->GetParam("references"); - if ((!referenceFiles.size() && bleuWeightStr.size()) || (referenceFiles.size() && !bleuWeightStr.size())) { - UserMessage::Add("You cannot use the bleu feature without references, and vice-versa"); - return false; - } - if (!referenceFiles.size()) { - return true; - } - if (bleuWeightStr.size() > 1) { - UserMessage::Add("Can only specify one weight for the bleu feature"); - return false; - } - - float bleuWeight = Scan(bleuWeightStr[0]); - BleuScoreFeature *bleuScoreFeature = new BleuScoreFeature(); - SetWeight(bleuScoreFeature, bleuWeight); - - cerr << "Loading reference file " << referenceFiles[0] << endl; - vector > references(referenceFiles.size()); - for (size_t i =0; i < referenceFiles.size(); ++i) { - ifstream in(referenceFiles[i].c_str()); - if (!in) { - stringstream strme; - strme << "Unable to load references from " << referenceFiles[i]; - UserMessage::Add(strme.str()); - return false; - } - string line; - while (getline(in,line)) { -/* if (GetSearchAlgorithm() == ChartDecoding) { - stringstream tmp; - tmp << " " << line << " "; - line = tmp.str(); - }*/ - references[i].push_back(line); - } - if (i > 0) { - if (references[i].size() != references[i-1].size()) { - UserMessage::Add("Reference files are of different lengths"); - return false; - } - } - in.close(); - } - //Set the references in the bleu feature - bleuScoreFeature->LoadReferences(references); - - return true; -} - - - const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const { std::pair key(decodeGraph.GetPosition(), sourcePhrase); diff --git a/moses/StaticData.h b/moses/StaticData.h index 8ee24f929..1ee1b1381 100644 --- a/moses/StaticData.h +++ b/moses/StaticData.h @@ -231,8 +231,6 @@ protected: bool LoadPhraseTables(); //! load decoding steps bool LoadDecodeGraphs(); - //References used for scoring feature (eg BleuScoreFeature) for online training - bool LoadReferences(); void ReduceTransOptCache() const; bool m_continuePartialTranslation;