mosesdecoder/moses/ChartParser.cpp

314 lines
10 KiB
C++
Raw Normal View History

// $Id$
// vim:tabstop=2
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2010 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "ChartParser.h"
#include "ChartParserCallback.h"
#include "ChartRuleLookupManager.h"
#include "StaticData.h"
#include "TreeInput.h"
#include "Sentence.h"
#include "DecodeGraph.h"
#include "moses/FF/UnknownWordPenaltyProducer.h"
#include "moses/TranslationModel/PhraseDictionary.h"
#include "moses/TranslationTask.h"
using namespace std;
using namespace Moses;
namespace Moses
{
ChartParserUnknown
::ChartParserUnknown(ttasksptr const& ttask)
: m_ttask(ttask)
{ }
2013-05-29 21:16:15 +04:00
ChartParserUnknown::~ChartParserUnknown()
{
RemoveAllInColl(m_unksrcs);
}
2015-12-11 03:00:33 +03:00
AllOptions::ptr const&
2015-12-10 06:17:36 +03:00
ChartParserUnknown::
2015-12-11 03:00:33 +03:00
options() const
{
return m_ttask.lock()->options();
}
2015-12-10 06:17:36 +03:00
2015-12-11 03:00:33 +03:00
void
2015-12-10 06:17:36 +03:00
ChartParserUnknown::
Process(const Word &sourceWord, const Range &range, ChartParserCallback &to)
2013-05-29 21:16:15 +04:00
{
// unknown word, add as trans opt
const StaticData &staticData = StaticData::Instance();
2015-12-11 03:00:33 +03:00
const UnknownWordPenaltyProducer &unknownWordPenaltyProducer
= UnknownWordPenaltyProducer::Instance();
2013-05-29 21:16:15 +04:00
size_t isDigit = 0;
2015-12-10 06:17:36 +03:00
if (options()->unk.drop) {
const Factor *f = sourceWord[0]; // TODO hack. shouldn't know which factor is surface
2013-04-29 21:46:48 +04:00
const StringPiece s = f->GetString();
isDigit = s.find_first_of("0123456789");
if (isDigit == string::npos)
isDigit = 0;
else
isDigit = 1;
// modify the starting bitmap
}
2013-05-29 21:16:15 +04:00
Phrase* unksrc = new Phrase(1);
unksrc->AddWord() = sourceWord;
2013-06-24 17:45:20 +04:00
Word &newWord = unksrc->GetWord(0);
newWord.SetIsOOV(true);
m_unksrcs.push_back(unksrc);
2013-05-29 21:16:15 +04:00
// hack. Once the OOV FF is a phrase table, get rid of this
PhraseDictionary *firstPt = NULL;
if (PhraseDictionary::GetColl().size() == 0) {
firstPt = PhraseDictionary::GetColl()[0];
}
//TranslationOption *transOpt;
2015-12-10 06:17:36 +03:00
if (! options()->unk.drop || isDigit) {
// loop
2015-12-10 06:17:36 +03:00
const UnknownLHSList &lhsList = options()->syntax.unknown_lhs; // staticData.GetUnknownLHS();
UnknownLHSList::const_iterator iterLHS;
for (iterLHS = lhsList.begin(); iterLHS != lhsList.end(); ++iterLHS) {
const string &targetLHSStr = iterLHS->first;
float prob = iterLHS->second;
2013-05-29 21:16:15 +04:00
// lhs
//const Word &sourceLHS = staticData.GetInputDefaultNonTerminal();
2013-05-22 14:22:17 +04:00
Word *targetLHS = new Word(true);
2013-05-29 21:16:15 +04:00
2015-12-11 03:00:33 +03:00
targetLHS->CreateFromString(Output, options()->output.factor_order,
targetLHSStr, true);
UTIL_THROW_IF2(targetLHS->GetFactor(0) == NULL, "Null factor for target LHS");
2013-05-29 21:16:15 +04:00
// add to dictionary
TargetPhrase *targetPhrase = new TargetPhrase(firstPt);
Word &targetWord = targetPhrase->AddWord();
targetWord.CreateUnknownWord(sourceWord);
2013-05-29 21:16:15 +04:00
// scores
float unknownScore = FloorScore(TransformScore(prob));
targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
targetPhrase->SetTargetLHS(targetLHS);
targetPhrase->SetAlignmentInfo("0-0");
targetPhrase->EvaluateInIsolation(*unksrc);
2015-08-08 02:00:45 +03:00
2015-12-10 06:17:36 +03:00
if (!options()->output.detailed_tree_transrep_filepath.empty() ||
options()->nbest.print_trees || staticData.GetTreeStructure() != NULL) {
2015-11-02 03:00:37 +03:00
std::string prop = "[ ";
prop += (*targetLHS)[0]->GetString().as_string() + " ";
prop += sourceWord[0]->GetString().as_string() + " ]";
targetPhrase->SetProperty("Tree", prop);
}
// chart rule
to.AddPhraseOOV(*targetPhrase, m_cacheTargetPhraseCollection, range);
} // for (iterLHS
} else {
// drop source word. create blank trans opt
float unknownScore = FloorScore(-numeric_limits<float>::infinity());
2013-05-29 21:16:15 +04:00
TargetPhrase *targetPhrase = new TargetPhrase(firstPt);
// loop
2015-12-10 06:17:36 +03:00
const UnknownLHSList &lhsList = options()->syntax.unknown_lhs;//staticData.GetUnknownLHS();
UnknownLHSList::const_iterator iterLHS;
for (iterLHS = lhsList.begin(); iterLHS != lhsList.end(); ++iterLHS) {
const string &targetLHSStr = iterLHS->first;
//float prob = iterLHS->second;
2013-05-29 21:16:15 +04:00
2013-05-22 14:22:17 +04:00
Word *targetLHS = new Word(true);
targetLHS->CreateFromString(Output, staticData.options()->output.factor_order,
2015-12-11 03:00:33 +03:00
targetLHSStr, true);
UTIL_THROW_IF2(targetLHS->GetFactor(0) == NULL, "Null factor for target LHS");
2013-05-29 21:16:15 +04:00
targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
2014-08-08 18:59:34 +04:00
targetPhrase->EvaluateInIsolation(*unksrc);
targetPhrase->SetTargetLHS(targetLHS);
// chart rule
to.AddPhraseOOV(*targetPhrase, m_cacheTargetPhraseCollection, range);
}
}
}
ChartParser
::ChartParser(ttasksptr const& ttask, ChartCellCollectionBase &cells)
: m_ttask(ttask)
, m_unknown(ttask)
, m_decodeGraphList(StaticData::Instance().GetDecodeGraphs())
, m_source(*(ttask->GetSource().get()))
2013-05-29 21:16:15 +04:00
{
const StaticData &staticData = StaticData::Instance();
staticData.InitializeForInput(ttask);
2013-09-28 22:06:04 +04:00
CreateInputPaths(m_source);
const std::vector<PhraseDictionary*> &dictionaries = PhraseDictionary::GetColl();
assert(dictionaries.size() == m_decodeGraphList.size());
m_ruleLookupManagers.reserve(dictionaries.size());
for (std::size_t i = 0; i < dictionaries.size(); ++i) {
const PhraseDictionary *dict = dictionaries[i];
PhraseDictionary *nonConstDict = const_cast<PhraseDictionary*>(dict);
std::size_t maxChartSpan = m_decodeGraphList[i]->GetMaxChartSpan();
ChartRuleLookupManager *lookupMgr = nonConstDict->CreateRuleLookupManager(*this, cells, maxChartSpan);
2013-09-28 22:06:04 +04:00
m_ruleLookupManagers.push_back(lookupMgr);
}
}
2013-05-29 21:16:15 +04:00
ChartParser::~ChartParser()
{
RemoveAllInColl(m_ruleLookupManagers);
StaticData::Instance().CleanUpAfterSentenceProcessing(m_ttask.lock());
InputPathMatrix::const_iterator iterOuter;
for (iterOuter = m_inputPathMatrix.begin(); iterOuter != m_inputPathMatrix.end(); ++iterOuter) {
const std::vector<InputPath*> &outer = *iterOuter;
std::vector<InputPath*>::const_iterator iterInner;
for (iterInner = outer.begin(); iterInner != outer.end(); ++iterInner) {
InputPath *path = *iterInner;
delete path;
}
}
}
2015-10-25 16:37:59 +03:00
void ChartParser::Create(const Range &range, ChartParserCallback &to)
2013-05-29 21:16:15 +04:00
{
assert(m_decodeGraphList.size() == m_ruleLookupManagers.size());
2013-05-29 21:16:15 +04:00
std::vector <DecodeGraph*>::const_iterator iterDecodeGraph;
std::vector <ChartRuleLookupManager*>::const_iterator iterRuleLookupManagers = m_ruleLookupManagers.begin();
for (iterDecodeGraph = m_decodeGraphList.begin(); iterDecodeGraph != m_decodeGraphList.end(); ++iterDecodeGraph, ++iterRuleLookupManagers) {
const DecodeGraph &decodeGraph = **iterDecodeGraph;
assert(decodeGraph.GetSize() == 1);
ChartRuleLookupManager &ruleLookupManager = **iterRuleLookupManagers;
size_t maxSpan = decodeGraph.GetMaxChartSpan();
size_t last = m_source.GetSize()-1;
if (maxSpan != 0) {
2015-10-25 16:37:59 +03:00
last = min(last, range.GetStartPos()+maxSpan);
}
2015-10-25 16:37:59 +03:00
if (maxSpan == 0 || range.GetNumWordsCovered() <= maxSpan) {
const InputPath &inputPath = GetInputPath(range);
ruleLookupManager.GetChartRuleCollection(inputPath, last, to);
}
}
2015-12-09 03:00:35 +03:00
if (range.GetNumWordsCovered() == 1
&& range.GetStartPos() != 0
2015-12-08 02:34:57 +03:00
&& range.GetStartPos() != m_source.GetSize()-1) {
2015-12-10 06:17:36 +03:00
bool always = options()->unk.always_create_direct_transopt;
2015-12-08 02:34:57 +03:00
if (to.Empty() || always) {
// create unknown words for 1 word coverage where we don't have any trans options
2015-10-25 16:37:59 +03:00
const Word &sourceWord = m_source.GetWord(range.GetStartPos());
m_unknown.Process(sourceWord, range, to);
}
2013-05-29 21:16:15 +04:00
}
}
2013-05-29 21:16:15 +04:00
void ChartParser::CreateInputPaths(const InputType &input)
{
2013-07-30 18:04:37 +04:00
size_t size = input.GetSize();
2013-08-02 21:24:36 +04:00
m_inputPathMatrix.resize(size);
UTIL_THROW_IF2(input.GetType() != SentenceInput && input.GetType() != TreeInputType,
2015-12-14 02:07:15 +03:00
"Input must be a sentence or a tree, " <<
"not lattice or confusion networks");
2015-12-12 19:23:37 +03:00
TranslationTask const* ttask = m_ttask.lock().get();
2013-07-30 18:04:37 +04:00
for (size_t phaseSize = 1; phaseSize <= size; ++phaseSize) {
for (size_t startPos = 0; startPos < size - phaseSize + 1; ++startPos) {
size_t endPos = startPos + phaseSize -1;
2013-08-02 21:24:36 +04:00
vector<InputPath*> &vec = m_inputPathMatrix[startPos];
2013-07-30 18:04:37 +04:00
2015-10-25 16:37:59 +03:00
Range range(startPos, endPos);
Phrase subphrase(input.GetSubString(Range(startPos, endPos)));
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
2013-07-30 18:04:37 +04:00
InputPath *node;
if (range.GetNumWordsCovered() == 1) {
2015-12-12 19:23:37 +03:00
node = new InputPath(ttask, subphrase, labels, range, NULL, NULL);
2013-07-30 18:04:37 +04:00
vec.push_back(node);
} else {
const InputPath &prevNode = GetInputPath(startPos, endPos - 1);
2015-12-12 19:23:37 +03:00
node = new InputPath(ttask, subphrase, labels, range, &prevNode, NULL);
2013-07-30 18:04:37 +04:00
vec.push_back(node);
}
//m_inputPathQueue.push_back(node);
2013-07-30 18:04:37 +04:00
}
}
}
2015-10-25 16:37:59 +03:00
const InputPath &ChartParser::GetInputPath(const Range &range) const
{
2013-09-27 12:35:24 +04:00
return GetInputPath(range.GetStartPos(), range.GetEndPos());
}
const InputPath &ChartParser::GetInputPath(size_t startPos, size_t endPos) const
{
size_t offset = endPos - startPos;
UTIL_THROW_IF2(offset >= m_inputPathMatrix[startPos].size(),
2014-01-15 19:42:02 +04:00
"Out of bound: " << offset);
2013-08-02 21:24:36 +04:00
return *m_inputPathMatrix[startPos][offset];
}
InputPath &ChartParser::GetInputPath(size_t startPos, size_t endPos)
{
size_t offset = endPos - startPos;
UTIL_THROW_IF2(offset >= m_inputPathMatrix[startPos].size(),
2014-01-15 19:42:02 +04:00
"Out of bound: " << offset);
2013-08-02 21:24:36 +04:00
return *m_inputPathMatrix[startPos][offset];
}
/*
const Sentence &ChartParser::GetSentence() const {
const Sentence &sentence = static_cast<const Sentence&>(m_source);
return sentence;
}
*/
size_t ChartParser::GetSize() const
{
return m_source.GetSize();
}
long ChartParser::GetTranslationId() const
{
2013-08-07 17:18:12 +04:00
return m_source.GetTranslationId();
}
2015-12-10 06:17:36 +03:00
2015-12-11 03:00:33 +03:00
AllOptions::ptr const&
2015-12-10 06:17:36 +03:00
ChartParser::
2015-12-11 03:00:33 +03:00
options() const
{
return m_ttask.lock()->options();
}
2015-12-10 06:17:36 +03:00
} // namespace Moses