mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 21:42:19 +03:00
More work on outputting HTK lattice format
This commit is contained in:
parent
e106e04dc3
commit
e7563111de
@ -53,12 +53,12 @@ using namespace std;
|
||||
namespace Moses
|
||||
{
|
||||
Manager::Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system)
|
||||
:m_lineNumber(lineNumber)
|
||||
,m_system(system)
|
||||
:m_system(system)
|
||||
,m_transOptColl(source.CreateTranslationOptionCollection(system))
|
||||
,m_search(Search::CreateSearch(*this, source, searchAlgorithm, *m_transOptColl))
|
||||
,interrupted_flag(0)
|
||||
,m_hypoId(0)
|
||||
,m_lineNumber(lineNumber)
|
||||
,m_source(source)
|
||||
{
|
||||
m_system->InitializeBeforeSentenceProcessing(source);
|
||||
@ -630,6 +630,140 @@ void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const
|
||||
|
||||
}
|
||||
|
||||
void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const
|
||||
{
|
||||
outputSearchGraphStream.setf(std::ios::fixed);
|
||||
outputSearchGraphStream.precision(6);
|
||||
|
||||
const StaticData& staticData = StaticData::Instance();
|
||||
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
|
||||
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
|
||||
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
|
||||
size_t featureIndex = 1;
|
||||
for (size_t i = 0; i < sff.size(); ++i) {
|
||||
featureIndex = OutputFeatureWeightsForSLF(featureIndex, sff[i], outputSearchGraphStream);
|
||||
}
|
||||
for (size_t i = 0; i < slf.size(); ++i) {
|
||||
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
|
||||
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
|
||||
slf[i]->GetScoreProducerWeightShortName() != "I" &&
|
||||
slf[i]->GetScoreProducerWeightShortName() != "g")
|
||||
{
|
||||
featureIndex = OutputFeatureWeightsForSLF(featureIndex, slf[i], outputSearchGraphStream);
|
||||
}
|
||||
}
|
||||
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
|
||||
for( size_t i=0; i<pds.size(); i++ ) {
|
||||
featureIndex = OutputFeatureWeightsForSLF(featureIndex, pds[i], outputSearchGraphStream);
|
||||
}
|
||||
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
|
||||
for( size_t i=0; i<gds.size(); i++ ) {
|
||||
featureIndex = OutputFeatureWeightsForSLF(featureIndex, gds[i], outputSearchGraphStream);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const
|
||||
{
|
||||
outputSearchGraphStream.setf(std::ios::fixed);
|
||||
outputSearchGraphStream.precision(6);
|
||||
|
||||
// outputSearchGraphStream << endl;
|
||||
// outputSearchGraphStream << (*hypo) << endl;
|
||||
// const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
|
||||
// outputSearchGraphStream << scoreCollection << endl;
|
||||
|
||||
const StaticData& staticData = StaticData::Instance();
|
||||
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
|
||||
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
|
||||
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
|
||||
size_t featureIndex = 1;
|
||||
for (size_t i = 0; i < sff.size(); ++i) {
|
||||
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, sff[i], outputSearchGraphStream);
|
||||
}
|
||||
for (size_t i = 0; i < slf.size(); ++i) {
|
||||
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
|
||||
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
|
||||
slf[i]->GetScoreProducerWeightShortName() != "I" &&
|
||||
slf[i]->GetScoreProducerWeightShortName() != "g")
|
||||
{
|
||||
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, slf[i], outputSearchGraphStream);
|
||||
}
|
||||
}
|
||||
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
|
||||
for( size_t i=0; i<pds.size(); i++ ) {
|
||||
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, pds[i], outputSearchGraphStream);
|
||||
}
|
||||
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
|
||||
for( size_t i=0; i<gds.size(); i++ ) {
|
||||
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, gds[i], outputSearchGraphStream);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
|
||||
{
|
||||
size_t numScoreComps = ff->GetNumScoreComponents();
|
||||
if (numScoreComps != ScoreProducer::unlimited) {
|
||||
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
|
||||
for (size_t i = 0; i < numScoreComps; ++i) {
|
||||
outputSearchGraphStream << "# " << ff->GetScoreProducerDescription()
|
||||
<< " " << ff->GetScoreProducerWeightShortName()
|
||||
<< " " << (i+1) << " of " << numScoreComps << endl
|
||||
<< "x" << (index+i) << "scale=" << values[i] << endl;
|
||||
}
|
||||
return index+numScoreComps;
|
||||
} else {
|
||||
cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
|
||||
{
|
||||
|
||||
// { const FeatureFunction* sp = ff;
|
||||
// const FVector& m_scores = scoreCollection.GetScoresVector();
|
||||
// FVector& scores = const_cast<FVector&>(m_scores);
|
||||
// std::string prefix = sp->GetScoreProducerDescription() + FName::SEP;
|
||||
// // std::cout << "prefix==" << prefix << endl;
|
||||
// // cout << "m_scores==" << m_scores << endl;
|
||||
// // cout << "m_scores.size()==" << m_scores.size() << endl;
|
||||
// // cout << "m_scores.coreSize()==" << m_scores.coreSize() << endl;
|
||||
// // cout << "m_scores.cbegin() ?= m_scores.cend()\t" << (m_scores.cbegin() == m_scores.cend()) << endl;
|
||||
|
||||
|
||||
// // for(FVector::FNVmap::const_iterator i = m_scores.cbegin(); i != m_scores.cend(); i++) {
|
||||
// // std::cout<<prefix << "\t" << (i->first) << "\t" << (i->second) << std::endl;
|
||||
// // }
|
||||
// for(int i=0, n=v.size(); i<n; i+=1) {
|
||||
// // outputSearchGraphStream << prefix << i << "==" << v[i] << std::endl;
|
||||
|
||||
// }
|
||||
// }
|
||||
|
||||
// FVector featureValues = scoreCollection.GetVectorForProducer(ff);
|
||||
// outputSearchGraphStream << featureValues << endl;
|
||||
const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
|
||||
|
||||
vector<float> featureValues = scoreCollection.GetScoresForProducer(ff);
|
||||
size_t numScoreComps = featureValues.size();//featureValues.coreSize();
|
||||
// if (numScoreComps != ScoreProducer::unlimited) {
|
||||
// vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
|
||||
for (size_t i = 0; i < numScoreComps; ++i) {
|
||||
outputSearchGraphStream << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " ";
|
||||
}
|
||||
return index+numScoreComps;
|
||||
// } else {
|
||||
// cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
|
||||
// assert(false);
|
||||
// return 0;
|
||||
// }
|
||||
}
|
||||
|
||||
/**! Output search graph in HTK standard lattice format (SLF) */
|
||||
void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const
|
||||
{
|
||||
@ -673,10 +807,12 @@ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSea
|
||||
|
||||
outputSearchGraphStream << "UTTERANCE=Sentence_" << translationId << endl;
|
||||
outputSearchGraphStream << "VERSION=1.1" << endl;
|
||||
outputSearchGraphStream << "base=e" << endl;
|
||||
outputSearchGraphStream << "base=2.71828182845905" << endl;
|
||||
outputSearchGraphStream << "NODES=" << (numNodes+1) << endl;
|
||||
outputSearchGraphStream << "LINKS=" << numArcs << endl;
|
||||
|
||||
OutputFeatureWeightsForSLF(outputSearchGraphStream);
|
||||
|
||||
// const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
|
||||
|
||||
for (size_t arcNumber = 0, lineNumber = 0; lineNumber < searchGraph.size(); ++lineNumber) {
|
||||
@ -709,8 +845,11 @@ void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSea
|
||||
}
|
||||
|
||||
outputSearchGraphStream << " E=" << endNode - (x-1) //(startNode + targetWordIndex + 1)
|
||||
<< " W=" << targetPhrase.GetWord(targetWordIndex)
|
||||
<< endl;
|
||||
<< " W=" << targetPhrase.GetWord(targetWordIndex);
|
||||
|
||||
OutputFeatureValuesForSLF(thisHypo, (targetWordIndex>0), outputSearchGraphStream);
|
||||
|
||||
outputSearchGraphStream << endl;
|
||||
|
||||
arcNumber += 1;
|
||||
}
|
||||
|
@ -93,6 +93,11 @@ class Manager
|
||||
Manager(Manager const&);
|
||||
void operator=(Manager const&);
|
||||
const TranslationSystem* m_system;
|
||||
private:
|
||||
void OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const;
|
||||
size_t OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
|
||||
void OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const;
|
||||
size_t OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
|
||||
protected:
|
||||
// data
|
||||
// InputType const& m_source; /**< source sentence to be translated */
|
||||
@ -103,6 +108,7 @@ protected:
|
||||
size_t interrupted_flag;
|
||||
std::auto_ptr<SentenceStats> m_sentenceStats;
|
||||
int m_hypoId; //used to number the hypos as they are created.
|
||||
size_t m_lineNumber;
|
||||
|
||||
void GetConnectedGraph(
|
||||
std::map< int, bool >* pConnected,
|
||||
@ -113,7 +119,6 @@ protected:
|
||||
|
||||
|
||||
public:
|
||||
size_t m_lineNumber;
|
||||
InputType const& m_source; /**< source sentence to be translated */
|
||||
Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system);
|
||||
~Manager();
|
||||
|
Loading…
Reference in New Issue
Block a user