Abstract ChartCellCollection, extract source labels

This commit is contained in:
Kenneth Heafield 2012-10-02 16:34:25 +01:00
parent a9524f079f
commit 542502e6fd
5 changed files with 86 additions and 57 deletions

View File

@ -41,6 +41,8 @@ ChartCellBase::ChartCellBase(size_t startPos, size_t endPos) :
m_coverage(startPos, endPos), m_coverage(startPos, endPos),
m_targetLabelSet(m_coverage) {} m_targetLabelSet(m_coverage) {}
ChartCellBase::~ChartCellBase() {}
/** Constructor /** Constructor
* \param startPos endPos range of this cell * \param startPos endPos range of this cell
* \param manager pointer back to the manager * \param manager pointer back to the manager
@ -49,12 +51,10 @@ ChartCell::ChartCell(size_t startPos, size_t endPos, ChartManager &manager) :
ChartCellBase(startPos, endPos), m_manager(manager) { ChartCellBase(startPos, endPos), m_manager(manager) {
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
m_nBestIsEnabled = staticData.IsNBestEnabled(); m_nBestIsEnabled = staticData.IsNBestEnabled();
if (startPos == endPos) {
const Word &sourceWord = manager.GetSource().GetWord(startPos);
m_sourceWordLabel.reset(new ChartCellLabel(m_coverage, sourceWord));
}
} }
ChartCell::~ChartCell() {}
/** Add the given hypothesis to the cell. /** Add the given hypothesis to the cell.
* Returns true if added, false if not. Maybe it already exists in the collection or score falls below threshold etc. * Returns true if added, false if not. Maybe it already exists in the collection or score falls below threshold etc.
* This function just calls the correspondind AddHypothesis() in ChartHypothesisCollection * This function just calls the correspondind AddHypothesis() in ChartHypothesisCollection

View File

@ -48,9 +48,19 @@ class ChartCellBase {
public: public:
ChartCellBase(size_t startPos, size_t endPos); ChartCellBase(size_t startPos, size_t endPos);
virtual ~ChartCellBase();
//! @todo what is a m_sourceWordLabel?
const ChartCellLabelSet &GetTargetLabelSet() const {
return m_targetLabelSet;
}
const WordsRange &GetCoverage() const {
return m_coverage;
}
protected: protected:
WordsRange m_coverage; WordsRange m_coverage;
boost::scoped_ptr<ChartCellLabel> m_sourceWordLabel;
ChartCellLabelSet m_targetLabelSet; ChartCellLabelSet m_targetLabelSet;
}; };
@ -97,17 +107,6 @@ public:
const ChartHypothesis *GetBestHypothesis() const; const ChartHypothesis *GetBestHypothesis() const;
//! @todo what is a m_sourceWordLabel?
const ChartCellLabel &GetSourceWordLabel() const {
CHECK(m_coverage.GetNumWordsCovered() == 1);
return *m_sourceWordLabel;
}
//! @todo what is a m_sourceWordLabel?
const ChartCellLabelSet &GetTargetLabelSet() const {
return m_targetLabelSet;
}
void CleanupArcList(); void CleanupArcList();
void OutputSizes(std::ostream &out) const; void OutputSizes(std::ostream &out) const;

View File

@ -23,37 +23,32 @@
#include "InputType.h" #include "InputType.h"
#include "WordsRange.h" #include "WordsRange.h"
namespace Moses namespace Moses {
{
ChartCellCollectionBase::~ChartCellCollectionBase() {
m_source.clear();
for (std::vector<std::vector<ChartCellBase*> >::iterator i = m_cells.begin(); i != m_cells.end(); ++i)
RemoveAllInColl(*i);
}
class CubeCellFactory {
public:
explicit CubeCellFactory(ChartManager &manager) : m_manager(manager) {}
ChartCell *operator()(size_t start, size_t end) const {
return new ChartCell(start, end, m_manager);
}
private:
ChartManager &m_manager;
};
/** Costructor /** Costructor
\param input the input sentence \param input the input sentence
\param manager reference back to the manager \param manager reference back to the manager
*/ */
ChartCellCollection::ChartCellCollection(const InputType &input, ChartManager &manager) ChartCellCollection::ChartCellCollection(const InputType &input, ChartManager &manager)
:m_hypoStackColl(input.GetSize()) :ChartCellCollectionBase(input, CubeCellFactory(manager)) {}
{
size_t size = input.GetSize();
for (size_t startPos = 0; startPos < size; ++startPos) {
InnerCollType &inner = m_hypoStackColl[startPos];
inner.resize(size - startPos);
size_t ind = 0;
for (size_t endPos = startPos ; endPos < size; ++endPos) {
ChartCell *cell = new ChartCell(startPos, endPos, manager);
inner[ind] = cell;
++ind;
}
}
}
ChartCellCollection::~ChartCellCollection()
{
OuterCollType::iterator iter;
for (iter = m_hypoStackColl.begin(); iter != m_hypoStackColl.end(); ++iter) {
InnerCollType &inner = *iter;
RemoveAllInColl(inner);
}
}
} // namespace } // namespace

