From a365bb5ce99135eab29ffe378e0c6c9fb9bf0c1b Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Mon, 7 Feb 2022 08:09:54 -0800 Subject: [PATCH] fix server behaviour --- src/data/text_input.cpp | 4 ++-- src/data/text_input.h | 6 ++---- src/translator/translator.h | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp index b1f4cdd4..196cf421 100644 --- a/src/data/text_input.cpp +++ b/src/data/text_input.cpp @@ -13,7 +13,7 @@ void TextIterator::increment() { } bool TextIterator::equal(TextIterator const& other) const { - return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty()); + return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid()); } const SentenceTuple& TextIterator::dereference() const { @@ -59,7 +59,7 @@ SentenceTuple TextInput::next() { if(tup.size() == files_.size()) // check if each input file provided an example return SentenceTuple(tup); else if(tup.size() == 0) // if no file provided examples we are done - return SentenceTuple(); + return SentenceTupleImpl(); // return an empty tuple if above test does not pass(); else // neither all nor none => we have at least on missing entry ABORT("There are missing entries in the text tuples."); } diff --git a/src/data/text_input.h b/src/data/text_input.h index b08a4fdc..98d991bc 100644 --- a/src/data/text_input.h +++ b/src/data/text_input.h @@ -37,12 +37,10 @@ private: bool maxLengthCrop_{false}; public: - typedef SentenceTuple Sample; - TextInput(std::vector inputs, std::vector> vocabs, Ptr options); virtual ~TextInput() {} - Sample next() override; + SentenceTuple next() override; void shuffle() override {} void reset() override {} @@ -52,7 +50,7 @@ public: // TODO: There are half dozen functions called toBatch(), which are very // similar. Factor them. - batch_ptr toBatch(const std::vector& batchVector) override { + batch_ptr toBatch(const std::vector& batchVector) override { size_t batchSize = batchVector.size(); std::vector sentenceIds; diff --git a/src/translator/translator.h b/src/translator/translator.h index 0621fc8c..75b5070b 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -330,7 +330,7 @@ public: ? convertTsvToLists(input, options_->get("tsv-fields", 1)) : std::vector({input}); auto corpus_ = New(inputs, srcVocabs_, options_); - data::BatchGenerator batchGenerator(corpus_, options_); + data::BatchGenerator batchGenerator(corpus_, options_, nullptr, /*runAsync=*/false); auto collector = New(options_->get("quiet-translation", false)); auto printer = New(options_, trgVocab_);