Add NonTerminalSet variable to InputPath

This commit is contained in:
Hieu Hoang 2013-08-02 15:54:49 +01:00
parent d1d07d5923
commit 0596c3e9e4
13 changed files with 72 additions and 22 deletions

View File

@ -189,16 +189,17 @@ void ChartParser::CreateInputPaths(const InputType &input)
size_t endPos = startPos + phaseSize -1;
vector<InputPath*> &vec = m_targetPhrasesfromPt[startPos];
Phrase subphrase(input.GetSubString(WordsRange(startPos, endPos)));
WordsRange range(startPos, endPos);
Phrase subphrase(input.GetSubString(WordsRange(startPos, endPos)));
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
InputPath *node;
if (range.GetNumWordsCovered() == 1) {
node = new InputPath(subphrase, range, NULL, NULL);
node = new InputPath(subphrase, labels, range, NULL, NULL);
vec.push_back(node);
} else {
const InputPath &prevNode = GetInputPath(startPos, endPos - 1);
node = new InputPath(subphrase, range, &prevNode, NULL);
node = new InputPath(subphrase, labels, range, &prevNode, NULL);
vec.push_back(node);
}
@ -207,6 +208,13 @@ void ChartParser::CreateInputPaths(const InputType &input)
}
}
const InputPath &ChartParser::GetInputPath(size_t startPos, size_t endPos) const
{
size_t offset = endPos - startPos;
CHECK(offset < m_targetPhrasesfromPt[startPos].size());
return *m_targetPhrasesfromPt[startPos][offset];
}
InputPath &ChartParser::GetInputPath(size_t startPos, size_t endPos)
{
size_t offset = endPos - startPos;
@ -219,4 +227,10 @@ const Sentence &ChartParser::GetSentence() const {
return sentence;
}
size_t ChartParser::GetSize() const
{
return m_source.GetSize();
}
} // namespace Moses

View File

@ -63,6 +63,8 @@ public:
//! the sentence being decoded
const Sentence &GetSentence() const;
size_t GetSize() const;
const InputPath &GetInputPath(size_t startPos, size_t endPos) const;
private:
ChartParserUnknown m_unknown;

View File

