mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-11 09:54:22 +03:00
parent
011b7b4289
commit
9de8796f33
@ -12,7 +12,6 @@
|
||||
#include "common/filter.h"
|
||||
#include "common/processor/bpe.h"
|
||||
#include "common/utils.h"
|
||||
#include "common/search.h"
|
||||
|
||||
#include "scorer.h"
|
||||
#include "loader_factory.h"
|
||||
@ -233,16 +232,3 @@ void God::CleanUp() {
|
||||
loader.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
Search &God::GetSearch(size_t taskCounter)
|
||||
{
|
||||
Search *obj;
|
||||
obj = search_.get();
|
||||
if (obj == NULL) {
|
||||
obj = new Search(*this, taskCounter);
|
||||
search_.reset(obj);
|
||||
}
|
||||
assert(obj);
|
||||
return *obj;
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include <boost/thread/tss.hpp>
|
||||
|
||||
#include "common/processor/processor.h"
|
||||
#include "common/config.h"
|
||||
@ -22,7 +21,6 @@ class Weights;
|
||||
class Vocab;
|
||||
class Filter;
|
||||
class InputFileStream;
|
||||
class Search;
|
||||
|
||||
class God {
|
||||
public:
|
||||
@ -66,8 +64,6 @@ class God {
|
||||
|
||||
void LoadWeights(const std::string& path);
|
||||
|
||||
Search &GetSearch(size_t taskCounter);
|
||||
|
||||
private:
|
||||
void LoadScorers();
|
||||
void LoadFiltering();
|
||||
@ -93,5 +89,4 @@ class God {
|
||||
std::unique_ptr<InputFileStream> inputStream_;
|
||||
OutputCollector outputCollector_;
|
||||
|
||||
mutable boost::thread_specific_ptr<Search> search_;
|
||||
};
|
||||
|
@ -5,7 +5,11 @@
|
||||
#include "printer.h"
|
||||
|
||||
void TranslationTask(God &god, std::shared_ptr<Sentences> sentences, size_t taskCounter, size_t maxBatchSize) {
|
||||
Search &search = god.GetSearch(taskCounter);
|
||||
thread_local std::unique_ptr<Search> search;
|
||||
if(!search) {
|
||||
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
|
||||
search.reset(new Search(god, taskCounter));
|
||||
}
|
||||
|
||||
try {
|
||||
Histories allHistories;
|
||||
@ -19,7 +23,7 @@ void TranslationTask(God &god, std::shared_ptr<Sentences> sentences, size_t task
|
||||
|
||||
if (decodeSentences->size() >= maxBatchSize) {
|
||||
assert(decodeSentences->size());
|
||||
std::shared_ptr<Histories> histories = search.Decode(god, *decodeSentences);
|
||||
std::shared_ptr<Histories> histories = search->Decode(god, *decodeSentences);
|
||||
allHistories.Append(*histories.get());
|
||||
|
||||
decodeSentences.reset(new Sentences(taskCounter, bunchId++));
|
||||
@ -27,7 +31,7 @@ void TranslationTask(God &god, std::shared_ptr<Sentences> sentences, size_t task
|
||||
}
|
||||
|
||||
if (decodeSentences->size()) {
|
||||
std::shared_ptr<Histories> histories = search.Decode(god, *decodeSentences);
|
||||
std::shared_ptr<Histories> histories = search->Decode(god, *decodeSentences);
|
||||
allHistories.Append(*histories.get());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user