fix server behaviour

This commit is contained in:
Marcin Junczys-Dowmunt 2022-02-07 08:09:54 -08:00
parent aafe8fb5ca
commit a365bb5ce9
3 changed files with 5 additions and 7 deletions

View File

@ -13,7 +13,7 @@ void TextIterator::increment() {
} }
bool TextIterator::equal(TextIterator const& other) const { bool TextIterator::equal(TextIterator const& other) const {
return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty()); return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid());
} }
const SentenceTuple& TextIterator::dereference() const { const SentenceTuple& TextIterator::dereference() const {
@ -59,7 +59,7 @@ SentenceTuple TextInput::next() {
if(tup.size() == files_.size()) // check if each input file provided an example if(tup.size() == files_.size()) // check if each input file provided an example
return SentenceTuple(tup); return SentenceTuple(tup);
else if(tup.size() == 0) // if no file provided examples we are done else if(tup.size() == 0) // if no file provided examples we are done
return SentenceTuple(); return SentenceTupleImpl(); // return an empty tuple if above test does not pass();
else // neither all nor none => we have at least on missing entry else // neither all nor none => we have at least on missing entry
ABORT("There are missing entries in the text tuples."); ABORT("There are missing entries in the text tuples.");
} }

View File

@ -37,12 +37,10 @@ private:
bool maxLengthCrop_{false}; bool maxLengthCrop_{false};
public: public:
typedef SentenceTuple Sample;
TextInput(std::vector<std::string> inputs, std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options); TextInput(std::vector<std::string> inputs, std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
virtual ~TextInput() {} virtual ~TextInput() {}
Sample next() override; SentenceTuple next() override;
void shuffle() override {} void shuffle() override {}
void reset() override {} void reset() override {}
@ -52,7 +50,7 @@ public:
// TODO: There are half dozen functions called toBatch(), which are very // TODO: There are half dozen functions called toBatch(), which are very
// similar. Factor them. // similar. Factor them.
batch_ptr toBatch(const std::vector<Sample>& batchVector) override { batch_ptr toBatch(const std::vector<SentenceTuple>& batchVector) override {
size_t batchSize = batchVector.size(); size_t batchSize = batchVector.size();
std::vector<size_t> sentenceIds; std::vector<size_t> sentenceIds;

View File

@ -330,7 +330,7 @@ public:
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1)) ? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
: std::vector<std::string>({input}); : std::vector<std::string>({input});
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_); auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_); data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_, nullptr, /*runAsync=*/false);
auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false)); auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
auto printer = New<OutputPrinter>(options_, trgVocab_); auto printer = New<OutputPrinter>(options_, trgVocab_);