fix server behaviour

This commit is contained in:
Marcin Junczys-Dowmunt 2022-02-07 08:09:54 -08:00
parent aafe8fb5ca
commit a365bb5ce9
3 changed files with 5 additions and 7 deletions

View File

@ -13,7 +13,7 @@ void TextIterator::increment() {
}
bool TextIterator::equal(TextIterator const& other) const {
return this->pos_ == other.pos_ || (this->tup_.empty() && other.tup_.empty());
return this->pos_ == other.pos_ || (!this->tup_.valid() && !other.tup_.valid());
}
const SentenceTuple& TextIterator::dereference() const {
@ -59,7 +59,7 @@ SentenceTuple TextInput::next() {
if(tup.size() == files_.size()) // check if each input file provided an example
return SentenceTuple(tup);
else if(tup.size() == 0) // if no file provided examples we are done
return SentenceTuple();
return SentenceTupleImpl(); // return an empty tuple if above test does not pass();
else // neither all nor none => we have at least on missing entry
ABORT("There are missing entries in the text tuples.");
}

View File

@ -37,12 +37,10 @@ private:
bool maxLengthCrop_{false};
public:
typedef SentenceTuple Sample;
TextInput(std::vector<std::string> inputs, std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
virtual ~TextInput() {}
Sample next() override;
SentenceTuple next() override;
void shuffle() override {}
void reset() override {}
@ -52,7 +50,7 @@ public:
// TODO: There are half dozen functions called toBatch(), which are very
// similar. Factor them.
batch_ptr toBatch(const std::vector<Sample>& batchVector) override {
batch_ptr toBatch(const std::vector<SentenceTuple>& batchVector) override {
size_t batchSize = batchVector.size();
std::vector<size_t> sentenceIds;

View File

@ -330,7 +330,7 @@ public:
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
: std::vector<std::string>({input});
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_, nullptr, /*runAsync=*/false);
auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
auto printer = New<OutputPrinter>(options_, trgVocab_);