From c3fb60cbcd4f99ecd51adee835a071af41d77b9e Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Wed, 24 Jun 2020 01:54:27 +0000 Subject: [PATCH] Merged PR 13476: Add LASER reimplementation and code for embeddings sentences This reimplements the LASER encoder from: ``` Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond Mikel Artetxe, Holger Schwenk https://arxiv.org/abs/1812.10464 ``` and adds functionality to embed sentences with any Marian encoder, also different from LASER. Some early attempts to train a transformer model with Encoder-Decoder bottle-neck. This is quite early code, so some code-duplication is to be expected. Nevertheless, it's functional and I would like to have it in master as we will slowly put that into production in various places. I will make the code "nicer" as we go along. --- CMakeLists.txt | 7 + scripts/laser/laser2marian.py | 85 +++++++++++ src/CMakeLists.txt | 1 + src/command/marian_embedder.cpp | 14 ++ src/command/marian_main.cpp | 7 +- src/common/config_parser.cpp | 44 ++++++ src/common/config_parser.h | 3 +- src/common/config_validator.cpp | 4 + src/embedder/embedder.h | 171 +++++++++++++++++++++ src/embedder/vector_collector.cpp | 71 +++++++++ src/embedder/vector_collector.h | 32 ++++ src/layers/loss.h | 2 +- src/models/encoder_decoder.cpp | 1 + src/models/encoder_pooler.h | 217 +++++++++++++++++++++++++++ src/models/laser.h | 71 +++++++++ src/models/model_base.h | 2 +- src/models/model_factory.cpp | 62 +++++++- src/models/model_factory.h | 31 ++++ src/models/pooler.h | 139 +++++++++++++++++ src/models/transformer.h | 43 +++++- src/rnn/cells.h | 2 +- src/tensors/cpu/tensor_operators.cpp | 99 ++++++++---- src/training/graph_group.cpp | 2 +- 23 files changed, 1066 insertions(+), 44 deletions(-) create mode 100644 scripts/laser/laser2marian.py create mode 100644 src/command/marian_embedder.cpp create mode 100644 src/embedder/embedder.h create mode 100644 src/embedder/vector_collector.cpp create mode 100644 src/embedder/vector_collector.h create mode 100644 src/models/encoder_pooler.h create mode 100644 src/models/laser.h create mode 100644 src/models/pooler.h diff --git a/CMakeLists.txt b/CMakeLists.txt index d6c36a8e..0e1a319f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -187,6 +187,13 @@ else(MSVC) set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction") endif(MSVC) +# with gcc 7.0 and above we need to mark fallthrough in switch case statements +# that can be done in comments for backcompat, but CCACHE removes comments. +# -C makes gcc keep comments. +if(USE_CCACHE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -C") +endif() + ############################################################################### # Downloading SentencePiece if requested and set to compile with it. # Requires all the dependencies imposed by SentencePiece diff --git a/scripts/laser/laser2marian.py b/scripts/laser/laser2marian.py new file mode 100644 index 00000000..e5a65d90 --- /dev/null +++ b/scripts/laser/laser2marian.py @@ -0,0 +1,85 @@ +import numpy as np +import sys +import yaml +import argparse + +import torch + +parser = argparse.ArgumentParser(description='Convert LASER model to Marian weight file.') +parser.add_argument('--laser', help='Path to LASER PyTorch model', required=True) +parser.add_argument('--marian', help='Output path for Marian weight file', required=True) +args = parser.parse_args() + +laser = torch.load(args.laser) + +config = dict() +config["type"] = "laser" +config["input-types"] = ["sequence"] +config["dim-vocabs"] = [laser["params"]["num_embeddings"]] + +config["version"] = "laser2marian.py conversion" + +config["enc-depth"] = laser["params"]["num_layers"] +config["enc-cell"] = "lstm" +config["dim-emb"] = laser["params"]["embed_dim"] +config["dim-rnn"] = laser["params"]["hidden_size"] + +yaml.dump(laser["dictionary"], open(args.marian + ".vocab.yml", "w")) + +marianModel = dict() + +def transposeOrder(mat): + matT = np.transpose(mat) # just a view with changed row order + return matT.flatten(order="C").reshape(matT.shape) # force row order change and reshape + +def convert(pd, srcs, trg, transpose=True, bias=False, lstm=False): + num = pd[srcs[0]].detach().numpy() + for i in range(1, len(srcs)): + num += pd[srcs[i]].detach().numpy() + + out = num + if bias: + num = np.atleast_2d(num) + else: + if transpose: + num = transposeOrder(num) # transpose with row order change + + if lstm: # different order in pytorch than marian + stateDim = int(num.shape[-1] / 4) + i = np.copy(num[:, 0*stateDim:1*stateDim]) + f = np.copy(num[:, 1*stateDim:2*stateDim]) + num[:, 0*stateDim:1*stateDim] = f + num[:, 1*stateDim:2*stateDim] = i + + marianModel[trg] = num + +for k in laser: + print(k) + +for k in laser["model"]: + print(k, laser["model"][k].shape) + +convert(laser["model"], ["embed_tokens.weight"], "encoder_Wemb", transpose=False) +for i in range(laser["params"]["num_layers"]): + convert(laser["model"], [f"lstm.weight_ih_l{i}"], f"encoder_lstm_l{i}_W", lstm=True) + convert(laser["model"], [f"lstm.weight_hh_l{i}"], f"encoder_lstm_l{i}_U", lstm=True) + convert(laser["model"], [f"lstm.bias_ih_l{i}", f"lstm.bias_hh_l{i}"], f"encoder_lstm_l{i}_b", bias=True, lstm=True) # needs to be summed! + + convert(laser["model"], [f"lstm.weight_ih_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_W", lstm=True) + convert(laser["model"], [f"lstm.weight_hh_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_U", lstm=True) + convert(laser["model"], [f"lstm.bias_ih_l{i}_reverse", f"lstm.bias_hh_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_b", bias=True, lstm=True) # needs to be summed! + +for m in marianModel: + print(m, marianModel[m].shape) + +configYamlStr = yaml.dump(config, default_flow_style=False) +desc = list(configYamlStr) +npDesc = np.chararray((len(desc),)) +npDesc[:] = desc +npDesc.dtype = np.int8 +marianModel["special:model.yml"] = npDesc + +print("\nMarian config:") +print(configYamlStr) +print("Saving Marian model to %s" % (args.marian,)) +np.savez(args.marian, **marianModel) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cc0d3d44..af376d8a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -83,6 +83,7 @@ add_library(marian STATIC models/transformer_stub.cpp rescorer/score_collector.cpp + embedder/vector_collector.cpp translator/history.cpp translator/output_collector.cpp diff --git a/src/command/marian_embedder.cpp b/src/command/marian_embedder.cpp new file mode 100644 index 00000000..6754fc6f --- /dev/null +++ b/src/command/marian_embedder.cpp @@ -0,0 +1,14 @@ +#include "marian.h" + +#include "models/model_task.h" +#include "embedder/embedder.h" +#include "common/timer.h" + +int main(int argc, char** argv) { + using namespace marian; + + auto options = parseOptions(argc, argv, cli::mode::embedding); + New>(options)->run(); + + return 0; +} diff --git a/src/command/marian_main.cpp b/src/command/marian_main.cpp index a2a19145..dcdea466 100644 --- a/src/command/marian_main.cpp +++ b/src/command/marian_main.cpp @@ -11,6 +11,7 @@ // train // decode // score +// embed // vocab // convert // Currently, marian_server is not supported, since it is a special use case with lots of extra dependencies. @@ -24,6 +25,9 @@ #define main mainScorer #include "marian_scorer.cpp" #undef main +#define main mainEmbedder +#include "marian_embedder.cpp" +#undef main #define main mainVocab #include "marian_vocab.cpp" #undef main @@ -44,9 +48,10 @@ int main(int argc, char** argv) { if(cmd == "train") return mainTrainer(argc, argv); else if(cmd == "decode") return mainDecoder(argc, argv); else if (cmd == "score") return mainScorer(argc, argv); + else if (cmd == "embed") return mainEmbedder(argc, argv); else if (cmd == "vocab") return mainVocab(argc, argv); else if (cmd == "convert") return mainConv(argc, argv); - std::cerr << "Command must be train, decode, score, vocab, or convert." << std::endl; + std::cerr << "Command must be train, decode, score, embed, vocab, or convert." << std::endl; exit(1); } else return mainTrainer(argc, argv); diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 2338cb45..14cd109b 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -91,6 +91,9 @@ ConfigParser::ConfigParser(cli::mode mode) case cli::mode::scoring: addOptionsScoring(cli_); break; + case cli::mode::embedding: + addOptionsEmbedding(cli_); + break; default: ABORT("wrong CLI mode"); break; @@ -235,6 +238,8 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { 8); cli.add("--transformer-no-projection", "Omit linear projection after multi-head attention (transformer)"); + cli.add("--transformer-pool", + "Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)"); cli.add("--transformer-dim-ffn", "Size of position-wise feed-forward network (transformer)", 2048); @@ -705,6 +710,45 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) { // clang-format on } +void ConfigParser::addOptionsEmbedding(cli::CLIWrapper& cli) { + auto previous_group = cli.switchGroup("Scorer options"); + + // clang-format off + cli.add("--no-reload", + "Do not load existing model specified in --model arg"); + // TODO: move options like vocabs and train-sets to a separate procedure as they are defined twice + cli.add>("--train-sets,-t", + "Paths to corpora to be scored: source target"); + cli.add("--output,-o", + "Path to output file, stdout by default", + "stdout"); + cli.add>("--vocabs,-v", + "Paths to vocabulary files have to correspond to --train-sets. " + "If this parameter is not supplied we look for vocabulary files source.{yml,json} and target.{yml,json}. " + "If these files do not exists they are created"); + + cli.add("--compute-similarity", + "Expect two inputs and compute cosine similarity instead of outputting embedding vector"); + cli.add("--binary", + "Output vectors as binary floats"); + + addSuboptionsInputLength(cli); + addSuboptionsTSV(cli); + addSuboptionsDevices(cli); + addSuboptionsBatching(cli); + + cli.add("--optimize", + "Optimize speed aggressively sacrificing memory or precision"); + cli.add("--fp16", + "Shortcut for mixed precision inference with float16, corresponds to: --precision float16"); + cli.add>("--precision", + "Mixed precision for inference, set parameter type in expression graph", + {"float32"}); + + cli.switchGroup(previous_group); + // clang-format on +} + void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) { // clang-format off cli.add>("--devices,-d", diff --git a/src/common/config_parser.h b/src/common/config_parser.h index 798ec622..933bbb59 100644 --- a/src/common/config_parser.h +++ b/src/common/config_parser.h @@ -14,7 +14,7 @@ namespace marian { namespace cli { -enum struct mode { training, translation, scoring, server }; +enum struct mode { training, translation, scoring, server, embedding }; } // namespace cli /** @@ -129,6 +129,7 @@ private: void addOptionsValidation(cli::CLIWrapper&); void addOptionsTranslation(cli::CLIWrapper&); void addOptionsScoring(cli::CLIWrapper&); + void addOptionsEmbedding(cli::CLIWrapper&); void addAliases(cli::CLIWrapper&); diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp index cf46f738..44d7eec6 100644 --- a/src/common/config_validator.cpp +++ b/src/common/config_validator.cpp @@ -27,6 +27,10 @@ void ConfigValidator::validateOptions(cli::mode mode) const { validateOptionsParallelData(); validateOptionsScoring(); break; + case cli::mode::embedding: + validateOptionsParallelData(); + validateOptionsScoring(); + break; case cli::mode::training: validateOptionsParallelData(); validateOptionsTraining(); diff --git a/src/embedder/embedder.h b/src/embedder/embedder.h new file mode 100644 index 00000000..cb78aec0 --- /dev/null +++ b/src/embedder/embedder.h @@ -0,0 +1,171 @@ +#pragma once + +#include "marian.h" + +#include "common/config.h" +#include "common/options.h" +#include "data/batch_generator.h" +#include "data/corpus.h" +#include "data/corpus_nbest.h" +#include "models/costs.h" +#include "models/model_task.h" +#include "embedder/vector_collector.h" +#include "training/scheduler.h" +#include "training/validator.h" + +namespace marian { + +using namespace data; + +/* + * The tool is used to create output sentence embeddings from available + * Marian encoders. With --compute-similiarity and can return the cosine + * similarity between two sentences provided from two sources. + */ +class Embedder { +private: + Ptr model_; + +public: + Embedder(Ptr options) + : model_(createModelFromOptions(options, models::usage::embedding)) {} + + void load(Ptr graph, const std::string& modelFile) { + model_->load(graph, modelFile); + } + + Expr build(Ptr graph, Ptr batch) { + auto embedder = std::dynamic_pointer_cast(model_); + ABORT_IF(!embedder, "Could not cast to EncoderPooler"); + return embedder->apply(graph, batch, /*clearGraph=*/true); + } +}; + +/* + * Actual Embed task. @TODO: this should be simplified in the future. + */ +template +class Embed : public ModelTask { +private: + Ptr options_; + Ptr corpus_; + std::vector> graphs_; + std::vector> models_; + +public: + Embed(Ptr options) : options_(options) { + + options_ = options_->with("inference", true, + "shuffle", "none"); + + // if a similarity is computed then double the input types and vocabs for + // the two encoders that are used in the model. + if(options->get("compute-similarity")) { + auto vInputTypes = options_->get>("input-types"); + auto vVocabs = options_->get>("vocabs"); + auto vDimVocabs = options_->get>("dim-vocabs"); + + vInputTypes.push_back(vInputTypes.back()); + vVocabs.push_back(vVocabs.back()); + vDimVocabs.push_back(vDimVocabs.back()); + + options_ = options_->with("input-types", vInputTypes, + "vocabs", vVocabs, + "dim-vocabs", vDimVocabs); + } + + corpus_ = New(options_); + corpus_->prepare(); + + auto devices = Config::getDevices(options_); + + for(auto device : devices) { + auto graph = New(true); + + auto precison = options_->get>("precision", {"float32"}); + graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph + graph->setDevice(device); + graph->getBackend()->setClip(options_->get("clip-gemm")); + if (device.type == DeviceType::cpu) { + graph->getBackend()->setOptimized(options_->get("optimize")); + } + + graph->reserveWorkspaceMB(options_->get("workspace")); + graphs_.push_back(graph); + } + + auto modelFile = options_->get("model"); + + models_.resize(graphs_.size()); + ThreadPool pool(graphs_.size(), graphs_.size()); + for(size_t i = 0; i < graphs_.size(); ++i) { + pool.enqueue( + [=](size_t j) { + models_[j] = New(options_); + models_[j]->load(graphs_[j], modelFile); + }, + i); + } + } + + void run() override { + LOG(info, "Embedding"); + timer::Timer timer; + + auto batchGenerator = New>(corpus_, options_); + batchGenerator->prepare(); + + auto output = New(options_); + + size_t batchId = 0; + std::mutex smutex; + { + ThreadPool pool(graphs_.size(), graphs_.size()); + + for(auto batch : *batchGenerator) { + auto task = [=, &smutex](size_t id) { + thread_local Ptr graph; + thread_local Ptr builder; + + if(!graph) { + graph = graphs_[id % graphs_.size()]; + builder = models_[id % graphs_.size()]; + } + + auto embeddings = builder->build(graph, batch); + graph->forward(); + + std::vector sentVectors; + embeddings->val()->get(sentVectors); + + // collect embedding vector per sentence. + // if we compute similarities this is only one similarity per sentence pair. + for(size_t i = 0; i < batch->size(); ++i) { + int embSize = embeddings->shape()[-1]; + int beg = i * embSize; + int end = (i + 1) * embSize; + std::vector sentVector(sentVectors.begin() + beg, sentVectors.begin() + end); + output->Write((long)batch->getSentenceIds()[i], + sentVector); + } + + // progress heartbeat for MS-internal Philly compute cluster + // otherwise this job may be killed prematurely if no log for 4 hrs + if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster + && id % 1000 == 0) // hard beat once every 1000 batches + { + auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches + fprintf(stderr, "PROGRESS: %.2f%%\n", progress); + fflush(stderr); + } + }; + + pool.enqueue(task, batchId++); + } + } + LOG(info, "Total time: {:.5f}s wall", timer.elapsed()); + } + +}; + +} // namespace marian diff --git a/src/embedder/vector_collector.cpp b/src/embedder/vector_collector.cpp new file mode 100644 index 00000000..c1caf2f7 --- /dev/null +++ b/src/embedder/vector_collector.cpp @@ -0,0 +1,71 @@ +#include "embedder/vector_collector.h" + +#include "common/logging.h" +#include "common/utils.h" + +#include +#include + +namespace marian { + +// This class manages multi-threaded writing of embedded vectors to stdout or an output file. +// It will either output string versions of float vectors or binary equal length versions depending +// on its binary_ flag. + +VectorCollector::VectorCollector(const Ptr& options) + : nextId_(0), binary_{options->get("binary", false)} { + if(options->get("output") == "stdout") + outStrm_.reset(new std::ostream(std::cout.rdbuf())); + else + outStrm_.reset(new io::OutputFileStream(options->get("output"))); + } + +void VectorCollector::Write(long id, const std::vector& vec) { + std::lock_guard lock(mutex_); + if(id == nextId_) { + WriteVector(vec); + + ++nextId_; + + typename Outputs::const_iterator iter, iterNext; + iter = outputs_.begin(); + while(iter != outputs_.end()) { + long currId = iter->first; + + if(currId == nextId_) { + // 1st element in the map is the next + WriteVector(iter->second); + + ++nextId_; + + // delete current record, move iter on 1 + iterNext = iter; + ++iterNext; + outputs_.erase(iter); + iter = iterNext; + } else { + // not the next. stop iterating + assert(nextId_ < currId); + break; + } + } + + } else { + // save for later + outputs_[id] = vec; + } +} + +void VectorCollector::WriteVector(const std::vector& vec) { + if(binary_) { + outStrm_->write((char*)vec.data(), vec.size() * sizeof(float)); + } else { + std::stringstream ss; + ss << std::fixed << std::setprecision(8); + for(auto v : vec) + *outStrm_ << v << " "; + *outStrm_ << std::endl; + } +} + +} // namespace marian diff --git a/src/embedder/vector_collector.h b/src/embedder/vector_collector.h new file mode 100644 index 00000000..b7f57c6c --- /dev/null +++ b/src/embedder/vector_collector.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/options.h" +#include "common/definitions.h" +#include "common/file_stream.h" + +#include +#include + +namespace marian { + +// This class manages multi-threaded writing of embedded vectors to stdout or an output file. +// It will either output string versions of float vectors or binary equal length versions depending +// on its binary_ flag. +class VectorCollector { +public: + VectorCollector(const Ptr& options); + virtual void Write(long id, const std::vector& vec); + +protected: + long nextId_{0}; + UPtr outStrm_; + bool binary_; // output binary floating point vectors if set + + std::mutex mutex_; + + typedef std::map> Outputs; + Outputs outputs_; + + virtual void WriteVector(const std::vector& vec); +}; +} // namespace marian diff --git a/src/layers/loss.h b/src/layers/loss.h index 315eda38..7dc6f469 100755 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -416,7 +416,7 @@ protected: ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors"); ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete. - // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks, mask again to eliminate padding (might be obsolete) + // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete) auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32); auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 09856a9c..e57e2ead 100755 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -51,6 +51,7 @@ EncoderDecoder::EncoderDecoder(Ptr graph, Ptr options) modelFeatures_.insert("transformer-tied-layers"); modelFeatures_.insert("transformer-guided-alignment-layer"); modelFeatures_.insert("transformer-train-position-embeddings"); + modelFeatures_.insert("transformer-pool"); modelFeatures_.insert("bert-train-type-embeddings"); modelFeatures_.insert("bert-type-vocab-size"); diff --git a/src/models/encoder_pooler.h b/src/models/encoder_pooler.h new file mode 100644 index 00000000..16e45ea2 --- /dev/null +++ b/src/models/encoder_pooler.h @@ -0,0 +1,217 @@ +#pragma once + +#include "marian.h" + +#include "models/encoder.h" +#include "models/pooler.h" +#include "models/model_base.h" +#include "models/states.h" + +// @TODO: this introduces functionality to use LASER in Marian for the filtering workflow or for use in MS-internal +// COSMOS server-farm. There is a lot of code duplication with Classifier and EncoderDecoder and this needs to be fixed. +// This will be done after the new layer system has been finished. + +namespace marian { + +/** + * Combines sequence encoders with generic poolers + * Can be used to train sequence poolers like language detection, BERT-next-sentence-prediction etc. + * Already has support for multi-objective training. + * + * @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/pooler + * multi-objective training. + */ +class EncoderPoolerBase : public models::IModel { +public: + virtual ~EncoderPoolerBase() {} + + virtual void load(Ptr graph, + const std::string& name, + bool markedReloaded = true) override + = 0; + + virtual void mmap(Ptr graph, + const void* ptr, + bool markedReloaded = true) + = 0; + + virtual void save(Ptr graph, + const std::string& name, + bool saveTranslatorConfig = false) override + = 0; + + virtual void clear(Ptr graph) override = 0; + + virtual Expr apply(Ptr, Ptr, bool) = 0; + + virtual Logits build(Ptr graph, + Ptr batch, + bool clearGraph = true) override { + ABORT("Poolers cannot produce Logits"); + }; + + virtual Logits build(Ptr graph, + Ptr batch, + bool clearGraph = true) { + ABORT("Poolers cannot produce Logits"); + } + + virtual Ptr getOptions() = 0; +}; + +class EncoderPooler : public EncoderPoolerBase { +protected: + Ptr options_; + + std::string prefix_; + + std::vector> encoders_; + std::vector> poolers_; + + bool inference_{true}; + + std::set modelFeatures_; + + Config::YamlNode getModelParameters() { + Config::YamlNode modelParams; + auto clone = options_->cloneToYamlNode(); + for(auto& key : modelFeatures_) + modelParams[key] = clone[key]; + + if(options_->has("original-type")) + modelParams["type"] = clone["original-type"]; + + modelParams["version"] = buildVersion(); + return modelParams; + } + + std::string getModelParametersAsString() { + auto yaml = getModelParameters(); + YAML::Emitter out; + cli::OutputYaml(yaml, out); + return std::string(out.c_str()); + } + +public: + typedef data::Corpus dataset_type; + + // @TODO: lots of code-duplication with EncoderDecoder + EncoderPooler(Ptr options) + : options_(options), + prefix_(options->get("prefix", "")), + inference_(options->get("inference", false)) { + modelFeatures_ = {"type", + "dim-vocabs", + "dim-emb", + "dim-rnn", + "enc-cell", + "enc-type", + "enc-cell-depth", + "enc-depth", + "dec-depth", + "dec-cell", + "dec-cell-base-depth", + "dec-cell-high-depth", + "skip", + "layer-normalization", + "right-left", + "input-types", + "special-vocab", + "tied-embeddings", + "tied-embeddings-src", + "tied-embeddings-all"}; + + modelFeatures_.insert("transformer-heads"); + modelFeatures_.insert("transformer-no-projection"); + modelFeatures_.insert("transformer-dim-ffn"); + modelFeatures_.insert("transformer-ffn-depth"); + modelFeatures_.insert("transformer-ffn-activation"); + modelFeatures_.insert("transformer-dim-aan"); + modelFeatures_.insert("transformer-aan-depth"); + modelFeatures_.insert("transformer-aan-activation"); + modelFeatures_.insert("transformer-aan-nogate"); + modelFeatures_.insert("transformer-preprocess"); + modelFeatures_.insert("transformer-postprocess"); + modelFeatures_.insert("transformer-postprocess-emb"); + modelFeatures_.insert("transformer-decoder-autoreg"); + modelFeatures_.insert("transformer-tied-layers"); + modelFeatures_.insert("transformer-guided-alignment-layer"); + modelFeatures_.insert("transformer-train-position-embeddings"); + modelFeatures_.insert("transformer-pool"); + + modelFeatures_.insert("bert-train-type-embeddings"); + modelFeatures_.insert("bert-type-vocab-size"); + + modelFeatures_.insert("ulr"); + modelFeatures_.insert("ulr-trainable-transformation"); + modelFeatures_.insert("ulr-dim-emb"); + modelFeatures_.insert("lemma-dim-emb"); + } + + virtual Ptr getOptions() override { return options_; } + + std::vector>& getEncoders() { return encoders_; } + std::vector>& getPoolers() { return poolers_; } + + void push_back(Ptr encoder) { encoders_.push_back(encoder); } + void push_back(Ptr pooler) { poolers_.push_back(pooler); } + + void load(Ptr graph, + const std::string& name, + bool markedReloaded) override { + graph->load(name, markedReloaded && !opt("ignore-model-config", false)); + } + + void mmap(Ptr graph, + const void* ptr, + bool markedReloaded) override { + graph->mmap(ptr, markedReloaded && !opt("ignore-model-config", false)); + } + + void save(Ptr graph, + const std::string& name, + bool /*saveModelConfig*/) override { + LOG(info, "Saving model weights and runtime parameters to {}", name); + graph->save(name , getModelParametersAsString()); + } + + void clear(Ptr graph) override { + graph->clear(); + + for(auto& enc : encoders_) + enc->clear(); + for(auto& pooler : poolers_) + pooler->clear(); + } + + template + T opt(const std::string& key) { + return options_->get(key); + } + + template + T opt(const std::string& key, const T& def) { + return options_->get(key, def); + } + + template + void set(std::string key, T value) { + options_->set(key, value); + } + + /*********************************************************************/ + + virtual Expr apply(Ptr graph, Ptr batch, bool clearGraph) override { + if(clearGraph) + clear(graph); + + std::vector> encoderStates; + for(auto& encoder : encoders_) + encoderStates.push_back(encoder->build(graph, batch)); + + ABORT_IF(poolers_.size() != 1, "Expected exactly one pooler"); + return poolers_[0]->apply(graph, batch, encoderStates); + } +}; + +} // namespace marian diff --git a/src/models/laser.h b/src/models/laser.h new file mode 100644 index 00000000..1bffe47b --- /dev/null +++ b/src/models/laser.h @@ -0,0 +1,71 @@ +#pragma once + +#include "marian.h" + +#include "layers/constructors.h" +#include "rnn/constructors.h" + +namespace marian { + +// Re-implements the LASER BiLSTM encoder from: +// Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond +// Mikel Artetxe, Holger Schwenk +// https://arxiv.org/abs/1812.10464 + +class EncoderLaser : public EncoderBase { + using EncoderBase::EncoderBase; + +public: + Expr applyEncoderRNN(Ptr graph, + Expr embeddings, + Expr mask) { + int depth = opt("enc-depth"); + float dropoutRnn = inference_ ? 0 : opt("dropout-rnn"); + + Expr output = embeddings; + + auto applyRnn = [&](int layer, rnn::dir direction, Expr input, Expr mask) { + + std::string paramPrefix = prefix_ + "_" + opt("enc-cell"); + paramPrefix += "_l" + std::to_string(layer); + if(direction == rnn::dir::backward) + paramPrefix += "_reverse"; + + auto rnnFactory = rnn::rnn() + ("type", opt("enc-cell")) + ("direction", (int)direction) + ("dimInput", input->shape()[-1]) + ("dimState", opt("dim-rnn")) + ("dropout", dropoutRnn) + ("layer-normalization", opt("layer-normalization")) + ("skip", opt("skip")) + .push_back(rnn::cell()("prefix", paramPrefix)); + + return rnnFactory.construct(graph)->transduce(input, mask); + }; + + for(int i = 0; i < depth; ++i) { + output = concatenate({applyRnn(i, rnn::dir::forward, output, mask), + applyRnn(i, rnn::dir::backward, output, mask)}, + /*axis =*/ -1); + } + + return output; + } + + virtual Ptr build(Ptr graph, + Ptr batch) override { + graph_ = graph; + // select embeddings that occur in the batch + Expr batchEmbeddings, batchMask; std::tie + (batchEmbeddings, batchMask) = getEmbeddingLayer()->apply((*batch)[batchIndex_]); + + Expr context = applyEncoderRNN(graph_, batchEmbeddings, batchMask); + + return New(context, batchMask, batch); + } + + void clear() override {} +}; + +} \ No newline at end of file diff --git a/src/models/model_base.h b/src/models/model_base.h index 5f76b380..09f3b734 100644 --- a/src/models/model_base.h +++ b/src/models/model_base.h @@ -8,7 +8,7 @@ namespace marian { namespace models { -enum struct usage { raw, training, scoring, translation }; +enum struct usage { raw, training, scoring, translation, embedding }; } } // namespace marian diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp index 3d9e1e27..586a36de 100755 --- a/src/models/model_factory.cpp +++ b/src/models/model_factory.cpp @@ -10,6 +10,7 @@ #include "models/amun.h" #include "models/nematus.h" #include "models/s2s.h" +#include "models/laser.h" #include "models/transformer_factory.h" #ifdef CUDNN @@ -29,6 +30,9 @@ namespace models { Ptr EncoderFactory::construct(Ptr graph) { if(options_->get("type") == "s2s") return New(graph, options_); + + if(options_->get("type") == "laser" || options_->get("type") == "laser-sim") + return New(graph, options_); #ifdef CUDNN if(options_->get("type") == "char-s2s") @@ -61,6 +65,17 @@ Ptr ClassifierFactory::construct(Ptr graph) { ABORT("Unknown classifier type"); } +Ptr PoolerFactory::construct(Ptr graph) { + if(options_->get("type") == "max-pooler") + return New(graph, options_); + if(options_->get("type") == "slice-pooler") + return New(graph, options_); + else if(options_->get("type") == "sim-pooler") + return New(graph, options_); + else + ABORT("Unknown pooler type"); +} + Ptr EncoderDecoderFactory::construct(Ptr graph) { Ptr encdec; if(options_->get("type") == "amun") @@ -97,9 +112,54 @@ Ptr EncoderClassifierFactory::construct(Ptr graph) { return enccls; } +Ptr EncoderPoolerFactory::construct(Ptr graph) { + Ptr encpool = New(options_); + + for(auto& ef : encoders_) + encpool->push_back(ef(options_).construct(graph)); + + for(auto& pl : poolers_) + encpool->push_back(pl(options_).construct(graph)); + + return encpool; +} + Ptr createBaseModelByType(std::string type, usage use, Ptr options) { Ptr graph = nullptr; // graph unknown at this stage // clang-format off + + if(use == usage::embedding) { // hijacking an EncoderDecoder model for embedding only + int dimVocab = options->get>("dim-vocabs")[0]; + + Ptr newOptions; + if(options->get("compute-similarity")) { + newOptions = options->with("usage", use, + "original-type", type, + "input-types", std::vector({"sequence", "sequence"}), + "dim-vocabs", std::vector(2, dimVocab)); + } else { + newOptions = options->with("usage", use, + "original-type", type, + "input-types", std::vector({"sequence"}), + "dim-vocabs", std::vector(1, dimVocab)); + } + + auto res = New(newOptions); + if(options->get("compute-similarity")) { + res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph)); + res->push_back(models::encoder(newOptions->with("index", 1)).construct(graph)); + res->push_back(New(graph, newOptions->with("type", "sim-pooler"))); + } else { + res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph)); + if(type == "laser") + res->push_back(New(graph, newOptions->with("type", "max-pooler"))); + else + res->push_back(New(graph, newOptions->with("type", "slice-pooler"))); + } + + return res; + } + if(type == "s2s" || type == "amun" || type == "nematus") { return models::encoder_decoder(options->with( "usage", use, @@ -313,7 +373,7 @@ Ptr createModelFromOptions(Ptr options, usage use) { else ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type); } - else if (use == usage::raw) + else if (use == usage::raw || use == usage::embedding) return baseModel; else ABORT("'Usage' parameter must be 'translation' or 'raw'"); diff --git a/src/models/model_factory.h b/src/models/model_factory.h index 1840df8f..5403b966 100755 --- a/src/models/model_factory.h +++ b/src/models/model_factory.h @@ -5,6 +5,7 @@ #include "layers/factory.h" #include "models/encoder_decoder.h" #include "models/encoder_classifier.h" +#include "models/encoder_pooler.h" namespace marian { namespace models { @@ -33,6 +34,14 @@ public: typedef Accumulator classifier; +class PoolerFactory : public Factory { + using Factory::Factory; +public: + virtual Ptr construct(Ptr graph); +}; + +typedef Accumulator pooler; + class EncoderDecoderFactory : public Factory { using Factory::Factory; private: @@ -77,6 +86,28 @@ public: typedef Accumulator encoder_classifier; +class EncoderPoolerFactory : public Factory { + using Factory::Factory; +private: + std::vector encoders_; + std::vector poolers_; + +public: + Accumulator push_back(encoder enc) { + encoders_.push_back(enc); + return Accumulator(*this); + } + + Accumulator push_back(pooler cls) { + poolers_.push_back(cls); + return Accumulator(*this); + } + + virtual Ptr construct(Ptr graph); +}; + +typedef Accumulator encoder_pooler; + Ptr createBaseModelByType(std::string type, usage, Ptr options); Ptr createModelFromOptions(Ptr options, usage); diff --git a/src/models/pooler.h b/src/models/pooler.h new file mode 100644 index 00000000..572c0ab3 --- /dev/null +++ b/src/models/pooler.h @@ -0,0 +1,139 @@ +#pragma once + +#include "marian.h" +#include "models/states.h" +#include "layers/constructors.h" +#include "layers/factory.h" + +namespace marian { + +/** + * Simple base class for Poolers to be used in EncoderPooler framework + * A pooler takes a encoder state (contextual word embeddings) and produces + * a single sentence embedding. + */ +class PoolerBase : public LayerBase { + using LayerBase::LayerBase; + +protected: + const std::string prefix_{"pooler"}; + const bool inference_{false}; + const size_t batchIndex_{0}; + +public: + PoolerBase(Ptr graph, Ptr options) + : LayerBase(graph, options), + prefix_(options->get("prefix", "pooler")), + inference_(options->get("inference", true)), + batchIndex_(options->get("index", 1)) {} // assume that training input has batch index 0 and labels has 1 + + virtual ~PoolerBase() {} + + virtual Expr apply(Ptr, Ptr, const std::vector>&) = 0; + + template + T opt(const std::string& key) const { + return options_->get(key); + } + + // Should be used to clear any batch-wise temporary objects if present + virtual void clear() = 0; +}; + +/** + * Pool encoder state (contextual word embeddings) via max-pooling along sentence-length dimension. + */ +class MaxPooler : public PoolerBase { +public: + MaxPooler(Ptr graph, Ptr options) + : PoolerBase(graph, options) {} + + Expr apply(Ptr graph, Ptr batch, const std::vector>& encoderStates) override { + ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state"); + + auto context = encoderStates[0]->getContext(); + auto batchMask = encoderStates[0]->getMask(); + + // do a max pool here + Expr logMask = (1.f - batchMask) * -9999.f; + Expr maxPool = max(context * batchMask + logMask, /*axis=*/-3); + + return maxPool; + } + + void clear() override {} + +}; + +/** + * Pool encoder state (contextual word embeddings) by selecting 1st embedding along sentence-length dimension. + */ +class SlicePooler : public PoolerBase { +public: + SlicePooler(Ptr graph, Ptr options) + : PoolerBase(graph, options) {} + + Expr apply(Ptr graph, Ptr batch, const std::vector>& encoderStates) override { + ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state"); + + auto context = encoderStates[0]->getContext(); + auto batchMask = encoderStates[0]->getMask(); + + // Corresponds to the way we do this in transformer.h + // @TODO: unify this better, this is currently hacky + Expr slicePool = slice(context * batchMask, /*axis=*/-3, 0); + + return slicePool; + } + + void clear() override {} + +}; + +/** + * Not really a pooler but abusing the interface to compute a similarity of two pooled states + */ +class SimPooler : public PoolerBase { +public: + SimPooler(Ptr graph, Ptr options) + : PoolerBase(graph, options) {} + + Expr apply(Ptr graph, Ptr batch, const std::vector>& encoderStates) override { + ABORT_IF(encoderStates.size() != 2, "SimPooler expects exactly two encoder states"); + + std::vector vecs; + for(auto encoderState : encoderStates) { + auto context = encoderState->getContext(); + auto batchMask = encoderState->getMask(); + + Expr pool; + auto type = options_->get("original-type"); + if(type == "laser") { + // LASER models do a max pool here + Expr logMask = (1.f - batchMask) * -9999.f; + pool = max(context * batchMask + logMask, /*axis=*/-3); + } else if(type == "transformer") { + // Our own implementation in transformer.h uses a slice of the first element + pool = slice(context, -3, 0); + } else { + // @TODO: make SimPooler take Pooler objects as arguments then it won't need to know this. + ABORT("Don't know what type of pooler to use for model type {}", type); + } + + vecs.push_back(pool); + } + + auto scalars = scalar_product(vecs[0], vecs[1], /*axis*/-1); + auto length1 = sqrt(sum(square(vecs[0]), /*axis=*/-1)); + auto length2 = sqrt(sum(square(vecs[1]), /*axis=*/-1)); + + auto cosine = scalars / ( length1 * length2 ); + + return cosine; + } + + void clear() override {} + +}; + +} \ No newline at end of file diff --git a/src/models/transformer.h b/src/models/transformer.h index 4fea94d0..32711d36 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -328,6 +328,30 @@ public: return output; } + // Reduce the encoder to a single sentence vector, here we just take the contextual embedding of the first word per sentence + // Replaces cross-attention in LASER-like models + Expr LayerPooling(std::string prefix, + Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] + const Expr& values, // [-4: beam depth=1, -3: batch size, -2: max length (src or trg), -1: vector dim] + const Expr& mask) { // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] + int dimModel = input->shape()[-1]; + auto output = slice(values, -2, 0); // Select first word [-4: beam depth, -3: batch size, -2: 1, -1: vector dim] + + int dimPool = output->shape()[-1]; + bool project = !opt("transformer-no-projection"); + if(project || dimPool != dimModel) { + auto Wo = graph_->param(prefix + "_Wo", {dimPool, dimModel}, inits::glorotUniform()); + auto bo = graph_->param(prefix + "_bo", {1, dimModel}, inits::zeros()); + output = affine(output, Wo, bo); // [-4: beam depth, -3: batch size, -2: 1, -1: vector dim] + } + + auto opsPost = opt("transformer-postprocess"); + output = postProcess(prefix + "_Wo", opsPost, output, input, 0.f); + + return output; + } + + Expr LayerAttention(std::string prefix, Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] @@ -790,14 +814,21 @@ public: saveAttentionWeights = i == attLayer; } - query = LayerAttention(prefix, + if(options_->get("transformer-pool", false)) { + query = LayerPooling(prefix, query, - encoderContexts[j], // keys encoderContexts[j], // values - encoderMasks[j], - opt("transformer-heads"), - /*cache=*/true, - saveAttentionWeights); + encoderMasks[j]); + } else { + query = LayerAttention(prefix, + query, + encoderContexts[j], // keys + encoderContexts[j], // values + encoderMasks[j], + opt("transformer-heads"), + /*cache=*/true, + saveAttentionWeights); + } } } diff --git a/src/rnn/cells.h b/src/rnn/cells.h index 9fbc8852..cddfd26e 100755 --- a/src/rnn/cells.h +++ b/src/rnn/cells.h @@ -651,7 +651,7 @@ public: using LSTM = FastLSTM; /******************************************************************************/ -// Experimentak cells, use with care +// Experimental cells, use with care template class Multiplicative : public CellType { diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 5f56e634..39e2e20d 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -1248,67 +1248,104 @@ void SetSparse(float* out, } } -void LSTMCellForward(Tensor out_, std::vector inputs) { +// should be implemented via slicing and elementwise +template +void LSTMCellForwardTyped(Tensor out_, const std::vector& inputs) { int rows = out_->shape().elements() / out_->shape()[-1]; - int cols = out_->shape()[-1]; - float* out = out_->data(); - const float* cell = inputs[0]->data(); - const float* xW = inputs[1]->data(); - const float* sU = inputs[2]->data(); - const float* b = inputs[3]->data(); + int fVecSize = sizeof(FType) / sizeof(float); + int cols = out_->shape()[-1] / fVecSize; + + FType* out = out_->data(); + const FType* cell = inputs[0]->data(); + const FType* xW = inputs[1]->data(); + const FType* sU = inputs[2]->data(); + const FType* b = inputs[3]->data(); const float* mask = inputs.size() > 4 ? inputs[4]->data() : nullptr; + using fop = functional::Ops; + for(int j = 0; j < rows; ++j) { float m = !mask || mask[j]; - float* rowOut = out + j * cols; - const float* rowCell = cell + j * cols; + FType* rowOut = out + j * cols; + const FType* rowCell = cell + j * cols; - const float* xWrow = xW + j * cols * 4; - const float* sUrow = sU + j * cols * 4; + const FType* xWrow = xW + j * cols * 4; + const FType* sUrow = sU + j * cols * 4; for(int i = 0; i < cols; ++i) { - float gf = functional::Ops::sigmoid(xWrow[i] + sUrow[i] + b[i]); + FType gf = fop::sigmoid(fop::add(fop::add(xWrow[i], sUrow[i]), b[i])); int k = i + cols; - float gi = functional::Ops::sigmoid(xWrow[k] + sUrow[k] + b[k]); + FType gi = fop::sigmoid(fop::add(fop::add(xWrow[k], sUrow[k]), b[k])); int l = i + 2 * cols; - float gc = std::tanh(xWrow[l] + sUrow[l] + b[l]); + FType gc = fop::tanh(fop::add(fop::add(xWrow[l], sUrow[l]), b[l])); - float cout = gf * rowCell[i] + gi * gc; - rowOut[i] = m * cout + (1 - m) * rowCell[i]; + FType cout = fop::add(fop::mul(gf, rowCell[i]), fop::mul(gi, gc)); + rowOut[i] = fop::add(fop::mul(m, cout), fop::mul(fop::sub(1.f, m), rowCell[i])); } } } -void LSTMOutputForward(Tensor out_, std::vector inputs) { - int rows = out_->shape().elements() / out_->shape()[-1]; - int cols = out_->shape()[-1]; +void LSTMCellForward(Tensor out, std::vector inputs) { + int cols = out->shape()[-1]; +#ifdef __AVX__ + if(cols % 8 == 0) + LSTMCellForwardTyped(out, inputs); + else +#endif + if(cols % 4 == 0) + LSTMCellForwardTyped(out, inputs); + else + LSTMCellForwardTyped(out, inputs); +} - float* out = out_->data(); - const float* cell = inputs[0]->data(); - const float* xW = inputs[1]->data(); - const float* sU = inputs[2]->data(); - const float* b = inputs[3]->data(); +template +void LSTMOutputForwardTyped(Tensor out_, const std::vector& inputs) { + int rows = out_->shape().elements() / out_->shape()[-1]; + + int fVecSize = sizeof(FType) / sizeof(float); + int cols = out_->shape()[-1] / fVecSize; + + FType* out = out_->data(); + const FType* cell = inputs[0]->data(); + const FType* xW = inputs[1]->data(); + const FType* sU = inputs[2]->data(); + const FType* b = inputs[3]->data(); + + using fop = functional::Ops; for(int j = 0; j < rows; ++j) { - float* rowOut = out + j * cols; - const float* rowCell = cell + j * cols; + FType* rowOut = out + j * cols; + const FType* rowCell = cell + j * cols; - const float* xWrow = xW + j * cols * 4; - const float* sUrow = sU + j * cols * 4; + const FType* xWrow = xW + j * cols * 4; + const FType* sUrow = sU + j * cols * 4; for(int i = 0; i < cols; ++i) { int k = i + 3 * cols; - float go = functional::Ops::sigmoid(xWrow[k] + sUrow[k] + b[k]); - - rowOut[i] = go * std::tanh(rowCell[i]); + FType go = fop::sigmoid(fop::add(fop::add(xWrow[k], sUrow[k]), b[k])); + rowOut[i] = fop::mul(go, fop::tanh(rowCell[i])); } } } +void LSTMOutputForward(Tensor out, std::vector inputs) { + int cols = out->shape()[-1]; + +#ifdef __AVX__ + if(cols % 8 == 0) + LSTMOutputForwardTyped(out, inputs); + else +#endif + if(cols % 4 == 0) + LSTMOutputForwardTyped(out, inputs); + else + LSTMOutputForwardTyped(out, inputs); +} + void LSTMCellBackward(std::vector outputs, std::vector inputs, Tensor adj_) { diff --git a/src/training/graph_group.cpp b/src/training/graph_group.cpp index 616bb991..1eba08cf 100644 --- a/src/training/graph_group.cpp +++ b/src/training/graph_group.cpp @@ -77,7 +77,7 @@ Ptr GraphGroup::collectStats(Ptr graph, } else { end = current - 1; } - } while(end - start > step); + } while(end - start > step); // @TODO: better replace with `end >= start` to remove the step here maxBatch = start; }