Output feature weights to a separate file when producing hypergraph

This commit is contained in:
Lane Schwartz 2013-02-22 16:20:03 -05:00
parent 285661fec7
commit 4adeb7e338
5 changed files with 76 additions and 60 deletions

View File

@ -207,6 +207,15 @@ InputType*IOWrapper::GetInput(InputType* inputType)
return file;
}
ofstream* IOWrapper::GetOutputSearchGraphHypergraphWeightsStream() {
const StaticData &staticData = StaticData::Instance();
stringstream fileName;
fileName << staticData.GetParam("output-search-graph-hypergraph")[1];
std::ofstream *file = new std::ofstream;
file->open(fileName.str().c_str());
return file;
}
/***
* print surface factor only for the given phrase
*/

View File

@ -119,6 +119,7 @@ public:
std::ofstream *GetOutputSearchGraphSLFStream(size_t sentenceNumber);
std::ofstream *GetOutputSearchGraphHypergraphStream(size_t sentenceNumber);
std::ofstream *GetOutputSearchGraphHypergraphWeightsStream();
std::ostream &GetDetailedTranslationReportingStream() {
assert (m_detailedTranslationReportingStream);

View File

@ -333,6 +333,7 @@ public:
}
delete m_searchGraphSLFStream;
delete m_searchGraphHypergraphStream;
delete m_source;
}
@ -406,6 +407,63 @@ static void ShowWeights()
}
size_t OutputFeatureWeightsForHypergraph(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream)
{
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
if (numScoreComps > 1) {
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName()
<< i
<< "=" << values[i] << endl;
}
} else {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName()
<< "=" << values[0] << endl;
}
return index+numScoreComps;
} else {
cerr << "Sparse features are not yet supported when outputting hypergraph format" << endl;
assert(false);
return 0;
}
}
void OutputFeatureWeightsForHypergraph(std::ostream &outputSearchGraphStream)
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, gds[i], outputSearchGraphStream);
}
}
} //namespace
/** main function of the command line version of the decoder **/
@ -469,6 +527,14 @@ int main(int argc, char** argv)
TRACE_ERR(weights);
TRACE_ERR("\n");
}
if (staticData.GetOutputSearchGraphHypergraph() && staticData.GetParam("output-search-graph-hypergraph").size() > 1) {
ofstream* weightsOut = ioWrapper->GetOutputSearchGraphHypergraphWeightsStream();
OutputFeatureWeightsForHypergraph(*weightsOut);
weightsOut->flush();
weightsOut->close();
delete weightsOut;
}
// initialize output streams
// note: we can't just write to STDOUT or files

View File

@ -663,40 +663,6 @@ void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream)
}
void Manager::OutputFeatureWeightsForHypergraph(std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, gds[i], outputSearchGraphStream);
}
}
void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
@ -788,30 +754,6 @@ size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction*
}
}
size_t Manager::OutputFeatureWeightsForHypergraph(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
if (numScoreComps > 1) {
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName()
<< i
<< "=" << values[i] << endl;
}
} else {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName()
<< "=" << values[0] << endl;
}
return index+numScoreComps;
} else {
cerr << "Sparse features are not yet supported when outputting hypergraph format" << endl;
assert(false);
return 0;
}
}
size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{

View File

@ -102,8 +102,6 @@ private:
size_t OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
// Helper functions to output search graph in the hypergraph format of Kenneth Heafield's lazy hypergraph decoder
void OutputFeatureWeightsForHypergraph(std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureWeightsForHypergraph(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
void OutputFeatureValuesForHypergraph(const Hypothesis* hypo, std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureValuesForHypergraph(size_t index, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;