Merged PR 13476: Add LASER reimplementation and code for embeddings sentences

This reimplements the LASER encoder from:
```
Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond
Mikel Artetxe, Holger Schwenk
https://arxiv.org/abs/1812.10464
```

and adds functionality to embed sentences with any Marian encoder, also different from LASER. Some early attempts to train a transformer model with Encoder-Decoder bottle-neck. This is quite early code, so some code-duplication is to be expected. Nevertheless, it's functional and I would like to have it in master as we will slowly put that into production in various places. I will make the code "nicer" as we go along.
This commit is contained in:
Martin Junczys-Dowmunt 2020-06-24 01:54:27 +00:00
parent 7b815cb936
commit c3fb60cbcd
23 changed files with 1066 additions and 44 deletions

View File

@ -187,6 +187,13 @@ else(MSVC)
set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction") set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
endif(MSVC) endif(MSVC)
# with gcc 7.0 and above we need to mark fallthrough in switch case statements
# that can be done in comments for backcompat, but CCACHE removes comments.
# -C makes gcc keep comments.
if(USE_CCACHE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -C")
endif()
############################################################################### ###############################################################################
# Downloading SentencePiece if requested and set to compile with it. # Downloading SentencePiece if requested and set to compile with it.
# Requires all the dependencies imposed by SentencePiece # Requires all the dependencies imposed by SentencePiece

View File

@ -0,0 +1,85 @@
import numpy as np
import sys
import yaml
import argparse
import torch
parser = argparse.ArgumentParser(description='Convert LASER model to Marian weight file.')
parser.add_argument('--laser', help='Path to LASER PyTorch model', required=True)
parser.add_argument('--marian', help='Output path for Marian weight file', required=True)
args = parser.parse_args()
laser = torch.load(args.laser)
config = dict()
config["type"] = "laser"
config["input-types"] = ["sequence"]
config["dim-vocabs"] = [laser["params"]["num_embeddings"]]
config["version"] = "laser2marian.py conversion"
config["enc-depth"] = laser["params"]["num_layers"]
config["enc-cell"] = "lstm"
config["dim-emb"] = laser["params"]["embed_dim"]
config["dim-rnn"] = laser["params"]["hidden_size"]
yaml.dump(laser["dictionary"], open(args.marian + ".vocab.yml", "w"))
marianModel = dict()
def transposeOrder(mat):
matT = np.transpose(mat) # just a view with changed row order
return matT.flatten(order="C").reshape(matT.shape) # force row order change and reshape
def convert(pd, srcs, trg, transpose=True, bias=False, lstm=False):
num = pd[srcs[0]].detach().numpy()
for i in range(1, len(srcs)):
num += pd[srcs[i]].detach().numpy()
out = num
if bias:
num = np.atleast_2d(num)
else:
if transpose:
num = transposeOrder(num) # transpose with row order change
if lstm: # different order in pytorch than marian
stateDim = int(num.shape[-1] / 4)
i = np.copy(num[:, 0*stateDim:1*stateDim])
f = np.copy(num[:, 1*stateDim:2*stateDim])
num[:, 0*stateDim:1*stateDim] = f
num[:, 1*stateDim:2*stateDim] = i
marianModel[trg] = num
for k in laser:
print(k)
for k in laser["model"]:
print(k, laser["model"][k].shape)
convert(laser["model"], ["embed_tokens.weight"], "encoder_Wemb", transpose=False)
for i in range(laser["params"]["num_layers"]):
convert(laser["model"], [f"lstm.weight_ih_l{i}"], f"encoder_lstm_l{i}_W", lstm=True)
convert(laser["model"], [f"lstm.weight_hh_l{i}"], f"encoder_lstm_l{i}_U", lstm=True)
convert(laser["model"], [f"lstm.bias_ih_l{i}", f"lstm.bias_hh_l{i}"], f"encoder_lstm_l{i}_b", bias=True, lstm=True) # needs to be summed!
convert(laser["model"], [f"lstm.weight_ih_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_W", lstm=True)
convert(laser["model"], [f"lstm.weight_hh_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_U", lstm=True)
convert(laser["model"], [f"lstm.bias_ih_l{i}_reverse", f"lstm.bias_hh_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_b", bias=True, lstm=True) # needs to be summed!
for m in marianModel:
print(m, marianModel[m].shape)
configYamlStr = yaml.dump(config, default_flow_style=False)
desc = list(configYamlStr)
npDesc = np.chararray((len(desc),))
npDesc[:] = desc
npDesc.dtype = np.int8
marianModel["special:model.yml"] = npDesc
print("\nMarian config:")
print(configYamlStr)
print("Saving Marian model to %s" % (args.marian,))
np.savez(args.marian, **marianModel)

View File

@ -83,6 +83,7 @@ add_library(marian STATIC
models/transformer_stub.cpp models/transformer_stub.cpp
rescorer/score_collector.cpp rescorer/score_collector.cpp
embedder/vector_collector.cpp
translator/history.cpp translator/history.cpp
translator/output_collector.cpp translator/output_collector.cpp

View File

@ -0,0 +1,14 @@
#include "marian.h"
#include "models/model_task.h"
#include "embedder/embedder.h"
#include "common/timer.h"
int main(int argc, char** argv) {
using namespace marian;
auto options = parseOptions(argc, argv, cli::mode::embedding);
New<Embed<Embedder>>(options)->run();
return 0;
}

View File

@ -11,6 +11,7 @@
// train // train
// decode // decode
// score // score
// embed
// vocab // vocab
// convert // convert
// Currently, marian_server is not supported, since it is a special use case with lots of extra dependencies. // Currently, marian_server is not supported, since it is a special use case with lots of extra dependencies.
@ -24,6 +25,9 @@
#define main mainScorer #define main mainScorer
#include "marian_scorer.cpp" #include "marian_scorer.cpp"
#undef main #undef main
#define main mainEmbedder
#include "marian_embedder.cpp"
#undef main
#define main mainVocab #define main mainVocab
#include "marian_vocab.cpp" #include "marian_vocab.cpp"
#undef main #undef main
@ -44,9 +48,10 @@ int main(int argc, char** argv) {
if(cmd == "train") return mainTrainer(argc, argv); if(cmd == "train") return mainTrainer(argc, argv);
else if(cmd == "decode") return mainDecoder(argc, argv); else if(cmd == "decode") return mainDecoder(argc, argv);
else if (cmd == "score") return mainScorer(argc, argv); else if (cmd == "score") return mainScorer(argc, argv);
else if (cmd == "embed") return mainEmbedder(argc, argv);
else if (cmd == "vocab") return mainVocab(argc, argv); else if (cmd == "vocab") return mainVocab(argc, argv);
else if (cmd == "convert") return mainConv(argc, argv); else if (cmd == "convert") return mainConv(argc, argv);
std::cerr << "Command must be train, decode, score, vocab, or convert." << std::endl; std::cerr << "Command must be train, decode, score, embed, vocab, or convert." << std::endl;
exit(1); exit(1);
} else } else
return mainTrainer(argc, argv); return mainTrainer(argc, argv);

View File

@ -91,6 +91,9 @@ ConfigParser::ConfigParser(cli::mode mode)
case cli::mode::scoring: case cli::mode::scoring:
addOptionsScoring(cli_); addOptionsScoring(cli_);
break; break;
case cli::mode::embedding:
addOptionsEmbedding(cli_);
break;
default: default:
ABORT("wrong CLI mode"); ABORT("wrong CLI mode");
break; break;
@ -235,6 +238,8 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
8); 8);
cli.add<bool>("--transformer-no-projection", cli.add<bool>("--transformer-no-projection",
"Omit linear projection after multi-head attention (transformer)"); "Omit linear projection after multi-head attention (transformer)");
cli.add<bool>("--transformer-pool",
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn", cli.add<int>("--transformer-dim-ffn",
"Size of position-wise feed-forward network (transformer)", "Size of position-wise feed-forward network (transformer)",
2048); 2048);
@ -705,6 +710,45 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
// clang-format on // clang-format on
} }
void ConfigParser::addOptionsEmbedding(cli::CLIWrapper& cli) {
auto previous_group = cli.switchGroup("Scorer options");
// clang-format off
cli.add<bool>("--no-reload",
"Do not load existing model specified in --model arg");
// TODO: move options like vocabs and train-sets to a separate procedure as they are defined twice
cli.add<std::vector<std::string>>("--train-sets,-t",
"Paths to corpora to be scored: source target");
cli.add<std::string>("--output,-o",
"Path to output file, stdout by default",
"stdout");
cli.add<std::vector<std::string>>("--vocabs,-v",
"Paths to vocabulary files have to correspond to --train-sets. "
"If this parameter is not supplied we look for vocabulary files source.{yml,json} and target.{yml,json}. "
"If these files do not exists they are created");
cli.add<bool>("--compute-similarity",
"Expect two inputs and compute cosine similarity instead of outputting embedding vector");
cli.add<bool>("--binary",
"Output vectors as binary floats");
addSuboptionsInputLength(cli);
addSuboptionsTSV(cli);
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);
cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
cli.add<std::vector<std::string>>("--precision",
"Mixed precision for inference, set parameter type in expression graph",
{"float32"});
cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) { void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
// clang-format off // clang-format off
cli.add<std::vector<std::string>>("--devices,-d", cli.add<std::vector<std::string>>("--devices,-d",

View File

@ -14,7 +14,7 @@
namespace marian { namespace marian {
namespace cli { namespace cli {
enum struct mode { training, translation, scoring, server }; enum struct mode { training, translation, scoring, server, embedding };
} // namespace cli } // namespace cli
/** /**
@ -129,6 +129,7 @@ private:
void addOptionsValidation(cli::CLIWrapper&); void addOptionsValidation(cli::CLIWrapper&);
void addOptionsTranslation(cli::CLIWrapper&); void addOptionsTranslation(cli::CLIWrapper&);
void addOptionsScoring(cli::CLIWrapper&); void addOptionsScoring(cli::CLIWrapper&);
void addOptionsEmbedding(cli::CLIWrapper&);
void addAliases(cli::CLIWrapper&); void addAliases(cli::CLIWrapper&);

View File

@ -27,6 +27,10 @@ void ConfigValidator::validateOptions(cli::mode mode) const {
validateOptionsParallelData(); validateOptionsParallelData();
validateOptionsScoring(); validateOptionsScoring();
break; break;
case cli::mode::embedding:
validateOptionsParallelData();
validateOptionsScoring();
break;
case cli::mode::training: case cli::mode::training:
validateOptionsParallelData(); validateOptionsParallelData();
validateOptionsTraining(); validateOptionsTraining();

171
src/embedder/embedder.h Normal file
View File

@ -0,0 +1,171 @@
#pragma once
#include "marian.h"
#include "common/config.h"
#include "common/options.h"
#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/corpus_nbest.h"
#include "models/costs.h"
#include "models/model_task.h"
#include "embedder/vector_collector.h"
#include "training/scheduler.h"
#include "training/validator.h"
namespace marian {
using namespace data;
/*
* The tool is used to create output sentence embeddings from available
* Marian encoders. With --compute-similiarity and can return the cosine
* similarity between two sentences provided from two sources.
*/
class Embedder {
private:
Ptr<models::IModel> model_;
public:
Embedder(Ptr<Options> options)
: model_(createModelFromOptions(options, models::usage::embedding)) {}
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
model_->load(graph, modelFile);
}
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
return embedder->apply(graph, batch, /*clearGraph=*/true);
}
};
/*
* Actual Embed task. @TODO: this should be simplified in the future.
*/
template <class Model>
class Embed : public ModelTask {
private:
Ptr<Options> options_;
Ptr<CorpusBase> corpus_;
std::vector<Ptr<ExpressionGraph>> graphs_;
std::vector<Ptr<Model>> models_;
public:
Embed(Ptr<Options> options) : options_(options) {
options_ = options_->with("inference", true,
"shuffle", "none");
// if a similarity is computed then double the input types and vocabs for
// the two encoders that are used in the model.
if(options->get<bool>("compute-similarity")) {
auto vInputTypes = options_->get<std::vector<std::string>>("input-types");
auto vVocabs = options_->get<std::vector<std::string>>("vocabs");
auto vDimVocabs = options_->get<std::vector<size_t>>("dim-vocabs");
vInputTypes.push_back(vInputTypes.back());
vVocabs.push_back(vVocabs.back());
vDimVocabs.push_back(vDimVocabs.back());
options_ = options_->with("input-types", vInputTypes,
"vocabs", vVocabs,
"dim-vocabs", vDimVocabs);
}
corpus_ = New<Corpus>(options_);
corpus_->prepare();
auto devices = Config::getDevices(options_);
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true);
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
}
auto modelFile = options_->get<std::string>("model");
models_.resize(graphs_.size());
ThreadPool pool(graphs_.size(), graphs_.size());
for(size_t i = 0; i < graphs_.size(); ++i) {
pool.enqueue(
[=](size_t j) {
models_[j] = New<Model>(options_);
models_[j]->load(graphs_[j], modelFile);
},
i);
}
}
void run() override {
LOG(info, "Embedding");
timer::Timer timer;
auto batchGenerator = New<BatchGenerator<CorpusBase>>(corpus_, options_);
batchGenerator->prepare();
auto output = New<VectorCollector>(options_);
size_t batchId = 0;
std::mutex smutex;
{
ThreadPool pool(graphs_.size(), graphs_.size());
for(auto batch : *batchGenerator) {
auto task = [=, &smutex](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<Model> builder;
if(!graph) {
graph = graphs_[id % graphs_.size()];
builder = models_[id % graphs_.size()];
}
auto embeddings = builder->build(graph, batch);
graph->forward();
std::vector<float> sentVectors;
embeddings->val()->get(sentVectors);
// collect embedding vector per sentence.
// if we compute similarities this is only one similarity per sentence pair.
for(size_t i = 0; i < batch->size(); ++i) {
int embSize = embeddings->shape()[-1];
int beg = i * embSize;
int end = (i + 1) * embSize;
std::vector<float> sentVector(sentVectors.begin() + beg, sentVectors.begin() + end);
output->Write((long)batch->getSentenceIds()[i],
sentVector);
}
// progress heartbeat for MS-internal Philly compute cluster
// otherwise this job may be killed prematurely if no log for 4 hrs
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
&& id % 1000 == 0) // hard beat once every 1000 batches
{
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
fflush(stderr);
}
};
pool.enqueue(task, batchId++);
}
}
LOG(info, "Total time: {:.5f}s wall", timer.elapsed());
}
};
} // namespace marian

