This commit is contained in:
Hieu Hoang 2021-04-28 23:40:00 -07:00
parent 49e379bba5
commit 909df372d1
8 changed files with 59 additions and 192 deletions

View File

@ -72,7 +72,6 @@ set(MARIAN_SOURCES
layers/generic.cpp
layers/loss.cpp
layers/weight.cpp
layers/lsh.cpp
layers/embedding.cpp
layers/output.cpp
layers/logits.cpp

View File

@ -1,5 +1,6 @@
#include "data/shortlist.h"
#include "microsoft/shortlist/utils/ParameterTree.h"
#include "marian.h"
namespace marian {
namespace data {
@ -12,6 +13,48 @@ const T* get(const void*& current, size_t num = 1) {
return ptr;
}
//////////////////////////////////////////////////////////////////////////////////////
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices) {}
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
WordIndex Shortlist::reverseMap(int idx) { return indices_[idx]; }
WordIndex Shortlist::tryForwardMap(WordIndex wIdx) {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
else
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
}
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
int k = indices_.size();
int currBeamSize = input->shape()[0];
int batchSize = input->shape()[2];
std::cerr << "currBeamSize=" << currBeamSize << std::endl;
std::cerr << "batchSize=" << batchSize << std::endl;
Expr indicesExprBC;
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k);
}
void Shortlist::broadcast(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
Expr indicesExprBC,
int k) {
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indices());
if (b) {
cachedShortb_ = index_select(b, -1, indices());
}
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indices());
return;
}
//////////////////////////////////////////////////////////////////////////////////////
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,

View File

@ -19,26 +19,29 @@ namespace marian {
namespace data {
class Shortlist {
private:
protected:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
virtual void broadcast(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
Expr indicesExprBC,
int k);
public:
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices) {}
Shortlist(const std::vector<WordIndex>& indices);
const std::vector<WordIndex>& indices() const { return indices_; }
WordIndex reverseMap(int idx) { return indices_[idx]; }
WordIndex tryForwardMap(WordIndex wIdx) {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
else
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
}
const std::vector<WordIndex>& indices() const;
WordIndex reverseMap(int idx);
WordIndex tryForwardMap(WordIndex wIdx);
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
};
class ShortlistGenerator {

View File

@ -4,7 +4,6 @@
#include "layers/constructors.h"
#include "layers/generic.h"
#include "layers/loss.h"
#include "layers/lsh.h"
#include "models/states.h" // for EncoderState
namespace marian {} // namespace marian

View File

@ -1,130 +0,0 @@
#include "layers/lsh.h"
#include "graph/expression_operators.h"
#include "tensors/cpu/prod_blas.h"
#if BLAS_FOUND
#include "3rd_party/faiss/IndexLSH.h"
#endif
namespace marian {
Expr LSH::apply(Expr input, Expr W, Expr b) {
auto idx = search(input, W);
return affine(idx, input, W, b);
}
Expr LSH::search(Expr query, Expr values) {
#if BLAS_FOUND
ABORT_IF(query->graph()->getDeviceId().type == DeviceType::gpu,
"LSH index (--output-approx-knn) currently not implemented for GPU");
auto kShape = query->shape();
kShape.set(-1, k_);
auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
auto query = inputs[0];
auto values = inputs[1];
int dim = values->shape()[-1];
if(!index_ || indexHash_ != values->hash()) {
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
index_.reset(new faiss::IndexLSH(dim, nbits_,
/*rotate=*/dim != nbits_,
/*train_thesholds*/false));
int vRows = values->shape().elements() / dim;
index_->train(vRows, values->val()->data<float>());
index_->add( vRows, values->val()->data<float>());
indexHash_ = values->hash();
}
int qRows = query->shape().elements() / dim;
std::vector<float> distances(qRows * k_);
std::vector<faiss::Index::idx_t> ids(qRows * k_);
index_->search(qRows, query->val()->data<float>(), k_,
distances.data(), ids.data());
std::vector<IndexType> vOut;
vOut.reserve(ids.size());
for(auto id : ids)
vOut.push_back((IndexType)id);
out->val()->set(vOut);
};
return lambda({query, values}, kShape, Type::uint32, forward);
#else
query; values;
ABORT("LSH output layer requires a CPU BLAS library");
#endif
}
Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) {
auto outShape = input->shape();
int dimVoc = W->shape()[-2];
outShape.set(-1, dimVoc);
auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
auto lowest = NumericLimits<float>(out->value_type()).lowest;
out->val()->set(lowest);
int dimIn = inputs[1]->shape()[-1];
int dimOut = out->shape()[-1];
int dimRows = out->shape().elements() / dimOut;
auto outPtr = out->val()->data<float>();
auto idxPtr = inputs[0]->val()->data<uint32_t>();
auto queryPtr = inputs[1]->val()->data<float>();
auto WPtr = inputs[2]->val()->data<float>();
auto bPtr = inputs.size() > 3 ? inputs[3]->val()->data<float>() : nullptr; // nullptr if no bias given
for(int row = 0; row < dimRows; ++row) {
auto currIdxPtr = idxPtr + row * k_; // move to next batch of k entries
auto currQueryPtr = queryPtr + row * dimIn; // move to next input query vector
auto currOutPtr = outPtr + row * dimOut; // move to next output position vector (of vocabulary size)
for(int k = 0; k < k_; k++) {
int relPos = currIdxPtr[k]; // k-th best vocabulay item
auto currWPtr = WPtr + relPos * dimIn; // offset for k-th best embedding
currOutPtr[relPos] = bPtr ? bPtr[relPos] : 0; // write bias value to position, init to 0 if no bias given
// proceed one vector product at a time writing to the correct position
sgemm(false, true, 1, 1, dimIn, 1.0f, currQueryPtr, dimIn, currWPtr, dimIn, 1.0f, &currOutPtr[relPos], 1);
}
}
};
std::vector<Expr> nodes = {idx, input, W};
if(b) // bias is optional
nodes.push_back(b);
return lambda(nodes,
outShape,
input->value_type(),
forward);
}
// @TODO: alternative version which does the same as above with Marian operators, currently missing "scatter".
// this uses more memory and likely to be slower. Would make sense to have a scatter node that actually creates
// the node instead of relying on an existing node, e.g. scatter(shape, defaultValue, axis, indices, values);
#if 0
Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) {
int dim = input->shape()[-1];
int bch = idx->shape().elements() / k;
auto W = reshape(rows(Wt_, flatten(idx)), {bch, k, dim}); // [rows, k, dim]
auto b = reshape(cols(b_, flatten(idx)), {bch, 1, k}); // [rows, 1, k]
auto aff = reshape(bdot(reshape(input, {bch, 1, dim}), W, false, true) + b, idx->shape()); // [beam, time, batch, k]
int dimVoc = Wt_->shape()[-2];
auto oShape = input->shape();
oShape.set(-1, dimVoc);
auto lowest = graph_->constant(oShape,
inits::fromValue(NumericLimits<float>(input->value_type()).lowest),
input->value_type());
return scatter(lowest, -1, idx, aff);
}
#endif
} // namespace marian

