mosesdecoder/moses/FF/SparseHieroReorderingFeature.cpp

225 lines
8.3 KiB
C++
Raw Normal View History

2013-09-10 11:58:45 +04:00
#include <iostream>
2013-09-10 14:20:14 +04:00
#include "moses/ChartHypothesis.h"
#include "moses/ChartManager.h"
2013-09-13 19:52:42 +04:00
#include "moses/FactorCollection.h"
#include "moses/Sentence.h"
2013-09-10 14:20:14 +04:00
2013-09-12 21:55:10 +04:00
#include "util/exception.hh"
2015-10-03 14:17:02 +03:00
#include "util/string_stream.hh"
2013-09-12 21:55:10 +04:00
2013-09-14 20:22:43 +04:00
#include "SparseHieroReorderingFeature.h"
2013-09-10 11:58:45 +04:00
using namespace std;
namespace Moses
{
2013-09-14 20:22:43 +04:00
SparseHieroReorderingFeature::SparseHieroReorderingFeature(const std::string &line)
:StatelessFeatureFunction(0, line),
2015-01-14 14:07:42 +03:00
m_type(SourceCombined),
m_sourceFactor(0),
m_targetFactor(0),
m_sourceVocabFile(""),
m_targetVocabFile("")
2013-09-10 11:58:45 +04:00
{
2013-09-12 21:55:10 +04:00
/*
Configuration of features.
factor - Which factor should it apply to
type - what type of sparse reordering feature. e.g. block (modelled on Matthias
Huck's EAMT 2012 features)
word - which words to include, e.g. src_bdry, src_all, tgt_bdry , ...
vocab - vocab file to limit it to
orientation - e.g. lr, etc.
*/
2013-09-10 11:58:45 +04:00
cerr << "Constructing a Sparse Reordering feature" << endl;
2013-09-12 21:55:10 +04:00
ReadParameters();
2013-09-13 19:52:42 +04:00
m_otherFactor = FactorCollection::Instance().AddFactor("##OTHER##");
2013-09-12 21:55:10 +04:00
LoadVocabulary(m_sourceVocabFile, m_sourceVocab);
LoadVocabulary(m_targetVocabFile, m_targetVocab);
}
2015-01-14 14:07:42 +03:00
void SparseHieroReorderingFeature::SetParameter(const std::string& key, const std::string& value)
{
2013-09-12 21:55:10 +04:00
if (key == "input-factor") {
m_sourceFactor = Scan<FactorType>(value);
} else if (key == "output-factor") {
m_targetFactor = Scan<FactorType>(value);
} else if (key == "input-vocab-file") {
m_sourceVocabFile = value;
} else if (key == "output-vocab-file") {
m_targetVocabFile = value;
2013-09-13 19:52:42 +04:00
} else if (key == "type") {
if (value == "SourceCombined") {
m_type = SourceCombined;
} else if (value == "SourceLeft") {
m_type = SourceLeft;
} else if (value == "SourceRight") {
m_type = SourceRight;
} else {
UTIL_THROW(util::Exception, "Unknown sparse reordering type " << value);
}
2013-09-12 21:55:10 +04:00
} else {
FeatureFunction::SetParameter(key, value);
}
}
2013-09-14 20:22:43 +04:00
void SparseHieroReorderingFeature::LoadVocabulary(const std::string& filename, Vocab& vocab)
2013-09-12 21:55:10 +04:00
{
if (filename.empty()) return;
ifstream in(filename.c_str());
UTIL_THROW_IF(!in, util::Exception, "Unable to open vocab file: " << filename);
2013-09-13 19:52:42 +04:00
string line;
while(getline(in,line)) {
2015-01-14 14:07:42 +03:00
vocab.insert(FactorCollection::Instance().AddFactor(line));
2013-09-13 19:52:42 +04:00
}
in.close();
2013-09-10 11:58:45 +04:00
}
2015-01-14 14:07:42 +03:00
const Factor* SparseHieroReorderingFeature::GetFactor(const Word& word, const Vocab& vocab, FactorType factorType) const
{
2013-09-13 19:52:42 +04:00
const Factor* factor = word.GetFactor(factorType);
if (vocab.size() && vocab.find(factor) == vocab.end()) return m_otherFactor;
return factor;
}
void SparseHieroReorderingFeature::EvaluateWhenApplied(
2013-09-10 14:20:14 +04:00
const ChartHypothesis& cur_hypo ,
2013-09-10 11:58:45 +04:00
ScoreComponentCollection* accumulator) const
{
2013-09-10 14:20:14 +04:00
// get index map for underlying hypotheses
2013-09-13 17:44:30 +04:00
//const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
// cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap();
2015-01-14 14:07:42 +03:00
2013-09-12 21:55:10 +04:00
//The Huck features. For a rule with source side:
// abXcdXef
//We first have to split into blocks:
// ab X cd X ef
//Then we extract features based in the boundary words of the neighbouring blocks
2015-01-14 14:07:42 +03:00
//For the block pair, we use the right word of the left block, and the left
2013-09-12 21:55:10 +04:00
//word of the right block.
2015-01-14 14:07:42 +03:00
//Need to get blocks, and their alignment. Each block has a word range (on the
2013-09-13 11:48:44 +04:00
// on the source), a non-terminal flag, and a set of alignment points in the target phrase
2013-09-13 17:44:30 +04:00
//We need to be able to map source word position to target word position, as
//much as possible (don't need interior of non-terminals). The alignment info
2015-01-14 14:07:42 +03:00
//objects just give us the mappings between *rule* positions. So if we can
2013-09-13 17:44:30 +04:00
//map source word position to source rule position, and target rule position
//to target word position, then we can map right through.
size_t sourceStart = cur_hypo.GetCurrSourceRange().GetStartPos();
size_t sourceSize = cur_hypo.GetCurrSourceRange().GetNumWordsCovered();
2015-10-25 16:37:59 +03:00
vector<Range> sourceNTSpans;
2013-09-13 11:48:44 +04:00
for (size_t prevHypoId = 0; prevHypoId < cur_hypo.GetPrevHypos().size(); ++prevHypoId) {
sourceNTSpans.push_back(cur_hypo.GetPrevHypo(prevHypoId)->GetCurrSourceRange());
}
2013-09-13 17:44:30 +04:00
//put in source order. Is this necessary?
2015-01-14 14:07:42 +03:00
sort(sourceNTSpans.begin(), sourceNTSpans.end());
2013-09-13 17:44:30 +04:00
//cerr << "Source NTs: ";
//for (size_t i = 0; i < sourceNTSpans.size(); ++i) cerr << sourceNTSpans[i] << " ";
//cerr << endl;
2013-09-13 11:48:44 +04:00
2015-10-25 16:37:59 +03:00
typedef pair<Range,bool> Block;//flag indicates NT
2015-01-14 14:07:42 +03:00
vector<Block> sourceBlocks;
2013-09-13 17:44:30 +04:00
sourceBlocks.push_back(Block(cur_hypo.GetCurrSourceRange(),false));
2015-10-25 16:37:59 +03:00
for (vector<Range>::const_iterator i = sourceNTSpans.begin();
2015-01-14 14:07:42 +03:00
i != sourceNTSpans.end(); ++i) {
2015-10-25 16:37:59 +03:00
const Range& prevHypoRange = *i;
2013-09-13 17:44:30 +04:00
Block lastBlock = sourceBlocks.back();
sourceBlocks.pop_back();
2013-09-13 11:48:44 +04:00
//split this range into before NT, NT and after NT
2013-09-13 17:44:30 +04:00
if (prevHypoRange.GetStartPos() > lastBlock.first.GetStartPos()) {
2015-10-25 16:37:59 +03:00
sourceBlocks.push_back(Block(Range(lastBlock.first.GetStartPos(),prevHypoRange.GetStartPos()-1),false));
2013-09-13 11:48:44 +04:00
}
2013-09-13 17:44:30 +04:00
sourceBlocks.push_back(Block(prevHypoRange,true));
if (prevHypoRange.GetEndPos() < lastBlock.first.GetEndPos()) {
2015-10-25 16:37:59 +03:00
sourceBlocks.push_back(Block(Range(prevHypoRange.GetEndPos()+1,lastBlock.first.GetEndPos()), false));
2013-09-13 11:48:44 +04:00
}
}
2013-09-13 19:52:42 +04:00
/*
2013-09-13 17:44:30 +04:00
cerr << "Source Blocks: ";
for (size_t i = 0; i < sourceBlocks.size(); ++i) cerr << sourceBlocks[i].first << " "
<< (sourceBlocks[i].second ? "NT" : "T") << " ";
2013-09-13 11:48:44 +04:00
cerr << endl;
2013-09-13 19:52:42 +04:00
*/
2013-09-13 11:48:44 +04:00
2013-09-13 17:44:30 +04:00
//Mapping from source word to target rule position
vector<size_t> sourceWordToTargetRulePos(sourceSize);
map<size_t,size_t> alignMap;
alignMap.insert(
cur_hypo.GetCurrTargetPhrase().GetAlignTerm().begin(),
cur_hypo.GetCurrTargetPhrase().GetAlignTerm().end());
alignMap.insert(
cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().begin(),
cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().end());
//vector<size_t> alignMapTerm = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm()
size_t sourceRulePos = 0;
//cerr << "SW->RP ";
2015-01-14 14:07:42 +03:00
for (vector<Block>::const_iterator sourceBlockIt = sourceBlocks.begin();
sourceBlockIt != sourceBlocks.end(); ++sourceBlockIt) {
2013-09-13 17:44:30 +04:00
for (size_t sourceWordPos = sourceBlockIt->first.GetStartPos();
2015-01-14 14:07:42 +03:00
sourceWordPos <= sourceBlockIt->first.GetEndPos(); ++sourceWordPos) {
2013-09-13 17:44:30 +04:00
sourceWordToTargetRulePos[sourceWordPos - sourceStart] = alignMap[sourceRulePos];
2015-01-14 14:07:42 +03:00
// cerr << sourceWordPos - sourceStart << "-" << alignMap[sourceRulePos] << " ";
2013-09-13 17:44:30 +04:00
if (! sourceBlockIt->second) {
//T
++sourceRulePos;
}
}
if ( sourceBlockIt->second) {
//NT
++sourceRulePos;
}
}
//cerr << endl;
//Iterate through block pairs
2015-01-14 14:07:42 +03:00
const Sentence& sentence =
2015-10-18 15:30:02 +03:00
static_cast<const Sentence&>(cur_hypo.GetManager().GetSource());
2013-09-13 17:44:30 +04:00
//const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
for (size_t i = 0; i < sourceBlocks.size()-1; ++i) {
Block& leftSourceBlock = sourceBlocks[i];
Block& rightSourceBlock = sourceBlocks[i+1];
size_t sourceLeftBoundaryPos = leftSourceBlock.first.GetEndPos();
size_t sourceRightBoundaryPos = rightSourceBlock.first.GetStartPos();
const Word& sourceLeftBoundaryWord = sentence.GetWord(sourceLeftBoundaryPos);
const Word& sourceRightBoundaryWord = sentence.GetWord(sourceRightBoundaryPos);
sourceLeftBoundaryPos -= sourceStart;
sourceRightBoundaryPos -= sourceStart;
2015-01-14 14:07:42 +03:00
2013-09-13 17:44:30 +04:00
// Need to figure out where these map to on the target.
2015-01-14 14:07:42 +03:00
size_t targetLeftRulePos =
2013-09-13 17:44:30 +04:00
sourceWordToTargetRulePos[sourceLeftBoundaryPos];
2015-01-14 14:07:42 +03:00
size_t targetRightRulePos =
2013-09-13 17:44:30 +04:00
sourceWordToTargetRulePos[sourceRightBoundaryPos];
bool isMonotone = true;
if ((sourceLeftBoundaryPos < sourceRightBoundaryPos &&
2015-01-14 14:07:42 +03:00
targetLeftRulePos > targetRightRulePos) ||
((sourceLeftBoundaryPos > sourceRightBoundaryPos &&
targetLeftRulePos < targetRightRulePos))) {
2013-09-13 17:44:30 +04:00
isMonotone = false;
}
2015-10-03 14:17:02 +03:00
util::StringStream buf;
2013-09-13 21:21:52 +04:00
buf << "h_"; //sparse reordering, Huck
2013-09-13 19:52:42 +04:00
if (m_type == SourceLeft || m_type == SourceCombined) {
buf << GetFactor(sourceLeftBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString();
buf << "_";
2013-09-12 01:10:23 +04:00
}
2013-09-13 19:52:42 +04:00
if (m_type == SourceRight || m_type == SourceCombined) {
2015-01-14 14:07:42 +03:00
buf << GetFactor(sourceRightBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString();
2013-09-13 19:52:42 +04:00
buf << "_";
}
buf << (isMonotone ? "M" : "S");
2013-09-13 21:21:52 +04:00
accumulator->PlusEquals(this,buf.str(), 1);
2013-09-13 19:52:42 +04:00
}
// cerr << endl;
2013-09-10 11:58:45 +04:00
}
}