Support tab-separated inputs in marian-server (#649)

* Enable multi-source input in marian-server
* Add function converting multi-line tab-separated textual input

Co-authored-by: Tomasz Dwojak <t.dwojak@amu.edu.pl>
This commit is contained in:
Roman Grundkiewicz 2020-07-26 17:39:12 +01:00 committed by GitHub
parent 9d4cc7b13d
commit 41de4f3d30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 4 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Decoding multi-source models in marian-server with --tsv
- GitHub workflows on Ubuntu, Windows, and MacOS
- LSH indexing to replace short list
- ONNX support for transformer models

View File

@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const {
TextInput::TextInput(std::vector<std::string> inputs,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options)
: DatasetBase(inputs, options), vocabs_(vocabs) {
// note: inputs are automatically stored in the inherited variable named paths_, but these are
: DatasetBase(inputs, options),
vocabs_(vocabs),
maxLength_(options_->get<size_t>("max-length")),
maxLengthCrop_(options_->get<bool>("max-length-crop")) {
// Note: inputs are automatically stored in the inherited variable named paths_, but these are
// texts not paths!
for(const auto& text : paths_)
files_.emplace_back(new std::istringstream(text));
@ -42,6 +45,10 @@ SentenceTuple TextInput::next() {
std::string line;
if(io::getline(*files_[i], line)) {
Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
if(this->maxLengthCrop_ && words.size() > this->maxLength_) {
words.resize(maxLength_);
words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels
}
if(words.empty())
words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
tup.push_back(words);

View File

@ -33,6 +33,9 @@ private:
size_t pos_{0};
size_t maxLength_{0};
bool maxLengthCrop_{false};
public:
typedef SentenceTuple Sample;

View File

@ -1,5 +1,7 @@
#pragma once
#include <string>
#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/shortlist.h"
@ -245,7 +247,11 @@ public:
}
std::string run(const std::string& input) override {
auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_);
// split tab-separated input into fields if necessary
auto inputs = options_->get<bool>("tsv", false)
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
: std::vector<std::string>({input});
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
auto collector = New<StringCollector>();
@ -258,7 +264,6 @@ public:
ThreadPool threadPool_(numDevices_, numDevices_);
for(auto batch : batchGenerator) {
auto task = [=](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@ -287,5 +292,30 @@ public:
auto translations = collector->collect(options_->get<bool>("n-best"));
return utils::join(translations, "\n");
}
private:
// Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
// of sentences from source(s) and target sides, e.g.
// "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
std::vector<std::string> outputFields(numFields);
std::string line;
std::vector<std::string> lineFields(numFields);
std::istringstream inputStream(inputText);
bool first = true;
while(std::getline(inputStream, line)) {
utils::splitTsv(line, lineFields, numFields);
for(size_t i = 0; i < numFields; ++i) {
if(!first)
outputFields[i] += "\n"; // join sentences with a new line sign
outputFields[i] += lineFields[i];
}
if(first)
first = false;
}
return outputFields;
}
};
} // namespace marian