mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Merged PR 13476: Add LASER reimplementation and code for embeddings sentences
This reimplements the LASER encoder from: ``` Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond Mikel Artetxe, Holger Schwenk https://arxiv.org/abs/1812.10464 ``` and adds functionality to embed sentences with any Marian encoder, also different from LASER. Some early attempts to train a transformer model with Encoder-Decoder bottle-neck. This is quite early code, so some code-duplication is to be expected. Nevertheless, it's functional and I would like to have it in master as we will slowly put that into production in various places. I will make the code "nicer" as we go along.
This commit is contained in:
parent
7b815cb936
commit
c3fb60cbcd
@ -187,6 +187,13 @@ else(MSVC)
|
|||||||
set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
|
set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
|
||||||
endif(MSVC)
|
endif(MSVC)
|
||||||
|
|
||||||
|
# with gcc 7.0 and above we need to mark fallthrough in switch case statements
|
||||||
|
# that can be done in comments for backcompat, but CCACHE removes comments.
|
||||||
|
# -C makes gcc keep comments.
|
||||||
|
if(USE_CCACHE)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -C")
|
||||||
|
endif()
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Downloading SentencePiece if requested and set to compile with it.
|
# Downloading SentencePiece if requested and set to compile with it.
|
||||||
# Requires all the dependencies imposed by SentencePiece
|
# Requires all the dependencies imposed by SentencePiece
|
||||||
|
85
scripts/laser/laser2marian.py
Normal file
85
scripts/laser/laser2marian.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Convert LASER model to Marian weight file.')
|
||||||
|
parser.add_argument('--laser', help='Path to LASER PyTorch model', required=True)
|
||||||
|
parser.add_argument('--marian', help='Output path for Marian weight file', required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
laser = torch.load(args.laser)
|
||||||
|
|
||||||
|
config = dict()
|
||||||
|
config["type"] = "laser"
|
||||||
|
config["input-types"] = ["sequence"]
|
||||||
|
config["dim-vocabs"] = [laser["params"]["num_embeddings"]]
|
||||||
|
|
||||||
|
config["version"] = "laser2marian.py conversion"
|
||||||
|
|
||||||
|
config["enc-depth"] = laser["params"]["num_layers"]
|
||||||
|
config["enc-cell"] = "lstm"
|
||||||
|
config["dim-emb"] = laser["params"]["embed_dim"]
|
||||||
|
config["dim-rnn"] = laser["params"]["hidden_size"]
|
||||||
|
|
||||||
|
yaml.dump(laser["dictionary"], open(args.marian + ".vocab.yml", "w"))
|
||||||
|
|
||||||
|
marianModel = dict()
|
||||||
|
|
||||||
|
def transposeOrder(mat):
|
||||||
|
matT = np.transpose(mat) # just a view with changed row order
|
||||||
|
return matT.flatten(order="C").reshape(matT.shape) # force row order change and reshape
|
||||||
|
|
||||||
|
def convert(pd, srcs, trg, transpose=True, bias=False, lstm=False):
|
||||||
|
num = pd[srcs[0]].detach().numpy()
|
||||||
|
for i in range(1, len(srcs)):
|
||||||
|
num += pd[srcs[i]].detach().numpy()
|
||||||
|
|
||||||
|
out = num
|
||||||
|
if bias:
|
||||||
|
num = np.atleast_2d(num)
|
||||||
|
else:
|
||||||
|
if transpose:
|
||||||
|
num = transposeOrder(num) # transpose with row order change
|
||||||
|
|
||||||
|
if lstm: # different order in pytorch than marian
|
||||||
|
stateDim = int(num.shape[-1] / 4)
|
||||||
|
i = np.copy(num[:, 0*stateDim:1*stateDim])
|
||||||
|
f = np.copy(num[:, 1*stateDim:2*stateDim])
|
||||||
|
num[:, 0*stateDim:1*stateDim] = f
|
||||||
|
num[:, 1*stateDim:2*stateDim] = i
|
||||||
|
|
||||||
|
marianModel[trg] = num
|
||||||
|
|
||||||
|
for k in laser:
|
||||||
|
print(k)
|
||||||
|
|
||||||
|
for k in laser["model"]:
|
||||||
|
print(k, laser["model"][k].shape)
|
||||||
|
|
||||||
|
convert(laser["model"], ["embed_tokens.weight"], "encoder_Wemb", transpose=False)
|
||||||
|
for i in range(laser["params"]["num_layers"]):
|
||||||
|
convert(laser["model"], [f"lstm.weight_ih_l{i}"], f"encoder_lstm_l{i}_W", lstm=True)
|
||||||
|
convert(laser["model"], [f"lstm.weight_hh_l{i}"], f"encoder_lstm_l{i}_U", lstm=True)
|
||||||
|
convert(laser["model"], [f"lstm.bias_ih_l{i}", f"lstm.bias_hh_l{i}"], f"encoder_lstm_l{i}_b", bias=True, lstm=True) # needs to be summed!
|
||||||
|
|
||||||
|
convert(laser["model"], [f"lstm.weight_ih_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_W", lstm=True)
|
||||||
|
convert(laser["model"], [f"lstm.weight_hh_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_U", lstm=True)
|
||||||
|
convert(laser["model"], [f"lstm.bias_ih_l{i}_reverse", f"lstm.bias_hh_l{i}_reverse"], f"encoder_lstm_l{i}_reverse_b", bias=True, lstm=True) # needs to be summed!
|
||||||
|
|
||||||
|
for m in marianModel:
|
||||||
|
print(m, marianModel[m].shape)
|
||||||
|
|
||||||
|
configYamlStr = yaml.dump(config, default_flow_style=False)
|
||||||
|
desc = list(configYamlStr)
|
||||||
|
npDesc = np.chararray((len(desc),))
|
||||||
|
npDesc[:] = desc
|
||||||
|
npDesc.dtype = np.int8
|
||||||
|
marianModel["special:model.yml"] = npDesc
|
||||||
|
|
||||||
|
print("\nMarian config:")
|
||||||
|
print(configYamlStr)
|
||||||
|
print("Saving Marian model to %s" % (args.marian,))
|
||||||
|
np.savez(args.marian, **marianModel)
|
@ -83,6 +83,7 @@ add_library(marian STATIC
|
|||||||
models/transformer_stub.cpp
|
models/transformer_stub.cpp
|
||||||
|
|
||||||
rescorer/score_collector.cpp
|
rescorer/score_collector.cpp
|
||||||
|
embedder/vector_collector.cpp
|
||||||
|
|
||||||
translator/history.cpp
|
translator/history.cpp
|
||||||
translator/output_collector.cpp
|
translator/output_collector.cpp
|
||||||
|
14
src/command/marian_embedder.cpp
Normal file
14
src/command/marian_embedder.cpp
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#include "marian.h"
|
||||||
|
|
||||||
|
#include "models/model_task.h"
|
||||||
|
#include "embedder/embedder.h"
|
||||||
|
#include "common/timer.h"
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
using namespace marian;
|
||||||
|
|
||||||
|
auto options = parseOptions(argc, argv, cli::mode::embedding);
|
||||||
|
New<Embed<Embedder>>(options)->run();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
@ -11,6 +11,7 @@
|
|||||||
// train
|
// train
|
||||||
// decode
|
// decode
|
||||||
// score
|
// score
|
||||||
|
// embed
|
||||||
// vocab
|
// vocab
|
||||||
// convert
|
// convert
|
||||||
// Currently, marian_server is not supported, since it is a special use case with lots of extra dependencies.
|
// Currently, marian_server is not supported, since it is a special use case with lots of extra dependencies.
|
||||||
@ -24,6 +25,9 @@
|
|||||||
#define main mainScorer
|
#define main mainScorer
|
||||||
#include "marian_scorer.cpp"
|
#include "marian_scorer.cpp"
|
||||||
#undef main
|
#undef main
|
||||||
|
#define main mainEmbedder
|
||||||
|
#include "marian_embedder.cpp"
|
||||||
|
#undef main
|
||||||
#define main mainVocab
|
#define main mainVocab
|
||||||
#include "marian_vocab.cpp"
|
#include "marian_vocab.cpp"
|
||||||
#undef main
|
#undef main
|
||||||
@ -44,9 +48,10 @@ int main(int argc, char** argv) {
|
|||||||
if(cmd == "train") return mainTrainer(argc, argv);
|
if(cmd == "train") return mainTrainer(argc, argv);
|
||||||
else if(cmd == "decode") return mainDecoder(argc, argv);
|
else if(cmd == "decode") return mainDecoder(argc, argv);
|
||||||
else if (cmd == "score") return mainScorer(argc, argv);
|
else if (cmd == "score") return mainScorer(argc, argv);
|
||||||
|
else if (cmd == "embed") return mainEmbedder(argc, argv);
|
||||||
else if (cmd == "vocab") return mainVocab(argc, argv);
|
else if (cmd == "vocab") return mainVocab(argc, argv);
|
||||||
else if (cmd == "convert") return mainConv(argc, argv);
|
else if (cmd == "convert") return mainConv(argc, argv);
|
||||||
std::cerr << "Command must be train, decode, score, vocab, or convert." << std::endl;
|
std::cerr << "Command must be train, decode, score, embed, vocab, or convert." << std::endl;
|
||||||
exit(1);
|
exit(1);
|
||||||
} else
|
} else
|
||||||
return mainTrainer(argc, argv);
|
return mainTrainer(argc, argv);
|
||||||
|
@ -91,6 +91,9 @@ ConfigParser::ConfigParser(cli::mode mode)
|
|||||||
case cli::mode::scoring:
|
case cli::mode::scoring:
|
||||||
addOptionsScoring(cli_);
|
addOptionsScoring(cli_);
|
||||||
break;
|
break;
|
||||||
|
case cli::mode::embedding:
|
||||||
|
addOptionsEmbedding(cli_);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
ABORT("wrong CLI mode");
|
ABORT("wrong CLI mode");
|
||||||
break;
|
break;
|
||||||
@ -235,6 +238,8 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||||||
8);
|
8);
|
||||||
cli.add<bool>("--transformer-no-projection",
|
cli.add<bool>("--transformer-no-projection",
|
||||||
"Omit linear projection after multi-head attention (transformer)");
|
"Omit linear projection after multi-head attention (transformer)");
|
||||||
|
cli.add<bool>("--transformer-pool",
|
||||||
|
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
|
||||||
cli.add<int>("--transformer-dim-ffn",
|
cli.add<int>("--transformer-dim-ffn",
|
||||||
"Size of position-wise feed-forward network (transformer)",
|
"Size of position-wise feed-forward network (transformer)",
|
||||||
2048);
|
2048);
|
||||||
@ -705,6 +710,45 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
|
|||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConfigParser::addOptionsEmbedding(cli::CLIWrapper& cli) {
|
||||||
|
auto previous_group = cli.switchGroup("Scorer options");
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
cli.add<bool>("--no-reload",
|
||||||
|
"Do not load existing model specified in --model arg");
|
||||||
|
// TODO: move options like vocabs and train-sets to a separate procedure as they are defined twice
|
||||||
|
cli.add<std::vector<std::string>>("--train-sets,-t",
|
||||||
|
"Paths to corpora to be scored: source target");
|
||||||
|
cli.add<std::string>("--output,-o",
|
||||||
|
"Path to output file, stdout by default",
|
||||||
|
"stdout");
|
||||||
|
cli.add<std::vector<std::string>>("--vocabs,-v",
|
||||||
|
"Paths to vocabulary files have to correspond to --train-sets. "
|
||||||
|
"If this parameter is not supplied we look for vocabulary files source.{yml,json} and target.{yml,json}. "
|
||||||
|
"If these files do not exists they are created");
|
||||||
|
|
||||||
|
cli.add<bool>("--compute-similarity",
|
||||||
|
"Expect two inputs and compute cosine similarity instead of outputting embedding vector");
|
||||||
|
cli.add<bool>("--binary",
|
||||||
|
"Output vectors as binary floats");
|
||||||
|
|
||||||
|
addSuboptionsInputLength(cli);
|
||||||
|
addSuboptionsTSV(cli);
|
||||||
|
addSuboptionsDevices(cli);
|
||||||
|
addSuboptionsBatching(cli);
|
||||||
|
|
||||||
|
cli.add<bool>("--optimize",
|
||||||
|
"Optimize speed aggressively sacrificing memory or precision");
|
||||||
|
cli.add<bool>("--fp16",
|
||||||
|
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
|
||||||
|
cli.add<std::vector<std::string>>("--precision",
|
||||||
|
"Mixed precision for inference, set parameter type in expression graph",
|
||||||
|
{"float32"});
|
||||||
|
|
||||||
|
cli.switchGroup(previous_group);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
|
void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
cli.add<std::vector<std::string>>("--devices,-d",
|
cli.add<std::vector<std::string>>("--devices,-d",
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
namespace cli {
|
namespace cli {
|
||||||
enum struct mode { training, translation, scoring, server };
|
enum struct mode { training, translation, scoring, server, embedding };
|
||||||
} // namespace cli
|
} // namespace cli
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -129,6 +129,7 @@ private:
|
|||||||
void addOptionsValidation(cli::CLIWrapper&);
|
void addOptionsValidation(cli::CLIWrapper&);
|
||||||
void addOptionsTranslation(cli::CLIWrapper&);
|
void addOptionsTranslation(cli::CLIWrapper&);
|
||||||
void addOptionsScoring(cli::CLIWrapper&);
|
void addOptionsScoring(cli::CLIWrapper&);
|
||||||
|
void addOptionsEmbedding(cli::CLIWrapper&);
|
||||||
|
|
||||||
void addAliases(cli::CLIWrapper&);
|
void addAliases(cli::CLIWrapper&);
|
||||||
|
|
||||||
|
@ -27,6 +27,10 @@ void ConfigValidator::validateOptions(cli::mode mode) const {
|
|||||||
validateOptionsParallelData();
|
validateOptionsParallelData();
|
||||||
validateOptionsScoring();
|
validateOptionsScoring();
|
||||||
break;
|
break;
|
||||||
|
case cli::mode::embedding:
|
||||||
|
validateOptionsParallelData();
|
||||||
|
validateOptionsScoring();
|
||||||
|
break;
|
||||||
case cli::mode::training:
|
case cli::mode::training:
|
||||||
validateOptionsParallelData();
|
validateOptionsParallelData();
|
||||||
validateOptionsTraining();
|
validateOptionsTraining();
|
||||||
|
171
src/embedder/embedder.h
Normal file
171
src/embedder/embedder.h
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "marian.h"
|
||||||
|
|
||||||
|
#include "common/config.h"
|
||||||
|
#include "common/options.h"
|
||||||
|
#include "data/batch_generator.h"
|
||||||
|
#include "data/corpus.h"
|
||||||
|
#include "data/corpus_nbest.h"
|
||||||
|
#include "models/costs.h"
|
||||||
|
#include "models/model_task.h"
|
||||||
|
#include "embedder/vector_collector.h"
|
||||||
|
#include "training/scheduler.h"
|
||||||
|
#include "training/validator.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
using namespace data;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* The tool is used to create output sentence embeddings from available
|
||||||
|
* Marian encoders. With --compute-similiarity and can return the cosine
|
||||||
|
* similarity between two sentences provided from two sources.
|
||||||
|
*/
|
||||||
|
class Embedder {
|
||||||
|
private:
|
||||||
|
Ptr<models::IModel> model_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Embedder(Ptr<Options> options)
|
||||||
|
: model_(createModelFromOptions(options, models::usage::embedding)) {}
|
||||||
|
|
||||||
|
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
|
||||||
|
model_->load(graph, modelFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
|
||||||
|
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
|
||||||
|
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
|
||||||
|
return embedder->apply(graph, batch, /*clearGraph=*/true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Actual Embed task. @TODO: this should be simplified in the future.
|
||||||
|
*/
|
||||||
|
template <class Model>
|
||||||
|
class Embed : public ModelTask {
|
||||||
|
private:
|
||||||
|
Ptr<Options> options_;
|
||||||
|
Ptr<CorpusBase> corpus_;
|
||||||
|
std::vector<Ptr<ExpressionGraph>> graphs_;
|
||||||
|
std::vector<Ptr<Model>> models_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Embed(Ptr<Options> options) : options_(options) {
|
||||||
|
|
||||||
|
options_ = options_->with("inference", true,
|
||||||
|
"shuffle", "none");
|
||||||
|
|
||||||
|
// if a similarity is computed then double the input types and vocabs for
|
||||||
|
// the two encoders that are used in the model.
|
||||||
|
if(options->get<bool>("compute-similarity")) {
|
||||||
|
auto vInputTypes = options_->get<std::vector<std::string>>("input-types");
|
||||||
|
auto vVocabs = options_->get<std::vector<std::string>>("vocabs");
|
||||||
|
auto vDimVocabs = options_->get<std::vector<size_t>>("dim-vocabs");
|
||||||
|
|
||||||
|
vInputTypes.push_back(vInputTypes.back());
|
||||||
|
vVocabs.push_back(vVocabs.back());
|
||||||
|
vDimVocabs.push_back(vDimVocabs.back());
|
||||||
|
|
||||||
|
options_ = options_->with("input-types", vInputTypes,
|
||||||
|
"vocabs", vVocabs,
|
||||||
|
"dim-vocabs", vDimVocabs);
|
||||||
|
}
|
||||||
|
|
||||||
|
corpus_ = New<Corpus>(options_);
|
||||||
|
corpus_->prepare();
|
||||||
|
|
||||||
|
auto devices = Config::getDevices(options_);
|
||||||
|
|
||||||
|
for(auto device : devices) {
|
||||||
|
auto graph = New<ExpressionGraph>(true);
|
||||||
|
|
||||||
|
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
|
||||||
|
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
|
||||||
|
graph->setDevice(device);
|
||||||
|
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
|
||||||
|
if (device.type == DeviceType::cpu) {
|
||||||
|
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
|
||||||
|
}
|
||||||
|
|
||||||
|
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
||||||
|
graphs_.push_back(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto modelFile = options_->get<std::string>("model");
|
||||||
|
|
||||||
|
models_.resize(graphs_.size());
|
||||||
|
ThreadPool pool(graphs_.size(), graphs_.size());
|
||||||
|
for(size_t i = 0; i < graphs_.size(); ++i) {
|
||||||
|
pool.enqueue(
|
||||||
|
[=](size_t j) {
|
||||||
|
models_[j] = New<Model>(options_);
|
||||||
|
models_[j]->load(graphs_[j], modelFile);
|
||||||
|
},
|
||||||
|
i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void run() override {
|
||||||
|
LOG(info, "Embedding");
|
||||||
|
timer::Timer timer;
|
||||||
|
|
||||||
|
auto batchGenerator = New<BatchGenerator<CorpusBase>>(corpus_, options_);
|
||||||
|
batchGenerator->prepare();
|
||||||
|
|
||||||
|
auto output = New<VectorCollector>(options_);
|
||||||
|
|
||||||
|
size_t batchId = 0;
|
||||||
|
std::mutex smutex;
|
||||||
|
{
|
||||||
|
ThreadPool pool(graphs_.size(), graphs_.size());
|
||||||
|
|
||||||
|
for(auto batch : *batchGenerator) {
|
||||||
|
auto task = [=, &smutex](size_t id) {
|
||||||
|
thread_local Ptr<ExpressionGraph> graph;
|
||||||
|
thread_local Ptr<Model> builder;
|
||||||
|
|
||||||
|
if(!graph) {
|
||||||
|
graph = graphs_[id % graphs_.size()];
|
||||||
|
builder = models_[id % graphs_.size()];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto embeddings = builder->build(graph, batch);
|
||||||
|
graph->forward();
|
||||||
|
|
||||||
|
std::vector<float> sentVectors;
|
||||||
|
embeddings->val()->get(sentVectors);
|
||||||
|
|
||||||
|
// collect embedding vector per sentence.
|
||||||
|
// if we compute similarities this is only one similarity per sentence pair.
|
||||||
|
for(size_t i = 0; i < batch->size(); ++i) {
|
||||||
|
int embSize = embeddings->shape()[-1];
|
||||||
|
int beg = i * embSize;
|
||||||
|
int end = (i + 1) * embSize;
|
||||||
|
std::vector<float> sentVector(sentVectors.begin() + beg, sentVectors.begin() + end);
|
||||||
|
output->Write((long)batch->getSentenceIds()[i],
|
||||||
|
sentVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
// progress heartbeat for MS-internal Philly compute cluster
|
||||||
|
// otherwise this job may be killed prematurely if no log for 4 hrs
|
||||||
|
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
|
||||||
|
&& id % 1000 == 0) // hard beat once every 1000 batches
|
||||||
|
{
|
||||||
|
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
|
||||||
|
fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
|
||||||
|
fflush(stderr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pool.enqueue(task, batchId++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG(info, "Total time: {:.5f}s wall", timer.elapsed());
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace marian
|
71
src/embedder/vector_collector.cpp
Normal file
71
src/embedder/vector_collector.cpp
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#include "embedder/vector_collector.h"
|
||||||
|
|
||||||
|
#include "common/logging.h"
|
||||||
|
#include "common/utils.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
// This class manages multi-threaded writing of embedded vectors to stdout or an output file.
|
||||||
|
// It will either output string versions of float vectors or binary equal length versions depending
|
||||||
|
// on its binary_ flag.
|
||||||
|
|
||||||
|
VectorCollector::VectorCollector(const Ptr<Options>& options)
|
||||||
|
: nextId_(0), binary_{options->get<bool>("binary", false)} {
|
||||||
|
if(options->get<std::string>("output") == "stdout")
|
||||||
|
outStrm_.reset(new std::ostream(std::cout.rdbuf()));
|
||||||
|
else
|
||||||
|
outStrm_.reset(new io::OutputFileStream(options->get<std::string>("output")));
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorCollector::Write(long id, const std::vector<float>& vec) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
if(id == nextId_) {
|
||||||
|
WriteVector(vec);
|
||||||
|
|
||||||
|
++nextId_;
|
||||||
|
|
||||||
|
typename Outputs::const_iterator iter, iterNext;
|
||||||
|
iter = outputs_.begin();
|
||||||
|
while(iter != outputs_.end()) {
|
||||||
|
long currId = iter->first;
|
||||||
|
|
||||||
|
if(currId == nextId_) {
|
||||||
|
// 1st element in the map is the next
|
||||||
|
WriteVector(iter->second);
|
||||||
|
|
||||||
|
++nextId_;
|
||||||
|
|
||||||
|
// delete current record, move iter on 1
|
||||||
|
iterNext = iter;
|
||||||
|
++iterNext;
|
||||||
|
outputs_.erase(iter);
|
||||||
|
iter = iterNext;
|
||||||
|
} else {
|
||||||
|
// not the next. stop iterating
|
||||||
|
assert(nextId_ < currId);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// save for later
|
||||||
|
outputs_[id] = vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorCollector::WriteVector(const std::vector<float>& vec) {
|
||||||
|
if(binary_) {
|
||||||
|
outStrm_->write((char*)vec.data(), vec.size() * sizeof(float));
|
||||||
|
} else {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << std::fixed << std::setprecision(8);
|
||||||
|
for(auto v : vec)
|
||||||
|
*outStrm_ << v << " ";
|
||||||
|
*outStrm_ << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace marian
|
32
src/embedder/vector_collector.h
Normal file
32
src/embedder/vector_collector.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common/options.h"
|
||||||
|
#include "common/definitions.h"
|
||||||
|
#include "common/file_stream.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
// This class manages multi-threaded writing of embedded vectors to stdout or an output file.
|
||||||
|
// It will either output string versions of float vectors or binary equal length versions depending
|
||||||
|
// on its binary_ flag.
|
||||||
|
class VectorCollector {
|
||||||
|
public:
|
||||||
|
VectorCollector(const Ptr<Options>& options);
|
||||||
|
virtual void Write(long id, const std::vector<float>& vec);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
long nextId_{0};
|
||||||
|
UPtr<std::ostream> outStrm_;
|
||||||
|
bool binary_; // output binary floating point vectors if set
|
||||||
|
|
||||||
|
std::mutex mutex_;
|
||||||
|
|
||||||
|
typedef std::map<long, std::vector<float>> Outputs;
|
||||||
|
Outputs outputs_;
|
||||||
|
|
||||||
|
virtual void WriteVector(const std::vector<float>& vec);
|
||||||
|
};
|
||||||
|
} // namespace marian
|
@ -416,7 +416,7 @@ protected:
|
|||||||
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
|
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
|
||||||
|
|
||||||
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
|
ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
|
||||||
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks, mask again to eliminate padding (might be obsolete)
|
// use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete)
|
||||||
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
|
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
|
||||||
|
|
||||||
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
|
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
|
||||||
|
@ -51,6 +51,7 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
|||||||
modelFeatures_.insert("transformer-tied-layers");
|
modelFeatures_.insert("transformer-tied-layers");
|
||||||
modelFeatures_.insert("transformer-guided-alignment-layer");
|
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||||
modelFeatures_.insert("transformer-train-position-embeddings");
|
modelFeatures_.insert("transformer-train-position-embeddings");
|
||||||
|
modelFeatures_.insert("transformer-pool");
|
||||||
|
|
||||||
modelFeatures_.insert("bert-train-type-embeddings");
|
modelFeatures_.insert("bert-train-type-embeddings");
|
||||||
modelFeatures_.insert("bert-type-vocab-size");
|
modelFeatures_.insert("bert-type-vocab-size");
|
||||||
|
217
src/models/encoder_pooler.h
Normal file
217
src/models/encoder_pooler.h
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "marian.h"
|
||||||
|
|
||||||
|
#include "models/encoder.h"
|
||||||
|
#include "models/pooler.h"
|
||||||
|
#include "models/model_base.h"
|
||||||
|
#include "models/states.h"
|
||||||
|
|
||||||
|
// @TODO: this introduces functionality to use LASER in Marian for the filtering workflow or for use in MS-internal
|
||||||
|
// COSMOS server-farm. There is a lot of code duplication with Classifier and EncoderDecoder and this needs to be fixed.
|
||||||
|
// This will be done after the new layer system has been finished.
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Combines sequence encoders with generic poolers
|
||||||
|
* Can be used to train sequence poolers like language detection, BERT-next-sentence-prediction etc.
|
||||||
|
* Already has support for multi-objective training.
|
||||||
|
*
|
||||||
|
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/pooler
|
||||||
|
* multi-objective training.
|
||||||
|
*/
|
||||||
|
class EncoderPoolerBase : public models::IModel {
|
||||||
|
public:
|
||||||
|
virtual ~EncoderPoolerBase() {}
|
||||||
|
|
||||||
|
virtual void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::string& name,
|
||||||
|
bool markedReloaded = true) override
|
||||||
|
= 0;
|
||||||
|
|
||||||
|
virtual void mmap(Ptr<ExpressionGraph> graph,
|
||||||
|
const void* ptr,
|
||||||
|
bool markedReloaded = true)
|
||||||
|
= 0;
|
||||||
|
|
||||||
|
virtual void save(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::string& name,
|
||||||
|
bool saveTranslatorConfig = false) override
|
||||||
|
= 0;
|
||||||
|
|
||||||
|
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
|
||||||
|
|
||||||
|
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
|
||||||
|
|
||||||
|
virtual Logits build(Ptr<ExpressionGraph> graph,
|
||||||
|
Ptr<data::Batch> batch,
|
||||||
|
bool clearGraph = true) override {
|
||||||
|
ABORT("Poolers cannot produce Logits");
|
||||||
|
};
|
||||||
|
|
||||||
|
virtual Logits build(Ptr<ExpressionGraph> graph,
|
||||||
|
Ptr<data::CorpusBatch> batch,
|
||||||
|
bool clearGraph = true) {
|
||||||
|
ABORT("Poolers cannot produce Logits");
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Ptr<Options> getOptions() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class EncoderPooler : public EncoderPoolerBase {
|
||||||
|
protected:
|
||||||
|
Ptr<Options> options_;
|
||||||
|
|
||||||
|
std::string prefix_;
|
||||||
|
|
||||||
|
std::vector<Ptr<EncoderBase>> encoders_;
|
||||||
|
std::vector<Ptr<PoolerBase>> poolers_;
|
||||||
|
|
||||||
|
bool inference_{true};
|
||||||
|
|
||||||
|
std::set<std::string> modelFeatures_;
|
||||||
|
|
||||||
|
Config::YamlNode getModelParameters() {
|
||||||
|
Config::YamlNode modelParams;
|
||||||
|
auto clone = options_->cloneToYamlNode();
|
||||||
|
for(auto& key : modelFeatures_)
|
||||||
|
modelParams[key] = clone[key];
|
||||||
|
|
||||||
|
if(options_->has("original-type"))
|
||||||
|
modelParams["type"] = clone["original-type"];
|
||||||
|
|
||||||
|
modelParams["version"] = buildVersion();
|
||||||
|
return modelParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getModelParametersAsString() {
|
||||||
|
auto yaml = getModelParameters();
|
||||||
|
YAML::Emitter out;
|
||||||
|
cli::OutputYaml(yaml, out);
|
||||||
|
return std::string(out.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef data::Corpus dataset_type;
|
||||||
|
|
||||||
|
// @TODO: lots of code-duplication with EncoderDecoder
|
||||||
|
EncoderPooler(Ptr<Options> options)
|
||||||
|
: options_(options),
|
||||||
|
prefix_(options->get<std::string>("prefix", "")),
|
||||||
|
inference_(options->get<bool>("inference", false)) {
|
||||||
|
modelFeatures_ = {"type",
|
||||||
|
"dim-vocabs",
|
||||||
|
"dim-emb",
|
||||||
|
"dim-rnn",
|
||||||
|
"enc-cell",
|
||||||
|
"enc-type",
|
||||||
|
"enc-cell-depth",
|
||||||
|
"enc-depth",
|
||||||
|
"dec-depth",
|
||||||
|
"dec-cell",
|
||||||
|
"dec-cell-base-depth",
|
||||||
|
"dec-cell-high-depth",
|
||||||
|
"skip",
|
||||||
|
"layer-normalization",
|
||||||
|
"right-left",
|
||||||
|
"input-types",
|
||||||
|
"special-vocab",
|
||||||
|
"tied-embeddings",
|
||||||
|
"tied-embeddings-src",
|
||||||
|
"tied-embeddings-all"};
|
||||||
|
|
||||||
|
modelFeatures_.insert("transformer-heads");
|
||||||
|
modelFeatures_.insert("transformer-no-projection");
|
||||||
|
modelFeatures_.insert("transformer-dim-ffn");
|
||||||
|
modelFeatures_.insert("transformer-ffn-depth");
|
||||||
|
modelFeatures_.insert("transformer-ffn-activation");
|
||||||
|
modelFeatures_.insert("transformer-dim-aan");
|
||||||
|
modelFeatures_.insert("transformer-aan-depth");
|
||||||
|
modelFeatures_.insert("transformer-aan-activation");
|
||||||
|
modelFeatures_.insert("transformer-aan-nogate");
|
||||||
|
modelFeatures_.insert("transformer-preprocess");
|
||||||
|
modelFeatures_.insert("transformer-postprocess");
|
||||||
|
modelFeatures_.insert("transformer-postprocess-emb");
|
||||||
|
modelFeatures_.insert("transformer-decoder-autoreg");
|
||||||
|
modelFeatures_.insert("transformer-tied-layers");
|
||||||
|
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||||
|
modelFeatures_.insert("transformer-train-position-embeddings");
|
||||||
|
modelFeatures_.insert("transformer-pool");
|
||||||
|
|
||||||
|
modelFeatures_.insert("bert-train-type-embeddings");
|
||||||
|
modelFeatures_.insert("bert-type-vocab-size");
|
||||||
|
|
||||||
|
modelFeatures_.insert("ulr");
|
||||||
|
modelFeatures_.insert("ulr-trainable-transformation");
|
||||||
|
modelFeatures_.insert("ulr-dim-emb");
|
||||||
|
modelFeatures_.insert("lemma-dim-emb");
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Ptr<Options> getOptions() override { return options_; }
|
||||||
|
|
||||||
|
std::vector<Ptr<EncoderBase>>& getEncoders() { return encoders_; }
|
||||||
|
std::vector<Ptr<PoolerBase>>& getPoolers() { return poolers_; }
|
||||||
|
|
||||||
|
void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
|
||||||
|
void push_back(Ptr<PoolerBase> pooler) { poolers_.push_back(pooler); }
|
||||||
|
|
||||||
|
void load(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::string& name,
|
||||||
|
bool markedReloaded) override {
|
||||||
|
graph->load(name, markedReloaded && !opt<bool>("ignore-model-config", false));
|
||||||
|
}
|
||||||
|
|
||||||
|
void mmap(Ptr<ExpressionGraph> graph,
|
||||||
|
const void* ptr,
|
||||||
|
bool markedReloaded) override {
|
||||||
|
graph->mmap(ptr, markedReloaded && !opt<bool>("ignore-model-config", false));
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(Ptr<ExpressionGraph> graph,
|
||||||
|
const std::string& name,
|
||||||
|
bool /*saveModelConfig*/) override {
|
||||||
|
LOG(info, "Saving model weights and runtime parameters to {}", name);
|
||||||
|
graph->save(name , getModelParametersAsString());
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear(Ptr<ExpressionGraph> graph) override {
|
||||||
|
graph->clear();
|
||||||
|
|
||||||
|
for(auto& enc : encoders_)
|
||||||
|
enc->clear();
|
||||||
|
for(auto& pooler : poolers_)
|
||||||
|
pooler->clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T opt(const std::string& key) {
|
||||||
|
return options_->get<T>(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T opt(const std::string& key, const T& def) {
|
||||||
|
return options_->get<T>(key, def);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void set(std::string key, T value) {
|
||||||
|
options_->set(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************/
|
||||||
|
|
||||||
|
virtual Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
|
||||||
|
if(clearGraph)
|
||||||
|
clear(graph);
|
||||||
|
|
||||||
|
std::vector<Ptr<EncoderState>> encoderStates;
|
||||||
|
for(auto& encoder : encoders_)
|
||||||
|
encoderStates.push_back(encoder->build(graph, batch));
|
||||||
|
|
||||||
|
ABORT_IF(poolers_.size() != 1, "Expected exactly one pooler");
|
||||||
|
return poolers_[0]->apply(graph, batch, encoderStates);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace marian
|
71
src/models/laser.h
Normal file
71
src/models/laser.h
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "marian.h"
|
||||||
|
|
||||||
|
#include "layers/constructors.h"
|
||||||
|
#include "rnn/constructors.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
// Re-implements the LASER BiLSTM encoder from:
|
||||||
|
// Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond
|
||||||
|
// Mikel Artetxe, Holger Schwenk
|
||||||
|
// https://arxiv.org/abs/1812.10464
|
||||||
|
|
||||||
|
class EncoderLaser : public EncoderBase {
|
||||||
|
using EncoderBase::EncoderBase;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Expr applyEncoderRNN(Ptr<ExpressionGraph> graph,
|
||||||
|
Expr embeddings,
|
||||||
|
Expr mask) {
|
||||||
|
int depth = opt<int>("enc-depth");
|
||||||
|
float dropoutRnn = inference_ ? 0 : opt<float>("dropout-rnn");
|
||||||
|
|
||||||
|
Expr output = embeddings;
|
||||||
|
|
||||||
|
auto applyRnn = [&](int layer, rnn::dir direction, Expr input, Expr mask) {
|
||||||
|
|
||||||
|
std::string paramPrefix = prefix_ + "_" + opt<std::string>("enc-cell");
|
||||||
|
paramPrefix += "_l" + std::to_string(layer);
|
||||||
|
if(direction == rnn::dir::backward)
|
||||||
|
paramPrefix += "_reverse";
|
||||||
|
|
||||||
|
auto rnnFactory = rnn::rnn()
|
||||||
|
("type", opt<std::string>("enc-cell"))
|
||||||
|
("direction", (int)direction)
|
||||||
|
("dimInput", input->shape()[-1])
|
||||||
|
("dimState", opt<int>("dim-rnn"))
|
||||||
|
("dropout", dropoutRnn)
|
||||||
|
("layer-normalization", opt<bool>("layer-normalization"))
|
||||||
|
("skip", opt<bool>("skip"))
|
||||||
|
.push_back(rnn::cell()("prefix", paramPrefix));
|
||||||
|
|
||||||
|
return rnnFactory.construct(graph)->transduce(input, mask);
|
||||||
|
};
|
||||||
|
|
||||||
|
for(int i = 0; i < depth; ++i) {
|
||||||
|
output = concatenate({applyRnn(i, rnn::dir::forward, output, mask),
|
||||||
|
applyRnn(i, rnn::dir::backward, output, mask)},
|
||||||
|
/*axis =*/ -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
|
||||||
|
Ptr<data::CorpusBatch> batch) override {
|
||||||
|
graph_ = graph;
|
||||||
|
// select embeddings that occur in the batch
|
||||||
|
Expr batchEmbeddings, batchMask; std::tie
|
||||||
|
(batchEmbeddings, batchMask) = getEmbeddingLayer()->apply((*batch)[batchIndex_]);
|
||||||
|
|
||||||
|
Expr context = applyEncoderRNN(graph_, batchEmbeddings, batchMask);
|
||||||
|
|
||||||
|
return New<EncoderState>(context, batchMask, batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -8,7 +8,7 @@
|
|||||||
namespace marian {
|
namespace marian {
|
||||||
namespace models {
|
namespace models {
|
||||||
|
|
||||||
enum struct usage { raw, training, scoring, translation };
|
enum struct usage { raw, training, scoring, translation, embedding };
|
||||||
}
|
}
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include "models/amun.h"
|
#include "models/amun.h"
|
||||||
#include "models/nematus.h"
|
#include "models/nematus.h"
|
||||||
#include "models/s2s.h"
|
#include "models/s2s.h"
|
||||||
|
#include "models/laser.h"
|
||||||
#include "models/transformer_factory.h"
|
#include "models/transformer_factory.h"
|
||||||
|
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
@ -29,6 +30,9 @@ namespace models {
|
|||||||
Ptr<EncoderBase> EncoderFactory::construct(Ptr<ExpressionGraph> graph) {
|
Ptr<EncoderBase> EncoderFactory::construct(Ptr<ExpressionGraph> graph) {
|
||||||
if(options_->get<std::string>("type") == "s2s")
|
if(options_->get<std::string>("type") == "s2s")
|
||||||
return New<EncoderS2S>(graph, options_);
|
return New<EncoderS2S>(graph, options_);
|
||||||
|
|
||||||
|
if(options_->get<std::string>("type") == "laser" || options_->get<std::string>("type") == "laser-sim")
|
||||||
|
return New<EncoderLaser>(graph, options_);
|
||||||
|
|
||||||
#ifdef CUDNN
|
#ifdef CUDNN
|
||||||
if(options_->get<std::string>("type") == "char-s2s")
|
if(options_->get<std::string>("type") == "char-s2s")
|
||||||
@ -61,6 +65,17 @@ Ptr<ClassifierBase> ClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
|
|||||||
ABORT("Unknown classifier type");
|
ABORT("Unknown classifier type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ptr<PoolerBase> PoolerFactory::construct(Ptr<ExpressionGraph> graph) {
|
||||||
|
if(options_->get<std::string>("type") == "max-pooler")
|
||||||
|
return New<MaxPooler>(graph, options_);
|
||||||
|
if(options_->get<std::string>("type") == "slice-pooler")
|
||||||
|
return New<SlicePooler>(graph, options_);
|
||||||
|
else if(options_->get<std::string>("type") == "sim-pooler")
|
||||||
|
return New<SimPooler>(graph, options_);
|
||||||
|
else
|
||||||
|
ABORT("Unknown pooler type");
|
||||||
|
}
|
||||||
|
|
||||||
Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
|
Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
|
||||||
Ptr<EncoderDecoder> encdec;
|
Ptr<EncoderDecoder> encdec;
|
||||||
if(options_->get<std::string>("type") == "amun")
|
if(options_->get<std::string>("type") == "amun")
|
||||||
@ -97,9 +112,54 @@ Ptr<IModel> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
|
|||||||
return enccls;
|
return enccls;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ptr<IModel> EncoderPoolerFactory::construct(Ptr<ExpressionGraph> graph) {
|
||||||
|
Ptr<EncoderPooler> encpool = New<EncoderPooler>(options_);
|
||||||
|
|
||||||
|
for(auto& ef : encoders_)
|
||||||
|
encpool->push_back(ef(options_).construct(graph));
|
||||||
|
|
||||||
|
for(auto& pl : poolers_)
|
||||||
|
encpool->push_back(pl(options_).construct(graph));
|
||||||
|
|
||||||
|
return encpool;
|
||||||
|
}
|
||||||
|
|
||||||
Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
|
Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
|
||||||
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
|
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
||||||
|
if(use == usage::embedding) { // hijacking an EncoderDecoder model for embedding only
|
||||||
|
int dimVocab = options->get<std::vector<int>>("dim-vocabs")[0];
|
||||||
|
|
||||||
|
Ptr<Options> newOptions;
|
||||||
|
if(options->get<bool>("compute-similarity")) {
|
||||||
|
newOptions = options->with("usage", use,
|
||||||
|
"original-type", type,
|
||||||
|
"input-types", std::vector<std::string>({"sequence", "sequence"}),
|
||||||
|
"dim-vocabs", std::vector<int>(2, dimVocab));
|
||||||
|
} else {
|
||||||
|
newOptions = options->with("usage", use,
|
||||||
|
"original-type", type,
|
||||||
|
"input-types", std::vector<std::string>({"sequence"}),
|
||||||
|
"dim-vocabs", std::vector<int>(1, dimVocab));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = New<EncoderPooler>(newOptions);
|
||||||
|
if(options->get<bool>("compute-similarity")) {
|
||||||
|
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
|
||||||
|
res->push_back(models::encoder(newOptions->with("index", 1)).construct(graph));
|
||||||
|
res->push_back(New<SimPooler>(graph, newOptions->with("type", "sim-pooler")));
|
||||||
|
} else {
|
||||||
|
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
|
||||||
|
if(type == "laser")
|
||||||
|
res->push_back(New<MaxPooler>(graph, newOptions->with("type", "max-pooler")));
|
||||||
|
else
|
||||||
|
res->push_back(New<SlicePooler>(graph, newOptions->with("type", "slice-pooler")));
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
if(type == "s2s" || type == "amun" || type == "nematus") {
|
if(type == "s2s" || type == "amun" || type == "nematus") {
|
||||||
return models::encoder_decoder(options->with(
|
return models::encoder_decoder(options->with(
|
||||||
"usage", use,
|
"usage", use,
|
||||||
@ -313,7 +373,7 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
|
|||||||
else
|
else
|
||||||
ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type);
|
ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type);
|
||||||
}
|
}
|
||||||
else if (use == usage::raw)
|
else if (use == usage::raw || use == usage::embedding)
|
||||||
return baseModel;
|
return baseModel;
|
||||||
else
|
else
|
||||||
ABORT("'Usage' parameter must be 'translation' or 'raw'");
|
ABORT("'Usage' parameter must be 'translation' or 'raw'");
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "layers/factory.h"
|
#include "layers/factory.h"
|
||||||
#include "models/encoder_decoder.h"
|
#include "models/encoder_decoder.h"
|
||||||
#include "models/encoder_classifier.h"
|
#include "models/encoder_classifier.h"
|
||||||
|
#include "models/encoder_pooler.h"
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
namespace models {
|
namespace models {
|
||||||
@ -33,6 +34,14 @@ public:
|
|||||||
|
|
||||||
typedef Accumulator<ClassifierFactory> classifier;
|
typedef Accumulator<ClassifierFactory> classifier;
|
||||||
|
|
||||||
|
class PoolerFactory : public Factory {
|
||||||
|
using Factory::Factory;
|
||||||
|
public:
|
||||||
|
virtual Ptr<PoolerBase> construct(Ptr<ExpressionGraph> graph);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef Accumulator<PoolerFactory> pooler;
|
||||||
|
|
||||||
class EncoderDecoderFactory : public Factory {
|
class EncoderDecoderFactory : public Factory {
|
||||||
using Factory::Factory;
|
using Factory::Factory;
|
||||||
private:
|
private:
|
||||||
@ -77,6 +86,28 @@ public:
|
|||||||
|
|
||||||
typedef Accumulator<EncoderClassifierFactory> encoder_classifier;
|
typedef Accumulator<EncoderClassifierFactory> encoder_classifier;
|
||||||
|
|
||||||
|
class EncoderPoolerFactory : public Factory {
|
||||||
|
using Factory::Factory;
|
||||||
|
private:
|
||||||
|
std::vector<encoder> encoders_;
|
||||||
|
std::vector<pooler> poolers_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Accumulator<EncoderPoolerFactory> push_back(encoder enc) {
|
||||||
|
encoders_.push_back(enc);
|
||||||
|
return Accumulator<EncoderPoolerFactory>(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
Accumulator<EncoderPoolerFactory> push_back(pooler cls) {
|
||||||
|
poolers_.push_back(cls);
|
||||||
|
return Accumulator<EncoderPoolerFactory>(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef Accumulator<EncoderPoolerFactory> encoder_pooler;
|
||||||
|
|
||||||
Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);
|
Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);
|
||||||
|
|
||||||
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);
|
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);
|
||||||
|
139
src/models/pooler.h
Normal file
139
src/models/pooler.h
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "marian.h"
|
||||||
|
#include "models/states.h"
|
||||||
|
#include "layers/constructors.h"
|
||||||
|
#include "layers/factory.h"
|
||||||
|
|
||||||
|
namespace marian {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple base class for Poolers to be used in EncoderPooler framework
|
||||||
|
* A pooler takes a encoder state (contextual word embeddings) and produces
|
||||||
|
* a single sentence embedding.
|
||||||
|
*/
|
||||||
|
class PoolerBase : public LayerBase {
|
||||||
|
using LayerBase::LayerBase;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const std::string prefix_{"pooler"};
|
||||||
|
const bool inference_{false};
|
||||||
|
const size_t batchIndex_{0};
|
||||||
|
|
||||||
|
public:
|
||||||
|
PoolerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||||
|
: LayerBase(graph, options),
|
||||||
|
prefix_(options->get<std::string>("prefix", "pooler")),
|
||||||
|
inference_(options->get<bool>("inference", true)),
|
||||||
|
batchIndex_(options->get<size_t>("index", 1)) {} // assume that training input has batch index 0 and labels has 1
|
||||||
|
|
||||||
|
virtual ~PoolerBase() {}
|
||||||
|
|
||||||
|
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T opt(const std::string& key) const {
|
||||||
|
return options_->get<T>(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be used to clear any batch-wise temporary objects if present
|
||||||
|
virtual void clear() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pool encoder state (contextual word embeddings) via max-pooling along sentence-length dimension.
|
||||||
|
*/
|
||||||
|
class MaxPooler : public PoolerBase {
|
||||||
|
public:
|
||||||
|
MaxPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||||
|
: PoolerBase(graph, options) {}
|
||||||
|
|
||||||
|
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||||
|
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
|
||||||
|
|
||||||
|
auto context = encoderStates[0]->getContext();
|
||||||
|
auto batchMask = encoderStates[0]->getMask();
|
||||||
|
|
||||||
|
// do a max pool here
|
||||||
|
Expr logMask = (1.f - batchMask) * -9999.f;
|
||||||
|
Expr maxPool = max(context * batchMask + logMask, /*axis=*/-3);
|
||||||
|
|
||||||
|
return maxPool;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() override {}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pool encoder state (contextual word embeddings) by selecting 1st embedding along sentence-length dimension.
|
||||||
|
*/
|
||||||
|
class SlicePooler : public PoolerBase {
|
||||||
|
public:
|
||||||
|
SlicePooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||||
|
: PoolerBase(graph, options) {}
|
||||||
|
|
||||||
|
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||||
|
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
|
||||||
|
|
||||||
|
auto context = encoderStates[0]->getContext();
|
||||||
|
auto batchMask = encoderStates[0]->getMask();
|
||||||
|
|
||||||
|
// Corresponds to the way we do this in transformer.h
|
||||||
|
// @TODO: unify this better, this is currently hacky
|
||||||
|
Expr slicePool = slice(context * batchMask, /*axis=*/-3, 0);
|
||||||
|
|
||||||
|
return slicePool;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() override {}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Not really a pooler but abusing the interface to compute a similarity of two pooled states
|
||||||
|
*/
|
||||||
|
class SimPooler : public PoolerBase {
|
||||||
|
public:
|
||||||
|
SimPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||||
|
: PoolerBase(graph, options) {}
|
||||||
|
|
||||||
|
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||||
|
ABORT_IF(encoderStates.size() != 2, "SimPooler expects exactly two encoder states");
|
||||||
|
|
||||||
|
std::vector<Expr> vecs;
|
||||||
|
for(auto encoderState : encoderStates) {
|
||||||
|
auto context = encoderState->getContext();
|
||||||
|
auto batchMask = encoderState->getMask();
|
||||||
|
|
||||||
|
Expr pool;
|
||||||
|
auto type = options_->get<std::string>("original-type");
|
||||||
|
if(type == "laser") {
|
||||||
|
// LASER models do a max pool here
|
||||||
|
Expr logMask = (1.f - batchMask) * -9999.f;
|
||||||
|
pool = max(context * batchMask + logMask, /*axis=*/-3);
|
||||||
|
} else if(type == "transformer") {
|
||||||
|
// Our own implementation in transformer.h uses a slice of the first element
|
||||||
|
pool = slice(context, -3, 0);
|
||||||
|
} else {
|
||||||
|
// @TODO: make SimPooler take Pooler objects as arguments then it won't need to know this.
|
||||||
|
ABORT("Don't know what type of pooler to use for model type {}", type);
|
||||||
|
}
|
||||||
|
|
||||||
|
vecs.push_back(pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scalars = scalar_product(vecs[0], vecs[1], /*axis*/-1);
|
||||||
|
auto length1 = sqrt(sum(square(vecs[0]), /*axis=*/-1));
|
||||||
|
auto length2 = sqrt(sum(square(vecs[1]), /*axis=*/-1));
|
||||||
|
|
||||||
|
auto cosine = scalars / ( length1 * length2 );
|
||||||
|
|
||||||
|
return cosine;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() override {}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -328,6 +328,30 @@ public:
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reduce the encoder to a single sentence vector, here we just take the contextual embedding of the first word per sentence
|
||||||
|
// Replaces cross-attention in LASER-like models
|
||||||
|
Expr LayerPooling(std::string prefix,
|
||||||
|
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
|
||||||
|
const Expr& values, // [-4: beam depth=1, -3: batch size, -2: max length (src or trg), -1: vector dim]
|
||||||
|
const Expr& mask) { // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
|
||||||
|
int dimModel = input->shape()[-1];
|
||||||
|
auto output = slice(values, -2, 0); // Select first word [-4: beam depth, -3: batch size, -2: 1, -1: vector dim]
|
||||||
|
|
||||||
|
int dimPool = output->shape()[-1];
|
||||||
|
bool project = !opt<bool>("transformer-no-projection");
|
||||||
|
if(project || dimPool != dimModel) {
|
||||||
|
auto Wo = graph_->param(prefix + "_Wo", {dimPool, dimModel}, inits::glorotUniform());
|
||||||
|
auto bo = graph_->param(prefix + "_bo", {1, dimModel}, inits::zeros());
|
||||||
|
output = affine(output, Wo, bo); // [-4: beam depth, -3: batch size, -2: 1, -1: vector dim]
|
||||||
|
}
|
||||||
|
|
||||||
|
auto opsPost = opt<std::string>("transformer-postprocess");
|
||||||
|
output = postProcess(prefix + "_Wo", opsPost, output, input, 0.f);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Expr LayerAttention(std::string prefix,
|
Expr LayerAttention(std::string prefix,
|
||||||
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
|
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
|
||||||
const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
|
const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
|
||||||
@ -790,14 +814,21 @@ public:
|
|||||||
saveAttentionWeights = i == attLayer;
|
saveAttentionWeights = i == attLayer;
|
||||||
}
|
}
|
||||||
|
|
||||||
query = LayerAttention(prefix,
|
if(options_->get<bool>("transformer-pool", false)) {
|
||||||
|
query = LayerPooling(prefix,
|
||||||
query,
|
query,
|
||||||
encoderContexts[j], // keys
|
|
||||||
encoderContexts[j], // values
|
encoderContexts[j], // values
|
||||||
encoderMasks[j],
|
encoderMasks[j]);
|
||||||
opt<int>("transformer-heads"),
|
} else {
|
||||||
/*cache=*/true,
|
query = LayerAttention(prefix,
|
||||||
saveAttentionWeights);
|
query,
|
||||||
|
encoderContexts[j], // keys
|
||||||
|
encoderContexts[j], // values
|
||||||
|
encoderMasks[j],
|
||||||
|
opt<int>("transformer-heads"),
|
||||||
|
/*cache=*/true,
|
||||||
|
saveAttentionWeights);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -651,7 +651,7 @@ public:
|
|||||||
using LSTM = FastLSTM;
|
using LSTM = FastLSTM;
|
||||||
|
|
||||||
/******************************************************************************/
|
/******************************************************************************/
|
||||||
// Experimentak cells, use with care
|
// Experimental cells, use with care
|
||||||
|
|
||||||
template <class CellType>
|
template <class CellType>
|
||||||
class Multiplicative : public CellType {
|
class Multiplicative : public CellType {
|
||||||
|
@ -1248,67 +1248,104 @@ void SetSparse(float* out,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LSTMCellForward(Tensor out_, std::vector<Tensor> inputs) {
|
// should be implemented via slicing and elementwise
|
||||||
|
template <typename FType>
|
||||||
|
void LSTMCellForwardTyped(Tensor out_, const std::vector<Tensor>& inputs) {
|
||||||
int rows = out_->shape().elements() / out_->shape()[-1];
|
int rows = out_->shape().elements() / out_->shape()[-1];
|
||||||
int cols = out_->shape()[-1];
|
|
||||||
|
|
||||||
float* out = out_->data();
|
int fVecSize = sizeof(FType) / sizeof(float);
|
||||||
const float* cell = inputs[0]->data();
|
int cols = out_->shape()[-1] / fVecSize;
|
||||||
const float* xW = inputs[1]->data();
|
|
||||||
const float* sU = inputs[2]->data();
|
FType* out = out_->data<FType>();
|
||||||
const float* b = inputs[3]->data();
|
const FType* cell = inputs[0]->data<FType>();
|
||||||
|
const FType* xW = inputs[1]->data<FType>();
|
||||||
|
const FType* sU = inputs[2]->data<FType>();
|
||||||
|
const FType* b = inputs[3]->data<FType>();
|
||||||
const float* mask = inputs.size() > 4 ? inputs[4]->data() : nullptr;
|
const float* mask = inputs.size() > 4 ? inputs[4]->data() : nullptr;
|
||||||
|
|
||||||
|
using fop = functional::Ops<FType>;
|
||||||
|
|
||||||
for(int j = 0; j < rows; ++j) {
|
for(int j = 0; j < rows; ++j) {
|
||||||
float m = !mask || mask[j];
|
float m = !mask || mask[j];
|
||||||
|
|
||||||
float* rowOut = out + j * cols;
|
FType* rowOut = out + j * cols;
|
||||||
const float* rowCell = cell + j * cols;
|
const FType* rowCell = cell + j * cols;
|
||||||
|
|
||||||
const float* xWrow = xW + j * cols * 4;
|
const FType* xWrow = xW + j * cols * 4;
|
||||||
const float* sUrow = sU + j * cols * 4;
|
const FType* sUrow = sU + j * cols * 4;
|
||||||
|
|
||||||
for(int i = 0; i < cols; ++i) {
|
for(int i = 0; i < cols; ++i) {
|
||||||
float gf = functional::Ops<float>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
|
FType gf = fop::sigmoid(fop::add(fop::add(xWrow[i], sUrow[i]), b[i]));
|
||||||
|
|
||||||
int k = i + cols;
|
int k = i + cols;
|
||||||
float gi = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
|
FType gi = fop::sigmoid(fop::add(fop::add(xWrow[k], sUrow[k]), b[k]));
|
||||||
|
|
||||||
int l = i + 2 * cols;
|
int l = i + 2 * cols;
|
||||||
float gc = std::tanh(xWrow[l] + sUrow[l] + b[l]);
|
FType gc = fop::tanh(fop::add(fop::add(xWrow[l], sUrow[l]), b[l]));
|
||||||
|
|
||||||
float cout = gf * rowCell[i] + gi * gc;
|
FType cout = fop::add(fop::mul(gf, rowCell[i]), fop::mul(gi, gc));
|
||||||
rowOut[i] = m * cout + (1 - m) * rowCell[i];
|
rowOut[i] = fop::add(fop::mul(m, cout), fop::mul(fop::sub(1.f, m), rowCell[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LSTMOutputForward(Tensor out_, std::vector<Tensor> inputs) {
|
void LSTMCellForward(Tensor out, std::vector<Tensor> inputs) {
|
||||||
int rows = out_->shape().elements() / out_->shape()[-1];
|
int cols = out->shape()[-1];
|
||||||
int cols = out_->shape()[-1];
|
#ifdef __AVX__
|
||||||
|
if(cols % 8 == 0)
|
||||||
|
LSTMCellForwardTyped<float32x8>(out, inputs);
|
||||||
|
else
|
||||||
|
#endif
|
||||||
|
if(cols % 4 == 0)
|
||||||
|
LSTMCellForwardTyped<float32x4>(out, inputs);
|
||||||
|
else
|
||||||
|
LSTMCellForwardTyped<float>(out, inputs);
|
||||||
|
}
|
||||||
|
|
||||||
float* out = out_->data();
|
template <typename FType>
|
||||||
const float* cell = inputs[0]->data();
|
void LSTMOutputForwardTyped(Tensor out_, const std::vector<Tensor>& inputs) {
|
||||||
const float* xW = inputs[1]->data();
|
int rows = out_->shape().elements() / out_->shape()[-1];
|
||||||
const float* sU = inputs[2]->data();
|
|
||||||
const float* b = inputs[3]->data();
|
int fVecSize = sizeof(FType) / sizeof(float);
|
||||||
|
int cols = out_->shape()[-1] / fVecSize;
|
||||||
|
|
||||||
|
FType* out = out_->data<FType>();
|
||||||
|
const FType* cell = inputs[0]->data<FType>();
|
||||||
|
const FType* xW = inputs[1]->data<FType>();
|
||||||
|
const FType* sU = inputs[2]->data<FType>();
|
||||||
|
const FType* b = inputs[3]->data<FType>();
|
||||||
|
|
||||||
|
using fop = functional::Ops<FType>;
|
||||||
|
|
||||||
for(int j = 0; j < rows; ++j) {
|
for(int j = 0; j < rows; ++j) {
|
||||||
float* rowOut = out + j * cols;
|
FType* rowOut = out + j * cols;
|
||||||
const float* rowCell = cell + j * cols;
|
const FType* rowCell = cell + j * cols;
|
||||||
|
|
||||||
const float* xWrow = xW + j * cols * 4;
|
const FType* xWrow = xW + j * cols * 4;
|
||||||
const float* sUrow = sU + j * cols * 4;
|
const FType* sUrow = sU + j * cols * 4;
|
||||||
|
|
||||||
for(int i = 0; i < cols; ++i) {
|
for(int i = 0; i < cols; ++i) {
|
||||||
int k = i + 3 * cols;
|
int k = i + 3 * cols;
|
||||||
float go = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
|
FType go = fop::sigmoid(fop::add(fop::add(xWrow[k], sUrow[k]), b[k]));
|
||||||
|
rowOut[i] = fop::mul(go, fop::tanh(rowCell[i]));
|
||||||
rowOut[i] = go * std::tanh(rowCell[i]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LSTMOutputForward(Tensor out, std::vector<Tensor> inputs) {
|
||||||
|
int cols = out->shape()[-1];
|
||||||
|
|
||||||
|
#ifdef __AVX__
|
||||||
|
if(cols % 8 == 0)
|
||||||
|
LSTMOutputForwardTyped<float32x8>(out, inputs);
|
||||||
|
else
|
||||||
|
#endif
|
||||||
|
if(cols % 4 == 0)
|
||||||
|
LSTMOutputForwardTyped<float32x4>(out, inputs);
|
||||||
|
else
|
||||||
|
LSTMOutputForwardTyped<float>(out, inputs);
|
||||||
|
}
|
||||||
|
|
||||||
void LSTMCellBackward(std::vector<Tensor> outputs,
|
void LSTMCellBackward(std::vector<Tensor> outputs,
|
||||||
std::vector<Tensor> inputs,
|
std::vector<Tensor> inputs,
|
||||||
Tensor adj_) {
|
Tensor adj_) {
|
||||||
|
@ -77,7 +77,7 @@ Ptr<data::BatchStats> GraphGroup::collectStats(Ptr<ExpressionGraph> graph,
|
|||||||
} else {
|
} else {
|
||||||
end = current - 1;
|
end = current - 1;
|
||||||
}
|
}
|
||||||
} while(end - start > step);
|
} while(end - start > step); // @TODO: better replace with `end >= start` to remove the step here
|
||||||
|
|
||||||
maxBatch = start;
|
maxBatch = start;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user