View File

@ -1,31 +0,0 @@
#include "graph/expression_graph.h"
#include <memory>
namespace faiss {
struct IndexLSH;
}
namespace marian {
class LSH {
public:
LSH(int k, int nbits) : k_{k}, nbits_{nbits} {
#if !BLAS_FOUND
ABORT("LSH-based output approximation requires BLAS library");
#endif
}
Expr apply(Expr query, Expr values, Expr bias);
private:
Ptr<faiss::IndexLSH> index_;
size_t indexHash_{0};
int k_{100};
int nbits_{1024};
Expr search(Expr query, Expr values);
Expr affine(Expr idx, Expr query, Expr values, Expr bias);
};
}

View File

@ -2,7 +2,6 @@
#include "common/timer.h"
#include "data/factored_vocab.h"
#include "layers/loss.h"
#include "layers/lsh.h"
namespace marian {
namespace mlp {
@ -12,13 +11,6 @@ namespace mlp {
if(Wt_)
return;
// this option is only set in the decoder
if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
auto k = opt<std::vector<int>>("output-approx-knn")[0];
auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
lsh_ = New<LSH>(k, nbits);
}
auto name = options_->get<std::string>("prefix");
auto numOutputClasses = options_->get<int>("dim");
@ -71,13 +63,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
};
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
if(lsh_) {
ABORT_IF(transA, "Transposed query not supported for LSH");
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
return lsh_->apply(x, W, b); // knows how to deal with undefined bias
} else {
return affineOrDot(x, W, b, transA, transB);
}
};
if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one

View File

@ -7,7 +7,6 @@
#include "marian.h"
namespace marian {
class LSH;
namespace mlp {
@ -28,7 +27,6 @@ private:
// optional parameters set/updated after construction
Expr tiedParam_;
Ptr<data::Shortlist> shortlist_;
Ptr<LSH> lsh_;
void lazyConstruct(int inputDim);