bleu feature

This commit is contained in:
Hieu Hoang 2013-01-29 10:54:09 +00:00
parent 00ebc1e5ab
commit 967ac6542a
4 changed files with 75 additions and 75 deletions

View File

@ -70,6 +70,75 @@ void BleuScoreState::AddNgramCountAndMatches(std::vector< size_t >& counts,
}
}
BleuScoreFeature::BleuScoreFeature(const std::string &line)
:StatefulFeatureFunction("BleuScoreFeature",1),
m_enabled(true),
m_sentence_bleu(true),
m_simple_history_bleu(false),
m_count_history(BleuScoreState::bleu_order),
m_match_history(BleuScoreState::bleu_order),
m_source_length_history(0),
m_target_length_history(0),
m_ref_length_history(0),
m_scale_by_input_length(true),
m_scale_by_avg_input_length(false),
m_scale_by_inverse_length(false),
m_scale_by_avg_inverse_length(false),
m_scale_by_x(1),
m_historySmoothing(0.9),
m_smoothing_scheme(PLUS_POINT_ONE) {}
{
vector<string> referenceFiles = m_parameter->GetParam("references");
if ((!referenceFiles.size() && bleuWeightStr.size()) || (referenceFiles.size() && !bleuWeightStr.size())) {
UserMessage::Add("You cannot use the bleu feature without references, and vice-versa");
return false;
}
if (!referenceFiles.size()) {
return true;
}
if (bleuWeightStr.size() > 1) {
UserMessage::Add("Can only specify one weight for the bleu feature");
return false;
}
float bleuWeight = Scan<float>(bleuWeightStr[0]);
BleuScoreFeature *bleuScoreFeature = new BleuScoreFeature();
SetWeight(bleuScoreFeature, bleuWeight);
cerr << "Loading reference file " << referenceFiles[0] << endl;
vector<vector<string> > references(referenceFiles.size());
for (size_t i =0; i < referenceFiles.size(); ++i) {
ifstream in(referenceFiles[i].c_str());
if (!in) {
stringstream strme;
strme << "Unable to load references from " << referenceFiles[i];
UserMessage::Add(strme.str());
return false;
}
string line;
while (getline(in,line)) {
/* if (GetSearchAlgorithm() == ChartDecoding) {
stringstream tmp;
tmp << "<s> " << line << " </s>";
line = tmp.str();
}*/
references[i].push_back(line);
}
if (i > 0) {
if (references[i].size() != references[i-1].size()) {
UserMessage::Add("Reference files are of different lengths");
return false;
}
}
in.close();
}
//Set the references in the bleu feature
bleuScoreFeature->LoadReferences(references);
}
void BleuScoreFeature::PrintHistory(std::ostream& out) const {
out << "source length history=" << m_source_length_history << endl;
out << "target length history=" << m_target_length_history << endl;

View File

@ -62,23 +62,7 @@ public:
typedef boost::unordered_map<size_t, RefValue > RefCounts;
typedef boost::unordered_map<size_t, NGrams> Matches;
BleuScoreFeature():
StatefulFeatureFunction("BleuScoreFeature",1),
m_enabled(true),
m_sentence_bleu(true),
m_simple_history_bleu(false),
m_count_history(BleuScoreState::bleu_order),
m_match_history(BleuScoreState::bleu_order),
m_source_length_history(0),
m_target_length_history(0),
m_ref_length_history(0),
m_scale_by_input_length(true),
m_scale_by_avg_input_length(false),
m_scale_by_inverse_length(false),
m_scale_by_avg_inverse_length(false),
m_scale_by_x(1),
m_historySmoothing(0.9),
m_smoothing_scheme(PLUS_POINT_ONE) {}
BleuScoreFeature(const std::string &line);
void PrintHistory(std::ostream& out) const;
void LoadReferences(const std::vector< std::vector< std::string > > &);

View File

@ -628,6 +628,11 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
SetWeights(model, weights);
}
else if (feature == "BleuScoreFeature") {
BleuScoreFeature *model = new BleuScoreFeature(line);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
SetWeights(model, weights);
}
#ifdef HAVE_SYNLM
else if (feature == "SyntacticLanguageModel") {
SyntacticLanguageModel *model = new SyntacticLanguageModel(line);
@ -647,7 +652,6 @@ SetWeight(m_unknownWordPenaltyProducer, weightUnknownWord);
if (!LoadPhraseTables()) return false;
if (!LoadDecodeGraphs()) return false;
if (!LoadReferences()) return false;
// report individual sparse features in n-best list
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
@ -1035,61 +1039,6 @@ bool StaticData::LoadDecodeGraphs()
return true;
}
bool StaticData::LoadReferences()
{
vector<string> bleuWeightStr = m_parameter->GetParam("weight-bl");
vector<string> referenceFiles = m_parameter->GetParam("references");
if ((!referenceFiles.size() && bleuWeightStr.size()) || (referenceFiles.size() && !bleuWeightStr.size())) {
UserMessage::Add("You cannot use the bleu feature without references, and vice-versa");
return false;
}
if (!referenceFiles.size()) {
return true;
}
if (bleuWeightStr.size() > 1) {
UserMessage::Add("Can only specify one weight for the bleu feature");
return false;
}
float bleuWeight = Scan<float>(bleuWeightStr[0]);
BleuScoreFeature *bleuScoreFeature = new BleuScoreFeature();
SetWeight(bleuScoreFeature, bleuWeight);
cerr << "Loading reference file " << referenceFiles[0] << endl;
vector<vector<string> > references(referenceFiles.size());
for (size_t i =0; i < referenceFiles.size(); ++i) {
ifstream in(referenceFiles[i].c_str());
if (!in) {
stringstream strme;
strme << "Unable to load references from " << referenceFiles[i];
UserMessage::Add(strme.str());
return false;
}
string line;
while (getline(in,line)) {
/* if (GetSearchAlgorithm() == ChartDecoding) {
stringstream tmp;
tmp << "<s> " << line << " </s>";
line = tmp.str();
}*/
references[i].push_back(line);
}
if (i > 0) {
if (references[i].size() != references[i-1].size()) {
UserMessage::Add("Reference files are of different lengths");
return false;
}
}
in.close();
}
//Set the references in the bleu feature
bleuScoreFeature->LoadReferences(references);
return true;
}
const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const
{
std::pair<size_t, Phrase> key(decodeGraph.GetPosition(), sourcePhrase);

View File

@ -231,8 +231,6 @@ protected:
bool LoadPhraseTables();
//! load decoding steps
bool LoadDecodeGraphs();
//References used for scoring feature (eg BleuScoreFeature) for online training
bool LoadReferences();
void ReduceTransOptCache() const;
bool m_continuePartialTranslation;