mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 22490: Faster LSH top-k for CPU
This PR replaces the top-k search from FAISS on the CPU with a more specialized version for discrete distances in sub-linear time.
This commit is contained in:
parent
f00d062189
commit
e6dbacb310
@ -86,11 +86,17 @@ int main(int argc, char** argv) {
|
||||
graph->setDevice(CPU0);
|
||||
graph->load(modelFrom);
|
||||
|
||||
std::vector<lsh::ParamConvInfo> toBeLSHed;
|
||||
if(addLsh) {
|
||||
// Add dummy parameters for the LSH before the model gets actually initialized.
|
||||
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
|
||||
toBeLSHed = {
|
||||
{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits}
|
||||
};
|
||||
|
||||
graph->setReloaded(false);
|
||||
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
|
||||
for(auto p : toBeLSHed)
|
||||
lsh::addDummyParameters(graph, /*paramInfo=*/p);
|
||||
graph->setReloaded(true);
|
||||
}
|
||||
|
||||
@ -99,7 +105,8 @@ int main(int argc, char** argv) {
|
||||
if(addLsh) {
|
||||
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
|
||||
// Once this is done we can just pack and save as normal.
|
||||
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
|
||||
for(auto p : toBeLSHed)
|
||||
lsh::overwriteDummyParameters(graph, /*paramInfo=*/p);
|
||||
}
|
||||
|
||||
// added a flag if the weights needs to be packed or not
|
||||
|
@ -3,12 +3,14 @@
|
||||
#include "common/utils.h"
|
||||
|
||||
#include "3rd_party/faiss/utils/hamming.h"
|
||||
#include "3rd_party/faiss/Index.h"
|
||||
|
||||
#if BLAS_FOUND
|
||||
#include "3rd_party/faiss/VectorTransform.h"
|
||||
#endif
|
||||
|
||||
#include "common/timer.h"
|
||||
|
||||
#include "layers/lsh_impl.h"
|
||||
|
||||
namespace marian {
|
||||
namespace lsh {
|
||||
@ -98,24 +100,22 @@ Expr encode(Expr input, Expr rotation) {
|
||||
return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash);
|
||||
}
|
||||
|
||||
Expr rotator(Expr weights, int nBits) {
|
||||
Expr rotator(Expr weights, int inDim, int nBits) {
|
||||
auto rotator = [](Expr out, const std::vector<Expr>& inputs) {
|
||||
inputs;
|
||||
fillRandomRotationMatrix(out->val(), out->graph()->allocator());
|
||||
};
|
||||
|
||||
static const size_t rotatorHash = (size_t)&rotator;
|
||||
int dim = weights->shape()[-1];
|
||||
return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash);
|
||||
return lambda({weights}, {inDim, nBits}, Type::float32, rotator, rotatorHash);
|
||||
}
|
||||
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) {
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNRows, bool noSort/*= false*/) {
|
||||
ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1],
|
||||
"Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]);
|
||||
|
||||
int currBeamSize = encodedQuery->shape()[0];
|
||||
int batchSize = encodedQuery->shape()[2];
|
||||
int numHypos = currBeamSize * batchSize;
|
||||
|
||||
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
|
||||
Expr encodedQuery = inputs[0];
|
||||
@ -128,30 +128,25 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows
|
||||
if(firstNRows != 0)
|
||||
wRows = firstNRows;
|
||||
|
||||
int qRows = encodedQuery->shape().elements() / bytesPerVector;
|
||||
ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently?
|
||||
|
||||
uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
|
||||
uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
|
||||
|
||||
// use actual faiss code for performing the hamming search.
|
||||
std::vector<int> distances(qRows * k);
|
||||
std::vector<faiss::Index::idx_t> ids(qRows * k);
|
||||
faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()};
|
||||
faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
|
||||
|
||||
// Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
|
||||
// The sorting is required as we later do a binary search on those values for reverse look-up.
|
||||
uint32_t* outData = out->val()->data<uint32_t>();
|
||||
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
|
||||
size_t startIdx = k * hypoIdx;
|
||||
size_t endIdx = startIdx + k;
|
||||
for(size_t i = startIdx; i < endIdx; ++i)
|
||||
outData[i] = (uint32_t)ids[i];
|
||||
std::sort(outData + startIdx, outData + endIdx);
|
||||
}
|
||||
IndexType* outData = out->val()->data<IndexType>();
|
||||
auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) {
|
||||
outData[rowId * dimK + k] = kthColId;
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, batchSize, k});
|
||||
Parameters params;
|
||||
params.k = dimK;
|
||||
params.queryRows = encodedQuery->val()->data<uint8_t>();
|
||||
params.numQueryRows = encodedQuery->shape().elements() / bytesPerVector;
|
||||
params.codeRows = encodedWeights->val()->data<uint8_t>();
|
||||
params.numCodeRows = wRows;
|
||||
params.bytesPerVector = bytesPerVector;
|
||||
|
||||
hammingTopK(params, gather);
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, batchSize, dimK});
|
||||
return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
|
||||
}
|
||||
|
||||
@ -166,7 +161,7 @@ Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abo
|
||||
} else {
|
||||
ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited");
|
||||
LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
|
||||
rotMat = rotator(weights, nBits);
|
||||
rotMat = rotator(weights, dim, nBits);
|
||||
}
|
||||
}
|
||||
|
||||
@ -195,34 +190,43 @@ Ptr<inits::NodeInitializer> randomRotation() {
|
||||
return New<RandomRotation>();
|
||||
}
|
||||
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBitsRot) {
|
||||
auto weights = graph->get(weightsName);
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo) {
|
||||
auto weights = graph->get(paramInfo.name);
|
||||
int nBitsRot = paramInfo.nBits;
|
||||
|
||||
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
|
||||
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name);
|
||||
|
||||
int nBits = weights->shape()[-1];
|
||||
if(paramInfo.transpose)
|
||||
nBits = weights->shape()[-2];
|
||||
|
||||
int nRows = weights->shape().elements() / nBits;
|
||||
|
||||
Expr rotation;
|
||||
if(nBits != nBitsRot) {
|
||||
LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot}));
|
||||
rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32);
|
||||
LOG(info, "Adding LSH rotation parameter {} with shape {}", paramInfo.rotationName, Shape({nBits, nBitsRot}));
|
||||
rotation = graph->param(paramInfo.rotationName, {nBits, nBitsRot}, inits::dummy(), Type::float32);
|
||||
nBits = nBitsRot;
|
||||
}
|
||||
|
||||
int bytesPerVector = lsh::bytesPerVector(nBits);
|
||||
LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector}));
|
||||
auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
|
||||
LOG(info, "Adding LSH encoded weights {} with shape {}", paramInfo.codesName, Shape({nRows, bytesPerVector}));
|
||||
auto codes = graph->param(paramInfo.codesName, {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
|
||||
}
|
||||
|
||||
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName) {
|
||||
Expr weights = graph->get(weightsName);
|
||||
Expr codes = graph->get("lsh_output_codes");
|
||||
Expr rotation = graph->get("lsh_output_rotation");
|
||||
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo) {
|
||||
Expr weights = graph->get(paramInfo.name);
|
||||
Expr codes = graph->get(paramInfo.codesName);
|
||||
Expr rotation = graph->get(paramInfo.rotationName);
|
||||
|
||||
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
|
||||
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name);
|
||||
ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??");
|
||||
|
||||
if(paramInfo.transpose) {
|
||||
weights = transpose(weights);
|
||||
graph->forward();
|
||||
}
|
||||
|
||||
if(rotation) {
|
||||
fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator());
|
||||
encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator());
|
||||
|
@ -17,26 +17,34 @@
|
||||
|
||||
namespace marian {
|
||||
namespace lsh {
|
||||
|
||||
// return the number of full bytes required to encoded that many bits
|
||||
int bytesPerVector(int nBits);
|
||||
|
||||
// encodes an input as a bit vector, with optional rotation
|
||||
Expr encode(Expr input, Expr rotator = nullptr);
|
||||
|
||||
// compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
|
||||
Expr rotator(Expr weights, int nbits);
|
||||
Expr rotator(Expr weights, int inDim, int nbits);
|
||||
|
||||
// perform the LSH search on fully encoded input and weights, return k results (indices) per input row
|
||||
// @TODO: add a top-k like operator that also returns the bitwise computed distances
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0);
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0, bool noSort = false);
|
||||
|
||||
// same as above, but performs encoding on the fly
|
||||
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false);
|
||||
|
||||
// struct for parameter conversion used in marian-conv
|
||||
struct ParamConvInfo {
|
||||
std::string name;
|
||||
std::string codesName;
|
||||
std::string rotationName;
|
||||
int nBits;
|
||||
bool transpose;
|
||||
|
||||
ParamConvInfo(const std::string& name, const std::string& codesName, const std::string& rotationName, int nBits, bool transpose = false)
|
||||
: name(name), codesName(codesName), rotationName(rotationName), nBits(nBits), transpose(transpose) {}
|
||||
};
|
||||
|
||||
// These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBits);
|
||||
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName);
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
|
||||
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
|
||||
|
||||
/**
|
||||
* Computes a random rotation matrix for LSH hashing.
|
||||
|
186
src/layers/lsh_impl.h
Normal file
186
src/layers/lsh_impl.h
Normal file
@ -0,0 +1,186 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define __builtin_popcountl __popcnt64
|
||||
#define __builtin_popcount __popcnt
|
||||
#endif
|
||||
|
||||
namespace marian {
|
||||
namespace lsh {
|
||||
|
||||
struct Parameters {
|
||||
int k;
|
||||
uint8_t* queryRows;
|
||||
int numQueryRows;
|
||||
uint8_t* codeRows;
|
||||
int numCodeRows;
|
||||
int bytesPerVector;
|
||||
};
|
||||
|
||||
typedef uint32_t DistType;
|
||||
typedef uint64_t ChunkType;
|
||||
|
||||
inline DistType popcount(const ChunkType& chunk) {
|
||||
switch (sizeof(ChunkType)) {
|
||||
case 8 : return (DistType)__builtin_popcountl((uint64_t)chunk);
|
||||
case 4 : return (DistType)__builtin_popcount((uint32_t)chunk);
|
||||
default: ABORT("Size {} not supported", sizeof(ChunkType));
|
||||
}
|
||||
}
|
||||
|
||||
// return the number of full bytes required to encoded that many bits
|
||||
inline int bytesPerVector(int nBits);
|
||||
|
||||
// compute top-k hamming distances for given query and weight binary codes. Faster than FAISS version, especially for larger k nearly constant wrt. k.
|
||||
template <int StaticValue = 0, bool Dynamic=true, typename T>
|
||||
inline constexpr T getStaticOrDynamic(T dynamicValue) {
|
||||
return Dynamic ? dynamicValue : StaticValue;
|
||||
}
|
||||
|
||||
template <size_t StepsStatic, bool Dynamic=false>
|
||||
inline DistType hamming(ChunkType* queryRow, ChunkType* codeRow, int stepsDynamic = 0) {
|
||||
static_assert(Dynamic == true || StepsStatic != 0, "Either define dynamic use of steps or provide non-zero template argument");
|
||||
DistType dist = 0;
|
||||
for(int i = 0; i < getStaticOrDynamic<StepsStatic, Dynamic>(stepsDynamic); ++i)
|
||||
dist += popcount(queryRow[i] ^ codeRow[i]);
|
||||
return dist;
|
||||
}
|
||||
|
||||
template <int warpSize, int NumCodeRows, int BytesPerVector, bool Dynamic, class Functor>
|
||||
inline void hammingTopKUnrollWarp(int queryOffset, const Parameters& parameters, const Functor& gather) {
|
||||
const int numBits = getStaticOrDynamic<BytesPerVector, Dynamic>(parameters.bytesPerVector) * 8;
|
||||
ABORT_IF(numBits % 64 != 0, "LSH hash size has to be a multiple of 64");
|
||||
|
||||
// counter to keep track of seen hamming distances
|
||||
std::vector<std::vector<DistType>> counter(warpSize, std::vector<DistType>(numBits, 0));
|
||||
// buffer the distances for query vector warpRowId to all weight weight vectors codeRowId
|
||||
std::vector<std::vector<DistType>> distBuffer(warpSize, std::vector<DistType>(getStaticOrDynamic<NumCodeRows, Dynamic>(parameters.numCodeRows), 0));
|
||||
// minimal distances per query
|
||||
std::vector<DistType> minDist(warpSize);
|
||||
|
||||
constexpr int StepStatic = BytesPerVector / sizeof(ChunkType);
|
||||
int stepDynamic = parameters.bytesPerVector / sizeof(ChunkType);
|
||||
|
||||
ChunkType* codeRow = (ChunkType*)parameters.codeRows;
|
||||
|
||||
for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) {
|
||||
std::memset(counter[warpRowId].data(), 0, numBits * sizeof(DistType)); // Reset the counter via memset to 0
|
||||
minDist[warpRowId] = (DistType)numBits;
|
||||
}
|
||||
|
||||
for(IndexType codeRowId = 0; codeRowId < (IndexType)getStaticOrDynamic<NumCodeRows, Dynamic>(parameters.numCodeRows); ++codeRowId, codeRow += getStaticOrDynamic<StepStatic, Dynamic>(stepDynamic)) {
|
||||
ChunkType* queryRow = (ChunkType*)parameters.queryRows;
|
||||
for(IndexType warpRowId = 0; warpRowId < warpSize; warpRowId++, queryRow += getStaticOrDynamic<StepStatic, Dynamic>(stepDynamic)) {
|
||||
// Compute the bit-wise hamming distance
|
||||
DistType dist = hamming<StepStatic, Dynamic>(queryRow, codeRow, stepDynamic);
|
||||
|
||||
// Record the minimal distance seen for this query vector wrt. all weight vectors
|
||||
if(dist < minDist[warpRowId]) {
|
||||
minDist[warpRowId] = dist;
|
||||
}
|
||||
|
||||
// Record the number of weight vectors that have this distance from the query vector.
|
||||
// Note, because there is at most numBits different distances this can be trivially done.
|
||||
// Not the case for generic distances like float.
|
||||
counter[warpRowId][dist]++;
|
||||
|
||||
// Record the distance for this weight vector
|
||||
distBuffer[warpRowId][codeRowId] = dist;
|
||||
}
|
||||
}
|
||||
// warp finished, harvest k top distances
|
||||
|
||||
for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) {
|
||||
// Here we search for the distance at which we have seen equal or more than k elements with
|
||||
// smaller distances. We start with the minimal distance from above which is its own address
|
||||
// to the counter.
|
||||
DistType maxDist = minDist[warpRowId];
|
||||
size_t cummulativeDistances = 0;
|
||||
|
||||
// Accumulate number of elements until we reach k in growing distance order. Note that
|
||||
// counter is indexed by hamming distance - from lowest to highest. Some slots will be 0.
|
||||
// The cumulative sum from position a to b tells you how many elements have distances smaller
|
||||
// than the distance at b.
|
||||
while(cummulativeDistances < parameters.k)
|
||||
cummulativeDistances += counter[warpRowId][maxDist++];
|
||||
if(cummulativeDistances)
|
||||
maxDist--; // fix overcounting
|
||||
|
||||
// Usually, we overshoot by a couple of elements and we need to take care of the distance at which the k-th
|
||||
// element sits. This elements has more neighbors at the same distance, but we only care for them
|
||||
// as long we have not reached k elements in total.
|
||||
// By contrast, we trivially collect all elements below that distance -- these are always safe.
|
||||
|
||||
// This is the number of elements we need to collect at the last distance.
|
||||
DistType maxDistLimit = /*number of elements at maxDist=*/counter[warpRowId][maxDist] - /*overflow=*/((DistType)cummulativeDistances - (DistType)parameters.k);
|
||||
IndexType kSeen = 0;
|
||||
IndexType kSeenAtKDist = 0;
|
||||
|
||||
for(IndexType codeRowId = 0; kSeen < (IndexType)parameters.k && codeRowId < (IndexType)getStaticOrDynamic<NumCodeRows, Dynamic>(parameters.numCodeRows); ++codeRowId) {
|
||||
DistType dist = distBuffer[warpRowId][codeRowId];
|
||||
// - if the current distance is smaller than the maxDist, just consume.
|
||||
// - if the distance is equal to maxDist, make sure to only consume maxDistLimit elements at maxDist
|
||||
// and ignore the rest (smaller indices make it in first).
|
||||
// - after we finish this loop we have exactly k top values for every query row in original index order.
|
||||
int queryRowId = queryOffset + warpRowId;
|
||||
if(dist < maxDist) {
|
||||
gather(queryRowId, (IndexType)kSeen, codeRowId, dist);
|
||||
kSeen++;
|
||||
} else if(dist == maxDist && kSeenAtKDist < (DistType)maxDistLimit) {
|
||||
gather(queryRowId, (IndexType)kSeen, codeRowId, dist);
|
||||
kSeen++;
|
||||
kSeenAtKDist++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Faster top-k search for hamming distance. The idea here is that instead of sorting the elements we find a hamming distances at which it is safe
|
||||
// to copy the given index. Copying only the indices below that distance is guaranteed to results in no more than k elements. For elements at that
|
||||
// distance we need to correct for overshooting.
|
||||
// Once we have that distance we only need to traverse the set of distances. In the end we get exactly k elements per queryRows vector.
|
||||
template <int NumCodeRows, int BytesPerVector, bool Dynamic, class Functor>
|
||||
inline void hammingTopKUnroll(const Parameters& parameters, const Functor& gather) {
|
||||
static_assert(Dynamic == true || (NumCodeRows != 0 && BytesPerVector != 0), "Either define dynamic use of variables or provide non-zero template arguments");
|
||||
|
||||
int warpSize = 4; // starting warpSize of 4 seems optimal
|
||||
auto warpParameters = parameters;
|
||||
for(int queryOffset = 0; queryOffset < parameters.numQueryRows; queryOffset += warpSize) {
|
||||
while(parameters.numQueryRows - queryOffset < warpSize)
|
||||
warpSize /= 2;
|
||||
|
||||
int step = getStaticOrDynamic<BytesPerVector, Dynamic>(parameters.bytesPerVector);
|
||||
warpParameters.queryRows = parameters.queryRows + queryOffset * step;
|
||||
warpParameters.numQueryRows = warpSize;
|
||||
switch(warpSize) {
|
||||
case 8 : hammingTopKUnrollWarp<8, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
|
||||
case 4 : hammingTopKUnrollWarp<4, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
|
||||
case 2 : hammingTopKUnrollWarp<2, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
|
||||
case 1 : hammingTopKUnrollWarp<1, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
|
||||
default: ABORT("Unhandled warpSize = {}??", warpSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
inline void hammingTopK(const Parameters& parameters, const Functor& gather) {
|
||||
if(parameters.numCodeRows == 2048 && parameters.bytesPerVector == 64)
|
||||
hammingTopKUnroll< 2048, 64, false>(parameters, gather);
|
||||
else if(parameters.numCodeRows == 4096 && parameters.bytesPerVector == 64)
|
||||
hammingTopKUnroll< 4096, 64, false>(parameters, gather);
|
||||
else if(parameters.numCodeRows == 6144 && parameters.bytesPerVector == 64)
|
||||
hammingTopKUnroll< 6144, 64, false>(parameters, gather);
|
||||
else if(parameters.numCodeRows == 8192 && parameters.bytesPerVector == 64)
|
||||
hammingTopKUnroll< 8192, 64, false>(parameters, gather);
|
||||
else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 64)
|
||||
hammingTopKUnroll<32000, 64, false>(parameters, gather);
|
||||
else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 128)
|
||||
hammingTopKUnroll<32000, 128, false>(parameters, gather);
|
||||
else
|
||||
hammingTopKUnroll< 0, 0, true>(parameters, gather);
|
||||
}
|
||||
|
||||
} // namespace lsh
|
||||
} // namespace marian
|
@ -178,8 +178,7 @@ public:
|
||||
auto score = std::get<2>(result);
|
||||
// determine alignment if present
|
||||
AlignmentSets alignmentSets;
|
||||
if (options_->hasAndNotEmpty("alignment"))
|
||||
{
|
||||
if (options_->hasAndNotEmpty("alignment")) {
|
||||
float alignmentThreshold;
|
||||
auto alignment = options_->get<std::string>("alignment"); // @TODO: this logic now exists three times in Marian
|
||||
if (alignment == "soft")
|
||||
@ -287,7 +286,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
|
||||
// Add dummy parameters for the LSH before the model gets actually initialized.
|
||||
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
|
||||
graph->setReloaded(false);
|
||||
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
|
||||
lsh::addDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits});
|
||||
graph->setReloaded(true);
|
||||
}
|
||||
|
||||
@ -296,7 +295,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
|
||||
if(addLsh) {
|
||||
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
|
||||
// Once this is done we can just pack and save as normal.
|
||||
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
|
||||
lsh::overwriteDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits});
|
||||
}
|
||||
|
||||
Type targetPrecType = (Type) targetPrec;
|
||||
|
Loading…
Reference in New Issue
Block a user