mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 05:14:36 +03:00
168 lines
4.1 KiB
C++
168 lines
4.1 KiB
C++
#ifndef LM_FILTER_THREAD_H
|
|
#define LM_FILTER_THREAD_H
|
|
|
|
#include "util/thread_pool.hh"
|
|
|
|
#include <boost/utility/in_place_factory.hpp>
|
|
|
|
#include <deque>
|
|
#include <stack>
|
|
|
|
namespace lm {
|
|
|
|
template <class OutputBuffer> class ThreadBatch {
|
|
public:
|
|
ThreadBatch() {}
|
|
|
|
void Reserve(size_t size) {
|
|
input_.Reserve(size);
|
|
output_.Reserve(size);
|
|
}
|
|
|
|
// File reading thread.
|
|
InputBuffer &Fill(uint64_t sequence) {
|
|
sequence_ = sequence;
|
|
// Why wait until now to clear instead of after output? free in the same
|
|
// thread as allocated.
|
|
input_.Clear();
|
|
return input_;
|
|
}
|
|
|
|
// Filter worker thread.
|
|
template <class Filter> void CallFilter(Filter &filter) {
|
|
input_.CallFilter(filter, output_);
|
|
}
|
|
|
|
uint64_t Sequence() const { return sequence_; }
|
|
|
|
// File writing thread.
|
|
template <class RealOutput> void Flush(RealOutput &output) {
|
|
output_.Flush(output);
|
|
}
|
|
|
|
private:
|
|
InputBuffer input_;
|
|
OutputBuffer output_;
|
|
|
|
uint64_t sequence_;
|
|
};
|
|
|
|
template <class Batch, class Filter> class FilterWorker {
|
|
public:
|
|
typedef Batch *Request;
|
|
|
|
FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
|
|
|
|
void operator()(Request request) {
|
|
request->CallFilter(filter_);
|
|
done_.Produce(request);
|
|
}
|
|
|
|
private:
|
|
Filter filter_;
|
|
|
|
util::PCQueue<Request> &done_;
|
|
};
|
|
|
|
// There should only be one OutputWorker.
|
|
template <class Batch, class Output> class OutputWorker {
|
|
public:
|
|
typedef Batch *Request;
|
|
|
|
OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
|
|
|
|
void operator()(Request request) {
|
|
assert(request->Sequence() >= base_sequence_);
|
|
// Assemble the output in order.
|
|
uint64_t pos = request->Sequence() - base_sequence_;
|
|
if (pos >= ordering_.size()) {
|
|
ordering_.resize(pos + 1, NULL);
|
|
}
|
|
ordering_[pos] = request;
|
|
while (!ordering_.empty() && ordering_.front()) {
|
|
ordering_.front()->Flush(output_);
|
|
done_.Produce(ordering_.front());
|
|
ordering_.pop_front();
|
|
++base_sequence_;
|
|
}
|
|
}
|
|
|
|
private:
|
|
Output &output_;
|
|
|
|
util::PCQueue<Request> &done_;
|
|
|
|
std::deque<Request> ordering_;
|
|
|
|
uint64_t base_sequence_;
|
|
};
|
|
|
|
template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
|
|
private:
|
|
typedef ThreadBatch<OutputBuffer> Batch;
|
|
|
|
public:
|
|
Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
|
|
: batch_size_(batch_size), queue_size_(queue),
|
|
batches_(queue),
|
|
to_read_(queue),
|
|
output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
|
|
filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
|
|
sequence_(0) {
|
|
for (size_t i = 0; i < queue; ++i) {
|
|
batches_[i].Reserve(batch_size);
|
|
local_read_.push(&batches_[i]);
|
|
}
|
|
NewInput();
|
|
}
|
|
|
|
void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
|
|
input_->AddNGram(ngram, line, output);
|
|
if (input_->Size() == batch_size_) {
|
|
FlushInput();
|
|
NewInput();
|
|
}
|
|
}
|
|
|
|
void Flush() {
|
|
FlushInput();
|
|
while (local_read_.size() < queue_size_) {
|
|
MoveRead();
|
|
}
|
|
NewInput();
|
|
}
|
|
|
|
private:
|
|
void FlushInput() {
|
|
if (input_->Empty()) return;
|
|
filter_.Produce(local_read_.top());
|
|
local_read_.pop();
|
|
if (local_read_.empty()) MoveRead();
|
|
}
|
|
|
|
void NewInput() {
|
|
input_ = &local_read_.top()->Fill(sequence_++);
|
|
}
|
|
|
|
void MoveRead() {
|
|
local_read_.push(to_read_.Consume());
|
|
}
|
|
|
|
const size_t batch_size_;
|
|
const size_t queue_size_;
|
|
|
|
std::vector<Batch> batches_;
|
|
|
|
util::PCQueue<Batch*> to_read_;
|
|
std::stack<Batch*> local_read_;
|
|
util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
|
|
util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
|
|
|
|
uint64_t sequence_;
|
|
InputBuffer *input_;
|
|
};
|
|
|
|
} // namespace lm
|
|
|
|
#endif // LM_FILTER_THREAD_H
|