View File

@ -20,37 +20,72 @@
***********************************************************************/ ***********************************************************************/
#pragma once #pragma once
#include "InputType.h"
#include "ChartCell.h" #include "ChartCell.h"
#include "WordsRange.h" #include "WordsRange.h"
#include <boost/ptr_container/ptr_vector.hpp>
namespace Moses namespace Moses
{ {
class InputType; class InputType;
class ChartManager; class ChartManager;
class ChartCellCollectionBase {
public:
template <class Factory> ChartCellCollectionBase(const InputType &input, const Factory &factory) :
m_cells(input.GetSize()) {
size_t size = input.GetSize();
for (size_t startPos = 0; startPos < size; ++startPos) {
std::vector<ChartCellBase*> &inner = m_cells[startPos];
inner.reserve(size - startPos);
for (size_t endPos = startPos; endPos < size; ++endPos) {
inner.push_back(factory(startPos, endPos));
}
/* Hack: ChartCellLabel shouldn't need to know its span, but the parser
* gets it from there :-(. The span is actually stored as a reference,
* which needs to point somewhere, so I have it refer to the ChartCell.
*/
m_source.push_back(new ChartCellLabel(inner[0]->GetCoverage(), input.GetWord(startPos)));
}
}
virtual ~ChartCellCollectionBase();
const ChartCellBase &GetBase(const WordsRange &coverage) const {
return *m_cells[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()];
}
const ChartCellLabel &GetSourceWordLabel(size_t at) const {
return m_source[at];
}
protected:
// There's nothing mutable in ChartCellLabel base, so no public method.
ChartCellBase &MutableBase(const WordsRange &coverage) {
return *m_cells[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()];
}
private:
std::vector<std::vector<ChartCellBase*> > m_cells;
boost::ptr_vector<ChartCellLabel> m_source;
};
/** Hold all the chart cells for 1 input sentence. A variable of this type is held by the ChartManager /** Hold all the chart cells for 1 input sentence. A variable of this type is held by the ChartManager
*/ */
class ChartCellCollection class ChartCellCollection : public ChartCellCollectionBase {
{ public:
public: ChartCellCollection(const InputType &input, ChartManager &manager);
typedef std::vector<ChartCell*> InnerCollType;
typedef std::vector<InnerCollType> OuterCollType;
protected:
OuterCollType m_hypoStackColl;
public:
ChartCellCollection(const InputType &input, ChartManager &manager);
~ChartCellCollection();
//! get a chart cell for a particular range //! get a chart cell for a particular range
ChartCell &Get(const WordsRange &coverage) { ChartCell &Get(const WordsRange &coverage) {
return *m_hypoStackColl[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()]; return static_cast<ChartCell&>(MutableBase(coverage));
} }
//! get a chart cell for a particular range //! get a chart cell for a particular range
const ChartCell &Get(const WordsRange &coverage) const { const ChartCell &Get(const WordsRange &coverage) const {
return *m_hypoStackColl[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()]; return static_cast<const ChartCell&>(GetBase(coverage));
} }
}; };

View File

@ -56,7 +56,7 @@ public:
} }
const ChartCellLabel &GetSourceAt(size_t at) const { const ChartCellLabel &GetSourceAt(size_t at) const {
return m_cellCollection.Get(WordsRange(at, at)).GetSourceWordLabel(); return m_cellCollection.GetSourceWordLabel(at);
} }
/** abstract function. Return a vector of translation options for given a range in the input sentence /** abstract function. Return a vector of translation options for given a range in the input sentence