mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
distribute minibatches
This commit is contained in:
parent
6ca1011070
commit
8ff1350fd1
@ -27,10 +27,12 @@ int main(int argc, char* argv[]) {
|
||||
std::size_t lineNum = 0;
|
||||
std::size_t taskCounter = 0;
|
||||
|
||||
size_t miniBatch = god.Get<size_t>("mini-batch");
|
||||
size_t maxiBatch = god.Get<size_t>("maxi-batch");
|
||||
//std::cerr << "mode=" << god.Get("mode") << std::endl;
|
||||
|
||||
if (god.Get<bool>("wipo")) {
|
||||
miniBatch = 1;
|
||||
maxiBatch = 1;
|
||||
}
|
||||
|
||||
@ -52,25 +54,31 @@ int main(int argc, char* argv[]) {
|
||||
ThreadPool pool(totalThreads, totalThreads);
|
||||
LOG(info) << "Reading input";
|
||||
|
||||
std::shared_ptr<Sentences> sentences(new Sentences());
|
||||
SentencesPtr maxiSentences(new Sentences());
|
||||
|
||||
while (std::getline(god.GetInputStream(), in)) {
|
||||
sentences->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
|
||||
maxiSentences->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
|
||||
|
||||
if (sentences->size() >= maxiBatch) {
|
||||
pool.enqueue(
|
||||
[&god,sentences,taskCounter]{ return TranslationTask(god, sentences, taskCounter); }
|
||||
);
|
||||
if (maxiSentences->size() >= maxiBatch) {
|
||||
maxiSentences->SortByLength();
|
||||
while (maxiSentences->size()) {
|
||||
SentencesPtr miniSentences = maxiSentences->NextMiniBatch(miniBatch);
|
||||
pool.enqueue(
|
||||
[&god,miniSentences,taskCounter]{ return TranslationTask(god, miniSentences, taskCounter); }
|
||||
);
|
||||
}
|
||||
|
||||
sentences.reset(new Sentences());
|
||||
maxiSentences.reset(new Sentences());
|
||||
taskCounter++;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (sentences->size()) {
|
||||
maxiSentences->SortByLength();
|
||||
while (maxiSentences->size()) {
|
||||
SentencesPtr miniSentences = maxiSentences->NextMiniBatch(miniBatch);
|
||||
pool.enqueue(
|
||||
[&god,sentences,taskCounter]{ return TranslationTask(god, sentences, taskCounter); }
|
||||
[&god,miniSentences,taskCounter]{ return TranslationTask(god, miniSentences, taskCounter); }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ class Search {
|
||||
public:
|
||||
Search(const God &god);
|
||||
virtual ~Search();
|
||||
|
||||
std::shared_ptr<Histories> Process(const God &god, const Sentences& sentences);
|
||||
|
||||
States NewStates() const;
|
||||
|
@ -82,5 +82,18 @@ void Sentences::SortByLength() {
|
||||
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
|
||||
}
|
||||
|
||||
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
|
||||
{
|
||||
SentencesPtr sentences(new Sentences());
|
||||
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
|
||||
for (size_t i = startInd; i < size(); ++i) {
|
||||
SentencePtr sentence = at(i);
|
||||
sentences->push_back(sentence);
|
||||
}
|
||||
|
||||
coll_.resize(startInd);
|
||||
return sentences;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,9 @@ class Sentence {
|
||||
|
||||
using SentencePtr = std::shared_ptr<Sentence>;
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
class Sentences;
|
||||
using SentencesPtr = std::shared_ptr<Sentences>;
|
||||
|
||||
class Sentences {
|
||||
public:
|
||||
@ -59,6 +62,8 @@ class Sentences {
|
||||
|
||||
void SortByLength();
|
||||
|
||||
SentencesPtr NextMiniBatch(size_t batchsize);
|
||||
|
||||
protected:
|
||||
std::vector<SentencePtr> coll_;
|
||||
size_t taskCounter_;
|
||||
@ -68,5 +73,6 @@ class Sentences {
|
||||
Sentences(const Sentences &) = delete;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@ -12,45 +12,18 @@ void TranslationTask(const God &god, std::shared_ptr<Sentences> sentences, size_
|
||||
Search &search = god.GetSearch();
|
||||
|
||||
try {
|
||||
size_t miniBatch;
|
||||
if (search.GetDeviceInfo().deviceType == CPUDevice) {
|
||||
miniBatch = 1;
|
||||
}
|
||||
else {
|
||||
miniBatch = god.Get<size_t>("mini-batch");
|
||||
}
|
||||
|
||||
Histories allHistories;
|
||||
sentences->SortByLength();
|
||||
|
||||
size_t bunchId = 0;
|
||||
std::shared_ptr<Sentences> decodeSentences(new Sentences(taskCounter, bunchId++));
|
||||
|
||||
for (size_t i = 0; i < sentences->size(); ++i) {
|
||||
decodeSentences->push_back(sentences->at(i));
|
||||
|
||||
if (decodeSentences->size() >= miniBatch) {
|
||||
//cerr << "decodeSentences=" << decodeSentences->GetMaxLength() << endl;
|
||||
assert(decodeSentences->size());
|
||||
std::shared_ptr<Histories> histories = search.Process(god, *decodeSentences);
|
||||
allHistories.Append(*histories.get());
|
||||
|
||||
decodeSentences.reset(new Sentences(taskCounter, bunchId++));
|
||||
}
|
||||
}
|
||||
|
||||
if (decodeSentences->size()) {
|
||||
std::shared_ptr<Histories> histories = search.Process(god, *decodeSentences);
|
||||
allHistories.Append(*histories.get());
|
||||
}
|
||||
|
||||
allHistories.SortByLineNum();
|
||||
|
||||
std::stringstream strm;
|
||||
Printer(god, allHistories, strm);
|
||||
std::shared_ptr<Histories> histories = search.Process(god, *sentences);
|
||||
|
||||
OutputCollector &outputCollector = god.GetOutputCollector();
|
||||
outputCollector.Write(taskCounter, strm.str());
|
||||
for (size_t i = 0; i < histories->size(); ++i) {
|
||||
const History &history = *histories->at(i);
|
||||
size_t lineNum = history.GetLineNum();
|
||||
|
||||
std::stringstream strm;
|
||||
Printer(god, history, strm);
|
||||
|
||||
outputCollector.Write(lineNum, strm.str());
|
||||
}
|
||||
}
|
||||
#ifdef CUDA
|
||||
catch(thrust::system_error &e)
|
||||
|
Loading…
Reference in New Issue
Block a user