added WordTranslationFeature

git-svn-id: http://svn.statmt.org/repository/mira@3930 cc96ff50-19ce-11e0-b349-13d7f0bd23df
This commit is contained in:
pkoehn 2011-08-13 02:40:54 +00:00 committed by Ondrej Bojar
parent c815741145
commit 69fdd15792
8 changed files with 171 additions and 2 deletions

View File

@ -117,6 +117,7 @@ libmoses_la_HEADERS = \
Word.h \
WordConsumed.h \
WordLattice.h \
WordTranslationFeature.h \
WordsBitmap.h \
WordsRange.h \
XmlOption.h
@ -248,6 +249,7 @@ libmoses_la_SOURCES = \
Word.cpp \
WordConsumed.cpp \
WordLattice.cpp \
WordTranslationFeature.cpp \
WordsBitmap.cpp \
WordsRange.cpp \
XmlOption.cpp

View File

@ -145,6 +145,7 @@ Parameter::Parameter()
AddParam("phrase-length-feature", "Count features for source length, target length, both of each phrase");
AddParam("target-word-insertion-feature", "Count feature for each unaligned target word");
AddParam("source-word-deletion-feature", "Count feature for each unaligned source word");
AddParam("word-translation-feature", "Count feature for word translation according to word alignment");
AddParam("report-sparse-features", "Indicate which sparse feature functions should report detailed scores in n-best, instead of aggregate");
AddParam("show-weights", "print feature weights and exit");
}

View File

