mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-24 04:12:47 +03:00
parse chart compression for faster CYK+ parsing with syntax systems.
This commit is contained in:
parent
2d73f6f803
commit
2a46e8ccea
@ -124,6 +124,15 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
const ChartCellLabel *Find(size_t idx) const {
|
||||
try {
|
||||
return m_map.at(idx);
|
||||
}
|
||||
catch (const std::out_of_range& oor) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
ChartCellLabel::Stack &FindOrInsert(const Word &w) {
|
||||
size_t idx = w[0]->GetId();
|
||||
if (! ChartCellExists(idx)) {
|
||||
|
@ -50,6 +50,19 @@ protected:
|
||||
StackVec m_stackVec;
|
||||
};
|
||||
|
||||
// struct that caches cellLabel, its end position and score for quicker lookup
|
||||
struct ChartCellCache
|
||||
{
|
||||
ChartCellCache(size_t endPos, const ChartCellLabel* cellLabel, float score)
|
||||
: endPos(endPos)
|
||||
, cellLabel(cellLabel)
|
||||
, score(score) {}
|
||||
|
||||
size_t endPos;
|
||||
const ChartCellLabel* cellLabel;
|
||||
float score;
|
||||
};
|
||||
|
||||
} // namespace Moses
|
||||
|
||||
#endif
|
||||
|
@ -22,10 +22,12 @@
|
||||
|
||||
#include "moses/ChartParser.h"
|
||||
#include "moses/InputType.h"
|
||||
#include "moses/Terminal.h"
|
||||
#include "moses/ChartParserCallback.h"
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/NonTerminal.h"
|
||||
#include "moses/ChartCellCollection.h"
|
||||
#include "moses/FactorCollection.h"
|
||||
#include "moses/TranslationModel/PhraseDictionaryMemory.h"
|
||||
|
||||
using namespace std;
|
||||
@ -59,9 +61,13 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
|
||||
|
||||
m_lastPos = lastPos;
|
||||
m_stackVec.clear();
|
||||
m_stackScores.clear();
|
||||
m_outColl = &outColl;
|
||||
m_unaryPos = absEndPos-1; // rules ending in this position are unary and should not be added to collection
|
||||
|
||||
// create/update data structure to quickly look up all chart cells that match start position and label.
|
||||
UpdateCompressedMatrix(startPos, absEndPos, lastPos);
|
||||
|
||||
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode();
|
||||
|
||||
// size-1 terminal rules
|
||||
@ -77,7 +83,7 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
|
||||
}
|
||||
// all rules starting with nonterminal
|
||||
else if (absEndPos > startPos) {
|
||||
GetNonTerminalExtension(&rootNode, startPos, absEndPos-1);
|
||||
GetNonTerminalExtension(&rootNode, startPos);
|
||||
// all (non-unary) rules starting with terminal
|
||||
if (absEndPos == startPos+1) {
|
||||
GetTerminalExtension(&rootNode, absEndPos-1);
|
||||
@ -94,21 +100,87 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
|
||||
|
||||
}
|
||||
|
||||
// Create/update compressed matrix that stores all valid ChartCellLabels for a given start position and label.
|
||||
void ChartRuleLookupManagerMemory::UpdateCompressedMatrix(size_t startPos,
|
||||
size_t origEndPos,
|
||||
size_t lastPos) {
|
||||
|
||||
std::vector<size_t> endPosVec;
|
||||
size_t numNonTerms = FactorCollection::Instance().GetNumNonTerminals();
|
||||
m_compressedMatrixVec.resize(lastPos+1);
|
||||
|
||||
// we only need to update cell at [startPos, origEndPos-1] for initial lookup
|
||||
if (startPos < origEndPos) {
|
||||
endPosVec.push_back(origEndPos-1);
|
||||
}
|
||||
|
||||
// update all cells starting from startPos+1 for lookup of rule extensions
|
||||
else if (startPos == origEndPos)
|
||||
{
|
||||
startPos++;
|
||||
for (size_t endPos = startPos; endPos <= lastPos; endPos++) {
|
||||
endPosVec.push_back(endPos);
|
||||
}
|
||||
//re-use data structure for cells with later start position, but remove chart cells that would break max-chart-span
|
||||
for (size_t pos = startPos+1; pos <= lastPos; pos++) {
|
||||
CompressedMatrix & cellMatrix = m_compressedMatrixVec[pos];
|
||||
cellMatrix.resize(numNonTerms);
|
||||
for (size_t i = 0; i < numNonTerms; i++) {
|
||||
if (!cellMatrix[i].empty() && cellMatrix[i].back().endPos > lastPos) {
|
||||
cellMatrix[i].pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (startPos > lastPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
// populate compressed matrix with all chart cells that start at current start position
|
||||
CompressedMatrix & cellMatrix = m_compressedMatrixVec[startPos];
|
||||
cellMatrix.clear();
|
||||
cellMatrix.resize(numNonTerms);
|
||||
for (std::vector<size_t>::iterator p = endPosVec.begin(); p != endPosVec.end(); ++p) {
|
||||
|
||||
size_t endPos = *p;
|
||||
// target non-terminal labels for the span
|
||||
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
|
||||
|
||||
if (targetNonTerms.GetSize() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
#if !defined(UNLABELLED_SOURCE)
|
||||
// source non-terminal labels for the span
|
||||
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
|
||||
const std::vector<bool> &sourceNonTermArray = inputPath.GetNonTerminalArray();
|
||||
|
||||
// can this ever be true? Moses seems to pad the non-terminal set of the input with [X]
|
||||
if (inputPath.GetNonTerminalSet().size() == 0) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < numNonTerms; i++) {
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(i);
|
||||
if (cellLabel != NULL) {
|
||||
float score = cellLabel->GetBestScore(m_outColl);
|
||||
cellMatrix[i].push_back(ChartCellCache(endPos, cellLabel, score));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if a (partial) rule matches, add it to list completed rules (if non-unary and non-empty), and try find expansions that have this partial rule as prefix.
|
||||
void ChartRuleLookupManagerMemory::AddAndExtend(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t endPos,
|
||||
const ChartCellLabel *cellLabel) {
|
||||
|
||||
// add backpointer
|
||||
if (cellLabel != NULL) {
|
||||
m_stackVec.push_back(cellLabel);
|
||||
}
|
||||
size_t endPos) {
|
||||
|
||||
const TargetPhraseCollection &tpc = node->GetTargetPhraseCollection();
|
||||
// add target phrase collection (except if rule is empty or unary)
|
||||
if (!tpc.IsEmpty() && endPos != m_unaryPos) {
|
||||
m_completedRules[endPos].Add(tpc, m_stackVec, *m_outColl);
|
||||
m_completedRules[endPos].Add(tpc, m_stackVec, m_stackScores, *m_outColl);
|
||||
}
|
||||
|
||||
// get all further extensions of rule (until reaching end of sentence or max-chart-span)
|
||||
@ -117,18 +189,12 @@ void ChartRuleLookupManagerMemory::AddAndExtend(
|
||||
GetTerminalExtension(node, endPos+1);
|
||||
}
|
||||
if (!node->GetNonTerminalMap().empty()) {
|
||||
for (size_t newEndPos = endPos+1; newEndPos <= m_lastPos; newEndPos++) {
|
||||
GetNonTerminalExtension(node, endPos+1, newEndPos);
|
||||
}
|
||||
GetNonTerminalExtension(node, endPos+1);
|
||||
}
|
||||
}
|
||||
|
||||
// remove backpointer
|
||||
if (cellLabel != NULL) {
|
||||
m_stackVec.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// search all possible terminal extensions of a partial rule (pointed at by node) at a given position
|
||||
// recursively try to expand partial rules into full rules up to m_lastPos.
|
||||
void ChartRuleLookupManagerMemory::GetTerminalExtension(
|
||||
@ -142,9 +208,10 @@ void ChartRuleLookupManagerMemory::GetTerminalExtension(
|
||||
if (terminals.size() < 5) {
|
||||
for (PhraseDictionaryNodeMemory::TerminalMap::const_iterator iter = terminals.begin(); iter != terminals.end(); ++iter) {
|
||||
const Word & word = iter->first;
|
||||
if (word == sourceWord) {
|
||||
if (TerminalEqualityPred()(word, sourceWord)) {
|
||||
const PhraseDictionaryNodeMemory *child = & iter->second;
|
||||
AddAndExtend(child, pos, NULL);
|
||||
AddAndExtend(child, pos);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -152,39 +219,26 @@ void ChartRuleLookupManagerMemory::GetTerminalExtension(
|
||||
else {
|
||||
const PhraseDictionaryNodeMemory *child = node->GetChild(sourceWord);
|
||||
if (child != NULL) {
|
||||
AddAndExtend(child, pos, NULL);
|
||||
AddAndExtend(child, pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// search all nonterminal possible nonterminal extensions of a partial rule (pointed at by node) for a given span (StartPos, endPos).
|
||||
// search all nonterminal possible nonterminal extensions of a partial rule (pointed at by node) for a variable span (starting from startPos).
|
||||
// recursively try to expand partial rules into full rules up to m_lastPos.
|
||||
void ChartRuleLookupManagerMemory::GetNonTerminalExtension(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t startPos,
|
||||
size_t endPos) {
|
||||
size_t startPos) {
|
||||
|
||||
// target non-terminal labels for the span
|
||||
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
|
||||
|
||||
if (targetNonTerms.GetSize() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if !defined(UNLABELLED_SOURCE)
|
||||
// source non-terminal labels for the span
|
||||
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
|
||||
const std::vector<bool> &sourceNonTermArray = inputPath.GetNonTerminalArray();
|
||||
|
||||
// can this ever be true? Moses seems to pad the non-terminal set of the input with [X]
|
||||
if (inputPath.GetNonTerminalSet().size() == 0) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const CompressedMatrix &compressedMatrix = m_compressedMatrixVec[startPos];
|
||||
|
||||
// non-terminal labels in phrase dictionary node
|
||||
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap = node->GetNonTerminalMap();
|
||||
|
||||
// make room for back pointer
|
||||
m_stackVec.push_back(NULL);
|
||||
m_stackScores.push_back(0);
|
||||
|
||||
// loop over possible expansions of the rule
|
||||
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator p;
|
||||
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end = nonTermMap.end();
|
||||
@ -193,37 +247,32 @@ void ChartRuleLookupManagerMemory::GetNonTerminalExtension(
|
||||
#if defined(UNLABELLED_SOURCE)
|
||||
const Word &targetNonTerm = p->first;
|
||||
#else
|
||||
const PhraseDictionaryNodeMemory::NonTerminalMapKey &key = p->first;
|
||||
const Word &sourceNonTerm = key.first;
|
||||
// check if source label matches
|
||||
if (! sourceNonTermArray[sourceNonTerm[0]->GetId()]) {
|
||||
continue;
|
||||
}
|
||||
const Word &targetNonTerm = key.second;
|
||||
const Word &targetNonTerm = p->first.second;
|
||||
#endif
|
||||
const PhraseDictionaryNodeMemory *child = &p->second;
|
||||
//soft matching of NTs
|
||||
if (m_isSoftMatching && !m_softMatchingMap[targetNonTerm[0]->GetId()].empty()) {
|
||||
const std::vector<Word>& softMatches = m_softMatchingMap[targetNonTerm[0]->GetId()];
|
||||
for (std::vector<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch);
|
||||
if (cellLabel == NULL) {
|
||||
continue;
|
||||
const CompressedColumn &matches = compressedMatrix[(*softMatch)[0]->GetId()];
|
||||
for (CompressedColumn::const_iterator match = matches.begin(); match != matches.end(); ++match) {
|
||||
m_stackVec.back() = match->cellLabel;
|
||||
m_stackScores.back() = match->score;
|
||||
AddAndExtend(child, match->endPos);
|
||||
}
|
||||
// create new rule
|
||||
const PhraseDictionaryNodeMemory &child = p->second;
|
||||
AddAndExtend(&child, endPos, cellLabel);
|
||||
}
|
||||
} // end of soft matches lookup
|
||||
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
|
||||
if (cellLabel == NULL) {
|
||||
continue;
|
||||
const CompressedColumn &matches = compressedMatrix[targetNonTerm[0]->GetId()];
|
||||
for (CompressedColumn::const_iterator match = matches.begin(); match != matches.end(); ++match) {
|
||||
m_stackVec.back() = match->cellLabel;
|
||||
m_stackScores.back() = match->score;
|
||||
AddAndExtend(child, match->endPos);
|
||||
}
|
||||
// create new rule
|
||||
const PhraseDictionaryNodeMemory &child = p->second;
|
||||
AddAndExtend(&child, endPos, cellLabel);
|
||||
}
|
||||
// remove last back pointer
|
||||
m_stackVec.pop_back();
|
||||
m_stackScores.pop_back();
|
||||
}
|
||||
|
||||
|
||||
} // namespace Moses
|
||||
|
@ -40,6 +40,10 @@ class WordsRange;
|
||||
class ChartRuleLookupManagerMemory : public ChartRuleLookupManagerCYKPlus
|
||||
{
|
||||
public:
|
||||
typedef std::vector<ChartCellCache> CompressedColumn;
|
||||
typedef std::vector<CompressedColumn> CompressedMatrix;
|
||||
|
||||
|
||||
ChartRuleLookupManagerMemory(const ChartParser &parser,
|
||||
const ChartCellCollectionBase &cellColl,
|
||||
const PhraseDictionaryMemory &ruleTable);
|
||||
@ -53,19 +57,21 @@ public:
|
||||
|
||||
private:
|
||||
|
||||
void GetTerminalExtension(
|
||||
void GetTerminalExtension(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t pos);
|
||||
|
||||
void GetNonTerminalExtension(
|
||||
void GetNonTerminalExtension(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t startPos,
|
||||
size_t endPos);
|
||||
size_t startPos);
|
||||
|
||||
void AddAndExtend(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t endPos);
|
||||
|
||||
void UpdateCompressedMatrix(size_t startPos,
|
||||
size_t endPos,
|
||||
const ChartCellLabel *cellLabel);
|
||||
size_t lastPos);
|
||||
|
||||
const PhraseDictionaryMemory &m_ruleTable;
|
||||
|
||||
@ -80,8 +86,13 @@ void GetNonTerminalExtension(
|
||||
size_t m_unaryPos;
|
||||
|
||||
StackVec m_stackVec;
|
||||
std::vector<float> m_stackScores;
|
||||
std::vector<const Word*> m_sourceWords;
|
||||
ChartParserCallback* m_outColl;
|
||||
|
||||
std::vector<CompressedMatrix> m_compressedMatrixVec;
|
||||
|
||||
|
||||
};
|
||||
|
||||
} // namespace Moses
|
||||
|
@ -22,10 +22,12 @@
|
||||
|
||||
#include "moses/ChartParser.h"
|
||||
#include "moses/InputType.h"
|
||||
#include "moses/Terminal.h"
|
||||
#include "moses/ChartParserCallback.h"
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/NonTerminal.h"
|
||||
#include "moses/ChartCellCollection.h"
|
||||
#include "moses/FactorCollection.h"
|
||||
#include "moses/TranslationModel/RuleTable/PhraseDictionaryFuzzyMatch.h"
|
||||
|
||||
using namespace std;
|
||||
@ -59,9 +61,13 @@ void ChartRuleLookupManagerMemoryPerSentence::GetChartRuleCollection(
|
||||
|
||||
m_lastPos = lastPos;
|
||||
m_stackVec.clear();
|
||||
m_stackScores.clear();
|
||||
m_outColl = &outColl;
|
||||
m_unaryPos = absEndPos-1; // rules ending in this position are unary and should not be added to collection
|
||||
|
||||
// create/update data structure to quickly look up all chart cells that match start position and label.
|
||||
UpdateCompressedMatrix(startPos, absEndPos, lastPos);
|
||||
|
||||
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode(GetParser().GetTranslationId());
|
||||
|
||||
// size-1 terminal rules
|
||||
@ -77,7 +83,7 @@ void ChartRuleLookupManagerMemoryPerSentence::GetChartRuleCollection(
|
||||
}
|
||||
// all rules starting with nonterminal
|
||||
else if (absEndPos > startPos) {
|
||||
GetNonTerminalExtension(&rootNode, startPos, absEndPos-1);
|
||||
GetNonTerminalExtension(&rootNode, startPos);
|
||||
// all (non-unary) rules starting with terminal
|
||||
if (absEndPos == startPos+1) {
|
||||
GetTerminalExtension(&rootNode, absEndPos-1);
|
||||
@ -94,21 +100,87 @@ void ChartRuleLookupManagerMemoryPerSentence::GetChartRuleCollection(
|
||||
|
||||
}
|
||||
|
||||
// Create/update compressed matrix that stores all valid ChartCellLabels for a given start position and label.
|
||||
void ChartRuleLookupManagerMemoryPerSentence::UpdateCompressedMatrix(size_t startPos,
|
||||
size_t origEndPos,
|
||||
size_t lastPos) {
|
||||
|
||||
std::vector<size_t> endPosVec;
|
||||
size_t numNonTerms = FactorCollection::Instance().GetNumNonTerminals();
|
||||
m_compressedMatrixVec.resize(lastPos+1);
|
||||
|
||||
// we only need to update cell at [startPos, origEndPos-1] for initial lookup
|
||||
if (startPos < origEndPos) {
|
||||
endPosVec.push_back(origEndPos-1);
|
||||
}
|
||||
|
||||
// update all cells starting from startPos+1 for lookup of rule extensions
|
||||
else if (startPos == origEndPos)
|
||||
{
|
||||
startPos++;
|
||||
for (size_t endPos = startPos; endPos <= lastPos; endPos++) {
|
||||
endPosVec.push_back(endPos);
|
||||
}
|
||||
//re-use data structure for cells with later start position, but remove chart cells that would break max-chart-span
|
||||
for (size_t pos = startPos+1; pos <= lastPos; pos++) {
|
||||
CompressedMatrix & cellMatrix = m_compressedMatrixVec[pos];
|
||||
cellMatrix.resize(numNonTerms);
|
||||
for (size_t i = 0; i < numNonTerms; i++) {
|
||||
if (!cellMatrix[i].empty() && cellMatrix[i].back().endPos > lastPos) {
|
||||
cellMatrix[i].pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (startPos > lastPos) {
|
||||
return;
|
||||
}
|
||||
|
||||
// populate compressed matrix with all chart cells that start at current start position
|
||||
CompressedMatrix & cellMatrix = m_compressedMatrixVec[startPos];
|
||||
cellMatrix.clear();
|
||||
cellMatrix.resize(numNonTerms);
|
||||
for (std::vector<size_t>::iterator p = endPosVec.begin(); p != endPosVec.end(); ++p) {
|
||||
|
||||
size_t endPos = *p;
|
||||
// target non-terminal labels for the span
|
||||
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
|
||||
|
||||
if (targetNonTerms.GetSize() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
#if !defined(UNLABELLED_SOURCE)
|
||||
// source non-terminal labels for the span
|
||||
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
|
||||
const std::vector<bool> &sourceNonTermArray = inputPath.GetNonTerminalArray();
|
||||
|
||||
// can this ever be true? Moses seems to pad the non-terminal set of the input with [X]
|
||||
if (inputPath.GetNonTerminalSet().size() == 0) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (size_t i = 0; i < numNonTerms; i++) {
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(i);
|
||||
if (cellLabel != NULL) {
|
||||
float score = cellLabel->GetBestScore(m_outColl);
|
||||
cellMatrix[i].push_back(ChartCellCache(endPos, cellLabel, score));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if a (partial) rule matches, add it to list completed rules (if non-unary and non-empty), and try find expansions that have this partial rule as prefix.
|
||||
void ChartRuleLookupManagerMemoryPerSentence::AddAndExtend(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t endPos,
|
||||
const ChartCellLabel *cellLabel) {
|
||||
|
||||
// add backpointer
|
||||
if (cellLabel != NULL) {
|
||||
m_stackVec.push_back(cellLabel);
|
||||
}
|
||||
size_t endPos) {
|
||||
|
||||
const TargetPhraseCollection &tpc = node->GetTargetPhraseCollection();
|
||||
// add target phrase collection (except if rule is empty or unary)
|
||||
if (!tpc.IsEmpty() && endPos != m_unaryPos) {
|
||||
m_completedRules[endPos].Add(tpc, m_stackVec, *m_outColl);
|
||||
m_completedRules[endPos].Add(tpc, m_stackVec, m_stackScores, *m_outColl);
|
||||
}
|
||||
|
||||
// get all further extensions of rule (until reaching end of sentence or max-chart-span)
|
||||
@ -117,18 +189,12 @@ void ChartRuleLookupManagerMemoryPerSentence::AddAndExtend(
|
||||
GetTerminalExtension(node, endPos+1);
|
||||
}
|
||||
if (!node->GetNonTerminalMap().empty()) {
|
||||
for (size_t newEndPos = endPos+1; newEndPos <= m_lastPos; newEndPos++) {
|
||||
GetNonTerminalExtension(node, endPos+1, newEndPos);
|
||||
}
|
||||
GetNonTerminalExtension(node, endPos+1);
|
||||
}
|
||||
}
|
||||
|
||||
// remove backpointer
|
||||
if (cellLabel != NULL) {
|
||||
m_stackVec.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// search all possible terminal extensions of a partial rule (pointed at by node) at a given position
|
||||
// recursively try to expand partial rules into full rules up to m_lastPos.
|
||||
void ChartRuleLookupManagerMemoryPerSentence::GetTerminalExtension(
|
||||
@ -142,9 +208,10 @@ void ChartRuleLookupManagerMemoryPerSentence::GetTerminalExtension(
|
||||
if (terminals.size() < 5) {
|
||||
for (PhraseDictionaryNodeMemory::TerminalMap::const_iterator iter = terminals.begin(); iter != terminals.end(); ++iter) {
|
||||
const Word & word = iter->first;
|
||||
if (word == sourceWord) {
|
||||
if (TerminalEqualityPred()(word, sourceWord)) {
|
||||
const PhraseDictionaryNodeMemory *child = & iter->second;
|
||||
AddAndExtend(child, pos, NULL);
|
||||
AddAndExtend(child, pos);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -152,39 +219,26 @@ void ChartRuleLookupManagerMemoryPerSentence::GetTerminalExtension(
|
||||
else {
|
||||
const PhraseDictionaryNodeMemory *child = node->GetChild(sourceWord);
|
||||
if (child != NULL) {
|
||||
AddAndExtend(child, pos, NULL);
|
||||
AddAndExtend(child, pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// search all nonterminal possible nonterminal extensions of a partial rule (pointed at by node) for a given span (StartPos, endPos).
|
||||
// search all nonterminal possible nonterminal extensions of a partial rule (pointed at by node) for a variable span (starting from startPos).
|
||||
// recursively try to expand partial rules into full rules up to m_lastPos.
|
||||
void ChartRuleLookupManagerMemoryPerSentence::GetNonTerminalExtension(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t startPos,
|
||||
size_t endPos) {
|
||||
size_t startPos) {
|
||||
|
||||
// target non-terminal labels for the span
|
||||
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
|
||||
|
||||
if (targetNonTerms.GetSize() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if !defined(UNLABELLED_SOURCE)
|
||||
// source non-terminal labels for the span
|
||||
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
|
||||
const std::vector<bool> &sourceNonTermArray = inputPath.GetNonTerminalArray();
|
||||
|
||||
// can this ever be true? Moses seems to pad the non-terminal set of the input with [X]
|
||||
if (inputPath.GetNonTerminalSet().size() == 0) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const CompressedMatrix &compressedMatrix = m_compressedMatrixVec[startPos];
|
||||
|
||||
// non-terminal labels in phrase dictionary node
|
||||
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap = node->GetNonTerminalMap();
|
||||
|
||||
// make room for back pointer
|
||||
m_stackVec.push_back(NULL);
|
||||
m_stackScores.push_back(0);
|
||||
|
||||
// loop over possible expansions of the rule
|
||||
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator p;
|
||||
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end = nonTermMap.end();
|
||||
@ -193,38 +247,32 @@ void ChartRuleLookupManagerMemoryPerSentence::GetNonTerminalExtension(
|
||||
#if defined(UNLABELLED_SOURCE)
|
||||
const Word &targetNonTerm = p->first;
|
||||
#else
|
||||
const PhraseDictionaryNodeMemory::NonTerminalMapKey &key = p->first;
|
||||
const Word &sourceNonTerm = key.first;
|
||||
// check if source label matches
|
||||
if (! sourceNonTermArray[sourceNonTerm[0]->GetId()]) {
|
||||
continue;
|
||||
}
|
||||
const Word &targetNonTerm = key.second;
|
||||
const Word &targetNonTerm = p->first.second;
|
||||
#endif
|
||||
|
||||
const PhraseDictionaryNodeMemory *child = &p->second;
|
||||
//soft matching of NTs
|
||||
if (m_isSoftMatching && !m_softMatchingMap[targetNonTerm[0]->GetId()].empty()) {
|
||||
const std::vector<Word>& softMatches = m_softMatchingMap[targetNonTerm[0]->GetId()];
|
||||
for (std::vector<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch);
|
||||
if (cellLabel == NULL) {
|
||||
continue;
|
||||
const CompressedColumn &matches = compressedMatrix[(*softMatch)[0]->GetId()];
|
||||
for (CompressedColumn::const_iterator match = matches.begin(); match != matches.end(); ++match) {
|
||||
m_stackVec.back() = match->cellLabel;
|
||||
m_stackScores.back() = match->score;
|
||||
AddAndExtend(child, match->endPos);
|
||||
}
|
||||
// create new rule
|
||||
const PhraseDictionaryNodeMemory &child = p->second;
|
||||
AddAndExtend(&child, endPos, cellLabel);
|
||||
}
|
||||
} // end of soft matches lookup
|
||||
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
|
||||
if (cellLabel == NULL) {
|
||||
continue;
|
||||
const CompressedColumn &matches = compressedMatrix[targetNonTerm[0]->GetId()];
|
||||
for (CompressedColumn::const_iterator match = matches.begin(); match != matches.end(); ++match) {
|
||||
m_stackVec.back() = match->cellLabel;
|
||||
m_stackScores.back() = match->score;
|
||||
AddAndExtend(child, match->endPos);
|
||||
}
|
||||
// create new rule
|
||||
const PhraseDictionaryNodeMemory &child = p->second;
|
||||
AddAndExtend(&child, endPos, cellLabel);
|
||||
}
|
||||
// remove last back pointer
|
||||
m_stackVec.pop_back();
|
||||
m_stackScores.pop_back();
|
||||
}
|
||||
|
||||
|
||||
} // namespace Moses
|
||||
|
@ -40,6 +40,9 @@ class WordsRange;
|
||||
class ChartRuleLookupManagerMemoryPerSentence : public ChartRuleLookupManagerCYKPlus
|
||||
{
|
||||
public:
|
||||
typedef std::vector<ChartCellCache> CompressedColumn;
|
||||
typedef std::vector<CompressedColumn> CompressedMatrix;
|
||||
|
||||
ChartRuleLookupManagerMemoryPerSentence(const ChartParser &parser,
|
||||
const ChartCellCollectionBase &cellColl,
|
||||
const PhraseDictionaryFuzzyMatch &ruleTable);
|
||||
@ -53,19 +56,21 @@ public:
|
||||
|
||||
private:
|
||||
|
||||
void GetTerminalExtension(
|
||||
void GetTerminalExtension(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t pos);
|
||||
|
||||
void GetNonTerminalExtension(
|
||||
void GetNonTerminalExtension(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t startPos,
|
||||
size_t endPos);
|
||||
size_t startPos);
|
||||
|
||||
void AddAndExtend(
|
||||
const PhraseDictionaryNodeMemory *node,
|
||||
size_t endPos);
|
||||
|
||||
void UpdateCompressedMatrix(size_t startPos,
|
||||
size_t endPos,
|
||||
const ChartCellLabel *cellLabel);
|
||||
size_t lastPos);
|
||||
|
||||
const PhraseDictionaryFuzzyMatch &m_ruleTable;
|
||||
|
||||
@ -80,8 +85,12 @@ void GetNonTerminalExtension(
|
||||
size_t m_unaryPos;
|
||||
|
||||
StackVec m_stackVec;
|
||||
std::vector<float> m_stackScores;
|
||||
std::vector<const Word*> m_sourceWords;
|
||||
ChartParserCallback* m_outColl;
|
||||
|
||||
std::vector<CompressedMatrix> m_compressedMatrixVec;
|
||||
|
||||
};
|
||||
|
||||
} // namespace Moses
|
||||
|
@ -77,4 +77,47 @@ void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// copies some functionality (pruning) from ChartTranslationOptionList::Add
|
||||
void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc,
|
||||
const StackVec &stackVec,
|
||||
const std::vector<float> &stackScores,
|
||||
const ChartParserCallback &outColl)
|
||||
{
|
||||
if (tpc.IsEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const TargetPhrase &targetPhrase = **(tpc.begin());
|
||||
float score = std::accumulate(stackScores.begin(), stackScores.end(), targetPhrase.GetFutureScore());
|
||||
|
||||
// If the rule limit has already been reached then don't add the option
|
||||
// unless it is better than at least one existing option.
|
||||
if (m_collection.size() > m_ruleLimit && score < m_scoreThreshold) {
|
||||
return;
|
||||
}
|
||||
|
||||
CompletedRule *completedRule = new CompletedRule(tpc, stackVec, score);
|
||||
m_collection.push_back(completedRule);
|
||||
|
||||
// If the rule limit hasn't been exceeded then update the threshold.
|
||||
if (m_collection.size() <= m_ruleLimit) {
|
||||
m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold;
|
||||
}
|
||||
|
||||
// Prune if bursting
|
||||
if (m_collection.size() == m_ruleLimit * 2) {
|
||||
NTH_ELEMENT4(m_collection.begin(),
|
||||
m_collection.begin() + m_ruleLimit - 1,
|
||||
m_collection.end(),
|
||||
CompletedRuleOrdered());
|
||||
m_scoreThreshold = m_collection[m_ruleLimit-1]->GetScoreEstimate();
|
||||
for (size_t i = 0 + m_ruleLimit; i < m_collection.size(); i++) {
|
||||
delete m_collection[i];
|
||||
|
||||
}
|
||||
m_collection.resize(m_ruleLimit);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -22,6 +22,7 @@
|
||||
#define moses_CompletedRuleCollectionS_h
|
||||
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
#include "moses/StackVec.h"
|
||||
#include "moses/TargetPhraseCollection.h"
|
||||
@ -105,6 +106,11 @@ public:
|
||||
const StackVec &stackVec,
|
||||
const ChartParserCallback &outColl);
|
||||
|
||||
void Add(const TargetPhraseCollection &tpc,
|
||||
const StackVec &stackVec,
|
||||
const std::vector<float> &stackScores,
|
||||
const ChartParserCallback &outColl);
|
||||
|
||||
private:
|
||||
std::vector<CompletedRule*> m_collection;
|
||||
float m_scoreThreshold;
|
||||
|
Loading…
Reference in New Issue
Block a user