mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 19842: Adapt LSH to work with Leaf
Small changes to make the LSH work with Leaf server and QuickSand.
This commit is contained in:
parent
42f0b8b74b
commit
8e88071ae8
@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Integrate a shortlist converter (which can convert a text lexical shortlist to a binary shortlist) into marian-conv with --shortlist option
|
||||
|
||||
### Fixed
|
||||
- Various fixes to enable LSH in Quicksand
|
||||
- Added support to MPIWrappest::bcast (and similar) for count of type size_t
|
||||
- Adding new validation metrics when training is restarted and --reset-valid-stalled is used
|
||||
- Missing depth-scaling in transformer FFN
|
||||
|
2
src/3rd_party/sentencepiece
vendored
2
src/3rd_party/sentencepiece
vendored
@ -1 +1 @@
|
||||
Subproject commit 5bafa8e8c3391bbe9721a16e986408341f95774c
|
||||
Subproject commit 28f9eb890f62907406c629acd2f04ca9b71442c9
|
@ -4,6 +4,8 @@
|
||||
|
||||
#include "spdlog/spdlog.h"
|
||||
|
||||
// set to 1 to use for debugging if no loggers can be created
|
||||
#define LOG_TO_STDERR 0
|
||||
|
||||
namespace marian {
|
||||
void logCallStack(size_t skipLevels);
|
||||
@ -149,6 +151,9 @@ class Config;
|
||||
|
||||
template <class... Args>
|
||||
void checkedLog(std::string logger, std::string level, Args... args) {
|
||||
#if LOG_TO_STDERR
|
||||
std::cerr << "[" << level << "] " << fmt::format(args...) << std::endl;
|
||||
#else
|
||||
Logger log = spdlog::get(logger);
|
||||
if(!log) {
|
||||
return;
|
||||
@ -169,6 +174,7 @@ void checkedLog(std::string logger, std::string level, Args... args) {
|
||||
else {
|
||||
log->warn("Unknown log level '{}' for logger '{}'", level, logger);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void createLoggers(const marian::Config* options = nullptr);
|
||||
|
@ -77,9 +77,9 @@ void Shortlist::createCachedTensors(Expr weights,
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
|
||||
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic)
|
||||
: Shortlist(std::vector<WordIndex>()),
|
||||
k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
|
||||
k_(k), nbits_(nbits), lemmaSize_(lemmaSize), abortIfDynamic_(abortIfDynamic) {
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
|
||||
@ -99,7 +99,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
|
||||
"LSH index (--output-approx-knn) currently not implemented for GPU");
|
||||
|
||||
indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_),
|
||||
indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_, abortIfDynamic_),
|
||||
[this](Expr node) {
|
||||
node->val()->get(indices_); // set the value of the field indices_ whenever the graph traverses this node
|
||||
});
|
||||
@ -135,12 +135,12 @@ void LSHShortlist::createCachedTensors(Expr weights,
|
||||
}
|
||||
}
|
||||
|
||||
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize)
|
||||
: k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
|
||||
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize, bool abortIfDynamic)
|
||||
: k_(k), nbits_(nbits), lemmaSize_(lemmaSize), abortIfDynamic_(abortIfDynamic) {
|
||||
}
|
||||
|
||||
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
|
||||
return New<LSHShortlist>(k_, nbits_, lemmaSize_);
|
||||
return New<LSHShortlist>(k_, nbits_, lemmaSize_, abortIfDynamic_);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -175,7 +175,7 @@ QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
|
||||
int32_t header_magic_number = *get<int32_t>(current);
|
||||
ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number");
|
||||
|
||||
auto config = ::quicksand::ParameterTree::FromBinaryReader(current);
|
||||
auto config = marian::quicksand::ParameterTree::FromBinaryReader(current);
|
||||
use16bit_ = config->GetBoolReq("use_16_bit");
|
||||
|
||||
LOG(info, "[data] Mapping Quicksand shortlist from {}", fname);
|
||||
@ -275,7 +275,7 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
if (lshOpts.size()) {
|
||||
assert(lshOpts.size() == 2);
|
||||
size_t lemmaSize = trgVocab->lemmaSize();
|
||||
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize);
|
||||
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize, /*abortIfDynamic=*/false);
|
||||
}
|
||||
else {
|
||||
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
|
||||
|
@ -74,6 +74,8 @@ private:
|
||||
int k_; // number of candidates returned from each input
|
||||
int nbits_; // length of hash
|
||||
size_t lemmaSize_; // vocab size
|
||||
bool abortIfDynamic_; // if true disallow dynamic allocation for encoded weights and rotation matrix (only allow use of pre-allocated parameters)
|
||||
|
||||
static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates
|
||||
static std::mutex mutex_;
|
||||
|
||||
@ -84,7 +86,7 @@ private:
|
||||
int k);
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits, size_t lemmaSize);
|
||||
LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
@ -97,8 +99,10 @@ private:
|
||||
int k_;
|
||||
int nbits_;
|
||||
size_t lemmaSize_;
|
||||
bool abortIfDynamic_;
|
||||
|
||||
public:
|
||||
LSHShortlistGenerator(int k, int nbits, size_t lemmaSize);
|
||||
LSHShortlistGenerator(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
|
||||
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
|
||||
};
|
||||
|
||||
|
@ -155,7 +155,7 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows
|
||||
return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
|
||||
}
|
||||
|
||||
Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) {
|
||||
Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abortIfDynamic) {
|
||||
int dim = weights->shape()[-1];
|
||||
|
||||
Expr rotMat = nullptr;
|
||||
@ -164,6 +164,7 @@ Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) {
|
||||
if(rotMat) {
|
||||
LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape());
|
||||
} else {
|
||||
ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited");
|
||||
LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
|
||||
rotMat = rotator(weights, nBits);
|
||||
}
|
||||
@ -173,6 +174,7 @@ Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) {
|
||||
if(encodedWeights) {
|
||||
LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape());
|
||||
} else {
|
||||
ABORT_IF(abortIfDynamic, "Dynamic creation of LSH code matrix prohibited");
|
||||
LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)}));
|
||||
encodedWeights = encode(weights, rotMat);
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ namespace lsh {
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0);
|
||||
|
||||
// same as above, but performs encoding on the fly
|
||||
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0);
|
||||
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false);
|
||||
|
||||
// These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBits);
|
||||
|
@ -78,7 +78,7 @@ public:
|
||||
graph_->setDevice(deviceId, device_);
|
||||
|
||||
#if MKL_FOUND
|
||||
mkl_set_num_threads(options->get<int>("mkl-threads", 1));
|
||||
mkl_set_num_threads(options_->get<int>("mkl-threads", 1));
|
||||
#endif
|
||||
|
||||
std::vector<std::string> models
|
||||
@ -114,6 +114,9 @@ public:
|
||||
for(auto scorer : scorers_) {
|
||||
scorer->init(graph_);
|
||||
}
|
||||
|
||||
// run parameter init once, this is required for graph_->get("parameter name") to work correctly
|
||||
graph_->forward();
|
||||
}
|
||||
|
||||
void setWorkspace(uint8_t* data, size_t size) override { device_->set(data, size); }
|
||||
@ -121,8 +124,21 @@ public:
|
||||
QSNBestBatch decode(const QSBatch& qsBatch,
|
||||
size_t maxLength,
|
||||
const std::unordered_set<WordIndex>& shortlist) override {
|
||||
if(shortlist.size() > 0) {
|
||||
auto shortListGen = New<data::FakeShortlistGenerator>(shortlist);
|
||||
|
||||
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn", {});
|
||||
ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
|
||||
ABORT_IF(lshOpts.size() == 2 && shortlist.size() > 0, "LSH and shortlist cannot be used at the same time");
|
||||
|
||||
if(lshOpts.size() == 2 || shortlist.size() > 0) {
|
||||
Ptr<data::ShortlistGenerator> shortListGen;
|
||||
// both ShortListGenerators are thin wrappers, hence no problem with calling this per query
|
||||
if(lshOpts.size() == 2) {
|
||||
// Setting abortIfDynamic to true disallows memory allocation for LSH parameters, this is specifically for use in Quicksand.
|
||||
// If we want to use the LSH in Quicksand we need to create a binary model that contains the LSH parameters via conversion.
|
||||
shortListGen = New<data::LSHShortlistGenerator>(lshOpts[0], lshOpts[1], vocabs_[1]->lemmaSize(), /*abortIfDynamic=*/true);
|
||||
} else {
|
||||
shortListGen = New<data::FakeShortlistGenerator>(shortlist);
|
||||
}
|
||||
for(auto scorer : scorers_)
|
||||
scorer->setShortlistGenerator(shortListGen);
|
||||
}
|
||||
@ -249,7 +265,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) {
|
||||
// This function converts an fp32 model into an FBGEMM based packed model.
|
||||
// marian defined types are used for external project as well.
|
||||
// The targetPrec is passed as int32_t for the exported function definition.
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) {
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, int32_t lshNBits) {
|
||||
std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl;
|
||||
|
||||
YAML::Node config;
|
||||
@ -264,9 +280,10 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
|
||||
|
||||
// MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students.
|
||||
// The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config.
|
||||
int lshNBits = 1024;
|
||||
std::string lshOutputWeights = "Wemb";
|
||||
bool addLsh = lshNBits > 0;
|
||||
if(addLsh) {
|
||||
std::cerr << "Adding LSH to model with hash size " << lshNBits << std::endl;
|
||||
// Add dummy parameters for the LSH before the model gets actually initialized.
|
||||
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
|
||||
graph->setReloaded(false);
|
||||
|
@ -79,7 +79,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name);
|
||||
// MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose.
|
||||
// The calling code should be adapted, not this interface. If you need to fix things in QS because of this
|
||||
// talk to me first!
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh);
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, int32_t lshNBits);
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "microsoft/shortlist/utils/Converter.h"
|
||||
|
||||
namespace marian {
|
||||
namespace quicksand {
|
||||
|
||||
#include "microsoft/shortlist/logging/LoggerMacros.h"
|
||||
@ -57,3 +58,4 @@ void Converter::HandleConversionError(const std::string& str, const char * type_
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
@ -5,6 +5,7 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
namespace marian {
|
||||
namespace quicksand {
|
||||
|
||||
class Converter {
|
||||
@ -81,3 +82,4 @@ std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * typ
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
@ -5,6 +5,7 @@
|
||||
#include "microsoft/shortlist/utils/StringUtils.h"
|
||||
#include "microsoft/shortlist/utils/Converter.h"
|
||||
|
||||
namespace marian {
|
||||
namespace quicksand {
|
||||
|
||||
#include "microsoft/shortlist/logging/LoggerMacros.h"
|
||||
@ -414,4 +415,4 @@ void ParameterTree::ReplaceVariablesInternal(
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
|
||||
} // namespace marian
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "microsoft/shortlist/utils/StringUtils.h"
|
||||
|
||||
namespace marian {
|
||||
namespace quicksand {
|
||||
|
||||
class ParameterTree {
|
||||
@ -183,3 +184,4 @@ void ParameterTree::SetParam(const std::string& name, const T& obj) {
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
@ -4,6 +4,7 @@
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
namespace marian {
|
||||
namespace quicksand {
|
||||
|
||||
#include "microsoft/shortlist/logging/LoggerMacros.h"
|
||||
@ -336,3 +337,4 @@ std::string StringUtils::ToLower(const std::string& str) {
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "microsoft/shortlist/utils/PrintTypes.h"
|
||||
|
||||
namespace marian {
|
||||
namespace quicksand {
|
||||
|
||||
class StringUtils {
|
||||
@ -96,3 +97,4 @@ std::string StringUtils::ToString(const T& obj) {
|
||||
}
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
@ -62,7 +62,7 @@ public:
|
||||
trgVocab_->load(vocabs.back());
|
||||
auto srcVocab = corpus_->getVocabs()[0];
|
||||
|
||||
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
|
||||
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn", {});
|
||||
ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
|
||||
|
||||
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
|
||||
|
Loading…
Reference in New Issue
Block a user