diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp index 943f61d4..b4a5f374 100644 --- a/src/command/marian_conv.cpp +++ b/src/command/marian_conv.cpp @@ -86,11 +86,17 @@ int main(int argc, char** argv) { graph->setDevice(CPU0); graph->load(modelFrom); + std::vector toBeLSHed; if(addLsh) { // Add dummy parameters for the LSH before the model gets actually initialized. // This create the parameters with useless values in the tensors, but it gives us the memory we need. + toBeLSHed = { + {lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits} + }; + graph->setReloaded(false); - lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits); + for(auto p : toBeLSHed) + lsh::addDummyParameters(graph, /*paramInfo=*/p); graph->setReloaded(true); } @@ -99,7 +105,8 @@ int main(int argc, char** argv) { if(addLsh) { // After initialization, hijack the paramters for the LSH and force-overwrite with correct values. // Once this is done we can just pack and save as normal. - lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights); + for(auto p : toBeLSHed) + lsh::overwriteDummyParameters(graph, /*paramInfo=*/p); } // added a flag if the weights needs to be packed or not diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp index 8a9c924e..73d45fc7 100644 --- a/src/layers/lsh.cpp +++ b/src/layers/lsh.cpp @@ -3,12 +3,14 @@ #include "common/utils.h" #include "3rd_party/faiss/utils/hamming.h" -#include "3rd_party/faiss/Index.h" #if BLAS_FOUND #include "3rd_party/faiss/VectorTransform.h" #endif +#include "common/timer.h" + +#include "layers/lsh_impl.h" namespace marian { namespace lsh { @@ -98,24 +100,22 @@ Expr encode(Expr input, Expr rotation) { return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash); } -Expr rotator(Expr weights, int nBits) { +Expr rotator(Expr weights, int inDim, int nBits) { auto rotator = [](Expr out, const std::vector& inputs) { inputs; fillRandomRotationMatrix(out->val(), out->graph()->allocator()); }; static const size_t rotatorHash = (size_t)&rotator; - int dim = weights->shape()[-1]; - return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash); + return lambda({weights}, {inDim, nBits}, Type::float32, rotator, rotatorHash); } -Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) { +Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNRows, bool noSort/*= false*/) { ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1], "Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]); int currBeamSize = encodedQuery->shape()[0]; int batchSize = encodedQuery->shape()[2]; - int numHypos = currBeamSize * batchSize; auto search = [=](Expr out, const std::vector& inputs) { Expr encodedQuery = inputs[0]; @@ -128,30 +128,25 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows if(firstNRows != 0) wRows = firstNRows; - int qRows = encodedQuery->shape().elements() / bytesPerVector; + ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently? - uint8_t* qCodes = encodedQuery->val()->data(); - uint8_t* wCodes = encodedWeights->val()->data(); + IndexType* outData = out->val()->data(); + auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) { + outData[rowId * dimK + k] = kthColId; + }; - // use actual faiss code for performing the hamming search. - std::vector distances(qRows * k); - std::vector ids(qRows * k); - faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()}; - faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0); + Parameters params; + params.k = dimK; + params.queryRows = encodedQuery->val()->data(); + params.numQueryRows = encodedQuery->shape().elements() / bytesPerVector; + params.codeRows = encodedWeights->val()->data(); + params.numCodeRows = wRows; + params.bytesPerVector = bytesPerVector; - // Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis. - // The sorting is required as we later do a binary search on those values for reverse look-up. - uint32_t* outData = out->val()->data(); - for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) { - size_t startIdx = k * hypoIdx; - size_t endIdx = startIdx + k; - for(size_t i = startIdx; i < endIdx; ++i) - outData[i] = (uint32_t)ids[i]; - std::sort(outData + startIdx, outData + endIdx); - } + hammingTopK(params, gather); }; - Shape kShape({currBeamSize, batchSize, k}); + Shape kShape({currBeamSize, batchSize, dimK}); return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search); } @@ -166,7 +161,7 @@ Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abo } else { ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited"); LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits})); - rotMat = rotator(weights, nBits); + rotMat = rotator(weights, dim, nBits); } } @@ -195,34 +190,43 @@ Ptr randomRotation() { return New(); } -void addDummyParameters(Ptr graph, std::string weightsName, int nBitsRot) { - auto weights = graph->get(weightsName); - - ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName); +void addDummyParameters(Ptr graph, ParamConvInfo paramInfo) { + auto weights = graph->get(paramInfo.name); + int nBitsRot = paramInfo.nBits; + + ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name); int nBits = weights->shape()[-1]; + if(paramInfo.transpose) + nBits = weights->shape()[-2]; + int nRows = weights->shape().elements() / nBits; Expr rotation; if(nBits != nBitsRot) { - LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot})); - rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32); + LOG(info, "Adding LSH rotation parameter {} with shape {}", paramInfo.rotationName, Shape({nBits, nBitsRot})); + rotation = graph->param(paramInfo.rotationName, {nBits, nBitsRot}, inits::dummy(), Type::float32); nBits = nBitsRot; } int bytesPerVector = lsh::bytesPerVector(nBits); - LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector})); - auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8); + LOG(info, "Adding LSH encoded weights {} with shape {}", paramInfo.codesName, Shape({nRows, bytesPerVector})); + auto codes = graph->param(paramInfo.codesName, {nRows, bytesPerVector}, inits::dummy(), Type::uint8); } -void overwriteDummyParameters(Ptr graph, std::string weightsName) { - Expr weights = graph->get(weightsName); - Expr codes = graph->get("lsh_output_codes"); - Expr rotation = graph->get("lsh_output_rotation"); +void overwriteDummyParameters(Ptr graph, ParamConvInfo paramInfo) { + Expr weights = graph->get(paramInfo.name); + Expr codes = graph->get(paramInfo.codesName); + Expr rotation = graph->get(paramInfo.rotationName); - ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName); + ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name); ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??"); + if(paramInfo.transpose) { + weights = transpose(weights); + graph->forward(); + } + if(rotation) { fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator()); encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator()); diff --git a/src/layers/lsh.h b/src/layers/lsh.h index 7a585891..5065ffcf 100644 --- a/src/layers/lsh.h +++ b/src/layers/lsh.h @@ -17,26 +17,34 @@ namespace marian { namespace lsh { - - // return the number of full bytes required to encoded that many bits - int bytesPerVector(int nBits); - // encodes an input as a bit vector, with optional rotation Expr encode(Expr input, Expr rotator = nullptr); // compute the rotation matrix (maps weights->shape()[-1] to nbits floats) - Expr rotator(Expr weights, int nbits); + Expr rotator(Expr weights, int inDim, int nbits); // perform the LSH search on fully encoded input and weights, return k results (indices) per input row // @TODO: add a top-k like operator that also returns the bitwise computed distances - Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0); + Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0, bool noSort = false); // same as above, but performs encoding on the fly Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false); + // struct for parameter conversion used in marian-conv + struct ParamConvInfo { + std::string name; + std::string codesName; + std::string rotationName; + int nBits; + bool transpose; + + ParamConvInfo(const std::string& name, const std::string& codesName, const std::string& rotationName, int nBits, bool transpose = false) + : name(name), codesName(codesName), rotationName(rotationName), nBits(nBits), transpose(transpose) {} + }; + // These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv - void addDummyParameters(Ptr graph, std::string weightsName, int nBits); - void overwriteDummyParameters(Ptr graph, std::string weightsName); + void addDummyParameters(Ptr graph, ParamConvInfo paramInfo); + void overwriteDummyParameters(Ptr graph, ParamConvInfo paramInfo); /** * Computes a random rotation matrix for LSH hashing. diff --git a/src/layers/lsh_impl.h b/src/layers/lsh_impl.h new file mode 100644 index 00000000..d87d23e0 --- /dev/null +++ b/src/layers/lsh_impl.h @@ -0,0 +1,186 @@ +#pragma once + +#include + +#ifdef _MSC_VER +#define __builtin_popcountl __popcnt64 +#define __builtin_popcount __popcnt +#endif + +namespace marian { +namespace lsh { + + struct Parameters { + int k; + uint8_t* queryRows; + int numQueryRows; + uint8_t* codeRows; + int numCodeRows; + int bytesPerVector; + }; + + typedef uint32_t DistType; + typedef uint64_t ChunkType; + + inline DistType popcount(const ChunkType& chunk) { + switch (sizeof(ChunkType)) { + case 8 : return (DistType)__builtin_popcountl((uint64_t)chunk); + case 4 : return (DistType)__builtin_popcount((uint32_t)chunk); + default: ABORT("Size {} not supported", sizeof(ChunkType)); + } + } + + // return the number of full bytes required to encoded that many bits + inline int bytesPerVector(int nBits); + + // compute top-k hamming distances for given query and weight binary codes. Faster than FAISS version, especially for larger k nearly constant wrt. k. + template + inline constexpr T getStaticOrDynamic(T dynamicValue) { + return Dynamic ? dynamicValue : StaticValue; + } + + template + inline DistType hamming(ChunkType* queryRow, ChunkType* codeRow, int stepsDynamic = 0) { + static_assert(Dynamic == true || StepsStatic != 0, "Either define dynamic use of steps or provide non-zero template argument"); + DistType dist = 0; + for(int i = 0; i < getStaticOrDynamic(stepsDynamic); ++i) + dist += popcount(queryRow[i] ^ codeRow[i]); + return dist; + } + + template + inline void hammingTopKUnrollWarp(int queryOffset, const Parameters& parameters, const Functor& gather) { + const int numBits = getStaticOrDynamic(parameters.bytesPerVector) * 8; + ABORT_IF(numBits % 64 != 0, "LSH hash size has to be a multiple of 64"); + + // counter to keep track of seen hamming distances + std::vector> counter(warpSize, std::vector(numBits, 0)); + // buffer the distances for query vector warpRowId to all weight weight vectors codeRowId + std::vector> distBuffer(warpSize, std::vector(getStaticOrDynamic(parameters.numCodeRows), 0)); + // minimal distances per query + std::vector minDist(warpSize); + + constexpr int StepStatic = BytesPerVector / sizeof(ChunkType); + int stepDynamic = parameters.bytesPerVector / sizeof(ChunkType); + + ChunkType* codeRow = (ChunkType*)parameters.codeRows; + + for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) { + std::memset(counter[warpRowId].data(), 0, numBits * sizeof(DistType)); // Reset the counter via memset to 0 + minDist[warpRowId] = (DistType)numBits; + } + + for(IndexType codeRowId = 0; codeRowId < (IndexType)getStaticOrDynamic(parameters.numCodeRows); ++codeRowId, codeRow += getStaticOrDynamic(stepDynamic)) { + ChunkType* queryRow = (ChunkType*)parameters.queryRows; + for(IndexType warpRowId = 0; warpRowId < warpSize; warpRowId++, queryRow += getStaticOrDynamic(stepDynamic)) { + // Compute the bit-wise hamming distance + DistType dist = hamming(queryRow, codeRow, stepDynamic); + + // Record the minimal distance seen for this query vector wrt. all weight vectors + if(dist < minDist[warpRowId]) { + minDist[warpRowId] = dist; + } + + // Record the number of weight vectors that have this distance from the query vector. + // Note, because there is at most numBits different distances this can be trivially done. + // Not the case for generic distances like float. + counter[warpRowId][dist]++; + + // Record the distance for this weight vector + distBuffer[warpRowId][codeRowId] = dist; + } + } + // warp finished, harvest k top distances + + for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) { + // Here we search for the distance at which we have seen equal or more than k elements with + // smaller distances. We start with the minimal distance from above which is its own address + // to the counter. + DistType maxDist = minDist[warpRowId]; + size_t cummulativeDistances = 0; + + // Accumulate number of elements until we reach k in growing distance order. Note that + // counter is indexed by hamming distance - from lowest to highest. Some slots will be 0. + // The cumulative sum from position a to b tells you how many elements have distances smaller + // than the distance at b. + while(cummulativeDistances < parameters.k) + cummulativeDistances += counter[warpRowId][maxDist++]; + if(cummulativeDistances) + maxDist--; // fix overcounting + + // Usually, we overshoot by a couple of elements and we need to take care of the distance at which the k-th + // element sits. This elements has more neighbors at the same distance, but we only care for them + // as long we have not reached k elements in total. + // By contrast, we trivially collect all elements below that distance -- these are always safe. + + // This is the number of elements we need to collect at the last distance. + DistType maxDistLimit = /*number of elements at maxDist=*/counter[warpRowId][maxDist] - /*overflow=*/((DistType)cummulativeDistances - (DistType)parameters.k); + IndexType kSeen = 0; + IndexType kSeenAtKDist = 0; + + for(IndexType codeRowId = 0; kSeen < (IndexType)parameters.k && codeRowId < (IndexType)getStaticOrDynamic(parameters.numCodeRows); ++codeRowId) { + DistType dist = distBuffer[warpRowId][codeRowId]; + // - if the current distance is smaller than the maxDist, just consume. + // - if the distance is equal to maxDist, make sure to only consume maxDistLimit elements at maxDist + // and ignore the rest (smaller indices make it in first). + // - after we finish this loop we have exactly k top values for every query row in original index order. + int queryRowId = queryOffset + warpRowId; + if(dist < maxDist) { + gather(queryRowId, (IndexType)kSeen, codeRowId, dist); + kSeen++; + } else if(dist == maxDist && kSeenAtKDist < (DistType)maxDistLimit) { + gather(queryRowId, (IndexType)kSeen, codeRowId, dist); + kSeen++; + kSeenAtKDist++; + } + } + } + } + + // Faster top-k search for hamming distance. The idea here is that instead of sorting the elements we find a hamming distances at which it is safe + // to copy the given index. Copying only the indices below that distance is guaranteed to results in no more than k elements. For elements at that + // distance we need to correct for overshooting. + // Once we have that distance we only need to traverse the set of distances. In the end we get exactly k elements per queryRows vector. + template + inline void hammingTopKUnroll(const Parameters& parameters, const Functor& gather) { + static_assert(Dynamic == true || (NumCodeRows != 0 && BytesPerVector != 0), "Either define dynamic use of variables or provide non-zero template arguments"); + + int warpSize = 4; // starting warpSize of 4 seems optimal + auto warpParameters = parameters; + for(int queryOffset = 0; queryOffset < parameters.numQueryRows; queryOffset += warpSize) { + while(parameters.numQueryRows - queryOffset < warpSize) + warpSize /= 2; + + int step = getStaticOrDynamic(parameters.bytesPerVector); + warpParameters.queryRows = parameters.queryRows + queryOffset * step; + warpParameters.numQueryRows = warpSize; + switch(warpSize) { + case 8 : hammingTopKUnrollWarp<8, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + case 4 : hammingTopKUnrollWarp<4, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + case 2 : hammingTopKUnrollWarp<2, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + case 1 : hammingTopKUnrollWarp<1, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + default: ABORT("Unhandled warpSize = {}??", warpSize); + } + } + } + + template + inline void hammingTopK(const Parameters& parameters, const Functor& gather) { + if(parameters.numCodeRows == 2048 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 2048, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 4096 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 4096, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 6144 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 6144, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 8192 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 8192, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 64) + hammingTopKUnroll<32000, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 128) + hammingTopKUnroll<32000, 128, false>(parameters, gather); + else + hammingTopKUnroll< 0, 0, true>(parameters, gather); + } + +} // namespace lsh +} // namespace marian \ No newline at end of file diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index a439197b..316c66d1 100644 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -178,8 +178,7 @@ public: auto score = std::get<2>(result); // determine alignment if present AlignmentSets alignmentSets; - if (options_->hasAndNotEmpty("alignment")) - { + if (options_->hasAndNotEmpty("alignment")) { float alignmentThreshold; auto alignment = options_->get("alignment"); // @TODO: this logic now exists three times in Marian if (alignment == "soft") @@ -287,7 +286,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP // Add dummy parameters for the LSH before the model gets actually initialized. // This create the parameters with useless values in the tensors, but it gives us the memory we need. graph->setReloaded(false); - lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits); + lsh::addDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits}); graph->setReloaded(true); } @@ -296,7 +295,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP if(addLsh) { // After initialization, hijack the paramters for the LSH and force-overwrite with correct values. // Once this is done we can just pack and save as normal. - lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights); + lsh::overwriteDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits}); } Type targetPrecType = (Type) targetPrec;