distribute minibatches

This commit is contained in:
Hieu Hoang 2017-02-09 15:07:42 +00:00
parent 6ca1011070
commit 8ff1350fd1
5 changed files with 47 additions and 46 deletions

View File

@ -27,10 +27,12 @@ int main(int argc, char* argv[]) {
std::size_t lineNum = 0;
std::size_t taskCounter = 0;
size_t miniBatch = god.Get<size_t>("mini-batch");
size_t maxiBatch = god.Get<size_t>("maxi-batch");
//std::cerr << "mode=" << god.Get("mode") << std::endl;
if (god.Get<bool>("wipo")) {
miniBatch = 1;
maxiBatch = 1;
}
@ -52,25 +54,31 @@ int main(int argc, char* argv[]) {
ThreadPool pool(totalThreads, totalThreads);
LOG(info) << "Reading input";
std::shared_ptr<Sentences> sentences(new Sentences());
SentencesPtr maxiSentences(new Sentences());
while (std::getline(god.GetInputStream(), in)) {
sentences->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
maxiSentences->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
if (sentences->size() >= maxiBatch) {
pool.enqueue(
[&god,sentences,taskCounter]{ return TranslationTask(god, sentences, taskCounter); }
);
if (maxiSentences->size() >= maxiBatch) {
maxiSentences->SortByLength();
while (maxiSentences->size()) {
SentencesPtr miniSentences = maxiSentences->NextMiniBatch(miniBatch);
pool.enqueue(
[&god,miniSentences,taskCounter]{ return TranslationTask(god, miniSentences, taskCounter); }
);
}
sentences.reset(new Sentences());
maxiSentences.reset(new Sentences());
taskCounter++;
}
}
if (sentences->size()) {
maxiSentences->SortByLength();
while (maxiSentences->size()) {
SentencesPtr miniSentences = maxiSentences->NextMiniBatch(miniBatch);
pool.enqueue(
[&god,sentences,taskCounter]{ return TranslationTask(god, sentences, taskCounter); }
[&god,miniSentences,taskCounter]{ return TranslationTask(god, miniSentences, taskCounter); }
);
}
}

View File

@ -12,6 +12,7 @@ class Search {
public:
Search(const God &god);
virtual ~Search();
std::shared_ptr<Histories> Process(const God &god, const Sentences& sentences);
States NewStates() const;

View File

@ -82,5 +82,18 @@ void Sentences::SortByLength() {
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
}
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
{
SentencesPtr sentences(new Sentences());
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
for (size_t i = startInd; i < size(); ++i) {
SentencePtr sentence = at(i);
sentences->push_back(sentence);
}
coll_.resize(startInd);
return sentences;
}
}

View File

@ -27,6 +27,9 @@ class Sentence {
using SentencePtr = std::shared_ptr<Sentence>;
//////////////////////////////////////////////////////////////////
class Sentences;
using SentencesPtr = std::shared_ptr<Sentences>;
class Sentences {
public:
@ -59,6 +62,8 @@ class Sentences {
void SortByLength();
SentencesPtr NextMiniBatch(size_t batchsize);
protected:
std::vector<SentencePtr> coll_;
size_t taskCounter_;
@ -68,5 +73,6 @@ class Sentences {
Sentences(const Sentences &) = delete;
};
}

View File

@ -12,45 +12,18 @@ void TranslationTask(const God &god, std::shared_ptr<Sentences> sentences, size_
Search &search = god.GetSearch();
try {
size_t miniBatch;
if (search.GetDeviceInfo().deviceType == CPUDevice) {
miniBatch = 1;
}
else {
miniBatch = god.Get<size_t>("mini-batch");
}
Histories allHistories;
sentences->SortByLength();
size_t bunchId = 0;
std::shared_ptr<Sentences> decodeSentences(new Sentences(taskCounter, bunchId++));
for (size_t i = 0; i < sentences->size(); ++i) {
decodeSentences->push_back(sentences->at(i));
if (decodeSentences->size() >= miniBatch) {
//cerr << "decodeSentences=" << decodeSentences->GetMaxLength() << endl;
assert(decodeSentences->size());
std::shared_ptr<Histories> histories = search.Process(god, *decodeSentences);
allHistories.Append(*histories.get());
decodeSentences.reset(new Sentences(taskCounter, bunchId++));
}
}
if (decodeSentences->size()) {
std::shared_ptr<Histories> histories = search.Process(god, *decodeSentences);
allHistories.Append(*histories.get());
}
allHistories.SortByLineNum();
std::stringstream strm;
Printer(god, allHistories, strm);
std::shared_ptr<Histories> histories = search.Process(god, *sentences);
OutputCollector &outputCollector = god.GetOutputCollector();
outputCollector.Write(taskCounter, strm.str());
for (size_t i = 0; i < histories->size(); ++i) {
const History &history = *histories->at(i);
size_t lineNum = history.GetLineNum();
std::stringstream strm;
Printer(god, history, strm);
outputCollector.Write(lineNum, strm.str());
}
}
#ifdef CUDA
catch(thrust::system_error &e)