mosesdecoder/moses/FF/TargetNgramFeature.h
2015-12-11 01:09:22 +00:00

240 lines
7.3 KiB
C++

#ifndef moses_TargetNgramFeature_h
#define moses_TargetNgramFeature_h
#include <string>
#include <map>
#include <boost/unordered_set.hpp>
#include "StatefulFeatureFunction.h"
#include "moses/FF/FFState.h"
#include "moses/Word.h"
#include "moses/FactorCollection.h"
#include "moses/LM/SingleFactor.h"
#include "moses/ChartHypothesis.h"
#include "moses/ChartManager.h"
#include "util/string_stream.hh"
namespace Moses
{
class TargetNgramState : public FFState
{
public:
TargetNgramState() {}
TargetNgramState(const std::vector<Word> &words): m_words(words) {}
const std::vector<Word> GetWords() const {
return m_words;
}
size_t hash() const;
virtual bool operator==(const FFState& other) const;
private:
std::vector<Word> m_words;
};
class TargetNgramChartState : public FFState
{
private:
Phrase m_contextPrefix, m_contextSuffix;
size_t m_numTargetTerminals; // This isn't really correct except for the surviving hypothesis
size_t m_startPos, m_endPos, m_inputSize;
/** Construct the prefix string of up to specified size
* \param ret prefix string
* \param size maximum size (typically max lm context window)
*/
size_t CalcPrefix(const ChartHypothesis &hypo, const int featureId, Phrase &ret, size_t size) const {
const TargetPhrase &target = hypo.GetCurrTargetPhrase();
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
target.GetAlignNonTerm().GetNonTermIndexMap();
// loop over the rule that is being applied
for (size_t pos = 0; pos < target.GetSize(); ++pos) {
const Word &word = target.GetWord(pos);
// for non-terminals, retrieve it from underlying hypothesis
if (word.IsNonTerminal()) {
size_t nonTermInd = nonTermIndexMap[pos];
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->CalcPrefix(*prevHypo, featureId, ret, size);
// Phrase phrase = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->GetPrefix();
// size = phrase.GetSize();
}
// for words, add word
else {
ret.AddWord(word);
size--;
}
// finish when maximum length reached
if (size==0)
break;
}
return size;
}
/** Construct the suffix phrase of up to specified size
* will always be called after the construction of prefix phrase
* \param ret suffix phrase
* \param size maximum size of suffix
*/
size_t CalcSuffix(const ChartHypothesis &hypo, int featureId, Phrase &ret, size_t size) const {
size_t prefixSize = m_contextPrefix.GetSize();
assert(prefixSize <= m_numTargetTerminals);
// special handling for small hypotheses
// does the prefix match the entire hypothesis string? -> just copy prefix
if (prefixSize == m_numTargetTerminals) {
size_t maxCount = std::min(prefixSize, size);
size_t pos= prefixSize - 1;
for (size_t ind = 0; ind < maxCount; ++ind) {
const Word &word = m_contextPrefix.GetWord(pos);
ret.PrependWord(word);
--pos;
}
size -= maxCount;
return size;
}
// construct suffix analogous to prefix
else {
const TargetPhrase targetPhrase = hypo.GetCurrTargetPhrase();
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
targetPhrase.GetAlignTerm().GetNonTermIndexMap();
for (int pos = (int) targetPhrase.GetSize() - 1; pos >= 0 ; --pos) {
const Word &word = targetPhrase.GetWord(pos);
if (word.IsNonTerminal()) {
size_t nonTermInd = nonTermIndexMap[pos];
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->CalcSuffix(*prevHypo, featureId, ret, size);
} else {
ret.PrependWord(word);
size--;
}
if (size==0)
break;
}
return size;
}
}
public:
TargetNgramChartState(const ChartHypothesis &hypo, int featureId, size_t order)
:m_contextPrefix(order - 1),
m_contextSuffix(order - 1) {
m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals();
const Range range = hypo.GetCurrSourceRange();
m_startPos = range.GetStartPos();
m_endPos = range.GetEndPos();
m_inputSize = hypo.GetManager().GetSource().GetSize();
const std::vector<const ChartHypothesis*> prevHypos = hypo.GetPrevHypos();
for (std::vector<const ChartHypothesis*>::const_iterator i = prevHypos.begin(); i != prevHypos.end(); ++i) {
// keep count of words (= length of generated string)
m_numTargetTerminals += static_cast<const TargetNgramChartState*>((*i)->GetFFState(featureId))->GetNumTargetTerminals();
}
CalcPrefix(hypo, featureId, m_contextPrefix, order - 1);
CalcSuffix(hypo, featureId, m_contextSuffix, order - 1);
}
size_t GetNumTargetTerminals() const {
return m_numTargetTerminals;
}
const Phrase &GetPrefix() const {
return m_contextPrefix;
}
const Phrase &GetSuffix() const {
return m_contextSuffix;
}
size_t hash() const {
// not sure if this is correct
size_t ret;
ret = m_startPos;
boost::hash_combine(ret, m_endPos);
boost::hash_combine(ret, m_inputSize);
// prefix
if (m_startPos > 0) { // not for "<s> ..."
boost::hash_combine(ret, hash_value(GetPrefix()));
}
if (m_endPos < m_inputSize - 1) { // not for "... </s>"
boost::hash_combine(ret, hash_value(GetSuffix()));
}
return ret;
}
virtual bool operator==(const FFState& o) const {
const TargetNgramChartState &other =
static_cast<const TargetNgramChartState &>( o );
// prefix
if (m_startPos > 0) { // not for "<s> ..."
if (GetPrefix() != other.GetPrefix())
return false;
}
if (m_endPos < m_inputSize - 1) { // not for "... </s>"
if (GetSuffix() != other.GetSuffix())
return false;
}
return true;
}
};
/** Sets the features of observed ngrams.
*/
class TargetNgramFeature : public StatefulFeatureFunction
{
public:
TargetNgramFeature(const std::string &line);
void Load(AllOptions::ptr const& opts);
bool IsUseable(const FactorMask &mask) const;
virtual const FFState* EmptyHypothesisState(const InputType &input) const;
virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
virtual FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureId,
ScoreComponentCollection* accumulator) const;
void SetParameter(const std::string& key, const std::string& value);
private:
FactorType m_factorType;
Word m_bos;
boost::unordered_set<std::string> m_vocab;
size_t m_n;
bool m_lower_ngrams;
std::string m_file;
std::string m_baseName;
void appendNgram(const Word& word, bool& skip, util::StringStream& ngram) const;
void MakePrefixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
size_t numberOfStartPos = 1, size_t offset = 0) const;
void MakeSuffixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
size_t numberOfEndPos = 1, size_t offset = 0) const;
};
}
#endif // moses_TargetNgramFeature_h