View File

@ -0,0 +1,71 @@
#include "embedder/vector_collector.h"
#include "common/logging.h"
#include "common/utils.h"
#include <iostream>
#include <iomanip>
namespace marian {
// This class manages multi-threaded writing of embedded vectors to stdout or an output file.
// It will either output string versions of float vectors or binary equal length versions depending
// on its binary_ flag.
VectorCollector::VectorCollector(const Ptr<Options>& options)
: nextId_(0), binary_{options->get<bool>("binary", false)} {
if(options->get<std::string>("output") == "stdout")
outStrm_.reset(new std::ostream(std::cout.rdbuf()));
else
outStrm_.reset(new io::OutputFileStream(options->get<std::string>("output")));
}
void VectorCollector::Write(long id, const std::vector<float>& vec) {
std::lock_guard<std::mutex> lock(mutex_);
if(id == nextId_) {
WriteVector(vec);
++nextId_;
typename Outputs::const_iterator iter, iterNext;
iter = outputs_.begin();
while(iter != outputs_.end()) {
long currId = iter->first;
if(currId == nextId_) {
// 1st element in the map is the next
WriteVector(iter->second);
++nextId_;
// delete current record, move iter on 1
iterNext = iter;
++iterNext;
outputs_.erase(iter);
iter = iterNext;
} else {
// not the next. stop iterating
assert(nextId_ < currId);
break;
}
}
} else {
// save for later
outputs_[id] = vec;
}
}
void VectorCollector::WriteVector(const std::vector<float>& vec) {
if(binary_) {
outStrm_->write((char*)vec.data(), vec.size() * sizeof(float));
} else {
std::stringstream ss;
ss << std::fixed << std::setprecision(8);
for(auto v : vec)
*outStrm_ << v << " ";
*outStrm_ << std::endl;
}
}
} // namespace marian

