From 19e71f8061c26c517aa751cd4e7b66a1d75f0bcd Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 14 Apr 2016 20:49:38 +0200 Subject: [PATCH] towards n-best-lists --- CMakeLists.txt | 4 +- src/common/vocab.h | 39 +++++++++--- src/decoder/decoder_main.cu | 43 +++++++++++-- src/decoder/search.h | 123 ++++++++++++++++++------------------ src/rescorer/nbest.cpp | 10 ++- src/rescorer/nbest.h | 2 +- 6 files changed, 134 insertions(+), 87 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fa647eb..76e90b45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.1.0) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) project(amunn CXX) -SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -fPIC -funroll-loops -Wno-unused-result -Wno-deprecated") +SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated") #SET(CUDA_PROPAGATE_HOST_FLAGS OFF) -SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -fPIC -arch=sm_35 -lineinfo --use_fast_math") +SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -arch=sm_35 -lineinfo --use_fast_math") #SET(CUDA_VERBOSE_BUILD ON) include_directories(${amunn_SOURCE_DIR}) diff --git a/src/common/vocab.h b/src/common/vocab.h index f70c3069..f8997041 100644 --- a/src/common/vocab.h +++ b/src/common/vocab.h @@ -4,6 +4,10 @@ #include #include #include +#include + +#include "types.h" +#include "utils.h" class Vocab { public: @@ -15,8 +19,6 @@ class Vocab { str2id_[line] = c++; id2str_.push_back(line); } - //str2id_[""] = c; - //id2str_.push_back(""); } size_t operator[](const std::string& word) const { @@ -27,15 +29,32 @@ class Vocab { return 1; } - inline std::vector Encode(const std::vector& sentence, bool addEOS=false) const { - std::vector indexes; - for (auto& word: sentence) { - indexes.push_back((*this)[word]); + Sentence operator()(const std::vector& lineTokens, bool addEOS = true) const { + Sentence words(lineTokens.size()); + std::transform(lineTokens.begin(), lineTokens.end(), words.begin(), + [&](const std::string& w) { return (*this)[w]; }); + if(addEOS) + words.push_back(EOS); + return words; + } + + Sentence operator()(const std::string& line, bool addEOS = true) const { + std::vector lineTokens; + Split(line, lineTokens, " "); + return (*this)(lineTokens, addEOS); + } + + std::string operator()(const Sentence& sentence, bool ignoreEOS = true) const { + std::stringstream line; + for(size_t i = 0; i < sentence.size(); ++i) { + if(sentence[i] != EOS || !ignoreEOS) { + if(i > 0) { + line << " "; + } + line << (*this)[sentence[i]]; + } } - if (addEOS) { - indexes.push_back((*this)[""]); - } - return indexes; + return line.str(); } diff --git a/src/decoder/decoder_main.cu b/src/decoder/decoder_main.cu index 2dfe0d2f..292021f8 100644 --- a/src/decoder/decoder_main.cu +++ b/src/decoder/decoder_main.cu @@ -57,6 +57,33 @@ void ProgramOptions(int argc, char *argv[], } } +class BPE { + public: + BPE(const std::string& sep = "@@ ") + : sep_(sep) {} + + std::string split(const std::string& line) { + return line; + } + + std::string unsplit(const std::string& line) { + std::string joined = line; + size_t pos = joined.find(sep_); + while(pos != std::string::npos) { + joined.erase(pos, sep_.size()); + pos = joined.find(sep_, pos); + } + return joined; + } + + operator bool() const { + return true; + } + + private: + std::string sep_; +}; + int main(int argc, char* argv[]) { std::string modelPath, srcVocabPath, trgVocabPath; size_t device = 0; @@ -70,17 +97,23 @@ int main(int argc, char* argv[]) { Vocab trgVocab(trgVocabPath); std::cerr << "done." << std::endl; - Search search(model, srcVocab, trgVocab); + Search search(model); std::cerr << "Translating...\n"; std::ios_base::sync_with_stdio(false); - std::string line; + BPE bpe; + boost::timer::cpu_timer timer; - while(std::getline(std::cin, line)) { - auto result = search.Decode(line, beamSize); - std::cout << result << std::endl; + std::string in; + while(std::getline(std::cin, in)) { + Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in); + History history = search.Decode(sentence, beamSize); + std::string out = trgVocab(history.Top().first); + if(bpe) + out = bpe.unsplit(out); + std::cout << out << std::endl; } std::cerr << timer.format() << std::endl; return 0; diff --git a/src/decoder/search.h b/src/decoder/search.h index 4b73675e..e674183d 100644 --- a/src/decoder/search.h +++ b/src/decoder/search.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -16,24 +17,70 @@ #include #include +#include "types.h" #include "matrix.h" #include "dl4mt.h" -#include "vocab.h" #include "hypothesis.h" #include "utils.h" -#define EOL "" +typedef std::vector Beam; +typedef std::pair Result; +typedef std::vector NBestList; + +class History { + private: + struct HypothesisCoord { + bool operator<(const HypothesisCoord& hc) const { + return cost < hc.cost; + } + + size_t i; + size_t j; + float cost; + }; + + public: + void Add(const Beam& beam, bool last = false) { + for(size_t j = 0; j < beam.size(); ++j) + if(beam[j].GetWord() == EOS || last) + topHyps_.push({ history_.size(), j, beam[j].GetCost() }); + history_.push_back(beam); + } + + size_t size() const { + return history_.size(); + } + + NBestList NBest(size_t n) { + + } + + Result Top() const { + Sentence targetWords; + auto bestHypCoord = topHyps_.top(); + size_t start = bestHypCoord.i; + size_t j = bestHypCoord.j; + for(int i = start; i >= 0; i--) { + auto& bestHyp = history_[i][j]; + targetWords.push_back(bestHyp.GetWord()); + j = bestHyp.GetPrevStateIndex(); + } + + std::reverse(targetWords.begin(), targetWords.end()); + return Result(targetWords, history_[bestHypCoord.i][bestHypCoord.j]); + } + + private: + std::vector history_; + std::priority_queue topHyps_; + +}; class Search { - typedef std::vector Beam; - typedef std::vector History; - private: const Weights& model_; Encoder encoder_; Decoder decoder_; - const Vocab svcb_; - const Vocab tvcb_; mblas::Matrix State_, NextState_, BeamState_; mblas::Matrix Embeddings_, NextEmbeddings_; @@ -41,24 +88,14 @@ class Search { mblas::Matrix SourceContext_; public: - Search(const Weights& model, const Vocab& svcb, const Vocab tvcb) + Search(const Weights& model) : model_(model), encoder_(model_), - decoder_(model_), - svcb_(svcb), tvcb_(tvcb) + decoder_(model_) {} - std::string Decode(const std::string& source, size_t beamSize = 12) { - // this should happen somewhere else - std::vector sourceSplit; - Split(source, sourceSplit, " "); - std::vector sourceWords(sourceSplit.size()); - std::transform(sourceSplit.begin(), sourceSplit.end(), sourceWords.begin(), - [&](const std::string& w) { return svcb_[w]; }); - sourceWords.push_back(svcb_[EOL]); - + History Decode(const Sentence sourceWords, size_t beamSize = 12) { encoder_.GetContext(sourceWords, SourceContext_); - decoder_.EmptyState(State_, SourceContext_, 1); decoder_.EmptyEmbedding(Embeddings_, 1); @@ -72,13 +109,13 @@ class Search { Beam hyps; BestHyps(hyps, prevHyps, Probs_, beamSize); - history.push_back(hyps); + history.Add(hyps, history.size() + 1 == sourceWords.size() * 3); Beam survivors; std::vector beamWords; std::vector beamStateIds; for(auto& h : hyps) { - if(h.GetWord() != tvcb_[EOL]) { + if(h.GetWord() != EOS) { survivors.push_back(h); beamWords.push_back(h.GetWord()); beamStateIds.push_back(h.GetPrevStateIndex()); @@ -98,7 +135,7 @@ class Search { } while(history.size() < sourceWords.size() * 3); - return FindBest(history); + return history; } void BestHyps(Beam& bestHyps, const Beam& prevHyps, mblas::Matrix& Probs, const size_t beamSize) { @@ -131,44 +168,4 @@ class Search { bestHyps.emplace_back(wordIndex, hypIndex, cost); } } - - std::string FindBest(const History& history) { - std::vector targetWords; - - size_t best = 0; - size_t beamSize = 0; - float bestCost = std::numeric_limits::lowest(); - - for(auto b = history.rbegin(); b != history.rend(); b++) { - if(b->size() > beamSize) { - beamSize = b->size(); - for(size_t i = 0; i < beamSize; ++i) { - if(b == history.rbegin() || (*b)[i].GetWord() == tvcb_[EOL]) { - if((*b)[i].GetCost() > bestCost) { - best = i; - bestCost = (*b)[i].GetCost(); - targetWords.clear(); - } - } - } - } - - auto& bestHyp = (*b)[best]; - targetWords.push_back(bestHyp.GetWord()); - best = bestHyp.GetPrevStateIndex(); - } - - std::reverse(targetWords.begin(), targetWords.end()); - std::stringstream translation; - for(size_t i = 0; i < targetWords.size(); ++i) { - if(tvcb_[targetWords[i]] != EOL) { - if(i > 0) { - translation << " "; - } - translation << tvcb_[targetWords[i]]; - } - } - return translation.str(); - } - }; \ No newline at end of file diff --git a/src/rescorer/nbest.cpp b/src/rescorer/nbest.cpp index e7135fa6..e3185092 100644 --- a/src/rescorer/nbest.cpp +++ b/src/rescorer/nbest.cpp @@ -39,9 +39,7 @@ std::vector NBest::GetTokens(const size_t index) const { } std::vector NBest::GetEncodedTokens(const size_t index) const { - std::vector tokens; - Split(srcSentences_[index], tokens); - return srcVocab_->Encode(tokens, true); + return (*srcVocab_)(srcSentences_[index]); } void NBest::Parse_(const std::string& path) { @@ -76,10 +74,10 @@ inline std::vector< std::vector< std::string > > NBest::SplitBatch(std::vector>& batch) const { +inline Batch NBest::EncodeBatch(const std::vector& batch) const { Batch encodedBatch; for (auto& sentence: batch) { - encodedBatch.push_back(trgVocab_->Encode(sentence, true)); + encodedBatch.push_back((*trgVocab_)(sentence)); } return encodedBatch; } @@ -103,7 +101,7 @@ inline Batch NBest::MaskAndTransposeBatch(const Batch& batch) const { Batch NBest::ProcessBatch(std::vector& batch) const { - return MaskAndTransposeBatch(EncodeBatch(SplitBatch(batch))); + return MaskAndTransposeBatch(EncodeBatch(batch)); } std::vector NBest::GetBatches(const size_t index) const { diff --git a/src/rescorer/nbest.h b/src/rescorer/nbest.h index 7c91aeec..d5ad6eac 100644 --- a/src/rescorer/nbest.h +++ b/src/rescorer/nbest.h @@ -41,7 +41,7 @@ class NBest { std::vector> SplitBatch(std::vector& batch) const; void ParseInputFile(const std::string& path); - Batch EncodeBatch(const std::vector>& batch) const; + Batch EncodeBatch(const std::vector& batch) const; Batch MaskAndTransposeBatch(const Batch& batch) const;