/*********************************************************************** Moses - factored phrase-based language decoder Copyright (C) 2011 University of Edinburgh This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA ***********************************************************************/ #include "ChartRuleLookupManagerOnDisk.h" #include #include "PhraseDictionaryOnDisk.h" #include "StaticData.h" #include "DotChartOnDisk.h" #include "ChartTranslationOptionList.h" #include "../../OnDiskPt/src/TargetPhraseCollection.h" using namespace std; namespace Moses { ChartRuleLookupManagerOnDisk::ChartRuleLookupManagerOnDisk( const InputType &sentence, const ChartCellCollection &cellColl, const PhraseDictionaryOnDisk &dictionary, OnDiskPt::OnDiskWrapper &dbWrapper, const LMList *languageModels, const WordPenaltyProducer *wpProducer, const std::vector &inputFactorsVec, const std::vector &outputFactorsVec, const std::vector &weight, const std::string &filePath) : ChartRuleLookupManager(sentence, cellColl) , m_dictionary(dictionary) , m_dbWrapper(dbWrapper) , m_languageModels(languageModels) , m_wpProducer(wpProducer) , m_inputFactorsVec(inputFactorsVec) , m_outputFactorsVec(outputFactorsVec) , m_weight(weight) , m_filePath(filePath) { assert(m_expandableDottedRuleListVec.size() == 0); size_t sourceSize = sentence.GetSize(); m_expandableDottedRuleListVec.resize(sourceSize); for (size_t ind = 0; ind < m_expandableDottedRuleListVec.size(); ++ind) { DottedRuleOnDisk *initDottedRule = new DottedRuleOnDisk(m_dbWrapper.GetRootSourceNode()); DottedRuleStackOnDisk *processedStack = new DottedRuleStackOnDisk(sourceSize - ind + 1); processedStack->Add(0, initDottedRule); // init rule. stores the top node in tree m_expandableDottedRuleListVec[ind] = processedStack; } } ChartRuleLookupManagerOnDisk::~ChartRuleLookupManagerOnDisk() { std::map::const_iterator iterCache; for (iterCache = m_cache.begin(); iterCache != m_cache.end(); ++iterCache) { delete iterCache->second; } m_cache.clear(); RemoveAllInColl(m_expandableDottedRuleListVec); RemoveAllInColl(m_sourcePhraseNode); } void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( const WordsRange &range, bool adhereTableLimit, ChartTranslationOptionList &outColl) { const StaticData &staticData = StaticData::Instance(); size_t rulesLimit = StaticData::Instance().GetRuleLimit(); size_t relEndPos = range.GetEndPos() - range.GetStartPos(); size_t absEndPos = range.GetEndPos(); // MAIN LOOP. create list of nodes of target phrases DottedRuleStackOnDisk &expandableDottedRuleList = *m_expandableDottedRuleListVec[range.GetStartPos()]; // sort save nodes so only do nodes with most counts expandableDottedRuleList.SortSavedNodes(); size_t numDerivations = 0 ,maxDerivations = 999999; // staticData.GetMaxDerivations(); bool overThreshold = true; const DottedRuleStackOnDisk::SavedNodeColl &savedNodeColl = expandableDottedRuleList.GetSavedNodeColl(); //cerr << "savedNodeColl=" << savedNodeColl.size() << " "; for (size_t ind = 0; ind < (savedNodeColl.size()) && ((numDerivations < maxDerivations) || overThreshold) ; ++ind) { const SavedNodeOnDisk &savedNode = *savedNodeColl[ind]; const DottedRuleOnDisk &prevDottedRule = savedNode.GetDottedRule(); const OnDiskPt::PhraseNode &prevNode = prevDottedRule.GetLastNode(); const CoveredChartSpan *prevCoveredChartSpan = prevDottedRule.GetLastCoveredChartSpan(); size_t startPos = (prevCoveredChartSpan == NULL) ? range.GetStartPos() : prevCoveredChartSpan->GetWordsRange().GetEndPos() + 1; // search for terminal symbol if (startPos == absEndPos) { const Word &sourceWord = GetSentence().GetWord(absEndPos); OnDiskPt::Word *sourceWordBerkeleyDb = m_dbWrapper.ConvertFromMoses(Input, m_inputFactorsVec, sourceWord); if (sourceWordBerkeleyDb != NULL) { const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceWordBerkeleyDb, m_dbWrapper); if (node != NULL) { // TODO figure out why source word is needed from node, not from sentence // prob to do with factors or non-term //const Word &sourceWord = node->GetSourceWord(); CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(absEndPos, absEndPos , sourceWord , prevCoveredChartSpan); DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, newCoveredChartSpan); expandableDottedRuleList.Add(relEndPos+1, dottedRule); // cache for cleanup m_sourcePhraseNode.push_back(node); } delete sourceWordBerkeleyDb; } } // search for non-terminals size_t endPos, stackInd; if (startPos > absEndPos) continue; else if (startPos == range.GetStartPos() && range.GetEndPos() > range.GetStartPos()) { // start. endPos = absEndPos - 1; stackInd = relEndPos; } else { endPos = absEndPos; stackInd = relEndPos + 1; } // size_t nonTermNumWordsCovered = endPos - startPos + 1; // get target nonterminals in this span from chart const NonTerminalSet &chartNonTermSet = GetCellCollection().GetConstituentLabelSet(WordsRange(startPos, endPos)); //const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal() // ,&defaultTargetNonTerm = staticData.GetOutputDefaultNonTerminal(); // go through each SOURCE lhs const NonTerminalSet &sourceLHSSet = GetSentence().GetLabelSet(startPos, endPos); NonTerminalSet::const_iterator iterSourceLHS; for (iterSourceLHS = sourceLHSSet.begin(); iterSourceLHS != sourceLHSSet.end(); ++iterSourceLHS) { const Word &sourceLHS = *iterSourceLHS; OnDiskPt::Word *sourceLHSBerkeleyDb = m_dbWrapper.ConvertFromMoses(Input, m_inputFactorsVec, sourceLHS); if (sourceLHSBerkeleyDb == NULL) { delete sourceLHSBerkeleyDb; continue; // vocab not in pt. node definately won't be in there } const OnDiskPt::PhraseNode *sourceNode = prevNode.GetChild(*sourceLHSBerkeleyDb, m_dbWrapper); delete sourceLHSBerkeleyDb; if (sourceNode == NULL) continue; // didn't find source node // go through each TARGET lhs NonTerminalSet::const_iterator iterChartNonTerm; for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) { const Word &chartNonTerm = *iterChartNonTerm; //cerr << sourceLHS << " " << defaultSourceNonTerm << " " << chartNonTerm << " " << defaultTargetNonTerm << endl; //bool isSyntaxNonTerm = (sourceLHS != defaultSourceNonTerm) || (chartNonTerm != defaultTargetNonTerm); bool doSearch = true; //isSyntaxNonTerm ? nonTermNumWordsCovered <= maxSyntaxSpan : // nonTermNumWordsCovered <= maxDefaultSpan; if (doSearch) { OnDiskPt::Word *chartNonTermBerkeleyDb = m_dbWrapper.ConvertFromMoses(Output, m_outputFactorsVec, chartNonTerm); if (chartNonTermBerkeleyDb == NULL) continue; const OnDiskPt::PhraseNode *node = sourceNode->GetChild(*chartNonTermBerkeleyDb, m_dbWrapper); delete chartNonTermBerkeleyDb; if (node == NULL) continue; // found matching entry //const Word &sourceWord = node->GetSourceWord(); CoveredChartSpan *newCoveredChartSpan = new CoveredChartSpan(startPos, endPos , chartNonTerm , prevCoveredChartSpan); DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, newCoveredChartSpan); expandableDottedRuleList.Add(stackInd, dottedRule); m_sourcePhraseNode.push_back(node); } } // for (iterChartNonTerm delete sourceNode; } // for (iterLabelListf // return list of target phrases DottedRuleCollOnDisk &nodes = expandableDottedRuleList.Get(relEndPos + 1); // source LHS DottedRuleCollOnDisk::const_iterator iterDottedRuleColl; for (iterDottedRuleColl = nodes.begin(); iterDottedRuleColl != nodes.end(); ++iterDottedRuleColl) { // node of last source word const DottedRuleOnDisk &prevDottedRule = **iterDottedRuleColl; if (prevDottedRule.Done()) continue; prevDottedRule.Done(true); const CoveredChartSpan *coveredChartSpan = prevDottedRule.GetLastCoveredChartSpan(); assert(coveredChartSpan); const OnDiskPt::PhraseNode &prevNode = prevDottedRule.GetLastNode(); //get node for each source LHS const NonTerminalSet &lhsSet = GetSentence().GetLabelSet(range.GetStartPos(), range.GetEndPos()); NonTerminalSet::const_iterator iterLabelSet; for (iterLabelSet = lhsSet.begin(); iterLabelSet != lhsSet.end(); ++iterLabelSet) { const Word &sourceLHS = *iterLabelSet; OnDiskPt::Word *sourceLHSBerkeleyDb = m_dbWrapper.ConvertFromMoses(Input, m_inputFactorsVec, sourceLHS); if (sourceLHSBerkeleyDb == NULL) continue; const TargetPhraseCollection *targetPhraseCollection = NULL; const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceLHSBerkeleyDb, m_dbWrapper); if (node) { UINT64 tpCollFilePos = node->GetValue(); std::map::const_iterator iterCache = m_cache.find(tpCollFilePos); if (iterCache == m_cache.end()) { // not in case overThreshold = node->GetCount(0) > staticData.GetRuleCountThreshold(); //cerr << node->GetCount(0) << " "; const OnDiskPt::TargetPhraseCollection *tpcollBerkeleyDb = node->GetTargetPhraseCollection(m_dictionary.GetTableLimit(), m_dbWrapper); targetPhraseCollection = tpcollBerkeleyDb->ConvertToMoses(m_inputFactorsVec ,m_outputFactorsVec ,m_dictionary ,m_weight ,m_wpProducer ,*m_languageModels ,m_filePath , m_dbWrapper.GetVocab()); delete tpcollBerkeleyDb; m_cache[tpCollFilePos] = targetPhraseCollection; } else { // jsut get out of cache targetPhraseCollection = iterCache->second; } assert(targetPhraseCollection); outColl.Add(*targetPhraseCollection, *coveredChartSpan, GetCellCollection(), adhereTableLimit, rulesLimit); numDerivations++; } // if (node) delete node; delete sourceLHSBerkeleyDb; } } } // for (size_t ind = 0; ind < savedNodeColl.size(); ++ind) outColl.CreateChartRules(rulesLimit); //cerr << numDerivations << " "; } } // namespace Moses