View File

@ -0,0 +1,32 @@
#pragma once
#include "common/options.h"
#include "common/definitions.h"
#include "common/file_stream.h"
#include <map>
#include <mutex>
namespace marian {
// This class manages multi-threaded writing of embedded vectors to stdout or an output file.
// It will either output string versions of float vectors or binary equal length versions depending
// on its binary_ flag.
class VectorCollector {
public:
VectorCollector(const Ptr<Options>& options);
virtual void Write(long id, const std::vector<float>& vec);
protected:
long nextId_{0};
UPtr<std::ostream> outStrm_;
bool binary_; // output binary floating point vectors if set
std::mutex mutex_;
typedef std::map<long, std::vector<float>> Outputs;
Outputs outputs_;
virtual void WriteVector(const std::vector<float>& vec);
};
} // namespace marian

View File

@ -416,7 +416,7 @@ protected:
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors"); ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete. ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks, mask again to eliminate padding (might be obsolete) // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete)
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32); auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {

View File

@ -51,6 +51,7 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
modelFeatures_.insert("transformer-tied-layers"); modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer"); modelFeatures_.insert("transformer-guided-alignment-layer");
modelFeatures_.insert("transformer-train-position-embeddings"); modelFeatures_.insert("transformer-train-position-embeddings");
modelFeatures_.insert("transformer-pool");
modelFeatures_.insert("bert-train-type-embeddings"); modelFeatures_.insert("bert-train-type-embeddings");
modelFeatures_.insert("bert-type-vocab-size"); modelFeatures_.insert("bert-type-vocab-size");

217
src/models/encoder_pooler.h Normal file
View File

@ -0,0 +1,217 @@
#pragma once
#include "marian.h"
#include "models/encoder.h"
#include "models/pooler.h"
#include "models/model_base.h"
#include "models/states.h"
// @TODO: this introduces functionality to use LASER in Marian for the filtering workflow or for use in MS-internal
// COSMOS server-farm. There is a lot of code duplication with Classifier and EncoderDecoder and this needs to be fixed.
// This will be done after the new layer system has been finished.
namespace marian {
/**
* Combines sequence encoders with generic poolers
* Can be used to train sequence poolers like language detection, BERT-next-sentence-prediction etc.
* Already has support for multi-objective training.
*
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/pooler
* multi-objective training.
*/
class EncoderPoolerBase : public models::IModel {
public:
virtual ~EncoderPoolerBase() {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override
= 0;
virtual void mmap(Ptr<ExpressionGraph> graph,
const void* ptr,
bool markedReloaded = true)
= 0;
virtual void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool saveTranslatorConfig = false) override
= 0;
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
ABORT("Poolers cannot produce Logits");
};
virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) {
ABORT("Poolers cannot produce Logits");
}
virtual Ptr<Options> getOptions() = 0;
};
class EncoderPooler : public EncoderPoolerBase {
protected:
Ptr<Options> options_;
std::string prefix_;
std::vector<Ptr<EncoderBase>> encoders_;
std::vector<Ptr<PoolerBase>> poolers_;
bool inference_{true};
std::set<std::string> modelFeatures_;
Config::YamlNode getModelParameters() {
Config::YamlNode modelParams;
auto clone = options_->cloneToYamlNode();
for(auto& key : modelFeatures_)
modelParams[key] = clone[key];
if(options_->has("original-type"))
modelParams["type"] = clone["original-type"];
modelParams["version"] = buildVersion();
return modelParams;
}
std::string getModelParametersAsString() {
auto yaml = getModelParameters();
YAML::Emitter out;
cli::OutputYaml(yaml, out);
return std::string(out.c_str());
}
public:
typedef data::Corpus dataset_type;
// @TODO: lots of code-duplication with EncoderDecoder
EncoderPooler(Ptr<Options> options)
: options_(options),
prefix_(options->get<std::string>("prefix", "")),
inference_(options->get<bool>("inference", false)) {
modelFeatures_ = {"type",
"dim-vocabs",
"dim-emb",
"dim-rnn",
"enc-cell",
"enc-type",
"enc-cell-depth",
"enc-depth",
"dec-depth",
"dec-cell",
"dec-cell-base-depth",
"dec-cell-high-depth",
"skip",
"layer-normalization",
"right-left",
"input-types",
"special-vocab",
"tied-embeddings",
"tied-embeddings-src",
"tied-embeddings-all"};
modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
modelFeatures_.insert("transformer-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
modelFeatures_.insert("transformer-ffn-activation");
modelFeatures_.insert("transformer-dim-aan");
modelFeatures_.insert("transformer-aan-depth");
modelFeatures_.insert("transformer-aan-activation");
modelFeatures_.insert("transformer-aan-nogate");
modelFeatures_.insert("transformer-preprocess");
modelFeatures_.insert("transformer-postprocess");
modelFeatures_.insert("transformer-postprocess-emb");
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");
modelFeatures_.insert("transformer-train-position-embeddings");
modelFeatures_.insert("transformer-pool");
modelFeatures_.insert("bert-train-type-embeddings");
modelFeatures_.insert("bert-type-vocab-size");
modelFeatures_.insert("ulr");
modelFeatures_.insert("ulr-trainable-transformation");
modelFeatures_.insert("ulr-dim-emb");
modelFeatures_.insert("lemma-dim-emb");
}
virtual Ptr<Options> getOptions() override { return options_; }
std::vector<Ptr<EncoderBase>>& getEncoders() { return encoders_; }
std::vector<Ptr<PoolerBase>>& getPoolers() { return poolers_; }
void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
void push_back(Ptr<PoolerBase> pooler) { poolers_.push_back(pooler); }
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded) override {
graph->load(name, markedReloaded && !opt<bool>("ignore-model-config", false));
}
void mmap(Ptr<ExpressionGraph> graph,
const void* ptr,
bool markedReloaded) override {
graph->mmap(ptr, markedReloaded && !opt<bool>("ignore-model-config", false));
}
void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool /*saveModelConfig*/) override {
LOG(info, "Saving model weights and runtime parameters to {}", name);
graph->save(name , getModelParametersAsString());
}
void clear(Ptr<ExpressionGraph> graph) override {
graph->clear();
for(auto& enc : encoders_)
enc->clear();
for(auto& pooler : poolers_)
pooler->clear();
}
template <typename T>
T opt(const std::string& key) {
return options_->get<T>(key);
}
template <typename T>
T opt(const std::string& key, const T& def) {
return options_->get<T>(key, def);
}
template <typename T>
void set(std::string key, T value) {
options_->set(key, value);
}
/*********************************************************************/
virtual Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
if(clearGraph)
clear(graph);
std::vector<Ptr<EncoderState>> encoderStates;
for(auto& encoder : encoders_)
encoderStates.push_back(encoder->build(graph, batch));
ABORT_IF(poolers_.size() != 1, "Expected exactly one pooler");
return poolers_[0]->apply(graph, batch, encoderStates);
}
};
} // namespace marian

