mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
changes for review
This commit is contained in:
parent
a332e550a5
commit
cd292d3b32
@ -63,7 +63,7 @@ std::string findReplace(const std::string& in, const std::string& what, const st
|
||||
double parseDouble(std::string s);
|
||||
double parseNumber(std::string s);
|
||||
|
||||
|
||||
// prints vector values with a custom label.
|
||||
template<class T>
|
||||
void Debug(const T *arr, size_t size, const std::string &str) {
|
||||
std::cerr << str << ":" << size << ": ";
|
||||
|
@ -274,7 +274,10 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
|
||||
groupRanges_[g].second = u + 1;
|
||||
groupCounts[g]++;
|
||||
}
|
||||
|
||||
// required by LSH shortlist
|
||||
lemmaSize_ = groupCounts[0];
|
||||
|
||||
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
|
||||
LOG(info, "[vocab] Factor group '{}' has {} members", groupPrefixes_[g], groupCounts[g]);
|
||||
if (groupCounts[g] == 0) { // factor group is unused --@TODO: once this is not hard-coded, this is an error condition
|
||||
|
@ -24,9 +24,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
|
||||
|
||||
Shortlist::~Shortlist() {}
|
||||
|
||||
WordIndex Shortlist::reverseMap(int , int , int idx) const { return indices_[idx]; }
|
||||
WordIndex Shortlist::reverseMap(int /*beamIdx*/, int /*batchIdx*/, int idx) const { return indices_[idx]; }
|
||||
|
||||
WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
return (int)std::distance(indices_.begin(), first); // return coordinate if found
|
||||
@ -83,15 +83,8 @@ Ptr<faiss::IndexLSH> LSHShortlist::index_;
|
||||
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
|
||||
: Shortlist(std::vector<WordIndex>())
|
||||
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
|
||||
/*
|
||||
for (int i = 0; i < k_; ++i) {
|
||||
indices_.push_back(i);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
//#define BLAS_FOUND 1
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
|
||||
//int currBeamSize = indicesExpr_->shape()[0];
|
||||
int currBatchSize = indicesExpr_->shape()[1];
|
||||
@ -100,15 +93,6 @@ WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
bool found = first != indices_.end();
|
||||
if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
return (int)std::distance(indices_.begin(), first); // return coordinate if found
|
||||
else
|
||||
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
|
||||
}
|
||||
|
||||
Expr LSHShortlist::getIndicesExpr() const {
|
||||
return indicesExpr_;
|
||||
}
|
||||
@ -128,7 +112,6 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
int dim = values->shape()[-1];
|
||||
|
||||
if(!index_) {
|
||||
//std::cerr << "build lsh index" << std::endl;
|
||||
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
|
||||
index_.reset(new faiss::IndexLSH(dim, nbits_,
|
||||
/*rotate=*/dim != nbits_,
|
||||
@ -199,7 +182,6 @@ void LSHShortlist::createCachedTensors(Expr weights,
|
||||
|
||||
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize)
|
||||
: k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
|
||||
//std::cerr << "LSHShortlistGenerator" << std::endl;
|
||||
}
|
||||
|
||||
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
|
||||
|
@ -29,7 +29,7 @@ protected:
|
||||
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
|
||||
Expr cachedShortb_; // these match the current value of shortlist_
|
||||
Expr cachedShortLemmaEt_;
|
||||
bool done_;
|
||||
bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
|
||||
|
||||
void createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
@ -43,7 +43,7 @@ public:
|
||||
virtual ~Shortlist();
|
||||
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
|
||||
virtual WordIndex tryForwardMap(WordIndex wIdx) const;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
|
||||
virtual Expr getIndicesExpr() const;
|
||||
@ -66,12 +66,14 @@ public:
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
// implements SLIDE for faster inference.
|
||||
// https://arxiv.org/pdf/1903.03129.pdf
|
||||
class LSHShortlist: public Shortlist {
|
||||
private:
|
||||
int k_;
|
||||
int nbits_;
|
||||
size_t lemmaSize_;
|
||||
static Ptr<faiss::IndexLSH> index_;
|
||||
int k_; // number of candidates returned from each input
|
||||
int nbits_; // length of hash
|
||||
size_t lemmaSize_; // vocab size
|
||||
static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates
|
||||
|
||||
void createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
@ -82,7 +84,6 @@ private:
|
||||
public:
|
||||
LSHShortlist(int k, int nbits, size_t lemmaSize);
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
|
||||
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
virtual Expr getIndicesExpr() const override;
|
||||
|
@ -133,7 +133,7 @@ size_t Vocab::lemmaSize() const {
|
||||
return vImpl_->lemmaSize();
|
||||
}
|
||||
|
||||
// number of vocabulary items
|
||||
// type of vocabulary items
|
||||
std::string Vocab::type() const { return vImpl_->type(); }
|
||||
|
||||
// return EOS symbol id
|
||||
|
@ -61,7 +61,7 @@ public:
|
||||
// number of vocabulary items
|
||||
size_t size() const;
|
||||
|
||||
// number of vocabulary items
|
||||
// number of lemma items. Same as size() except in factored models
|
||||
size_t lemmaSize() const;
|
||||
|
||||
// number of vocabulary items
|
||||
|
@ -247,8 +247,6 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
|
||||
|
||||
std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
|
||||
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
|
||||
//std::cerr << "indicesExpr=" << indicesExpr->shape() << std::endl;
|
||||
//int batchSize
|
||||
int batchSize = indicesExpr->shape()[0];
|
||||
int currBeamSize = indicesExpr->shape()[1];
|
||||
int numHypos = batchSize * currBeamSize;
|
||||
|
@ -56,12 +56,12 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
lazyConstruct(input->shape()[-1]);
|
||||
|
||||
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
|
||||
/*
|
||||
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
|
||||
|
||||
*/
|
||||
if(b)
|
||||
return affine(x, W, b, transA, transB);
|
||||
else
|
||||
|
@ -315,7 +315,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
suppressed.erase(std::remove_if(suppressed.begin(),
|
||||
suppressed.end(),
|
||||
[&](WordIndex i) {
|
||||
return shortlist->tryForwardMap(4545, 3343, i) == data::Shortlist::npos; // TODO beamIdx
|
||||
return shortlist->tryForwardMap(i) == data::Shortlist::npos; // TODO beamIdx
|
||||
}),
|
||||
suppressed.end());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user