@ -51,6 +51,8 @@ public:
return m_cellCollection.GetBase(WordsRange(begin, end)).GetTargetLabelSet();
}
const ChartParser &GetParser() const
{ return m_parser; }
const Sentence &GetSentence() const;
const ChartCellLabel &GetSourceAt(size_t at) const {

View File

@ -64,6 +64,11 @@ ConfusionNet::ConfusionNet()
: InputType()
{
stats.createOne();
const StaticData& staticData = StaticData::Instance();
if (staticData.IsChart()) {
m_defaultLabelSet.insert(StaticData::Instance().GetInputDefaultNonTerminal());
}
}
ConfusionNet::~ConfusionNet()
{

View File

@ -7,6 +7,7 @@
#include <iostream>
#include "Word.h"
#include "InputType.h"
#include "NonTerminal.h"
namespace Moses
{
@ -25,6 +26,7 @@ public:
protected:
std::vector<Column> data;
NonTerminalSet m_defaultLabelSet;
bool ReadFormat0(std::istream&,const std::vector<FactorType>& factorOrder);
bool ReadFormat1(std::istream&,const std::vector<FactorType>& factorOrder);
@ -71,8 +73,7 @@ public:
TranslationOptionCollection* CreateTranslationOptionCollection() const;
const NonTerminalSet &GetLabelSet(size_t /*startPos*/, size_t /*endPos*/) const {
CHECK(false);
return *(new NonTerminalSet());
return m_defaultLabelSet;
}
};

View File

@ -10,10 +10,11 @@ using namespace std;
namespace Moses
{
InputPath::InputPath(const Phrase &phrase, const WordsRange &range, const InputPath *prevNode
InputPath::InputPath(const Phrase &phrase, const NonTerminalSet &sourceNonTerms, const WordsRange &range, const InputPath *prevNode
,const ScoreComponentCollection *inputScore)
:m_prevNode(prevNode)
,m_phrase(phrase)
,m_sourceNonTerms(sourceNonTerms)
,m_range(range)
,m_inputScore(inputScore)
{

View File

@ -5,6 +5,7 @@
#include <list>
#include "Phrase.h"
#include "WordsRange.h"
#include "NonTerminal.h"
namespace Moses
{
@ -33,6 +34,7 @@ protected:
WordsRange m_range;
const ScoreComponentCollection *m_inputScore;
std::map<const PhraseDictionary*, std::pair<const TargetPhraseCollection*, const void*> > m_targetPhrases;
const NonTerminalSet m_sourceNonTerms;
std::vector<size_t> m_placeholders;
@ -44,13 +46,16 @@ public:
, m_inputScore(NULL) {
}
InputPath(const Phrase &phrase, const WordsRange &range, const InputPath *prevNode
InputPath(const Phrase &phrase, const NonTerminalSet &sourceNonTerms, const WordsRange &range, const InputPath *prevNode
,const ScoreComponentCollection *inputScore);
~InputPath();
const Phrase &GetPhrase() const {
return m_phrase;
}
const NonTerminalSet &GetNonTerminalSet() const {
return m_sourceNonTerms;
}
const WordsRange &GetWordsRange() const {
return m_range;
}

View File

@ -1,3 +1,20 @@
#include "NonTerminal.h"
using namespace std;
namespace Moses {
std::ostream& operator<<(std::ostream &out, const NonTerminalSet &obj)
{
NonTerminalSet::const_iterator iter;
for (iter = obj.begin(); iter != obj.end(); ++iter) {
const Word &word = *iter;
out << word << " ";
}
return out;
}
}

View File

@ -22,6 +22,7 @@
#include "Factor.h"
#include "Word.h"
#include <iostream>
#include <boost/functional/hash.hpp>
#include <boost/unordered_set.hpp>
@ -61,4 +62,6 @@ typedef boost::unordered_set<Word,
NonTerminalHasher,
NonTerminalEqualityPred> NonTerminalSet;
std::ostream& operator<<(std::ostream&, const NonTerminalSet&);
} // namespace Moses

View File

@ -17,6 +17,7 @@
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <iostream>
#include "ChartRuleLookupManagerMemory.h"
#include "DotChartInMemory.h"
@ -28,6 +29,8 @@
#include "moses/ChartCellCollection.h"
#include "moses/TranslationModel/PhraseDictionaryMemory.h"
using namespace std;
namespace Moses
{
@ -40,8 +43,7 @@ ChartRuleLookupManagerMemory::ChartRuleLookupManagerMemory(
{
CHECK(m_dottedRuleColls.size() == 0);
const Sentence &src = parser.GetSentence();
size_t sourceSize = src.GetSize();
size_t sourceSize = parser.GetSize();
m_dottedRuleColls.resize(sourceSize);
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode();
@ -178,8 +180,8 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
DottedRuleColl & dottedRuleColl)
{
// source non-terminal labels for the remainder
const NonTerminalSet &sourceNonTerms =
GetSentence().GetLabelSet(startPos, endPos);
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
const NonTerminalSet &sourceNonTerms = inputPath.GetNonTerminalSet();
// target non-terminal labels for the remainder
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);

View File

@ -33,6 +33,7 @@ TranslationOptionCollectionConfusionNet::TranslationOptionCollectionConfusionNet
InputPathList &list = vec.back();
WordsRange range(startPos, startPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, startPos);
const ConfusionNet::Column &col = input.GetColumn(startPos);
for (size_t i = 0; i < col.size(); ++i) {
@ -44,7 +45,7 @@ TranslationOptionCollectionConfusionNet::TranslationOptionCollectionConfusionNet
ScoreComponentCollection *inputScore = new ScoreComponentCollection();
inputScore->Assign(inputFeature, scores);
InputPath *node = new InputPath(subphrase, range, NULL, inputScore);
InputPath *node = new InputPath(subphrase, labels, range, NULL, inputScore);
list.push_back(node);
m_phraseDictionaryQueue.push_back(node);
@ -55,13 +56,14 @@ TranslationOptionCollectionConfusionNet::TranslationOptionCollectionConfusionNet
for (size_t phaseSize = 2; phaseSize <= size; ++phaseSize) {
for (size_t startPos = 0; startPos < size - phaseSize + 1; ++startPos) {
size_t endPos = startPos + phaseSize -1;
WordsRange range(startPos, endPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
vector<InputPathList> &vec = m_targetPhrasesfromPt[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
// loop thru every previous path
const InputPathList &prevNodes = GetInputPathList(startPos, endPos - 1);
@ -88,7 +90,7 @@ TranslationOptionCollectionConfusionNet::TranslationOptionCollectionConfusionNet
ScoreComponentCollection *inputScore = new ScoreComponentCollection(*prevInputScore);
inputScore->PlusEquals(inputFeature, scores);
InputPath *node = new InputPath(subphrase, range, &prevNode, inputScore);
InputPath *node = new InputPath(subphrase, labels, range, &prevNode, inputScore);
list.push_back(node);
m_phraseDictionaryQueue.push_back(node);

View File

@ -42,16 +42,17 @@ TranslationOptionCollectionText::TranslationOptionCollectionText(Sentence const
size_t endPos = startPos + phaseSize -1;
vector<InputPath*> &vec = m_targetPhrasesfromPt[startPos];
Phrase subphrase(input.GetSubString(WordsRange(startPos, endPos)));
WordsRange range(startPos, endPos);
Phrase subphrase(input.GetSubString(WordsRange(startPos, endPos)));
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
InputPath *node;
if (range.GetNumWordsCovered() == 1) {
node = new InputPath(subphrase, range, NULL, NULL);
node = new InputPath(subphrase, labels, range, NULL, NULL);
vec.push_back(node);
} else {
const InputPath &prevNode = GetInputPath(startPos, endPos - 1);
node = new InputPath(subphrase, range, &prevNode, NULL);
node = new InputPath(subphrase, labels, range, &prevNode, NULL);
vec.push_back(node);
}

View File

@ -40,11 +40,6 @@ public:
*/
void GetAsEdgeMatrix(std::vector<std::vector<bool> >& edges) const;
const NonTerminalSet &GetLabelSet(size_t /*startPos*/, size_t /*endPos*/) const {
CHECK(false);
return *(new NonTerminalSet());
}
};
}