71
src/models/laser.h Normal file
View File

@ -0,0 +1,71 @@
#pragma once
#include "marian.h"
#include "layers/constructors.h"
#include "rnn/constructors.h"
namespace marian {
// Re-implements the LASER BiLSTM encoder from:
// Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond
// Mikel Artetxe, Holger Schwenk
// https://arxiv.org/abs/1812.10464
class EncoderLaser : public EncoderBase {
using EncoderBase::EncoderBase;
public:
Expr applyEncoderRNN(Ptr<ExpressionGraph> graph,
Expr embeddings,
Expr mask) {
int depth = opt<int>("enc-depth");
float dropoutRnn = inference_ ? 0 : opt<float>("dropout-rnn");
Expr output = embeddings;
auto applyRnn = [&](int layer, rnn::dir direction, Expr input, Expr mask) {
std::string paramPrefix = prefix_ + "_" + opt<std::string>("enc-cell");
paramPrefix += "_l" + std::to_string(layer);
if(direction == rnn::dir::backward)
paramPrefix += "_reverse";
auto rnnFactory = rnn::rnn()
("type", opt<std::string>("enc-cell"))
("direction", (int)direction)
("dimInput", input->shape()[-1])
("dimState", opt<int>("dim-rnn"))
("dropout", dropoutRnn)
("layer-normalization", opt<bool>("layer-normalization"))
("skip", opt<bool>("skip"))
.push_back(rnn::cell()("prefix", paramPrefix));
return rnnFactory.construct(graph)->transduce(input, mask);
};
for(int i = 0; i < depth; ++i) {
output = concatenate({applyRnn(i, rnn::dir::forward, output, mask),
applyRnn(i, rnn::dir::backward, output, mask)},
/*axis =*/ -1);
}
return output;
}
virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) override {
graph_ = graph;
// select embeddings that occur in the batch
Expr batchEmbeddings, batchMask; std::tie
(batchEmbeddings, batchMask) = getEmbeddingLayer()->apply((*batch)[batchIndex_]);
Expr context = applyEncoderRNN(graph_, batchEmbeddings, batchMask);
return New<EncoderState>(context, batchMask, batch);
}
void clear() override {}
};
}

