moses_chart: reduce memory use for rule lookup by decreasing the amount

of state information duplicated between CoveredChartSpan objects.

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4050 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
pjwilliams 2011-06-29 13:38:11 +00:00
parent 7fe3143feb
commit 7e288fae98
12 changed files with 209 additions and 74 deletions

View File

@ -34,7 +34,6 @@ class CellCollection
public:
virtual ~CellCollection()
{}
virtual const NonTerminalSet &GetConstituentLabelSet(const WordsRange &coverage) const = 0;
};
}

View File

@ -30,6 +30,7 @@
#include "StaticData.h"
#include "ChartTranslationOption.h"
#include "ChartTranslationOptionList.h"
#include "ChartManager.h"
using namespace std;
@ -39,10 +40,21 @@ extern bool g_debug;
ChartCell::ChartCell(size_t startPos, size_t endPos, ChartManager &manager)
:m_coverage(startPos, endPos)
,m_sourceWordLabel(NULL)
,m_targetLabelSet(m_coverage)
,m_manager(manager)
{
const StaticData &staticData = StaticData::Instance();
m_nBestIsEnabled = staticData.IsNBestEnabled();
if (startPos == endPos) {
const Word &sourceWord = manager.GetSource().GetWord(startPos);
m_sourceWordLabel = new ChartCellLabel(m_coverage, sourceWord);
}
}
ChartCell::~ChartCell()
{
delete m_sourceWordLabel;
}
/** Get all hypotheses in the cell that have the specified constituent label */
@ -105,11 +117,10 @@ void ChartCell::ProcessSentence(const ChartTranslationOptionList &transOptList
void ChartCell::SortHypotheses()
{
// sort each mini cells & fill up target lhs list
assert(m_constituentLabelSet.empty());
assert(m_targetLabelSet.Empty());
std::map<Word, ChartHypothesisCollection>::iterator iter;
for (iter = m_hypoColl.begin(); iter != m_hypoColl.end(); ++iter) {
m_constituentLabelSet.insert(iter->first);
m_targetLabelSet.Add(iter->first);
ChartHypothesisCollection &coll = iter->second;
coll.SortHypotheses();
}
@ -136,14 +147,6 @@ const ChartHypothesis *ChartCell::GetBestHypothesis() const
return ret;
}
/** Is there a hypothesis in the cell that has the specified constituent label? */
bool ChartCell::ConstituentLabelExists(const Word &constituentLabel) const
{
std::map<Word, ChartHypothesisCollection>::const_iterator iter;
iter = m_hypoColl.find(constituentLabel);
return (iter != m_hypoColl.end());
}
void ChartCell::CleanupArcList()
{
// only necessary if n-best calculations are enabled

View File

@ -30,7 +30,9 @@
#include "NonTerminal.h"
#include "ChartHypothesis.h"
#include "ChartHypothesisCollection.h"
#include "CoveredChartSpan.h"
#include "RuleCube.h"
#include "ChartCellLabelSet.h"
namespace Moses
{
@ -46,15 +48,18 @@ public:
protected:
std::map<Word, ChartHypothesisCollection> m_hypoColl;
NonTerminalSet m_constituentLabelSet;
WordsRange m_coverage;
ChartCellLabel *m_sourceWordLabel;
ChartCellLabelSet m_targetLabelSet;
bool m_nBestIsEnabled; /**< flag to determine whether to keep track of old arcs */
ChartManager &m_manager;
public:
ChartCell(size_t startPos, size_t endPos, ChartManager &manager);
~ChartCell();
void ProcessSentence(const ChartTranslationOptionList &transOptList
,const ChartCellCollection &allChartCells);
@ -67,9 +72,13 @@ public:
const ChartHypothesis *GetBestHypothesis() const;
bool ConstituentLabelExists(const Word &constituentLabel) const;
const NonTerminalSet &GetConstituentLabelSet() const {
return m_constituentLabelSet;
const ChartCellLabel &GetSourceWordLabel() const {
assert(m_coverage.GetNumWordsCovered() == 1);
return *m_sourceWordLabel;
}
const ChartCellLabelSet &GetTargetLabelSet() const {
return m_targetLabelSet;
}
void CleanupArcList();

View File

@ -48,12 +48,6 @@ public:
const ChartCell &Get(const WordsRange &coverage) const {
return *m_hypoStackColl[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()];
}
/** Return set of constituents that have hypotheses in the given span */
const NonTerminalSet &GetConstituentLabelSet(const WordsRange &coverage) const {
const ChartCell &cell = Get(coverage);
return cell.GetConstituentLabelSet();
}
};
}

View File

@ -0,0 +1,58 @@
/***********************************************************************
Moses - statistical machine translation system
Copyright (C) 2006-2011 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#pragma once
#if HAVE_CONFIG_H
#include "config.h"
#endif
#include "Word.h"
#include "WordsRange.h"
namespace Moses
{
class Word;
class ChartCellLabel
{
public:
ChartCellLabel(const WordsRange &coverage, const Word &label)
: m_coverage(coverage)
, m_label(label)
{}
const WordsRange &GetCoverage() const { return m_coverage; }
const Word &GetLabel() const { return m_label; }
bool operator<(const ChartCellLabel &other) const
{
if (m_coverage == other.m_coverage) {
return m_label < other.m_label;
}
return m_coverage < other.m_coverage;
}
private:
const WordsRange &m_coverage;
const Word &m_label;
};
}

View File

@ -0,0 +1,67 @@
/***********************************************************************
Moses - statistical machine translation system
Copyright (C) 2006-2011 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#pragma once
#if HAVE_CONFIG_H
#include "config.h"
#endif
#include "ChartCellLabel.h"
#include <set>
namespace Moses
{
class ChartCellLabelSet
{
private:
typedef std::set<ChartCellLabel> SetType;
public:
typedef SetType::const_iterator const_iterator;
ChartCellLabelSet(const WordsRange &coverage) : m_coverage(coverage) {}
const_iterator begin() const { return m_set.begin(); }
const_iterator end() const { return m_set.end(); }
void Add(const Word &w)
{
ChartCellLabel edge(m_coverage, w);
m_set.insert(edge);
}
bool Empty() const { return m_set.empty(); }
size_t GetSize() const { return m_set.size(); }
const ChartCellLabel *Find(const Word &w) const
{
SetType::const_iterator p = m_set.find(ChartCellLabel(m_coverage, w));
return p == m_set.end() ? 0 : &(*p);
}
private:
const WordsRange &m_coverage;
SetType m_set;
};
}

View File

@ -78,6 +78,8 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
DottedRuleColl &dottedRuleCol = *m_dottedRuleColls[range.GetStartPos()];
const DottedRuleList &expandableDottedRuleList = dottedRuleCol.GetExpandableDottedRuleList();
const ChartCellLabel &sourceWordLabel = GetCellCollection().Get(WordsRange(absEndPos, absEndPos)).GetSourceWordLabel();
// loop through the rules
// (note that expandableDottedRuleList can be expanded as the loop runs
// through calls to ExtendPartialRuleApplication())
@ -99,7 +101,7 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
// look up in rule dictionary, if the current rule can be extended
// with the source word in the last position
const Word &sourceWord = GetSentence().GetWord(absEndPos);
const Word &sourceWord = sourceWordLabel.GetLabel();
const PhraseDictionaryNodeSCFG *node = prevNode.GetChild(sourceWord);
// if we found a new rule -> create it and add it to the list
@ -107,14 +109,11 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
// create the rule
#ifdef USE_BOOST_POOL
CoveredChartSpan *newCoveredChartSpan = m_coveredChartSpanPool.malloc();
new (newCoveredChartSpan) CoveredChartSpan(absEndPos, absEndPos, sourceWord,
prevCoveredChartSpan);
new (newCoveredChartSpan) CoveredChartSpan(sourceWordLabel, prevCoveredChartSpan);
DottedRule *dottedRule = m_dottedRulePool.malloc();
new (dottedRule) DottedRule(*node, newCoveredChartSpan);
#else
CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(absEndPos, absEndPos,
sourceWord,
prevCoveredChartSpan);
CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(sourceWordLabel, prevCoveredChartSpan);
DottedRule *dottedRule = new DottedRule(*node,
newCoveredChartSpan);
#endif
@ -175,6 +174,9 @@ void ChartRuleLookupManagerMemory::GetChartRuleCollection(
GetCellCollection(), adhereTableLimit, rulesLimit);
}
}
dottedRuleCol.Clear(relEndPos+1);
outColl.CreateChartRules(rulesLimit);
}
@ -195,8 +197,8 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
GetSentence().GetLabelSet(startPos, endPos);
// target non-terminal labels for the remainder
const NonTerminalSet &targetNonTerms =
GetCellCollection().GetConstituentLabelSet(WordsRange(startPos, endPos));
const ChartCellLabelSet &targetNonTerms =
GetCellCollection().Get(WordsRange(startPos, endPos)).GetTargetLabelSet();
const PhraseDictionaryNodeSCFG::NonTerminalMap & nonTermMap =
node.GetNonTerminalMap();
@ -206,7 +208,7 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
return;
}
const size_t numSourceNonTerms = sourceNonTerms.size();
const size_t numTargetNonTerms = targetNonTerms.size();
const size_t numTargetNonTerms = targetNonTerms.GetSize();
const size_t numCombinations = numSourceNonTerms * numTargetNonTerms;
// We can search by either:
@ -225,14 +227,14 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
const Word & sourceNonTerm = *p;
// loop over possible target non-terminal labels (as found in chart)
NonTerminalSet::const_iterator q = targetNonTerms.begin();
NonTerminalSet::const_iterator tEnd = targetNonTerms.end();
ChartCellLabelSet::const_iterator q = targetNonTerms.begin();
ChartCellLabelSet::const_iterator tEnd = targetNonTerms.end();
for (; q != tEnd; ++q) {
const Word & targetNonTerm = *q;
const ChartCellLabel &cellLabel = *q;
// try to match both source and target non-terminal
const PhraseDictionaryNodeSCFG * child =
node.GetChild(sourceNonTerm, targetNonTerm);
node.GetChild(sourceNonTerm, cellLabel.GetLabel());
// nothing found? then we are done
if (child == NULL) {
@ -242,14 +244,11 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
// create new rule
#ifdef USE_BOOST_POOL
CoveredChartSpan *wc = m_coveredChartSpanPool.malloc();
new (wc) CoveredChartSpan(startPos, endPos, targetNonTerm,
prevCoveredChartSpan);
new (wc) CoveredChartSpan(cellLabel, prevCoveredChartSpan);
DottedRule *rule = m_dottedRulePool.malloc();
new (rule) DottedRule(*child, wc);
#else
CoveredChartSpan * wc = new CoveredChartSpan(startPos, endPos,
targetNonTerm,
prevCoveredChartSpan);
CoveredChartSpan * wc = new CoveredChartSpan(cellLabel, prevCoveredChartSpan);
DottedRule * rule = new DottedRule(*child, wc);
#endif
dottedRuleColl.Add(stackInd, rule);
@ -270,7 +269,8 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
continue;
}
const Word & targetNonTerm = key.second;
if (targetNonTerms.find(targetNonTerm) == targetNonTerms.end()) {
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
if (!cellLabel) {
continue;
}
@ -278,14 +278,11 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
const PhraseDictionaryNodeSCFG & child = p->second;
#ifdef USE_BOOST_POOL
CoveredChartSpan *wc = m_coveredChartSpanPool.malloc();
new (wc) CoveredChartSpan(startPos, endPos, targetNonTerm,
prevCoveredChartSpan);
new (wc) CoveredChartSpan(*cellLabel, prevCoveredChartSpan);
DottedRule *rule = m_dottedRulePool.malloc();
new (rule) DottedRule(child, wc);
#else
CoveredChartSpan * wc = new CoveredChartSpan(startPos, endPos,
targetNonTerm,
prevCoveredChartSpan);
CoveredChartSpan * wc = new CoveredChartSpan(*cellLabel, prevCoveredChartSpan);
DottedRule * rule = new DottedRule(child, wc);
#endif
dottedRuleColl.Add(stackInd, rule);

View File

@ -99,6 +99,8 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
const DottedRuleStackOnDisk::SavedNodeColl &savedNodeColl = expandableDottedRuleList.GetSavedNodeColl();
//cerr << "savedNodeColl=" << savedNodeColl.size() << " ";
const ChartCellLabel &sourceWordLabel = GetCellCollection().Get(WordsRange(absEndPos, absEndPos)).GetSourceWordLabel();
for (size_t ind = 0; ind < (savedNodeColl.size()) ; ++ind) {
const SavedNodeOnDisk &savedNode = *savedNodeColl[ind];
@ -109,8 +111,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
// search for terminal symbol
if (startPos == absEndPos) {
const Word &sourceWord = GetSentence().GetWord(absEndPos);
OnDiskPt::Word *sourceWordBerkeleyDb = m_dbWrapper.ConvertFromMoses(Input, m_inputFactorsVec, sourceWord);
OnDiskPt::Word *sourceWordBerkeleyDb = m_dbWrapper.ConvertFromMoses(Input, m_inputFactorsVec, sourceWordLabel.GetLabel());
if (sourceWordBerkeleyDb != NULL) {
const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceWordBerkeleyDb, m_dbWrapper);
@ -118,9 +119,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
// TODO figure out why source word is needed from node, not from sentence
// prob to do with factors or non-term
//const Word &sourceWord = node->GetSourceWord();
CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(absEndPos, absEndPos
, sourceWord
, prevCoveredChartSpan);
CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(sourceWordLabel, prevCoveredChartSpan);
DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, newCoveredChartSpan);
expandableDottedRuleList.Add(relEndPos+1, dottedRule);
@ -148,7 +147,8 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
// size_t nonTermNumWordsCovered = endPos - startPos + 1;
// get target nonterminals in this span from chart
const NonTerminalSet &chartNonTermSet = GetCellCollection().GetConstituentLabelSet(WordsRange(startPos, endPos));
const ChartCellLabelSet &chartNonTermSet =
GetCellCollection().Get(WordsRange(startPos, endPos)).GetTargetLabelSet();
//const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal()
// ,&defaultTargetNonTerm = staticData.GetOutputDefaultNonTerminal();
@ -174,9 +174,9 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
continue; // didn't find source node
// go through each TARGET lhs
NonTerminalSet::const_iterator iterChartNonTerm;
ChartCellLabelSet::const_iterator iterChartNonTerm;
for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) {
const Word &chartNonTerm = *iterChartNonTerm;
const ChartCellLabel &cellLabel = *iterChartNonTerm;
//cerr << sourceLHS << " " << defaultSourceNonTerm << " " << chartNonTerm << " " << defaultTargetNonTerm << endl;
@ -186,7 +186,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
if (doSearch) {
OnDiskPt::Word *chartNonTermBerkeleyDb = m_dbWrapper.ConvertFromMoses(Output, m_outputFactorsVec, chartNonTerm);
OnDiskPt::Word *chartNonTermBerkeleyDb = m_dbWrapper.ConvertFromMoses(Output, m_outputFactorsVec, cellLabel.GetLabel());
if (chartNonTermBerkeleyDb == NULL)
continue;
@ -199,10 +199,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
// found matching entry
//const Word &sourceWord = node->GetSourceWord();
CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(startPos, endPos
, chartNonTerm
, prevCoveredChartSpan);
CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(cellLabel, prevCoveredChartSpan);
DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, newCoveredChartSpan);
expandableDottedRuleList.Add(stackInd, dottedRule);

View File

@ -182,6 +182,9 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
ChartTranslationOptionList &transOptColl = GetTranslationOptionList(sourcePos, sourcePos);
const WordsRange &range = transOptColl.GetSourceRange();
const ChartCell &chartCell = m_hypoStackColl.Get(range);
const ChartCellLabel &sourceWordLabel = chartCell.GetSourceWordLabel();
size_t isDigit = 0;
if (staticData.GetDropUnknown()) {
const Factor *f = sourceWord[0]; // TODO hack. shouldn't know which factor is surface
@ -204,7 +207,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
std::vector<CoveredChartSpan*> *coveredChartSpanList = new std::vector<CoveredChartSpan*>();
m_coveredChartSpanCache.push_back(coveredChartSpanList);
CoveredChartSpan *wc = new CoveredChartSpan(sourcePos, sourcePos, sourceWord, NULL);
CoveredChartSpan *wc = new CoveredChartSpan(sourceWordLabel, NULL);
coveredChartSpanList->push_back(wc);
assert(coveredChartSpanList->size());
@ -273,7 +276,7 @@ void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceW
// words consumed
std::vector<CoveredChartSpan*> *coveredChartSpanList = new std::vector<CoveredChartSpan*>;
m_coveredChartSpanCache.push_back(coveredChartSpanList);
coveredChartSpanList->push_back(new CoveredChartSpan(sourcePos, sourcePos, sourceWord, NULL));
coveredChartSpanList->push_back(new CoveredChartSpan(sourceWordLabel, NULL));
// chart rule
assert(coveredChartSpanList->size());

View File

@ -26,7 +26,8 @@ namespace Moses
std::ostream& operator<<(std::ostream &out, const CoveredChartSpan &coveredChartSpan)
{
out << coveredChartSpan.m_coverage << "=" << coveredChartSpan.m_sourceWord << " ";
out << coveredChartSpan.GetWordsRange()
<< "=" << coveredChartSpan.GetSourceWord() << " ";
if (coveredChartSpan.m_prevCoveredChartSpan)
out << " " << *coveredChartSpan.m_prevCoveredChartSpan;

View File

@ -21,6 +21,7 @@
#pragma once
#include <iostream>
#include "ChartCellLabel.h"
#include "WordsRange.h"
#include "Word.h"
@ -32,27 +33,23 @@ class CoveredChartSpan
friend std::ostream& operator<<(std::ostream&, const CoveredChartSpan&);
protected:
WordsRange m_coverage;
const Word &m_sourceWord; // can be non-term or term
const ChartCellLabel &m_cellLabel;
const CoveredChartSpan *m_prevCoveredChartSpan;
public:
CoveredChartSpan(); // not implmented
CoveredChartSpan(size_t startPos, size_t endPos, const Word &sourceWord, const CoveredChartSpan *prevCoveredChartSpan)
:m_coverage(startPos, endPos)
,m_sourceWord(sourceWord)
CoveredChartSpan(const ChartCellLabel &cellLabel,
const CoveredChartSpan *prevCoveredChartSpan)
:m_cellLabel(cellLabel)
,m_prevCoveredChartSpan(prevCoveredChartSpan)
{}
const WordsRange &GetWordsRange() const {
return m_coverage;
return m_cellLabel.GetCoverage();
}
const Word &GetSourceWord() const {
return m_sourceWord;
}
WordsRange &GetWordsRange() {
return m_coverage;
return m_cellLabel.GetLabel();
}
bool IsNonTerminal() const {
return m_sourceWord.IsNonTerminal();
return m_cellLabel.GetLabel().IsNonTerminal();
}
const CoveredChartSpan *GetPrevCoveredChartSpan() const {
@ -64,7 +61,7 @@ public:
if (IsNonTerminal() < compare.IsNonTerminal())
return true;
else if (IsNonTerminal() == compare.IsNonTerminal())
return m_coverage < compare.m_coverage;
return m_cellLabel.GetCoverage() < compare.m_cellLabel.GetCoverage();
return false;
}

View File

@ -19,6 +19,10 @@
***********************************************************************/
#pragma once
#if HAVE_CONFIG_H
#include "config.h"
#endif
#include <vector>
#include <cassert>
#include "PhraseDictionaryNodeSCFG.h"
@ -112,6 +116,12 @@ public:
}
}
void Clear(size_t pos) {
#ifdef USE_BOOST_POOL
m_coll[pos].clear();
#endif
}
const DottedRuleList &GetExpandableDottedRuleList() const {
return m_expandableDottedRuleList;
}