This commit is contained in:
Ulrich Germann 2015-10-28 00:28:16 +00:00
commit 951bebb037
14 changed files with 87 additions and 73 deletions

View File

@ -129,7 +129,7 @@ public:
// Fallback: scoreA < scoreB == false, non-deterministic sort
return false;
}
return (phrA->Compare(*phrB) < 0);
return (phrA->Compare(*phrB) > 0);
}
}
};

View File

@ -151,7 +151,7 @@ EvaluateWhenApplied(StatefulFeatureFunction const& sfff, int state_idx)
// ttasksptr const& ttask = manager.GetTtask();
FFState const* prev = m_prevHypo ? m_prevHypo->m_ffStates[state_idx] : NULL;
m_ffStates[state_idx]
= sfff.EvaluateWhenApplied(*this, prev, &m_currScoreBreakdown);
= sfff.EvaluateWhenApplied(*this, prev, &m_currScoreBreakdown);
}
}

View File

@ -265,7 +265,7 @@ LanguageModelIRST::
CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
{
bool isContextAdaptive
= m_lmtb->getLanguageModelType() == _IRSTLM_LMCONTEXTDEPENDENT;
= m_lmtb->getLanguageModelType() == _IRSTLM_LMCONTEXTDEPENDENT;
fullScore = 0;
ngramScore = 0;
@ -347,7 +347,7 @@ EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps,
ScoreComponentCollection *out) const
{
bool isContextAdaptive
= m_lmtb->getLanguageModelType() == _IRSTLM_LMCONTEXTDEPENDENT;
= m_lmtb->getLanguageModelType() == _IRSTLM_LMCONTEXTDEPENDENT;
if (!hypo.GetCurrTargetLength()) {
std::auto_ptr<IRSTLMState> ret(new IRSTLMState(ps));
@ -387,18 +387,17 @@ EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps,
position = (const int) begin+1;
float score;
#ifdef IRSTLM_CONTEXT_DEPENDENT
if (CW)
{
score = m_lmtb->clprob(codes,m_lmtb_size,*CW,NULL,NULL,&msp);
while (position < adjust_end) {
for (idx=1; idx<m_lmtb_size; idx++) {
codes[idx-1] = codes[idx];
}
codes[idx-1] = GetLmID(hypo.GetWord(position));
score += m_lmtb->clprob(codes,m_lmtb_size,*CW,NULL,NULL,&msp);
++position;
if (CW) {
score = m_lmtb->clprob(codes,m_lmtb_size,*CW,NULL,NULL,&msp);
while (position < adjust_end) {
for (idx=1; idx<m_lmtb_size; idx++) {
codes[idx-1] = codes[idx];
}
} else {
codes[idx-1] = GetLmID(hypo.GetWord(position));
score += m_lmtb->clprob(codes,m_lmtb_size,*CW,NULL,NULL,&msp);
++position;
}
} else {
#endif
score = m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
position = (const int) begin+1;
@ -433,9 +432,9 @@ EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps,
if (CW) score += m_lmtb->clprob(codes,m_lmtb_size,*CW,NULL,NULL,&msp);
else
#else
score += m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
score += m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
#endif
} else {
} else {
// need to set the LM state
if (adjust_end < end) { //the LMstate of this target phrase refers to the last m_lmtb_size-1 words

View File

@ -44,7 +44,7 @@ public:
// Fallback: compare pointers, non-deterministic sort
return A < B;
}
return (phrA->Compare(*phrB) < 0);
return (phrA->Compare(*phrB) > 0);
}
}
};

View File

@ -287,8 +287,8 @@ void Manager<RuleMatcher>::RecombineAndSort(
// any 'duplicate' vertices are deleted.
// TODO Set?
typedef boost::unordered_map<SVertex *, SVertex *,
SVertexRecombinationHasher,
SVertexRecombinationEqualityPred> Map;
SVertexRecombinationHasher,
SVertexRecombinationEqualityPred> Map;
Map map;
for (std::vector<SHyperedge*>::const_iterator p = buffer.begin();
p != buffer.end(); ++p) {

View File

@ -351,8 +351,8 @@ void Manager<Parser>::RecombineAndSort(const std::vector<SHyperedge*> &buffer,
// any 'duplicate' vertices are deleted.
// TODO Set?
typedef boost::unordered_map<SVertex *, SVertex *,
SVertexRecombinationHasher,
SVertexRecombinationEqualityPred> Map;
SVertexRecombinationHasher,
SVertexRecombinationEqualityPred> Map;
Map map;
for (std::vector<SHyperedge*>::const_iterator p = buffer.begin();
p != buffer.end(); ++p) {

View File

@ -11,7 +11,7 @@ namespace Syntax
class SVertexRecombinationEqualityPred
{
public:
public:
bool operator()(const SVertex *v1, const SVertex *v2) const {
assert(v1->states.size() == v2->states.size());
for (std::size_t i = 0; i < v1->states.size(); ++i) {

View File

@ -11,7 +11,7 @@ namespace Syntax
class SVertexRecombinationHasher
{
public:
public:
std::size_t operator()(const SVertex *v) const {
std::size_t seed = 0;
for (std::vector<FFState*>::const_iterator p = v->states.begin();

View File

@ -247,8 +247,8 @@ void Manager<RuleMatcher>::RecombineAndSort(
// any 'duplicate' vertices are deleted.
// TODO Set?
typedef boost::unordered_map<SVertex *, SVertex *,
SVertexRecombinationHasher,
SVertexRecombinationEqualityPred> Map;
SVertexRecombinationHasher,
SVertexRecombinationEqualityPred> Map;
Map map;
for (std::vector<SHyperedge*>::const_iterator p = buffer.begin();
p != buffer.end(); ++p) {

View File

@ -33,7 +33,8 @@ namespace Moses
PhraseDictionaryGroup::PhraseDictionaryGroup(const string &line)
: PhraseDictionary(line, true),
m_numModels(0),
m_restrict(false)
m_restrict(false),
m_specifiedZeros(false)
{
ReadParameters();
}
@ -45,6 +46,9 @@ void PhraseDictionaryGroup::SetParameter(const string& key, const string& value)
m_numModels = m_memberPDStrs.size();
} else if (key == "restrict") {
m_restrict = Scan<bool>(value);
} else if (key == "zeros") {
m_specifiedZeros = true;
m_zeros = Scan<float>(Tokenize(value, ","));
} else {
PhraseDictionary::SetParameter(key, value);
}
@ -67,10 +71,20 @@ void PhraseDictionaryGroup::Load()
}
}
UTIL_THROW_IF2(!pdFound,
"Could not find component phrase table " << pdName);
"Could not find member phrase table " << pdName);
}
UTIL_THROW_IF2(componentWeights != m_numScoreComponents,
"Total number of component model scores is unequal to specified number of scores");
"Total number of member model scores is unequal to specified number of scores");
// Determine "zero" scores for features
if (m_specifiedZeros) {
UTIL_THROW_IF2(m_zeros.size() != m_numScoreComponents,
"Number of specified zeros is unequal to number of member model scores");
} else {
// Default is all 0 (as opposed to e.g. -99 or similar to approximate log(0)
// or a smoothed "not in model" score)
m_zeros = vector<float>(m_numScoreComponents, 0);
}
}
void PhraseDictionaryGroup::GetTargetPhraseCollectionBatch(
@ -150,7 +164,7 @@ CreateTargetPhraseCollection(const ttasksptr& ttask, const Phrase& src) const
phrase->GetScoreBreakdown().ZeroDenseFeatures(&pd);
// Add phrase entry
allPhrases.push_back(phrase);
allScores[targetPhrase] = vector<float>(m_numScoreComponents, 0);
allScores[targetPhrase] = vector<float>(m_zeros);
}
vector<float>& scores = allScores.find(targetPhrase)->second;

View File

@ -70,6 +70,8 @@ protected:
std::vector<PhraseDictionary*> m_memberPDs;
size_t m_numModels;
bool m_restrict;
bool m_specifiedZeros;
std::vector<float> m_zeros;
std::vector<FeatureFunction*> m_pdFeature;
typedef std::vector<TargetPhraseCollection::shared_ptr > PhraseCache;

View File

@ -173,14 +173,13 @@ interpret_dlt()
if (m_source->GetType() != SentenceInput) return;
Sentence const& snt = static_cast<Sentence const&>(*m_source);
typedef std::map<std::string,std::string> dltmap_t;
BOOST_FOREACH(dltmap_t const& M, snt.GetDltMeta())
{
dltmap_t::const_iterator i = M.find("type");
if (i == M.end() || i->second != "adaptive-lm") continue;
dltmap_t::const_iterator j = M.find("context-weights");
if (j == M.end()) continue;
SetContextWeights(j->second);
}
BOOST_FOREACH(dltmap_t const& M, snt.GetDltMeta()) {
dltmap_t::const_iterator i = M.find("type");
if (i == M.end() || i->second != "adaptive-lm") continue;
dltmap_t::const_iterator j = M.find("context-weights");
if (j == M.end()) continue;
SetContextWeights(j->second);
}
}