View File

@ -8,7 +8,7 @@
namespace marian { namespace marian {
namespace models { namespace models {
enum struct usage { raw, training, scoring, translation }; enum struct usage { raw, training, scoring, translation, embedding };
} }
} // namespace marian } // namespace marian

View File

@ -10,6 +10,7 @@
#include "models/amun.h" #include "models/amun.h"
#include "models/nematus.h" #include "models/nematus.h"
#include "models/s2s.h" #include "models/s2s.h"
#include "models/laser.h"
#include "models/transformer_factory.h" #include "models/transformer_factory.h"
#ifdef CUDNN #ifdef CUDNN
@ -29,6 +30,9 @@ namespace models {
Ptr<EncoderBase> EncoderFactory::construct(Ptr<ExpressionGraph> graph) { Ptr<EncoderBase> EncoderFactory::construct(Ptr<ExpressionGraph> graph) {
if(options_->get<std::string>("type") == "s2s") if(options_->get<std::string>("type") == "s2s")
return New<EncoderS2S>(graph, options_); return New<EncoderS2S>(graph, options_);
if(options_->get<std::string>("type") == "laser" || options_->get<std::string>("type") == "laser-sim")
return New<EncoderLaser>(graph, options_);
#ifdef CUDNN #ifdef CUDNN
if(options_->get<std::string>("type") == "char-s2s") if(options_->get<std::string>("type") == "char-s2s")
@ -61,6 +65,17 @@ Ptr<ClassifierBase> ClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
ABORT("Unknown classifier type"); ABORT("Unknown classifier type");
} }
Ptr<PoolerBase> PoolerFactory::construct(Ptr<ExpressionGraph> graph) {
if(options_->get<std::string>("type") == "max-pooler")
return New<MaxPooler>(graph, options_);
if(options_->get<std::string>("type") == "slice-pooler")
return New<SlicePooler>(graph, options_);
else if(options_->get<std::string>("type") == "sim-pooler")
return New<SimPooler>(graph, options_);
else
ABORT("Unknown pooler type");
}
Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) { Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<EncoderDecoder> encdec; Ptr<EncoderDecoder> encdec;
if(options_->get<std::string>("type") == "amun") if(options_->get<std::string>("type") == "amun")
@ -97,9 +112,54 @@ Ptr<IModel> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
return enccls; return enccls;
} }
Ptr<IModel> EncoderPoolerFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<EncoderPooler> encpool = New<EncoderPooler>(options_);
for(auto& ef : encoders_)
encpool->push_back(ef(options_).construct(graph));
for(auto& pl : poolers_)
encpool->push_back(pl(options_).construct(graph));
return encpool;
}
Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) { Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
// clang-format off // clang-format off
if(use == usage::embedding) { // hijacking an EncoderDecoder model for embedding only
int dimVocab = options->get<std::vector<int>>("dim-vocabs")[0];
Ptr<Options> newOptions;
if(options->get<bool>("compute-similarity")) {
newOptions = options->with("usage", use,
"original-type", type,
"input-types", std::vector<std::string>({"sequence", "sequence"}),
"dim-vocabs", std::vector<int>(2, dimVocab));
} else {
newOptions = options->with("usage", use,
"original-type", type,
"input-types", std::vector<std::string>({"sequence"}),
"dim-vocabs", std::vector<int>(1, dimVocab));
}
auto res = New<EncoderPooler>(newOptions);
if(options->get<bool>("compute-similarity")) {
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
res->push_back(models::encoder(newOptions->with("index", 1)).construct(graph));
res->push_back(New<SimPooler>(graph, newOptions->with("type", "sim-pooler")));
} else {
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
if(type == "laser")
res->push_back(New<MaxPooler>(graph, newOptions->with("type", "max-pooler")));
else
res->push_back(New<SlicePooler>(graph, newOptions->with("type", "slice-pooler")));
}
return res;
}
if(type == "s2s" || type == "amun" || type == "nematus") { if(type == "s2s" || type == "amun" || type == "nematus") {
return models::encoder_decoder(options->with( return models::encoder_decoder(options->with(
"usage", use, "usage", use,
@ -313,7 +373,7 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
else else
ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type); ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type);
} }
else if (use == usage::raw) else if (use == usage::raw || use == usage::embedding)
return baseModel; return baseModel;
else else
ABORT("'Usage' parameter must be 'translation' or 'raw'"); ABORT("'Usage' parameter must be 'translation' or 'raw'");

View File

@ -5,6 +5,7 @@
#include "layers/factory.h" #include "layers/factory.h"
#include "models/encoder_decoder.h" #include "models/encoder_decoder.h"
#include "models/encoder_classifier.h" #include "models/encoder_classifier.h"
#include "models/encoder_pooler.h"
namespace marian { namespace marian {
namespace models { namespace models {
@ -33,6 +34,14 @@ public:
typedef Accumulator<ClassifierFactory> classifier; typedef Accumulator<ClassifierFactory> classifier;
class PoolerFactory : public Factory {
using Factory::Factory;
public:
virtual Ptr<PoolerBase> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<PoolerFactory> pooler;
class EncoderDecoderFactory : public Factory { class EncoderDecoderFactory : public Factory {
using Factory::Factory; using Factory::Factory;
private: private:
@ -77,6 +86,28 @@ public:
typedef Accumulator<EncoderClassifierFactory> encoder_classifier; typedef Accumulator<EncoderClassifierFactory> encoder_classifier;
class EncoderPoolerFactory : public Factory {
using Factory::Factory;
private:
std::vector<encoder> encoders_;
std::vector<pooler> poolers_;
public:
Accumulator<EncoderPoolerFactory> push_back(encoder enc) {
encoders_.push_back(enc);
return Accumulator<EncoderPoolerFactory>(*this);
}
Accumulator<EncoderPoolerFactory> push_back(pooler cls) {
poolers_.push_back(cls);
return Accumulator<EncoderPoolerFactory>(*this);
}
virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderPoolerFactory> encoder_pooler;
Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options); Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage); Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);

139
src/models/pooler.h Normal file
View File

@ -0,0 +1,139 @@
#pragma once
#include "marian.h"
#include "models/states.h"
#include "layers/constructors.h"
#include "layers/factory.h"
namespace marian {
/**
* Simple base class for Poolers to be used in EncoderPooler framework
* A pooler takes a encoder state (contextual word embeddings) and produces
* a single sentence embedding.
*/
class PoolerBase : public LayerBase {
using LayerBase::LayerBase;
protected:
const std::string prefix_{"pooler"};
const bool inference_{false};
const size_t batchIndex_{0};
public:
PoolerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", "pooler")),
inference_(options->get<bool>("inference", true)),
batchIndex_(options->get<size_t>("index", 1)) {} // assume that training input has batch index 0 and labels has 1
virtual ~PoolerBase() {}
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
template <typename T>
T opt(const std::string& key) const {
return options_->get<T>(key);
}
// Should be used to clear any batch-wise temporary objects if present
virtual void clear() = 0;
};
/**
* Pool encoder state (contextual word embeddings) via max-pooling along sentence-length dimension.
*/
class MaxPooler : public PoolerBase {
public:
MaxPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
auto context = encoderStates[0]->getContext();
auto batchMask = encoderStates[0]->getMask();
// do a max pool here
Expr logMask = (1.f - batchMask) * -9999.f;
Expr maxPool = max(context * batchMask + logMask, /*axis=*/-3);
return maxPool;
}
void clear() override {}
};
/**
* Pool encoder state (contextual word embeddings) by selecting 1st embedding along sentence-length dimension.
*/
class SlicePooler : public PoolerBase {
public:
SlicePooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
auto context = encoderStates[0]->getContext();
auto batchMask = encoderStates[0]->getMask();
// Corresponds to the way we do this in transformer.h
// @TODO: unify this better, this is currently hacky
Expr slicePool = slice(context * batchMask, /*axis=*/-3, 0);
return slicePool;
}
void clear() override {}
};
/**
* Not really a pooler but abusing the interface to compute a similarity of two pooled states
*/
class SimPooler : public PoolerBase {
public:
SimPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 2, "SimPooler expects exactly two encoder states");
std::vector<Expr> vecs;
for(auto encoderState : encoderStates) {
auto context = encoderState->getContext();
auto batchMask = encoderState->getMask();
Expr pool;
auto type = options_->get<std::string>("original-type");
if(type == "laser") {
// LASER models do a max pool here
Expr logMask = (1.f - batchMask) * -9999.f;
pool = max(context * batchMask + logMask, /*axis=*/-3);
} else if(type == "transformer") {
// Our own implementation in transformer.h uses a slice of the first element
pool = slice(context, -3, 0);
} else {
// @TODO: make SimPooler take Pooler objects as arguments then it won't need to know this.
ABORT("Don't know what type of pooler to use for model type {}", type);
}
vecs.push_back(pool);
}
auto scalars = scalar_product(vecs[0], vecs[1], /*axis*/-1);
auto length1 = sqrt(sum(square(vecs[0]), /*axis=*/-1));
auto length2 = sqrt(sum(square(vecs[1]), /*axis=*/-1));
auto cosine = scalars / ( length1 * length2 );
return cosine;
}
void clear() override {}
};
}

View File

@ -328,6 +328,30 @@ public:
return output; return output;
} }
// Reduce the encoder to a single sentence vector, here we just take the contextual embedding of the first word per sentence
// Replaces cross-attention in LASER-like models
Expr LayerPooling(std::string prefix,
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
const Expr& values, // [-4: beam depth=1, -3: batch size, -2: max length (src or trg), -1: vector dim]
const Expr& mask) { // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
int dimModel = input->shape()[-1];
auto output = slice(values, -2, 0); // Select first word [-4: beam depth, -3: batch size, -2: 1, -1: vector dim]
int dimPool = output->shape()[-1];
bool project = !opt<bool>("transformer-no-projection");
if(project || dimPool != dimModel) {
auto Wo = graph_->param(prefix + "_Wo", {dimPool, dimModel}, inits::glorotUniform());
auto bo = graph_->param(prefix + "_bo", {1, dimModel}, inits::zeros());
output = affine(output, Wo, bo); // [-4: beam depth, -3: batch size, -2: 1, -1: vector dim]
}
auto opsPost = opt<std::string>("transformer-postprocess");
output = postProcess(prefix + "_Wo", opsPost, output, input, 0.f);
return output;
}
Expr LayerAttention(std::string prefix, Expr LayerAttention(std::string prefix,
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
@ -790,14 +814,21 @@ public:
saveAttentionWeights = i == attLayer; saveAttentionWeights = i == attLayer;
} }
query = LayerAttention(prefix, if(options_->get<bool>("transformer-pool", false)) {
query = LayerPooling(prefix,
query, query,
encoderContexts[j], // keys
encoderContexts[j], // values encoderContexts[j], // values
encoderMasks[j], encoderMasks[j]);
opt<int>("transformer-heads"), } else {
/*cache=*/true, query = LayerAttention(prefix,
saveAttentionWeights); query,
encoderContexts[j], // keys
encoderContexts[j], // values
encoderMasks[j],
opt<int>("transformer-heads"),
/*cache=*/true,
saveAttentionWeights);
}
} }
} }

