// $Id$ // vim:tabstop=2 /*********************************************************************** Moses - factored phrase-based language decoder Copyright (C) 2006 University of Edinburgh This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA ***********************************************************************/ #ifndef moses_PhraseTableCreator_h #define moses_PhraseTableCreator_h #include #include #include #include #include #include #include "moses/InputFileStream.h" #include "moses/ThreadPool.h" #include "moses/Util.h" #include "BlockHashIndex.h" #include "StringVector.h" #include "StringVectorTemp.h" #include "CanonicalHuffman.h" namespace Moses { typedef std::pair AlignPoint; template class Counter { public: typedef boost::unordered_map FreqMap; typedef typename FreqMap::iterator iterator; typedef typename FreqMap::mapped_type mapped_type; typedef typename FreqMap::value_type value_type; private: #ifdef WITH_THREADS boost::mutex m_mutex; #endif FreqMap m_freqMap; size_t m_maxSize; std::vector m_bestVec; struct FreqSorter { bool operator()(const value_type& a, const value_type& b) const { if(a.second > b.second) return true; // Check impact on translation quality! if(a.second == b.second && a.first > b.first) return true; return false; } }; public: Counter() : m_maxSize(0) {} iterator Begin() { return m_freqMap.begin(); } iterator End() { return m_freqMap.end(); } void Increase(DataType data) { #ifdef WITH_THREADS boost::mutex::scoped_lock lock(m_mutex); #endif m_freqMap[data]++; } void IncreaseBy(DataType data, size_t num) { #ifdef WITH_THREADS boost::mutex::scoped_lock lock(m_mutex); #endif m_freqMap[data] += num; } mapped_type& operator[](DataType data) { return m_freqMap[data]; } size_t Size() { #ifdef WITH_THREADS boost::mutex::scoped_lock lock(m_mutex); #endif return m_freqMap.size(); } void Quantize(size_t maxSize) { #ifdef WITH_THREADS boost::mutex::scoped_lock lock(m_mutex); #endif m_maxSize = maxSize; std::vector > freqVec; freqVec.insert(freqVec.begin(), m_freqMap.begin(), m_freqMap.end()); std::sort(freqVec.begin(), freqVec.end(), FreqSorter()); for(size_t i = 0; i < freqVec.size() && i < m_maxSize; i++) m_bestVec.push_back(freqVec[i].first); std::sort(m_bestVec.begin(), m_bestVec.end()); FreqMap t_freqMap; for(typename std::vector >::iterator it = freqVec.begin(); it != freqVec.end(); it++) { DataType closest = LowerBound(it->first); t_freqMap[closest] += it->second; } m_freqMap.swap(t_freqMap); } void Clear() { #ifdef WITH_THREADS boost::mutex::scoped_lock lock(m_mutex); #endif m_freqMap.clear(); } DataType LowerBound(DataType data) { if(m_maxSize == 0 || m_bestVec.size() == 0) return data; else { typename std::vector::iterator it = std::lower_bound(m_bestVec.begin(), m_bestVec.end(), data); if(it != m_bestVec.end()) return *it; else return m_bestVec.back(); } } }; class PackedItem { private: long m_line; std::string m_sourcePhrase; std::string m_packedTargetPhrase; size_t m_rank; float m_score; public: PackedItem(long line, std::string sourcePhrase, std::string packedTargetPhrase, size_t rank, float m_score = 0); long GetLine() const; const std::string& GetSrc() const; const std::string& GetTrg() const; size_t GetRank() const; float GetScore() const; }; bool operator<(const PackedItem &pi1, const PackedItem &pi2); class PhraseTableCreator { public: enum Coding { None, REnc, PREnc }; private: std::string m_inPath; std::string m_outPath; std::string m_tempfilePath; std::FILE* m_outFile; size_t m_numScoreComponent; size_t m_sortScoreIndex; size_t m_warnMe; Coding m_coding; size_t m_orderBits; size_t m_fingerPrintBits; bool m_useAlignmentInfo; bool m_multipleScoreTrees; size_t m_quantize; size_t m_maxRank; static std::string m_phraseStopSymbol; static std::string m_separator; #ifdef WITH_THREADS size_t m_threads; boost::mutex m_mutex; #endif BlockHashIndex m_srcHash; BlockHashIndex m_rnkHash; size_t m_maxPhraseLength; std::vector m_ranks; typedef std::pair SrcTrg; typedef std::pair SrcTrgString; typedef std::pair SrcTrgProb; struct SrcTrgProbSorter { bool operator()(const SrcTrgProb& a, const SrcTrgProb& b) const { if(a.first.first < b.first.first) return true; if(a.first.first == b.first.first && a.second > b.second) return true; if(a.first.first == b.first.first && a.second == b.second && a.first.second < b.first.second) return true; return false; } }; std::vector m_lexicalTableIndex; std::vector m_lexicalTable; StringVectorTemp* m_encodedTargetPhrases; StringVector* m_compressedTargetPhrases; boost::unordered_map m_targetSymbolsMap; boost::unordered_map m_sourceSymbolsMap; typedef Counter SymbolCounter; typedef Counter ScoreCounter; typedef Counter AlignCounter; typedef CanonicalHuffman SymbolTree; typedef CanonicalHuffman ScoreTree; typedef CanonicalHuffman AlignTree; SymbolCounter m_symbolCounter; SymbolTree* m_symbolTree; AlignCounter m_alignCounter; AlignTree* m_alignTree; std::vector m_scoreCounters; std::vector m_scoreTrees; std::priority_queue m_queue; long m_lastFlushedLine; long m_lastFlushedSourceNum; std::string m_lastFlushedSourcePhrase; std::vector m_lastSourceRange; std::priority_queue > m_rankQueue; std::vector m_lastCollection; void Save(); void PrintInfo(); void AddSourceSymbolId(std::string& symbol); unsigned GetSourceSymbolId(std::string& symbol); void AddTargetSymbolId(std::string& symbol); unsigned GetTargetSymbolId(std::string& symbol); unsigned GetOrAddTargetSymbolId(std::string& symbol); unsigned GetRank(unsigned srcIdx, unsigned trgIdx); unsigned EncodeREncSymbol1(unsigned symbol); unsigned EncodeREncSymbol2(unsigned position, unsigned rank); unsigned EncodeREncSymbol3(unsigned rank); unsigned EncodePREncSymbol1(unsigned symbol); unsigned EncodePREncSymbol2(int lOff, int rOff, unsigned rank); void EncodeTargetPhraseNone(std::vector& t, std::ostream& os); void EncodeTargetPhraseREnc(std::vector& s, std::vector& t, std::set& a, std::ostream& os); void EncodeTargetPhrasePREnc(std::vector& s, std::vector& t, std::set& a, size_t ownRank, std::ostream& os); void EncodeScores(std::vector& scores, std::ostream& os); void EncodeAlignment(std::set& alignment, std::ostream& os); std::string MakeSourceKey(std::string&); std::string MakeSourceTargetKey(std::string&, std::string&); void LoadLexicalTable(std::string filePath); void CreateRankHash(); void EncodeTargetPhrases(); void CalcHuffmanCodes(); void CompressTargetPhrases(); void AddRankedLine(PackedItem& pi); void FlushRankedQueue(bool force = false); std::string EncodeLine(std::vector& tokens, size_t ownRank); void AddEncodedLine(PackedItem& pi); void FlushEncodedQueue(bool force = false); std::string CompressEncodedCollection(std::string encodedCollection); void AddCompressedCollection(PackedItem& pi); void FlushCompressedQueue(bool force = false); public: PhraseTableCreator(std::string inPath, std::string outPath, std::string tempfilePath, size_t numScoreComponent = 5, size_t sortScoreIndex = 2, Coding coding = PREnc, size_t orderBits = 10, size_t fingerPrintBits = 16, bool useAlignmentInfo = false, bool multipleScoreTrees = true, size_t quantize = 0, size_t maxRank = 100, bool warnMe = true #ifdef WITH_THREADS , size_t threads = 2 #endif ); ~PhraseTableCreator(); friend class RankingTask; friend class EncodingTask; friend class CompressionTask; }; class RankingTask { private: #ifdef WITH_THREADS static boost::mutex m_mutex; static boost::mutex m_fileMutex; #endif static size_t m_lineNum; InputFileStream& m_inFile; PhraseTableCreator& m_creator; public: RankingTask(InputFileStream& inFile, PhraseTableCreator& creator); void operator()(); }; class EncodingTask { private: #ifdef WITH_THREADS static boost::mutex m_mutex; static boost::mutex m_fileMutex; #endif static size_t m_lineNum; static size_t m_sourcePhraseNum; static std::string m_lastSourcePhrase; InputFileStream& m_inFile; PhraseTableCreator& m_creator; public: EncodingTask(InputFileStream& inFile, PhraseTableCreator& creator); void operator()(); }; class CompressionTask { private: #ifdef WITH_THREADS static boost::mutex m_mutex; #endif static size_t m_collectionNum; StringVectorTemp& m_encodedCollections; PhraseTableCreator& m_creator; public: CompressionTask(StringVectorTemp& encodedCollections, PhraseTableCreator& creator); void operator()(); }; } #endif