mosesdecoder/moses/FF/WordTranslationFeature.cpp

366 lines
13 KiB
C++
Raw Normal View History

#include <sstream>
#include <boost/algorithm/string.hpp>
#include "WordTranslationFeature.h"
2013-05-24 21:02:49 +04:00
#include "moses/Phrase.h"
#include "moses/TargetPhrase.h"
#include "moses/Hypothesis.h"
#include "moses/ChartHypothesis.h"
#include "moses/ScoreComponentCollection.h"
#include "moses/TranslationOption.h"
#include "moses/UserMessage.h"
#include "moses/InputPath.h"
#include "util/string_piece_hash.hh"
#include "util/exception.hh"
using namespace std;
2013-05-29 21:16:15 +04:00
namespace Moses
{
WordTranslationFeature::WordTranslationFeature(const std::string &line)
:StatelessFeatureFunction(0, line)
2013-05-29 21:16:15 +04:00
,m_unrestricted(true)
,m_simple(true)
,m_sourceContext(false)
,m_targetContext(false)
,m_ignorePunctuation(false)
,m_domainTrigger(false)
{
std::cerr << "Initializing word translation feature.. " << endl;
ReadParameters();
if (m_simple == 1) std::cerr << "using simple word translations.. ";
if (m_sourceContext == 1) std::cerr << "using source context.. ";
if (m_targetContext == 1) std::cerr << "using target context.. ";
if (m_domainTrigger == 1) std::cerr << "using domain triggers.. ";
// compile a list of punctuation characters
if (m_ignorePunctuation) {
std::cerr << "ignoring punctuation for triggers.. ";
char punctuation[] = "\"'!?¿·()#_,.:;•&@/\\0123456789~=";
for (size_t i=0; i < sizeof(punctuation)-1; ++i) {
m_punctuationHash[punctuation[i]] = 1;
}
}
std::cerr << "done." << std::endl;
// TODO not sure about this
/*
if (weight[0] != 1) {
AddSparseProducer(wordTranslationFeature);
cerr << "wt sparse producer weight: " << weight[0] << endl;
if (m_mira)
m_metaFeatureProducer = new MetaFeatureProducer("wt");
}
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
wordTranslationFeature->SetSparseFeatureReporting();
}
*/
}
void WordTranslationFeature::SetParameter(const std::string& key, const std::string& value)
{
if (key == "input-factor") {
m_factorTypeSource = Scan<FactorType>(value);
} else if (key == "output-factor") {
m_factorTypeTarget = Scan<FactorType>(value);
} else if (key == "simple") {
m_simple = Scan<bool>(value);
} else if (key == "source-context") {
m_sourceContext = Scan<bool>(value);
} else if (key == "target-context") {
m_targetContext = Scan<bool>(value);
} else if (key == "ignore-punctuation") {
m_ignorePunctuation = Scan<bool>(value);
} else if (key == "domain-trigger") {
m_domainTrigger = Scan<bool>(value);
} else if (key == "texttype") {
//texttype = value; TODO not used
} else if (key == "source-path") {
m_filePathSource = value;
} else if (key == "target-path") {
m_filePathTarget = value;
} else {
StatelessFeatureFunction::SetParameter(key, value);
}
}
void WordTranslationFeature::Load()
{
// load word list for restricted feature set
if (m_filePathSource.empty()) {
return;
} //else if (tokens.size() == 8) {
cerr << "loading word translation word lists from " << m_filePathSource << " and " << m_filePathTarget << endl;
if (m_domainTrigger) {
// domain trigger terms for each input document
ifstream inFileSource(m_filePathSource.c_str());
UTIL_THROW_IF(!inFileSource, util::Exception, "could not open file " << m_filePathSource);
2013-05-29 21:16:15 +04:00
std::string line;
while (getline(inFileSource, line)) {
2013-05-29 21:16:15 +04:00
m_vocabDomain.resize(m_vocabDomain.size() + 1);
vector<string> termVector;
boost::split(termVector, line, boost::is_any_of("\t "));
for (size_t i=0; i < termVector.size(); ++i)
m_vocabDomain.back().insert(termVector[i]);
}
2013-05-29 21:16:15 +04:00
inFileSource.close();
2013-05-29 21:16:15 +04:00
} else {
// restricted source word vocabulary
ifstream inFileSource(m_filePathSource.c_str());
UTIL_THROW_IF(!inFileSource, util::Exception, "could not open file " << m_filePathSource);
2013-05-29 21:16:15 +04:00
std::string line;
while (getline(inFileSource, line)) {
m_vocabSource.insert(line);
}
2013-05-29 21:16:15 +04:00
inFileSource.close();
2013-05-29 21:16:15 +04:00
// restricted target word vocabulary
ifstream inFileTarget(m_filePathTarget.c_str());
UTIL_THROW_IF(!inFileTarget, util::Exception, "could not open file " << m_filePathTarget);
2013-05-29 21:16:15 +04:00
while (getline(inFileTarget, line)) {
m_vocabTarget.insert(line);
}
2013-05-29 21:16:15 +04:00
inFileTarget.close();
2013-05-29 21:16:15 +04:00
m_unrestricted = false;
}
}
void WordTranslationFeature::Evaluate
2013-08-23 17:25:25 +04:00
(const Hypothesis& hypo,
2013-05-29 21:16:15 +04:00
ScoreComponentCollection* accumulator) const
{
2013-08-23 17:25:25 +04:00
const Sentence& input = static_cast<const Sentence&>(hypo.GetInput());
const TranslationOption& transOpt = hypo.GetTranslationOption();
const TargetPhrase& targetPhrase = hypo.GetCurrTargetPhrase();
const AlignmentInfo &alignment = targetPhrase.GetAlignTerm();
// process aligned words
for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {
const Phrase& sourcePhrase = transOpt.GetInputPath().GetPhrase();
int sourceIndex = alignmentPoint->first;
int targetIndex = alignmentPoint->second;
Word ws = sourcePhrase.GetWord(sourceIndex);
if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue;
Word wt = targetPhrase.GetWord(targetIndex);
if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue;
StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString();
StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString();
if (m_ignorePunctuation) {
// check if source or target are punctuation
2013-04-29 21:46:48 +04:00
char firstChar = sourceWord[0];
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
2013-05-29 21:16:15 +04:00
continue;
2013-04-29 21:46:48 +04:00
firstChar = targetWord[0];
charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
continue;
}
2012-03-07 21:56:29 +04:00
if (!m_unrestricted) {
if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end())
2013-05-29 21:16:15 +04:00
sourceWord = "OTHER";
if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end())
2013-05-29 21:16:15 +04:00
targetWord = "OTHER";
2012-03-07 21:56:29 +04:00
}
if (m_simple) {
// construct feature name
stringstream featureName;
featureName << m_description << "_";
featureName << sourceWord;
featureName << "~";
featureName << targetWord;
accumulator->SparsePlusEquals(featureName.str(), 1);
}
if (m_domainTrigger && !m_sourceContext) {
const bool use_topicid = input.GetUseTopicId();
const bool use_topicid_prob = input.GetUseTopicIdAndProb();
if (use_topicid || use_topicid_prob) {
2013-05-29 21:16:15 +04:00
if(use_topicid) {
// use topicid as trigger
const long topicid = input.GetTopicId();
stringstream feature;
feature << m_description << "_";
if (topicid == -1)
feature << "unk";
else
feature << topicid;
feature << "_";
feature << sourceWord;
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
} else {
// use topic probabilities
const vector<string> &topicid_prob = *(input.GetTopicIdAndProb());
if (atol(topicid_prob[0].c_str()) == -1) {
stringstream feature;
feature << m_description << "_unk_";
feature << sourceWord;
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
} else {
for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
stringstream feature;
feature << m_description << "_";
feature << topicid_prob[i];
feature << "_";
feature << sourceWord;
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
}
}
}
} else {
// range over domain trigger words (keywords)
const long docid = input.GetDocumentId();
for (boost::unordered_set<std::string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
string sourceTrigger = *p;
stringstream feature;
feature << m_description << "_";
feature << sourceTrigger;
feature << "_";
feature << sourceWord;
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
}
}
}
if (m_sourceContext) {
2013-08-23 17:25:25 +04:00
size_t globalSourceIndex = hypo.GetTranslationOption().GetStartPos() + sourceIndex;
if (!m_domainTrigger && globalSourceIndex == 0) {
2013-05-29 21:16:15 +04:00
// add <s> trigger feature for source
stringstream feature;
feature << m_description << "_";
feature << "<s>,";
feature << sourceWord;
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
}
// range over source words to get context
for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
2013-05-29 21:16:15 +04:00
if (contextIndex == globalSourceIndex) continue;
StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
if (m_ignorePunctuation) {
// check if trigger is punctuation
char firstChar = sourceTrigger[0];
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
continue;
}
const long docid = input.GetDocumentId();
bool sourceTriggerExists = false;
if (m_domainTrigger)
sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger ) != m_vocabDomain[docid].end();
else if (!m_unrestricted)
sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
if (m_domainTrigger) {
if (sourceTriggerExists) {
stringstream feature;
feature << m_description << "_";
feature << sourceTrigger;
feature << "_";
feature << sourceWord;
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
}
} else if (m_unrestricted || sourceTriggerExists) {
stringstream feature;
feature << m_description << "_";
if (contextIndex < globalSourceIndex) {
feature << sourceTrigger;
feature << ",";
feature << sourceWord;
} else {
feature << sourceWord;
feature << ",";
feature << sourceTrigger;
}
feature << "~";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
}
}
}
if (m_targetContext) {
throw runtime_error("Can't use target words outside current translation option in a stateless feature");
/*
2013-05-29 21:16:15 +04:00
size_t globalTargetIndex = cur_hypo.GetCurrTargetWordsRange().GetStartPos() + targetIndex;
if (globalTargetIndex == 0) {
// add <s> trigger feature for source
stringstream feature;
feature << "wt_";
feature << sourceWord;
feature << "~";
feature << "<s>,";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
}
// range over target words (up to current position) to get context
for(size_t contextIndex = 0; contextIndex < globalTargetIndex; contextIndex++ ) {
string targetTrigger = cur_hypo.GetWord(contextIndex).GetFactor(m_factorTypeTarget)->GetString();
if (m_ignorePunctuation) {
// check if trigger is punctuation
char firstChar = targetTrigger.at(0);
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
if(charIterator != m_punctuationHash.end())
continue;
}
bool targetTriggerExists = false;
if (!m_unrestricted)
targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();
if (m_unrestricted || targetTriggerExists) {
stringstream feature;
feature << "wt_";
feature << sourceWord;
feature << "~";
feature << targetTrigger;
feature << ",";
feature << targetWord;
accumulator->SparsePlusEquals(feature.str(), 1);
}
}*/
}
}
}
void WordTranslationFeature::EvaluateChart(
2013-08-23 18:00:47 +04:00
const ChartHypothesis &hypo,
2013-05-29 21:16:15 +04:00
ScoreComponentCollection* accumulator) const
{
2013-08-09 21:17:18 +04:00
UTIL_THROW(util::Exception, "Need source phrase. Can't be arsed at the moment");
}
bool WordTranslationFeature::IsUseable(const FactorMask &mask) const
{
2013-05-30 15:51:40 +04:00
bool ret = mask[m_factorTypeTarget];
return ret;
}
}