View File

@ -651,7 +651,7 @@ public:
using LSTM = FastLSTM; using LSTM = FastLSTM;
/******************************************************************************/ /******************************************************************************/
// Experimentak cells, use with care // Experimental cells, use with care
template <class CellType> template <class CellType>
class Multiplicative : public CellType { class Multiplicative : public CellType {

View File

@ -1248,67 +1248,104 @@ void SetSparse(float* out,
} }
} }
void LSTMCellForward(Tensor out_, std::vector<Tensor> inputs) { // should be implemented via slicing and elementwise
template <typename FType>
void LSTMCellForwardTyped(Tensor out_, const std::vector<Tensor>& inputs) {
int rows = out_->shape().elements() / out_->shape()[-1]; int rows = out_->shape().elements() / out_->shape()[-1];
int cols = out_->shape()[-1];
float* out = out_->data(); int fVecSize = sizeof(FType) / sizeof(float);
const float* cell = inputs[0]->data(); int cols = out_->shape()[-1] / fVecSize;
const float* xW = inputs[1]->data();
const float* sU = inputs[2]->data(); FType* out = out_->data<FType>();
const float* b = inputs[3]->data(); const FType* cell = inputs[0]->data<FType>();
const FType* xW = inputs[1]->data<FType>();
const FType* sU = inputs[2]->data<FType>();
const FType* b = inputs[3]->data<FType>();
const float* mask = inputs.size() > 4 ? inputs[4]->data() : nullptr; const float* mask = inputs.size() > 4 ? inputs[4]->data() : nullptr;
using fop = functional::Ops<FType>;
for(int j = 0; j < rows; ++j) { for(int j = 0; j < rows; ++j) {
float m = !mask || mask[j]; float m = !mask || mask[j];
float* rowOut = out + j * cols; FType* rowOut = out + j * cols;
const float* rowCell = cell + j * cols; const FType* rowCell = cell + j * cols;
const float* xWrow = xW + j * cols * 4; const FType* xWrow = xW + j * cols * 4;
const float* sUrow = sU + j * cols * 4; const FType* sUrow = sU + j * cols * 4;
for(int i = 0; i < cols; ++i) { for(int i = 0; i < cols; ++i) {
float gf = functional::Ops<float>::sigmoid(xWrow[i] + sUrow[i] + b[i]); FType gf = fop::sigmoid(fop::add(fop::add(xWrow[i], sUrow[i]), b[i]));
int k = i + cols; int k = i + cols;
float gi = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]); FType gi = fop::sigmoid(fop::add(fop::add(xWrow[k], sUrow[k]), b[k]));
int l = i + 2 * cols; int l = i + 2 * cols;
float gc = std::tanh(xWrow[l] + sUrow[l] + b[l]); FType gc = fop::tanh(fop::add(fop::add(xWrow[l], sUrow[l]), b[l]));
float cout = gf * rowCell[i] + gi * gc; FType cout = fop::add(fop::mul(gf, rowCell[i]), fop::mul(gi, gc));
rowOut[i] = m * cout + (1 - m) * rowCell[i]; rowOut[i] = fop::add(fop::mul(m, cout), fop::mul(fop::sub(1.f, m), rowCell[i]));
} }
} }
} }
void LSTMOutputForward(Tensor out_, std::vector<Tensor> inputs) { void LSTMCellForward(Tensor out, std::vector<Tensor> inputs) {
int rows = out_->shape().elements() / out_->shape()[-1]; int cols = out->shape()[-1];
int cols = out_->shape()[-1]; #ifdef __AVX__
if(cols % 8 == 0)
LSTMCellForwardTyped<float32x8>(out, inputs);
else
#endif
if(cols % 4 == 0)
LSTMCellForwardTyped<float32x4>(out, inputs);
else
LSTMCellForwardTyped<float>(out, inputs);
}
float* out = out_->data(); template <typename FType>
const float* cell = inputs[0]->data(); void LSTMOutputForwardTyped(Tensor out_, const std::vector<Tensor>& inputs) {
const float* xW = inputs[1]->data(); int rows = out_->shape().elements() / out_->shape()[-1];
const float* sU = inputs[2]->data();
const float* b = inputs[3]->data(); int fVecSize = sizeof(FType) / sizeof(float);
int cols = out_->shape()[-1] / fVecSize;
FType* out = out_->data<FType>();
const FType* cell = inputs[0]->data<FType>();
const FType* xW = inputs[1]->data<FType>();
const FType* sU = inputs[2]->data<FType>();
const FType* b = inputs[3]->data<FType>();
using fop = functional::Ops<FType>;
for(int j = 0; j < rows; ++j) { for(int j = 0; j < rows; ++j) {
float* rowOut = out + j * cols; FType* rowOut = out + j * cols;
const float* rowCell = cell + j * cols; const FType* rowCell = cell + j * cols;
const float* xWrow = xW + j * cols * 4; const FType* xWrow = xW + j * cols * 4;
const float* sUrow = sU + j * cols * 4; const FType* sUrow = sU + j * cols * 4;
for(int i = 0; i < cols; ++i) { for(int i = 0; i < cols; ++i) {
int k = i + 3 * cols; int k = i + 3 * cols;
float go = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]); FType go = fop::sigmoid(fop::add(fop::add(xWrow[k], sUrow[k]), b[k]));
rowOut[i] = fop::mul(go, fop::tanh(rowCell[i]));
rowOut[i] = go * std::tanh(rowCell[i]);
} }
} }
} }
void LSTMOutputForward(Tensor out, std::vector<Tensor> inputs) {
int cols = out->shape()[-1];
#ifdef __AVX__
if(cols % 8 == 0)
LSTMOutputForwardTyped<float32x8>(out, inputs);
else
#endif
if(cols % 4 == 0)
LSTMOutputForwardTyped<float32x4>(out, inputs);
else
LSTMOutputForwardTyped<float>(out, inputs);
}
void LSTMCellBackward(std::vector<Tensor> outputs, void LSTMCellBackward(std::vector<Tensor> outputs,
std::vector<Tensor> inputs, std::vector<Tensor> inputs,
Tensor adj_) { Tensor adj_) {

View File

@ -77,7 +77,7 @@ Ptr<data::BatchStats> GraphGroup::collectStats(Ptr<ExpressionGraph> graph,
} else { } else {
end = current - 1; end = current - 1;
} }
} while(end - start > step); } while(end - start > step); // @TODO: better replace with `end >= start` to remove the step here
maxBatch = start; maxBatch = start;
} }