@ -53,7 +53,6 @@ void SourceWordDeletionFeature::Evaluate(const TargetPhrase& targetPhrase,
for(size_t i=0; i<sourceLength; i++) {
if (!aligned[i]) {
const string &word = targetPhrase.GetSourcePhrase().GetWord(i).GetFactor(m_factorType)->GetString();
stringstream featureName;
if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) {
accumulator->PlusEquals(this,"OTHER",1);
}

View File

@ -41,6 +41,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "PhraseLengthFeature.h"
#include "TargetWordInsertionFeature.h"
#include "SourceWordDeletionFeature.h"
#include "WordTranslationFeature.h"
#include "UserMessage.h"
#include "TranslationOption.h"
#include "TargetBigramFeature.h"
@ -78,6 +79,7 @@ StaticData::StaticData()
,m_phraseLengthFeature(NULL)
,m_targetWordInsertionFeature(NULL)
,m_sourceWordDeletionFeature(NULL)
,m_wordTranslationFeature(NULL)
,m_numLinkParams(1)
,m_fLMsLoaded(false)
,m_sourceStartPosMattersForRecombination(false)
@ -467,6 +469,7 @@ bool StaticData::LoadData(Parameter *parameter)
if (!LoadPhraseLengthFeature()) return false;
if (!LoadTargetWordInsertionFeature()) return false;
if (!LoadSourceWordDeletionFeature()) return false;
if (!LoadWordTranslationFeature()) return false;
// report individual sparse features in n-best list
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
@ -484,6 +487,8 @@ bool StaticData::LoadData(Parameter *parameter)
m_targetWordInsertionFeature->SetSparseFeatureReporting();
if (m_sourceWordDeletionFeature && name.compare(m_sourceWordDeletionFeature->GetScoreProducerWeightShortName()) == 0)
m_sourceWordDeletionFeature->SetSparseFeatureReporting();
if (m_wordTranslationFeature && name.compare(m_wordTranslationFeature->GetScoreProducerWeightShortName()) == 0)
m_wordTranslationFeature->SetSparseFeatureReporting();
}
}
@ -586,6 +591,9 @@ bool StaticData::LoadData(Parameter *parameter)
if (m_sourceWordDeletionFeature) {
m_translationSystems.find(config[0])->second.AddFeatureFunction(m_sourceWordDeletionFeature);
}
if (m_wordTranslationFeature) {
m_translationSystems.find(config[0])->second.AddFeatureFunction(m_wordTranslationFeature);
}
}
//Load extra feature weights
@ -667,7 +675,9 @@ StaticData::~StaticData()
delete m_phrasePairFeature;
delete m_phraseBoundaryFeature;
delete m_phraseLengthFeature;
delete m_targetWordInsertionFeature;
delete m_sourceWordDeletionFeature;
delete m_wordTranslationFeature;
//delete m_parameter;
@ -1483,6 +1493,47 @@ bool StaticData::LoadSourceWordDeletionFeature()
return true;
}
bool StaticData::LoadWordTranslationFeature()
{
const vector<string> &parameters = m_parameter->GetParam("word-translation-feature");
if (parameters.empty())
return true;
if (parameters.size() != 1) {
UserMessage::Add("Can only have one word-translation-feature");
return false;
}
vector<string> tokens = Tokenize(parameters[0]);
if (tokens.size() != 2 && tokens.size() != 4) {
UserMessage::Add("Format of word translation feature parameter is: --word-translation-feature <factor-src> <factor-tgt> [filename-src filename-tgt]");
return false;
}
if (!m_UseAlignmentInfo) {
UserMessage::Add("Word translation feature needs word alignments in phrase table.");
return false;
}
// set factor
FactorType factorIdSource = Scan<size_t>(tokens[0]);
FactorType factorIdTarget = Scan<size_t>(tokens[1]);
m_wordTranslationFeature = new WordTranslationFeature(factorIdSource,factorIdTarget);
// load word list for restricted feature set
if (tokens.size() == 4) {
string filenameSource = tokens[2];
string filenameTarget = tokens[3];
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl;
if (!m_wordTranslationFeature->Load(filenameSource, filenameTarget)) {
UserMessage::Add("Unable to load word lists for word translation feature from files " + filenameSource + " and " + filenameTarget);
return false;
}
}
return true;
}
const TranslationOptionList* StaticData::FindTransOptListInCache(const DecodeGraph &decodeGraph, const Phrase &sourcePhrase) const
{
std::pair<size_t, Phrase> key(decodeGraph.GetPosition(), sourcePhrase);

View File

@ -63,6 +63,7 @@ class BleuScoreFeature;
class PhraseLengthFeature;
class TargetWordInsertionFeature;
class SourceWordDeletionFeature;
class WordTranslationFeature;
class GenerationDictionary;
class DistortionScoreProducer;
class DecodeStep;
@ -99,6 +100,7 @@ protected:
PhraseLengthFeature* m_phraseLengthFeature;
TargetWordInsertionFeature* m_targetWordInsertionFeature;
SourceWordDeletionFeature* m_sourceWordDeletionFeature;
WordTranslationFeature* m_wordTranslationFeature;
float
m_beamWidth,
m_earlyDiscardingThreshold,
@ -246,6 +248,7 @@ protected:
bool LoadPhraseLengthFeature();
bool LoadTargetWordInsertionFeature();
bool LoadSourceWordDeletionFeature();
bool LoadWordTranslationFeature();
void ReduceTransOptCache() const;
bool m_continuePartialTranslation;

View File

@ -53,7 +53,6 @@ void TargetWordInsertionFeature::Evaluate(const TargetPhrase& targetPhrase,
for(size_t i=0; i<targetLength; i++) {
if (!aligned[i]) {
const string &word = targetPhrase.GetWord(i).GetFactor(m_factorType)->GetString();
stringstream featureName;
if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) {
accumulator->PlusEquals(this,"OTHER",1);
}

View File

@ -0,0 +1,71 @@
#include <sstream>
#include "WordTranslationFeature.h"
#include "Phrase.h"
#include "TargetPhrase.h"
#include "Hypothesis.h"
#include "ScoreComponentCollection.h"
namespace Moses {
using namespace std;
bool WordTranslationFeature::Load(const std::string &filePathSource, const std::string &filePathTarget)
{
// restricted source word vocabulary
ifstream inFileSource(filePathSource.c_str());
if (!inFileSource)
{
cerr << "could not open file " << filePathSource << endl;
return false;
}
std::string line;
while (getline(inFileSource, line)) {
m_vocabSource.insert(line);
}
inFileSource.close();
// restricted target word vocabulary
ifstream inFileTarget(filePathTarget.c_str());
if (!inFileTarget)
{
cerr << "could not open file " << filePathTarget << endl;
return false;
}
while (getline(inFileTarget, line)) {
m_vocabTarget.insert(line);
}
inFileTarget.close();
m_unrestricted = false;
return true;
}
void WordTranslationFeature::Evaluate(const TargetPhrase& targetPhrase,
ScoreComponentCollection* accumulator) const
{
const AlignmentInfo &alignment = targetPhrase.GetAlignmentInfo();
// process aligned words
for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {
// look up words
const string &sourceWord = targetPhrase.GetSourcePhrase().GetWord(alignmentPoint->first).GetFactor(m_factorTypeSource)->GetString();
const string &targetWord = targetPhrase.GetWord(alignmentPoint->second).GetFactor(m_factorTypeTarget)->GetString();
bool sourceExists = m_vocabSource.find( sourceWord ) != m_vocabSource.end();
bool targetExists = m_vocabTarget.find( targetWord ) != m_vocabTarget.end();
// no feature if both words are not in restricted vocabularies
if (m_unrestricted || sourceExists || targetExists) {
// construct feature name
stringstream featureName;
featureName << (sourceExists ? sourceWord : "OTHER");
featureName << "|";
featureName << (targetExists ? targetWord : "OTHER");
accumulator->PlusEquals(this,featureName.str(),1);
}
}
}
}

View File

@ -0,0 +1,43 @@
#ifndef moses_WordTranslationFeature_h
#define moses_WordTranslationFeature_h
#include <string>
#include <map>
#include "FeatureFunction.h"
#include "FactorCollection.h"
namespace Moses
{
/** Sets the features for word translation
*/
class WordTranslationFeature : public StatelessFeatureFunction {
private:
std::set<std::string> m_vocabSource;
std::set<std::string> m_vocabTarget;
FactorType m_factorTypeSource;
FactorType m_factorTypeTarget;
bool m_unrestricted;
public:
WordTranslationFeature(FactorType factorTypeSource = 0, FactorType factorTypeTarget = 0):
StatelessFeatureFunction("wt"),
m_factorTypeSource(factorTypeSource),
m_factorTypeTarget(factorTypeTarget),
m_unrestricted(true)
{}
bool Load(const std::string &filePathSource, const std::string &filePathTarget);
void Evaluate(const TargetPhrase& cur_phrase,
ScoreComponentCollection* accumulator) const;
// basic properties
size_t GetNumScoreComponents() const { return ScoreProducer::unlimited; }
std::string GetScoreProducerWeightShortName() const { return "wt"; }
size_t GetNumInputScores() const { return 0; }
};
}
#endif // moses_WordTranslationFeature_h