diff --git a/src/common/decoder_main.cpp b/src/common/decoder_main.cpp index e9e1fd4d..ebdb8e97 100755 --- a/src/common/decoder_main.cpp +++ b/src/common/decoder_main.cpp @@ -27,10 +27,12 @@ int main(int argc, char* argv[]) { std::size_t lineNum = 0; std::size_t taskCounter = 0; + size_t miniBatch = god.Get("mini-batch"); size_t maxiBatch = god.Get("maxi-batch"); //std::cerr << "mode=" << god.Get("mode") << std::endl; if (god.Get("wipo")) { + miniBatch = 1; maxiBatch = 1; } @@ -52,25 +54,31 @@ int main(int argc, char* argv[]) { ThreadPool pool(totalThreads, totalThreads); LOG(info) << "Reading input"; - std::shared_ptr sentences(new Sentences()); + SentencesPtr maxiSentences(new Sentences()); while (std::getline(god.GetInputStream(), in)) { - sentences->push_back(SentencePtr(new Sentence(god, lineNum++, in))); + maxiSentences->push_back(SentencePtr(new Sentence(god, lineNum++, in))); - if (sentences->size() >= maxiBatch) { - pool.enqueue( - [&god,sentences,taskCounter]{ return TranslationTask(god, sentences, taskCounter); } - ); + if (maxiSentences->size() >= maxiBatch) { + maxiSentences->SortByLength(); + while (maxiSentences->size()) { + SentencesPtr miniSentences = maxiSentences->NextMiniBatch(miniBatch); + pool.enqueue( + [&god,miniSentences,taskCounter]{ return TranslationTask(god, miniSentences, taskCounter); } + ); + } - sentences.reset(new Sentences()); + maxiSentences.reset(new Sentences()); taskCounter++; } } - if (sentences->size()) { + maxiSentences->SortByLength(); + while (maxiSentences->size()) { + SentencesPtr miniSentences = maxiSentences->NextMiniBatch(miniBatch); pool.enqueue( - [&god,sentences,taskCounter]{ return TranslationTask(god, sentences, taskCounter); } + [&god,miniSentences,taskCounter]{ return TranslationTask(god, miniSentences, taskCounter); } ); } } diff --git a/src/common/search.h b/src/common/search.h index a93084e0..f372336e 100755 --- a/src/common/search.h +++ b/src/common/search.h @@ -12,6 +12,7 @@ class Search { public: Search(const God &god); virtual ~Search(); + std::shared_ptr Process(const God &god, const Sentences& sentences); States NewStates() const; diff --git a/src/common/sentence.cpp b/src/common/sentence.cpp index b2c090d0..378b7bae 100755 --- a/src/common/sentence.cpp +++ b/src/common/sentence.cpp @@ -82,5 +82,18 @@ void Sentences::SortByLength() { std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer()); } +SentencesPtr Sentences::NextMiniBatch(size_t batchsize) +{ + SentencesPtr sentences(new Sentences()); + size_t startInd = (batchsize > size()) ? 0 : size() - batchsize; + for (size_t i = startInd; i < size(); ++i) { + SentencePtr sentence = at(i); + sentences->push_back(sentence); + } + + coll_.resize(startInd); + return sentences; +} + } diff --git a/src/common/sentence.h b/src/common/sentence.h index b54d59eb..d23ede66 100755 --- a/src/common/sentence.h +++ b/src/common/sentence.h @@ -27,6 +27,9 @@ class Sentence { using SentencePtr = std::shared_ptr; +////////////////////////////////////////////////////////////////// +class Sentences; +using SentencesPtr = std::shared_ptr; class Sentences { public: @@ -59,6 +62,8 @@ class Sentences { void SortByLength(); + SentencesPtr NextMiniBatch(size_t batchsize); + protected: std::vector coll_; size_t taskCounter_; @@ -68,5 +73,6 @@ class Sentences { Sentences(const Sentences &) = delete; }; + } diff --git a/src/common/translation_task.cpp b/src/common/translation_task.cpp index 342dccc9..dc403ce8 100755 --- a/src/common/translation_task.cpp +++ b/src/common/translation_task.cpp @@ -12,45 +12,18 @@ void TranslationTask(const God &god, std::shared_ptr sentences, size_ Search &search = god.GetSearch(); try { - size_t miniBatch; - if (search.GetDeviceInfo().deviceType == CPUDevice) { - miniBatch = 1; - } - else { - miniBatch = god.Get("mini-batch"); - } - - Histories allHistories; - sentences->SortByLength(); - - size_t bunchId = 0; - std::shared_ptr decodeSentences(new Sentences(taskCounter, bunchId++)); - - for (size_t i = 0; i < sentences->size(); ++i) { - decodeSentences->push_back(sentences->at(i)); - - if (decodeSentences->size() >= miniBatch) { - //cerr << "decodeSentences=" << decodeSentences->GetMaxLength() << endl; - assert(decodeSentences->size()); - std::shared_ptr histories = search.Process(god, *decodeSentences); - allHistories.Append(*histories.get()); - - decodeSentences.reset(new Sentences(taskCounter, bunchId++)); - } - } - - if (decodeSentences->size()) { - std::shared_ptr histories = search.Process(god, *decodeSentences); - allHistories.Append(*histories.get()); - } - - allHistories.SortByLineNum(); - - std::stringstream strm; - Printer(god, allHistories, strm); + std::shared_ptr histories = search.Process(god, *sentences); OutputCollector &outputCollector = god.GetOutputCollector(); - outputCollector.Write(taskCounter, strm.str()); + for (size_t i = 0; i < histories->size(); ++i) { + const History &history = *histories->at(i); + size_t lineNum = history.GetLineNum(); + + std::stringstream strm; + Printer(god, history, strm); + + outputCollector.Write(lineNum, strm.str()); + } } #ifdef CUDA catch(thrust::system_error &e)