More work on outputting HTK lattice format

This commit is contained in:
Lane Schwartz 2013-02-20 11:03:23 -05:00
parent e106e04dc3
commit e7563111de
2 changed files with 150 additions and 6 deletions

View File

@ -53,12 +53,12 @@ using namespace std;
namespace Moses
{
Manager::Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system)
:m_lineNumber(lineNumber)
,m_system(system)
:m_system(system)
,m_transOptColl(source.CreateTranslationOptionCollection(system))
,m_search(Search::CreateSearch(*this, source, searchAlgorithm, *m_transOptColl))
,interrupted_flag(0)
,m_hypoId(0)
,m_lineNumber(lineNumber)
,m_source(source)
{
m_system->InitializeBeforeSentenceProcessing(source);
@ -630,6 +630,140 @@ void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const
}
void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureWeightsForSLF(featureIndex, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureWeightsForSLF(featureIndex, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForSLF(featureIndex, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForSLF(featureIndex, gds[i], outputSearchGraphStream);
}
}
void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
// outputSearchGraphStream << endl;
// outputSearchGraphStream << (*hypo) << endl;
// const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
// outputSearchGraphStream << scoreCollection << endl;
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, gds[i], outputSearchGraphStream);
}
}
size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << "# " << ff->GetScoreProducerDescription()
<< " " << ff->GetScoreProducerWeightShortName()
<< " " << (i+1) << " of " << numScoreComps << endl
<< "x" << (index+i) << "scale=" << values[i] << endl;
}
return index+numScoreComps;
} else {
cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
assert(false);
return 0;
}
}
size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{
// { const FeatureFunction* sp = ff;
// const FVector& m_scores = scoreCollection.GetScoresVector();
// FVector& scores = const_cast<FVector&>(m_scores);
// std::string prefix = sp->GetScoreProducerDescription() + FName::SEP;
// // std::cout << "prefix==" << prefix << endl;
// // cout << "m_scores==" << m_scores << endl;
// // cout << "m_scores.size()==" << m_scores.size() << endl;
// // cout << "m_scores.coreSize()==" << m_scores.coreSize() << endl;
// // cout << "m_scores.cbegin() ?= m_scores.cend()\t" << (m_scores.cbegin() == m_scores.cend()) << endl;
// // for(FVector::FNVmap::const_iterator i = m_scores.cbegin(); i != m_scores.cend(); i++) {
// // std::cout<<prefix << "\t" << (i->first) << "\t" << (i->second) << std::endl;
// // }
// for(int i=0, n=v.size(); i<n; i+=1) {
// // outputSearchGraphStream << prefix << i << "==" << v[i] << std::endl;
// }
// }
// FVector featureValues = scoreCollection.GetVectorForProducer(ff);
// outputSearchGraphStream << featureValues << endl;
const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
vector<float> featureValues = scoreCollection.GetScoresForProducer(ff);
size_t numScoreComps = featureValues.size();//featureValues.coreSize();
// if (numScoreComps != ScoreProducer::unlimited) {
// vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " ";
}
return index+numScoreComps;
// } else {
// cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
// assert(false);
// return 0;
// }
}
/**! Output search graph in HTK standard lattice format (SLF) */
void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const
{
@ -673,10 +807,12 @@ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSea
outputSearchGraphStream << "UTTERANCE=Sentence_" << translationId << endl;
outputSearchGraphStream << "VERSION=1.1" << endl;
outputSearchGraphStream << "base=e" << endl;
outputSearchGraphStream << "base=2.71828182845905" << endl;
outputSearchGraphStream << "NODES=" << (numNodes+1) << endl;
outputSearchGraphStream << "LINKS=" << numArcs << endl;
OutputFeatureWeightsForSLF(outputSearchGraphStream);
// const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
for (size_t arcNumber = 0, lineNumber = 0; lineNumber < searchGraph.size(); ++lineNumber) {
@ -709,8 +845,11 @@ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSea
}
outputSearchGraphStream << " E=" << endNode - (x-1) //(startNode + targetWordIndex + 1)
<< " W=" << targetPhrase.GetWord(targetWordIndex)
<< endl;
<< " W=" << targetPhrase.GetWord(targetWordIndex);
OutputFeatureValuesForSLF(thisHypo, (targetWordIndex>0), outputSearchGraphStream);
outputSearchGraphStream << endl;
arcNumber += 1;
}

View File

@ -93,6 +93,11 @@ class Manager
Manager(Manager const&);
void operator=(Manager const&);
const TranslationSystem* m_system;
private:
void OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
void OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
protected:
// data
// InputType const& m_source; /**< source sentence to be translated */
@ -103,6 +108,7 @@ protected:
size_t interrupted_flag;
std::auto_ptr<SentenceStats> m_sentenceStats;
int m_hypoId; //used to number the hypos as they are created.
size_t m_lineNumber;
void GetConnectedGraph(
std::map< int, bool >* pConnected,
@ -113,7 +119,6 @@ protected:
public:
size_t m_lineNumber;
InputType const& m_source; /**< source sentence to be translated */
Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system);
~Manager();