mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Support tab-separated inputs in marian-server (#649)
* Enable multi-source input in marian-server * Add function converting multi-line tab-separated textual input Co-authored-by: Tomasz Dwojak <t.dwojak@amu.edu.pl>
This commit is contained in:
parent
9d4cc7b13d
commit
41de4f3d30
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Decoding multi-source models in marian-server with --tsv
|
||||
- GitHub workflows on Ubuntu, Windows, and MacOS
|
||||
- LSH indexing to replace short list
|
||||
- ONNX support for transformer models
|
||||
|
@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const {
|
||||
TextInput::TextInput(std::vector<std::string> inputs,
|
||||
std::vector<Ptr<Vocab>> vocabs,
|
||||
Ptr<Options> options)
|
||||
: DatasetBase(inputs, options), vocabs_(vocabs) {
|
||||
// note: inputs are automatically stored in the inherited variable named paths_, but these are
|
||||
: DatasetBase(inputs, options),
|
||||
vocabs_(vocabs),
|
||||
maxLength_(options_->get<size_t>("max-length")),
|
||||
maxLengthCrop_(options_->get<bool>("max-length-crop")) {
|
||||
// Note: inputs are automatically stored in the inherited variable named paths_, but these are
|
||||
// texts not paths!
|
||||
for(const auto& text : paths_)
|
||||
files_.emplace_back(new std::istringstream(text));
|
||||
@ -42,6 +45,10 @@ SentenceTuple TextInput::next() {
|
||||
std::string line;
|
||||
if(io::getline(*files_[i], line)) {
|
||||
Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
|
||||
if(this->maxLengthCrop_ && words.size() > this->maxLength_) {
|
||||
words.resize(maxLength_);
|
||||
words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels
|
||||
}
|
||||
if(words.empty())
|
||||
words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
|
||||
tup.push_back(words);
|
||||
|
@ -33,6 +33,9 @@ private:
|
||||
|
||||
size_t pos_{0};
|
||||
|
||||
size_t maxLength_{0};
|
||||
bool maxLengthCrop_{false};
|
||||
|
||||
public:
|
||||
typedef SentenceTuple Sample;
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "data/batch_generator.h"
|
||||
#include "data/corpus.h"
|
||||
#include "data/shortlist.h"
|
||||
@ -245,7 +247,11 @@ public:
|
||||
}
|
||||
|
||||
std::string run(const std::string& input) override {
|
||||
auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_);
|
||||
// split tab-separated input into fields if necessary
|
||||
auto inputs = options_->get<bool>("tsv", false)
|
||||
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
|
||||
: std::vector<std::string>({input});
|
||||
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
|
||||
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
|
||||
|
||||
auto collector = New<StringCollector>();
|
||||
@ -258,7 +264,6 @@ public:
|
||||
ThreadPool threadPool_(numDevices_, numDevices_);
|
||||
|
||||
for(auto batch : batchGenerator) {
|
||||
|
||||
auto task = [=](size_t id) {
|
||||
thread_local Ptr<ExpressionGraph> graph;
|
||||
thread_local std::vector<Ptr<Scorer>> scorers;
|
||||
@ -287,5 +292,30 @@ public:
|
||||
auto translations = collector->collect(options_->get<bool>("n-best"));
|
||||
return utils::join(translations, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
// Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
|
||||
// of sentences from source(s) and target sides, e.g.
|
||||
// "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
|
||||
std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
|
||||
std::vector<std::string> outputFields(numFields);
|
||||
|
||||
std::string line;
|
||||
std::vector<std::string> lineFields(numFields);
|
||||
std::istringstream inputStream(inputText);
|
||||
bool first = true;
|
||||
while(std::getline(inputStream, line)) {
|
||||
utils::splitTsv(line, lineFields, numFields);
|
||||
for(size_t i = 0; i < numFields; ++i) {
|
||||
if(!first)
|
||||
outputFields[i] += "\n"; // join sentences with a new line sign
|
||||
outputFields[i] += lineFields[i];
|
||||
}
|
||||
if(first)
|
||||
first = false;
|
||||
}
|
||||
|
||||
return outputFields;
|
||||
}
|
||||
};
|
||||
} // namespace marian
|
||||
|
Loading…
Reference in New Issue
Block a user