mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 19685: Marianize LSH as operators for mmapping and use in Quicksand
This PR turns the LSH index and search into a set of operators that live in the expression graph. This makes creation etc. thread-safe (one index per graph) and allows to later implement GPU versions. This allows to mmap the LSH as a Marian parameter since now we only need to turn the index into something that can be saved to disk using the existing tensors. This happens in marian_conv or the equivalent interface function in the Quicksand interface.
This commit is contained in:
parent
d6c09b24de
commit
35c822eb4e
119
src/3rd_party/faiss/Index.cpp
vendored
119
src/3rd_party/faiss/Index.cpp
vendored
@ -1,119 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include "Index.h"
|
||||
#include "common/logging.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
Index::~Index ()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void Index::train(idx_t /*n*/, const float* /*x*/) {
|
||||
// does nothing by default
|
||||
}
|
||||
|
||||
|
||||
void Index::range_search (idx_t , const float *, float,
|
||||
RangeSearchResult *) const
|
||||
{
|
||||
ABORT ("range search not implemented");
|
||||
}
|
||||
|
||||
void Index::assign (idx_t n, const float * x, idx_t * labels, idx_t k)
|
||||
{
|
||||
float * distances = new float[n * k];
|
||||
ScopeDeleter<float> del(distances);
|
||||
search (n, x, k, distances, labels);
|
||||
}
|
||||
|
||||
void Index::add_with_ids(
|
||||
idx_t /*n*/,
|
||||
const float* /*x*/,
|
||||
const idx_t* /*xids*/) {
|
||||
ABORT ("add_with_ids not implemented for this type of index");
|
||||
}
|
||||
|
||||
size_t Index::remove_ids(const IDSelector& /*sel*/) {
|
||||
ABORT ("remove_ids not implemented for this type of index");
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
void Index::reconstruct (idx_t, float * ) const {
|
||||
ABORT ("reconstruct not implemented for this type of index");
|
||||
}
|
||||
|
||||
|
||||
void Index::reconstruct_n (idx_t i0, idx_t ni, float *recons) const {
|
||||
for (idx_t i = 0; i < ni; i++) {
|
||||
reconstruct (i0 + i, recons + i * d);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Index::search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
float *recons) const {
|
||||
search (n, x, k, distances, labels);
|
||||
for (idx_t i = 0; i < n; ++i) {
|
||||
for (idx_t j = 0; j < k; ++j) {
|
||||
idx_t ij = i * k + j;
|
||||
idx_t key = labels[ij];
|
||||
float* reconstructed = recons + ij * d;
|
||||
if (key < 0) {
|
||||
// Fill with NaNs
|
||||
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
||||
} else {
|
||||
reconstruct (key, reconstructed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Index::compute_residual (const float * x,
|
||||
float * residual, idx_t key) const {
|
||||
reconstruct (key, residual);
|
||||
for (size_t i = 0; i < d; i++) {
|
||||
residual[i] = x[i] - residual[i];
|
||||
}
|
||||
}
|
||||
|
||||
void Index::compute_residual_n (idx_t n, const float* xs,
|
||||
float* residuals,
|
||||
const idx_t* keys) const {
|
||||
//#pragma omp parallel for
|
||||
for (idx_t i = 0; i < n; ++i) {
|
||||
compute_residual(&xs[i * d], &residuals[i * d], keys[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
size_t Index::sa_code_size () const
|
||||
{
|
||||
ABORT ("standalone codec not implemented for this type of index");
|
||||
}
|
||||
|
||||
void Index::sa_encode (idx_t, const float *,
|
||||
uint8_t *) const
|
||||
{
|
||||
ABORT ("standalone codec not implemented for this type of index");
|
||||
}
|
||||
|
||||
void Index::sa_decode (idx_t, const uint8_t *,
|
||||
float *) const
|
||||
{
|
||||
ABORT ("standalone codec not implemented for this type of index");
|
||||
}
|
||||
|
||||
}
|
177
src/3rd_party/faiss/Index.h
vendored
177
src/3rd_party/faiss/Index.h
vendored
@ -39,11 +39,6 @@
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/// Forward declarations see AuxIndexStructures.h
|
||||
struct IDSelector;
|
||||
struct RangeSearchResult;
|
||||
struct DistanceComputer;
|
||||
|
||||
/** Abstract structure for an index, supports adding vectors and searching them.
|
||||
*
|
||||
* All vectors provided at add or search time are 32-bit float arrays,
|
||||
@ -53,178 +48,6 @@ struct Index {
|
||||
using idx_t = int64_t; ///< all indices are this type
|
||||
using component_t = float;
|
||||
using distance_t = float;
|
||||
|
||||
int d; ///< vector dimension
|
||||
idx_t ntotal; ///< total nb of indexed vectors
|
||||
bool verbose; ///< verbosity level
|
||||
|
||||
/// set if the Index does not require training, or if training is
|
||||
/// done already
|
||||
bool is_trained;
|
||||
|
||||
/// type of metric this index uses for search
|
||||
MetricType metric_type;
|
||||
float metric_arg; ///< argument of the metric type
|
||||
|
||||
explicit Index (idx_t d = 0, MetricType metric = METRIC_L2):
|
||||
d((int)d),
|
||||
ntotal(0),
|
||||
verbose(false),
|
||||
is_trained(true),
|
||||
metric_type (metric),
|
||||
metric_arg(0) {}
|
||||
|
||||
virtual ~Index ();
|
||||
|
||||
|
||||
/** Perform training on a representative set of vectors
|
||||
*
|
||||
* @param n nb of training vectors
|
||||
* @param x training vecors, size n * d
|
||||
*/
|
||||
virtual void train(idx_t n, const float* x);
|
||||
|
||||
/** Add n vectors of dimension d to the index.
|
||||
*
|
||||
* Vectors are implicitly assigned labels ntotal .. ntotal + n - 1
|
||||
* This function slices the input vectors in chuncks smaller than
|
||||
* blocksize_add and calls add_core.
|
||||
* @param x input matrix, size n * d
|
||||
*/
|
||||
virtual void add (idx_t n, const float *x) = 0;
|
||||
|
||||
/** Same as add, but stores xids instead of sequential ids.
|
||||
*
|
||||
* The default implementation fails with an assertion, as it is
|
||||
* not supported by all indexes.
|
||||
*
|
||||
* @param xids if non-null, ids to store for the vectors (size n)
|
||||
*/
|
||||
virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids);
|
||||
|
||||
/** query n vectors of dimension d to the index.
|
||||
*
|
||||
* return at most k vectors. If there are not enough results for a
|
||||
* query, the result array is padded with -1s.
|
||||
*
|
||||
* @param x input vectors to search, size n * d
|
||||
* @param labels output labels of the NNs, size n*k
|
||||
* @param distances output pairwise distances, size n*k
|
||||
*/
|
||||
virtual void search (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels) const = 0;
|
||||
|
||||
/** query n vectors of dimension d to the index.
|
||||
*
|
||||
* return all vectors with distance < radius. Note that many
|
||||
* indexes do not implement the range_search (only the k-NN search
|
||||
* is mandatory).
|
||||
*
|
||||
* @param x input vectors to search, size n * d
|
||||
* @param radius search radius
|
||||
* @param result result table
|
||||
*/
|
||||
virtual void range_search (idx_t n, const float *x, float radius,
|
||||
RangeSearchResult *result) const;
|
||||
|
||||
/** return the indexes of the k vectors closest to the query x.
|
||||
*
|
||||
* This function is identical as search but only return labels of neighbors.
|
||||
* @param x input vectors to search, size n * d
|
||||
* @param labels output labels of the NNs, size n*k
|
||||
*/
|
||||
void assign (idx_t n, const float * x, idx_t * labels, idx_t k = 1);
|
||||
|
||||
/// removes all elements from the database.
|
||||
virtual void reset() = 0;
|
||||
|
||||
/** removes IDs from the index. Not supported by all
|
||||
* indexes. Returns the number of elements removed.
|
||||
*/
|
||||
virtual size_t remove_ids (const IDSelector & sel);
|
||||
|
||||
/** Reconstruct a stored vector (or an approximation if lossy coding)
|
||||
*
|
||||
* this function may not be defined for some indexes
|
||||
* @param key id of the vector to reconstruct
|
||||
* @param recons reconstucted vector (size d)
|
||||
*/
|
||||
virtual void reconstruct (idx_t key, float * recons) const;
|
||||
|
||||
/** Reconstruct vectors i0 to i0 + ni - 1
|
||||
*
|
||||
* this function may not be defined for some indexes
|
||||
* @param recons reconstucted vector (size ni * d)
|
||||
*/
|
||||
virtual void reconstruct_n (idx_t i0, idx_t ni, float *recons) const;
|
||||
|
||||
/** Similar to search, but also reconstructs the stored vectors (or an
|
||||
* approximation in the case of lossy coding) for the search results.
|
||||
*
|
||||
* If there are not enough results for a query, the resulting arrays
|
||||
* is padded with -1s.
|
||||
*
|
||||
* @param recons reconstructed vectors size (n, k, d)
|
||||
**/
|
||||
virtual void search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
||||
float *distances, idx_t *labels,
|
||||
float *recons) const;
|
||||
|
||||
/** Computes a residual vector after indexing encoding.
|
||||
*
|
||||
* The residual vector is the difference between a vector and the
|
||||
* reconstruction that can be decoded from its representation in
|
||||
* the index. The residual can be used for multiple-stage indexing
|
||||
* methods, like IndexIVF's methods.
|
||||
*
|
||||
* @param x input vector, size d
|
||||
* @param residual output residual vector, size d
|
||||
* @param key encoded index, as returned by search and assign
|
||||
*/
|
||||
virtual void compute_residual (const float * x,
|
||||
float * residual, idx_t key) const;
|
||||
|
||||
/** Computes a residual vector after indexing encoding (batch form).
|
||||
* Equivalent to calling compute_residual for each vector.
|
||||
*
|
||||
* The residual vector is the difference between a vector and the
|
||||
* reconstruction that can be decoded from its representation in
|
||||
* the index. The residual can be used for multiple-stage indexing
|
||||
* methods, like IndexIVF's methods.
|
||||
*
|
||||
* @param n number of vectors
|
||||
* @param xs input vectors, size (n x d)
|
||||
* @param residuals output residual vectors, size (n x d)
|
||||
* @param keys encoded index, as returned by search and assign
|
||||
*/
|
||||
virtual void compute_residual_n (idx_t n, const float* xs,
|
||||
float* residuals,
|
||||
const idx_t* keys) const;
|
||||
|
||||
/* The standalone codec interface */
|
||||
|
||||
/** size of the produced codes in bytes */
|
||||
virtual size_t sa_code_size () const;
|
||||
|
||||
/** encode a set of vectors
|
||||
*
|
||||
* @param n number of vectors
|
||||
* @param x input vectors, size n * d
|
||||
* @param bytes output encoded vectors, size n * sa_code_size()
|
||||
*/
|
||||
virtual void sa_encode (idx_t n, const float *x,
|
||||
uint8_t *bytes) const;
|
||||
|
||||
/** encode a set of vectors
|
||||
*
|
||||
* @param n number of vectors
|
||||
* @param bytes input encoded vectors, size n * sa_code_size()
|
||||
* @param x output vectors, size n * d
|
||||
*/
|
||||
virtual void sa_decode (idx_t n, const uint8_t *bytes,
|
||||
float *x) const;
|
||||
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
224
src/3rd_party/faiss/IndexLSH.cpp
vendored
224
src/3rd_party/faiss/IndexLSH.cpp
vendored
@ -1,224 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#include <faiss/IndexLSH.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include "common/logging.h"
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/***************************************************************
|
||||
* IndexLSH
|
||||
***************************************************************/
|
||||
|
||||
|
||||
IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
|
||||
Index(d), nbits(nbits), rotate_data(rotate_data),
|
||||
train_thresholds (train_thresholds), rrot(d, nbits)
|
||||
{
|
||||
is_trained = !train_thresholds;
|
||||
|
||||
bytes_per_vec = (nbits + 7) / 8;
|
||||
|
||||
if (rotate_data) {
|
||||
rrot.init(5);
|
||||
} else {
|
||||
ABORT_UNLESS(d >= nbits, "d >= nbits");
|
||||
}
|
||||
}
|
||||
|
||||
IndexLSH::IndexLSH ():
|
||||
nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
|
||||
{
|
||||
|
||||
float *xt = nullptr;
|
||||
if (rotate_data) {
|
||||
// also applies bias if exists
|
||||
xt = rrot.apply (n, x);
|
||||
} else if (d != nbits) {
|
||||
assert (nbits < d);
|
||||
xt = new float [nbits * n];
|
||||
float *xp = xt;
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
const float *xl = x + i * d;
|
||||
for (int j = 0; j < nbits; j++)
|
||||
*xp++ = xl [j];
|
||||
}
|
||||
}
|
||||
|
||||
if (train_thresholds) {
|
||||
|
||||
if (xt == NULL) {
|
||||
xt = new float [nbits * n];
|
||||
memcpy (xt, x, sizeof(*x) * n * nbits);
|
||||
}
|
||||
|
||||
float *xp = xt;
|
||||
for (idx_t i = 0; i < n; i++)
|
||||
for (int j = 0; j < nbits; j++)
|
||||
*xp++ -= thresholds [j];
|
||||
}
|
||||
|
||||
return xt ? xt : x;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void IndexLSH::train (idx_t n, const float *x)
|
||||
{
|
||||
if (train_thresholds) {
|
||||
thresholds.resize (nbits);
|
||||
train_thresholds = false;
|
||||
const float *xt = apply_preprocess (n, x);
|
||||
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
||||
train_thresholds = true;
|
||||
|
||||
float * transposed_x = new float [n * nbits];
|
||||
ScopeDeleter<float> del2 (transposed_x);
|
||||
|
||||
for (idx_t i = 0; i < n; i++)
|
||||
for (idx_t j = 0; j < nbits; j++)
|
||||
transposed_x [j * n + i] = xt [i * nbits + j];
|
||||
|
||||
for (idx_t i = 0; i < nbits; i++) {
|
||||
float *xi = transposed_x + i * n;
|
||||
// std::nth_element
|
||||
std::sort (xi, xi + n);
|
||||
if (n % 2 == 1)
|
||||
thresholds [i] = xi [n / 2];
|
||||
else
|
||||
thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
|
||||
|
||||
}
|
||||
}
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
|
||||
void IndexLSH::add (idx_t n, const float *x)
|
||||
{
|
||||
ABORT_UNLESS (is_trained, "is_trained");
|
||||
codes.resize ((ntotal + n) * bytes_per_vec);
|
||||
|
||||
sa_encode (n, x, &codes[ntotal * bytes_per_vec]);
|
||||
|
||||
ntotal += n;
|
||||
}
|
||||
|
||||
|
||||
void IndexLSH::search (
|
||||
idx_t n,
|
||||
const float *x,
|
||||
idx_t k,
|
||||
float *distances,
|
||||
idx_t *labels) const
|
||||
{
|
||||
ABORT_UNLESS (is_trained, "is_trained");
|
||||
const float *xt = apply_preprocess (n, x);
|
||||
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
||||
|
||||
uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
|
||||
ScopeDeleter<uint8_t> del2 (qcodes);
|
||||
|
||||
fvecs2bitvecs (xt, qcodes, nbits, n);
|
||||
|
||||
int * idistances = new int [n * k];
|
||||
ScopeDeleter<int> del3 (idistances);
|
||||
|
||||
int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
|
||||
|
||||
hammings_knn_hc (&res, qcodes, codes.data(),
|
||||
ntotal, bytes_per_vec, true);
|
||||
|
||||
|
||||
// convert distances to floats
|
||||
for (int i = 0; i < k * n; i++)
|
||||
distances[i] = idistances[i];
|
||||
|
||||
}
|
||||
|
||||
|
||||
void IndexLSH::transfer_thresholds (LinearTransform *vt) {
|
||||
if (!train_thresholds) return;
|
||||
ABORT_UNLESS (nbits == vt->d_out, "nbits == vt->d_out");
|
||||
if (!vt->have_bias) {
|
||||
vt->b.resize (nbits, 0);
|
||||
vt->have_bias = true;
|
||||
}
|
||||
for (int i = 0; i < nbits; i++)
|
||||
vt->b[i] -= thresholds[i];
|
||||
train_thresholds = false;
|
||||
thresholds.clear();
|
||||
}
|
||||
|
||||
void IndexLSH::reset() {
|
||||
codes.clear();
|
||||
ntotal = 0;
|
||||
}
|
||||
|
||||
|
||||
size_t IndexLSH::sa_code_size () const
|
||||
{
|
||||
return bytes_per_vec;
|
||||
}
|
||||
|
||||
void IndexLSH::sa_encode (idx_t n, const float *x,
|
||||
uint8_t *bytes) const
|
||||
{
|
||||
ABORT_UNLESS (is_trained, "is_trained");
|
||||
const float *xt = apply_preprocess (n, x);
|
||||
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
||||
fvecs2bitvecs (xt, bytes, nbits, n);
|
||||
}
|
||||
|
||||
void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes,
|
||||
float *x) const
|
||||
{
|
||||
float *xt = x;
|
||||
ScopeDeleter<float> del;
|
||||
if (rotate_data || nbits != d) {
|
||||
xt = new float [n * nbits];
|
||||
del.set(xt);
|
||||
}
|
||||
bitvecs2fvecs (bytes, xt, nbits, n);
|
||||
|
||||
if (train_thresholds) {
|
||||
float *xp = xt;
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
for (int j = 0; j < nbits; j++) {
|
||||
*xp++ += thresholds [j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (rotate_data) {
|
||||
rrot.reverse_transform (n, xt, x);
|
||||
} else if (nbits != d) {
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
memcpy (x + i * d, xt + i * nbits,
|
||||
nbits * sizeof(xt[0]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace faiss
|
90
src/3rd_party/faiss/IndexLSH.h
vendored
90
src/3rd_party/faiss/IndexLSH.h
vendored
@ -1,90 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
#ifndef INDEX_LSH_H
|
||||
#define INDEX_LSH_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#include <faiss/VectorTransform.h>
|
||||
|
||||
namespace faiss {
|
||||
|
||||
|
||||
/** The sign of each vector component is put in a binary signature */
|
||||
struct IndexLSH:Index {
|
||||
typedef unsigned char uint8_t;
|
||||
|
||||
int nbits; ///< nb of bits per vector
|
||||
int bytes_per_vec; ///< nb of 8-bits per encoded vector
|
||||
bool rotate_data; ///< whether to apply a random rotation to input
|
||||
bool train_thresholds; ///< whether we train thresholds or use 0
|
||||
|
||||
RandomRotationMatrix rrot; ///< optional random rotation
|
||||
|
||||
std::vector <float> thresholds; ///< thresholds to compare with
|
||||
|
||||
/// encoded dataset
|
||||
std::vector<uint8_t> codes;
|
||||
|
||||
IndexLSH (
|
||||
idx_t d, int nbits,
|
||||
bool rotate_data = true,
|
||||
bool train_thresholds = false);
|
||||
|
||||
/** Preprocesses and resizes the input to the size required to
|
||||
* binarize the data
|
||||
*
|
||||
* @param x input vectors, size n * d
|
||||
* @return output vectors, size n * bits. May be the same pointer
|
||||
* as x, otherwise it should be deleted by the caller
|
||||
*/
|
||||
const float *apply_preprocess (idx_t n, const float *x) const;
|
||||
|
||||
void train(idx_t n, const float* x) override;
|
||||
|
||||
void add(idx_t n, const float* x) override;
|
||||
|
||||
void search(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
idx_t k,
|
||||
float* distances,
|
||||
idx_t* labels) const override;
|
||||
|
||||
void reset() override;
|
||||
|
||||
/// transfer the thresholds to a pre-processing stage (and unset
|
||||
/// train_thresholds)
|
||||
void transfer_thresholds (LinearTransform * vt);
|
||||
|
||||
~IndexLSH() override {}
|
||||
|
||||
IndexLSH ();
|
||||
|
||||
/* standalone codec interface.
|
||||
*
|
||||
* The vectors are decoded to +/- 1 (not 0, 1) */
|
||||
|
||||
size_t sa_code_size () const override;
|
||||
|
||||
void sa_encode (idx_t n, const float *x,
|
||||
uint8_t *bytes) const override;
|
||||
|
||||
void sa_decode (idx_t n, const uint8_t *bytes,
|
||||
float *x) const override;
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif
|
10
src/3rd_party/faiss/utils/hamming-inl.h
vendored
10
src/3rd_party/faiss/utils/hamming-inl.h
vendored
@ -10,8 +10,8 @@
|
||||
namespace faiss {
|
||||
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define bzero(p,n) (memset((p),0,(n)))
|
||||
#ifdef _MSC_VER
|
||||
#define bzero(p,n) (memset((p),0,(n)))
|
||||
#endif
|
||||
inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size):
|
||||
code (code), code_size (code_size), i(0)
|
||||
@ -29,7 +29,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
|
||||
i += nbit;
|
||||
return;
|
||||
} else {
|
||||
int j = i >> 3;
|
||||
size_t j = i >> 3;
|
||||
code[j++] |= x << (i & 7);
|
||||
i += nbit;
|
||||
x >>= na;
|
||||
@ -57,7 +57,7 @@ inline uint64_t BitstringReader::read(int nbit) {
|
||||
return res;
|
||||
} else {
|
||||
int ofs = na;
|
||||
int j = (i >> 3) + 1;
|
||||
size_t j = (i >> 3) + 1;
|
||||
i += nbit;
|
||||
nbit -= na;
|
||||
while (nbit > 8) {
|
||||
@ -160,7 +160,7 @@ struct HammingComputer20 {
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
assert (code_size == 20);
|
||||
const uint64_t *a = (uint64_t *)a8;
|
||||
a0 = a[0]; a1 = a[1]; a2 = a[2];
|
||||
a0 = a[0]; a1 = a[1]; a2 = (uint32_t)a[2];
|
||||
}
|
||||
|
||||
inline int hamming (const uint8_t *b8) const {
|
||||
|
4
src/3rd_party/faiss/utils/hamming.h
vendored
4
src/3rd_party/faiss/utils/hamming.h
vendored
@ -31,7 +31,7 @@
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h> // needed for some intrinsics in <memory>
|
||||
#define __builtin_popcountl __popcnt64
|
||||
#define __builtin_popcountl __popcnt64
|
||||
#endif
|
||||
|
||||
/* The Hamming distance type */
|
||||
@ -116,7 +116,7 @@ struct BitstringReader {
|
||||
extern size_t hamming_batch_size;
|
||||
|
||||
static inline int popcount64(uint64_t x) {
|
||||
return __builtin_popcountl(x);
|
||||
return (int)__builtin_popcountl(x);
|
||||
}
|
||||
|
||||
|
||||
|
@ -75,6 +75,7 @@ set(MARIAN_SOURCES
|
||||
layers/embedding.cpp
|
||||
layers/output.cpp
|
||||
layers/logits.cpp
|
||||
layers/lsh.cpp
|
||||
|
||||
rnn/cells.cpp
|
||||
rnn/attention.cpp
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "common/cli_wrapper.h"
|
||||
#include "tensors/cpu/expression_graph_packable.h"
|
||||
#include "onnx/expression_graph_onnx_exporter.h"
|
||||
#include "layers/lsh.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
@ -25,6 +26,9 @@ int main(int argc, char** argv) {
|
||||
cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512, "
|
||||
"intgemm8, intgemm8ssse3, intgemm8avx2, intgemm8avx512, intgemm16, intgemm16sse2, intgemm16avx2, intgemm16avx512",
|
||||
"float32");
|
||||
cli->add<std::vector<std::string>>("--add-lsh",
|
||||
"Encode output matrix and optional rotation matrix into model file. "
|
||||
"arg1: number of bits in LSH encoding, arg2: name of output weights matrix")->implicit_val("1024 Wemb");
|
||||
cli->add<std::vector<std::string>>("--vocabs,-V", "Vocabulary file, required for ONNX export");
|
||||
cli->parse(argc, argv);
|
||||
options->merge(config);
|
||||
@ -34,6 +38,16 @@ int main(int argc, char** argv) {
|
||||
|
||||
auto exportAs = options->get<std::string>("export-as");
|
||||
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");// , std::vector<std::string>());
|
||||
|
||||
bool addLsh = options->hasAndNotEmpty("add-lsh");
|
||||
int lshNBits = 1024;
|
||||
std::string lshOutputWeights = "Wemb";
|
||||
if(addLsh) {
|
||||
auto lshParams = options->get<std::vector<std::string>>("add-lsh");
|
||||
lshNBits = std::stoi(lshParams[0]);
|
||||
if(lshParams.size() > 1)
|
||||
lshOutputWeights = lshParams[1];
|
||||
}
|
||||
|
||||
// We accept any type here and will later croak during packAndSave if the type cannot be used for conversion
|
||||
Type saveGemmType = typeFromString(options->get<std::string>("gemm-type", "float32"));
|
||||
@ -45,23 +59,36 @@ int main(int argc, char** argv) {
|
||||
marian::io::getYamlFromModel(config, "special:model.yml", modelFrom);
|
||||
configStr << config;
|
||||
|
||||
auto load = [&](Ptr<ExpressionGraph> graph) {
|
||||
graph->setDevice(CPU0);
|
||||
graph->load(modelFrom);
|
||||
graph->forward(); // run the initializers
|
||||
};
|
||||
|
||||
|
||||
if (exportAs == "marian-bin") {
|
||||
auto graph = New<ExpressionGraphPackable>();
|
||||
load(graph);
|
||||
graph->setDevice(CPU0);
|
||||
graph->load(modelFrom);
|
||||
|
||||
if(addLsh) {
|
||||
// Add dummy parameters for the LSH before the model gets actually initialized.
|
||||
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
|
||||
graph->setReloaded(false);
|
||||
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
|
||||
graph->setReloaded(true);
|
||||
}
|
||||
|
||||
graph->forward(); // run the initializers
|
||||
|
||||
if(addLsh) {
|
||||
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
|
||||
// Once this is done we can just pack and save as normal.
|
||||
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
|
||||
}
|
||||
|
||||
// added a flag if the weights needs to be packed or not
|
||||
graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32);
|
||||
}
|
||||
else if (exportAs == "onnx-encode") {
|
||||
#ifdef USE_ONNX
|
||||
auto graph = New<ExpressionGraphONNXExporter>();
|
||||
load(graph);
|
||||
graph->setDevice(CPU0);
|
||||
graph->load(modelFrom);
|
||||
graph->forward(); // run the initializers
|
||||
auto modelOptions = New<Options>(config)->with("vocabs", vocabPaths, "inference", true);
|
||||
|
||||
graph->exportToONNX(modelTo, modelOptions, vocabPaths);
|
||||
|
@ -1,10 +1,7 @@
|
||||
#include "data/shortlist.h"
|
||||
#include "microsoft/shortlist/utils/ParameterTree.h"
|
||||
#include "marian.h"
|
||||
|
||||
#if BLAS_FOUND
|
||||
#include "3rd_party/faiss/IndexLSH.h"
|
||||
#endif
|
||||
#include "layers/lsh.h"
|
||||
|
||||
namespace marian {
|
||||
namespace data {
|
||||
@ -47,7 +44,6 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
Shape kShape({k});
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
|
||||
initialized_ = true;
|
||||
}
|
||||
@ -78,12 +74,10 @@ void Shortlist::createCachedTensors(Expr weights,
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
Ptr<faiss::IndexLSH> LSHShortlist::index_;
|
||||
std::mutex LSHShortlist::mutex_;
|
||||
|
||||
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
|
||||
: Shortlist(std::vector<WordIndex>())
|
||||
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
|
||||
: Shortlist(std::vector<WordIndex>()),
|
||||
k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
|
||||
@ -99,67 +93,23 @@ Expr LSHShortlist::getIndicesExpr() const {
|
||||
}
|
||||
|
||||
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
#if BLAS_FOUND
|
||||
|
||||
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
|
||||
"LSH index (--output-approx-knn) currently not implemented for GPU");
|
||||
|
||||
int currBeamSize = input->shape()[0];
|
||||
int batchSize = input->shape()[2];
|
||||
int numHypos = currBeamSize * batchSize;
|
||||
|
||||
auto forward = [this, numHypos](Expr out, const std::vector<Expr>& inputs) {
|
||||
auto query = inputs[0];
|
||||
auto values = inputs[1];
|
||||
int dim = values->shape()[-1];
|
||||
|
||||
mutex_.lock();
|
||||
if(!index_) {
|
||||
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
|
||||
index_.reset(new faiss::IndexLSH(dim, nbits_,
|
||||
/*rotate=*/dim != nbits_,
|
||||
/*train_thesholds*/false));
|
||||
index_->train(lemmaSize_, values->val()->data<float>());
|
||||
index_->add( lemmaSize_, values->val()->data<float>());
|
||||
}
|
||||
mutex_.unlock();
|
||||
|
||||
int qRows = query->shape().elements() / dim;
|
||||
std::vector<float> distances(qRows * k_);
|
||||
std::vector<faiss::Index::idx_t> ids(qRows * k_);
|
||||
|
||||
index_->search(qRows, query->val()->data<float>(), k_,
|
||||
distances.data(), ids.data());
|
||||
|
||||
indices_.clear();
|
||||
for(auto iter = ids.begin(); iter != ids.end(); ++iter) {
|
||||
faiss::Index::idx_t id = *iter;
|
||||
indices_.push_back((WordIndex)id);
|
||||
}
|
||||
|
||||
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
|
||||
size_t startIdx = k_ * hypoIdx;
|
||||
size_t endIdx = startIdx + k_;
|
||||
std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
|
||||
}
|
||||
out->val()->set(indices_);
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, batchSize, k_});
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_),
|
||||
[this](Expr node) {
|
||||
node->val()->get(indices_); // set the value of the field indices_ whenever the graph traverses this node
|
||||
});
|
||||
|
||||
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_);
|
||||
|
||||
#else
|
||||
input; weights; isLegacyUntransposedW; b; lemmaEt;
|
||||
ABORT("LSH output layer requires a CPU BLAS library");
|
||||
#endif
|
||||
}
|
||||
|
||||
void LSHShortlist::createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k) {
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k) {
|
||||
int currBeamSize = indicesExpr_->shape()[0];
|
||||
int batchSize = indicesExpr_->shape()[1];
|
||||
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
|
@ -25,7 +25,8 @@ namespace data {
|
||||
class Shortlist {
|
||||
protected:
|
||||
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
|
||||
Expr indicesExpr_;
|
||||
Expr indicesExpr_; // cache an expression that contains the short list indices
|
||||
|
||||
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
|
||||
Expr cachedShortb_; // these match the current value of shortlist_
|
||||
Expr cachedShortLemmaEt_;
|
||||
|
@ -646,6 +646,16 @@ public:
|
||||
return it->second;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the Parameters object related to the graph by elementType.
|
||||
* The Parameters object holds the whole set of the parameter nodes of the given type.
|
||||
*/
|
||||
Ptr<Parameters>& params(Type elementType) {
|
||||
auto it = paramsByElementType_.find(elementType);
|
||||
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set default element type for the graph.
|
||||
* The default value is used if some node type is not specified.
|
||||
|
@ -28,13 +28,17 @@ Expr checkpoint(Expr a) {
|
||||
}
|
||||
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
||||
LambdaNodeFunctor fwd) {
|
||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd);
|
||||
LambdaNodeFunctor fwd, size_t hash) {
|
||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
|
||||
}
|
||||
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
||||
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd) {
|
||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd);
|
||||
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
|
||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
|
||||
}
|
||||
|
||||
Expr callback(Expr node, LambdaNodeCallback call) {
|
||||
return Expression<CallbackNodeOp>(node, call);
|
||||
}
|
||||
|
||||
// logistic function. Note: scipy name is expit()
|
||||
|
@ -26,12 +26,19 @@ typedef std::function<void(Expr out, const std::vector<Expr>& in)> LambdaNodeFun
|
||||
/**
|
||||
* Arbitrary node with forward operation only.
|
||||
*/
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd);
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, size_t hash = 0);
|
||||
|
||||
/**
|
||||
* Arbitrary node with forward and backward operation.
|
||||
*/
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd);
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash = 0);
|
||||
|
||||
|
||||
/**
|
||||
* Convience typedef for graph @ref lambda expressions.
|
||||
*/
|
||||
typedef std::function<void(Expr)> LambdaNodeCallback;
|
||||
Expr callback(Expr node, LambdaNodeCallback call);
|
||||
|
||||
/**
|
||||
* @addtogroup graph_ops_activation Activation Functions
|
||||
|
@ -11,6 +11,15 @@ namespace marian {
|
||||
|
||||
namespace inits {
|
||||
|
||||
class DummyInit : public NodeInitializer {
|
||||
public:
|
||||
void apply(Tensor tensor) override {
|
||||
tensor;
|
||||
}
|
||||
};
|
||||
|
||||
Ptr<NodeInitializer> dummy() { return New<DummyInit>(); }
|
||||
|
||||
class LambdaInit : public NodeInitializer {
|
||||
private:
|
||||
std::function<void(Tensor)> lambda_;
|
||||
@ -237,24 +246,3 @@ template Ptr<NodeInitializer> range<IndexType>(IndexType begin, IndexType end, I
|
||||
} // namespace inits
|
||||
} // namespace marian
|
||||
|
||||
|
||||
#if BLAS_FOUND
|
||||
#include "faiss/VectorTransform.h"
|
||||
|
||||
namespace marian {
|
||||
namespace inits {
|
||||
|
||||
Ptr<NodeInitializer> randomRotation(size_t seed) {
|
||||
auto rot = [=](Tensor t) {
|
||||
int rows = t->shape()[-2];
|
||||
int cols = t->shape()[-1];
|
||||
faiss::RandomRotationMatrix rrot(cols, rows); // transposed in faiss
|
||||
rrot.init((int)seed);
|
||||
t->set(rrot.A);
|
||||
};
|
||||
return fromLambda(rot, Type::float32);
|
||||
}
|
||||
|
||||
} // namespace inits
|
||||
} // namespace marian
|
||||
#endif
|
||||
|
@ -35,6 +35,11 @@ public:
|
||||
virtual ~NodeInitializer() {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Dummy do-nothing initializer. Mostly for testing.
|
||||
*/
|
||||
Ptr<NodeInitializer> dummy();
|
||||
|
||||
/**
|
||||
* Use a lambda function of form [](Tensor t) { do something with t } to initialize tensor.
|
||||
* @param func functor
|
||||
@ -263,13 +268,6 @@ Ptr<NodeInitializer> fromWord2vec(const std::string& file,
|
||||
*/
|
||||
Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start);
|
||||
|
||||
/**
|
||||
* Computes a random rotation matrix for LSH hashing.
|
||||
* This is part of a hash function. The values are orthonormal and computed via
|
||||
* QR decomposition. Same seed results in same random rotation.
|
||||
*/
|
||||
Ptr<NodeInitializer> randomRotation(size_t seed = Config::seed);
|
||||
|
||||
/**
|
||||
* Computes the equivalent of Python's range().
|
||||
* Computes a range from begin to end-1, like Python's range().
|
||||
|
@ -21,20 +21,26 @@ private:
|
||||
std::unique_ptr<LambdaNodeFunctor> forward_;
|
||||
std::unique_ptr<LambdaNodeFunctor> backward_;
|
||||
|
||||
size_t externalHash_;
|
||||
|
||||
public:
|
||||
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
|
||||
LambdaNodeFunctor forward)
|
||||
LambdaNodeFunctor forward,
|
||||
size_t externalHash = 0)
|
||||
: NaryNodeOp(inputs, shape, type),
|
||||
forward_(new LambdaNodeFunctor(forward)) {
|
||||
forward_(new LambdaNodeFunctor(forward)),
|
||||
externalHash_(externalHash) {
|
||||
Node::trainable_ = !!backward_;
|
||||
}
|
||||
|
||||
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
|
||||
LambdaNodeFunctor forward,
|
||||
LambdaNodeFunctor backward)
|
||||
LambdaNodeFunctor backward,
|
||||
size_t externalHash = 0)
|
||||
: NaryNodeOp(inputs, shape, type),
|
||||
forward_(new LambdaNodeFunctor(forward)),
|
||||
backward_(new LambdaNodeFunctor(backward)) {
|
||||
backward_(new LambdaNodeFunctor(backward)),
|
||||
externalHash_(externalHash) {
|
||||
}
|
||||
|
||||
void forward() override {
|
||||
@ -50,8 +56,12 @@ public:
|
||||
|
||||
virtual size_t hash() override {
|
||||
size_t seed = NaryNodeOp::hash();
|
||||
util::hash_combine(seed, forward_.get());
|
||||
util::hash_combine(seed, backward_.get());
|
||||
if(externalHash_ != 0) {
|
||||
util::hash_combine(seed, externalHash_);
|
||||
} else {
|
||||
util::hash_combine(seed, forward_.get());
|
||||
util::hash_combine(seed, backward_.get());
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
|
||||
|
@ -795,7 +795,7 @@ private:
|
||||
};
|
||||
|
||||
class ReshapeNodeOp : public UnaryNodeOp {
|
||||
private:
|
||||
protected:
|
||||
friend class SerializationHelpers;
|
||||
Expr reshapee_;
|
||||
|
||||
@ -858,6 +858,45 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// @TODO: add version with access to backward step
|
||||
// This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise
|
||||
// i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the
|
||||
// compute in the lambda function. It gets called after the forward step of the argument node.
|
||||
class CallbackNodeOp : public ReshapeNodeOp {
|
||||
private:
|
||||
typedef std::function<void(Expr)> LambdaNodeCallback;
|
||||
std::unique_ptr<LambdaNodeCallback> callback_;
|
||||
|
||||
public:
|
||||
CallbackNodeOp(Expr node, LambdaNodeCallback callback)
|
||||
: ReshapeNodeOp(node, node->shape()),
|
||||
callback_(new LambdaNodeCallback(callback)) {
|
||||
}
|
||||
|
||||
void forward() override {
|
||||
(*callback_)(ReshapeNodeOp::reshapee_);
|
||||
}
|
||||
|
||||
const std::string type() override { return "callback"; }
|
||||
|
||||
virtual size_t hash() override {
|
||||
size_t seed = ReshapeNodeOp::hash();
|
||||
util::hash_combine(seed, callback_.get());
|
||||
return seed;
|
||||
}
|
||||
|
||||
virtual bool equal(Expr node) override {
|
||||
if(!ReshapeNodeOp::equal(node))
|
||||
return false;
|
||||
auto cnode = std::dynamic_pointer_cast<CallbackNodeOp>(node);
|
||||
if(!cnode)
|
||||
return false;
|
||||
if(callback_ != cnode->callback_) // pointer compare on purpose
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// @TODO: review if still required as this is an ugly hack anyway.
|
||||
// Memory less operator that clips gradients during backward step
|
||||
// Executes this as an additional operation on the gradient.
|
||||
|
233
src/layers/lsh.cpp
Normal file
233
src/layers/lsh.cpp
Normal file
@ -0,0 +1,233 @@
|
||||
#include "layers/lsh.h"
|
||||
#include "tensors/tensor_operators.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
#include "3rd_party/faiss/utils/hamming.h"
|
||||
#include "3rd_party/faiss/Index.h"
|
||||
|
||||
#if BLAS_FOUND
|
||||
#include "3rd_party/faiss/VectorTransform.h"
|
||||
#endif
|
||||
|
||||
|
||||
namespace marian {
|
||||
namespace lsh {
|
||||
|
||||
int bytesPerVector(int nBits) {
|
||||
return (nBits + 7) / 8;
|
||||
}
|
||||
|
||||
void fillRandomRotationMatrix(Tensor output, Ptr<Allocator> allocator) {
|
||||
#if BLAS_FOUND
|
||||
int nRows = output->shape()[-2];
|
||||
int nBits = output->shape()[-1];
|
||||
|
||||
// @TODO re-implement using Marian code so it uses the correct random generator etc.
|
||||
faiss::RandomRotationMatrix rrot(nRows, nBits);
|
||||
// Then we do not need to use this seed at all
|
||||
rrot.init(5); // currently set to 5 following the default from FAISS, this could be any number really.
|
||||
|
||||
// The faiss random rotation matrix is column major, hence we create a temporary tensor,
|
||||
// copy the rotation matrix into it and transpose to output.
|
||||
Shape tempShape = {nBits, nRows};
|
||||
auto memory = allocator->alloc(requiredBytes(tempShape, output->type()));
|
||||
auto temp = TensorBase::New(memory,
|
||||
tempShape,
|
||||
output->type(),
|
||||
output->getBackend());
|
||||
temp->set(rrot.A);
|
||||
TransposeND(output, temp, {0, 1, 3, 2});
|
||||
allocator->free(memory);
|
||||
#else
|
||||
output; allocator;
|
||||
ABORT("LSH with rotation matrix requires Marian to be compiled with a BLAS library");
|
||||
#endif
|
||||
}
|
||||
|
||||
void encode(Tensor output, Tensor input) {
|
||||
int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix
|
||||
int nRows = input->shape().elements() / nBits;
|
||||
faiss::fvecs2bitvecs(input->data<float>(), output->data<uint8_t>(), (size_t)nBits, (size_t)nRows);
|
||||
}
|
||||
|
||||
void encodeWithRotation(Tensor output, Tensor input, Tensor rotation, Ptr<Allocator> allocator) {
|
||||
int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix unless we rotate
|
||||
int nRows = input->shape().elements() / nBits;
|
||||
|
||||
Tensor tempInput = input;
|
||||
MemoryPiece::PtrType memory;
|
||||
if(rotation) {
|
||||
int nBitsRot = rotation->shape()[-1];
|
||||
Shape tempShape = {nRows, nBitsRot};
|
||||
memory = allocator->alloc(requiredBytes(tempShape, rotation->type()));
|
||||
tempInput = TensorBase::New(memory, tempShape, rotation->type(), rotation->getBackend());
|
||||
Prod(tempInput, input, rotation, false, false, 0.f, 1.f);
|
||||
}
|
||||
encode(output, tempInput);
|
||||
|
||||
if(memory)
|
||||
allocator->free(memory);
|
||||
};
|
||||
|
||||
Expr encode(Expr input, Expr rotation) {
|
||||
auto encodeFwd = [](Expr out, const std::vector<Expr>& inputs) {
|
||||
if(inputs.size() == 1) {
|
||||
encode(out->val(), inputs[0]->val());
|
||||
} else if(inputs.size() == 2) {
|
||||
encodeWithRotation(out->val(), inputs[0]->val(), inputs[1]->val(), out->graph()->allocator());
|
||||
} else {
|
||||
ABORT("Too many inputs to encode??");
|
||||
}
|
||||
};
|
||||
|
||||
// Use the address of the first lambda function as an immutable hash. Making it static and const makes sure
|
||||
// that this hash value will not change. Next pass the hash into the lambda functor were it will be used
|
||||
// to identify this unique operation. Marian's ExpressionGraph can automatically memoize and identify nodes
|
||||
// that operate only on immutable nodes (parameters) and have the same hash. This way we make sure that the
|
||||
// codes node won't actually get recomputed throughout ExpressionGraph lifetime. `codes` will be reused
|
||||
// and the body of the lambda will not be called again. This does however build one index per graph.
|
||||
static const size_t encodeHash = (size_t)&encodeFwd;
|
||||
|
||||
Shape encodedShape = input->shape();
|
||||
|
||||
int nBits = rotation ? rotation->shape()[-1] : input->shape()[-1];
|
||||
encodedShape.set(-1, bytesPerVector(nBits));
|
||||
std::vector<Expr> inputs = {input};
|
||||
if(rotation)
|
||||
inputs.push_back(rotation);
|
||||
return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash);
|
||||
}
|
||||
|
||||
Expr rotator(Expr weights, int nBits) {
|
||||
auto rotator = [](Expr out, const std::vector<Expr>& inputs) {
|
||||
inputs;
|
||||
fillRandomRotationMatrix(out->val(), out->graph()->allocator());
|
||||
};
|
||||
|
||||
static const size_t rotatorHash = (size_t)&rotator;
|
||||
int dim = weights->shape()[-1];
|
||||
return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash);
|
||||
}
|
||||
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) {
|
||||
ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1],
|
||||
"Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]);
|
||||
|
||||
int currBeamSize = encodedQuery->shape()[0];
|
||||
int batchSize = encodedQuery->shape()[2];
|
||||
int numHypos = currBeamSize * batchSize;
|
||||
|
||||
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
|
||||
Expr encodedQuery = inputs[0];
|
||||
Expr encodedWeights = inputs[1];
|
||||
|
||||
int bytesPerVector = encodedWeights->shape()[-1];
|
||||
int wRows = encodedWeights->shape().elements() / bytesPerVector;
|
||||
|
||||
// we use this with Factored Segmenter to skip the factor embeddings at the end
|
||||
if(firstNRows != 0)
|
||||
wRows = firstNRows;
|
||||
|
||||
int qRows = encodedQuery->shape().elements() / bytesPerVector;
|
||||
|
||||
uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
|
||||
uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
|
||||
|
||||
// use actual faiss code for performing the hamming search.
|
||||
std::vector<int> distances(qRows * k);
|
||||
std::vector<faiss::Index::idx_t> ids(qRows * k);
|
||||
faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()};
|
||||
faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
|
||||
|
||||
// Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
|
||||
// The sorting is required as we later do a binary search on those values for reverse look-up.
|
||||
uint32_t* outData = out->val()->data<uint32_t>();
|
||||
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
|
||||
size_t startIdx = k * hypoIdx;
|
||||
size_t endIdx = startIdx + k;
|
||||
for(size_t i = startIdx; i < endIdx; ++i)
|
||||
outData[i] = (uint32_t)ids[i];
|
||||
std::sort(outData + startIdx, outData + endIdx);
|
||||
}
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, batchSize, k});
|
||||
return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
|
||||
}
|
||||
|
||||
Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) {
|
||||
int dim = weights->shape()[-1];
|
||||
|
||||
Expr rotMat = nullptr;
|
||||
if(dim != nBits) {
|
||||
rotMat = weights->graph()->get("lsh_output_rotation");
|
||||
if(rotMat) {
|
||||
LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape());
|
||||
} else {
|
||||
LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
|
||||
rotMat = rotator(weights, nBits);
|
||||
}
|
||||
}
|
||||
|
||||
Expr encodedWeights = weights->graph()->get("lsh_output_codes");
|
||||
if(encodedWeights) {
|
||||
LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape());
|
||||
} else {
|
||||
LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)}));
|
||||
encodedWeights = encode(weights, rotMat);
|
||||
}
|
||||
|
||||
return searchEncoded(encode(query, rotMat), encodedWeights, k, firstNRows);
|
||||
}
|
||||
|
||||
class RandomRotation : public inits::NodeInitializer {
|
||||
public:
|
||||
void apply(Tensor tensor) override {
|
||||
auto sharedAllocator = allocator_.lock();
|
||||
ABORT_IF(!sharedAllocator, "Allocator in RandomRotation has not been set or expired");
|
||||
fillRandomRotationMatrix(tensor, sharedAllocator);
|
||||
}
|
||||
};
|
||||
|
||||
Ptr<inits::NodeInitializer> randomRotation() {
|
||||
return New<RandomRotation>();
|
||||
}
|
||||
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBitsRot) {
|
||||
auto weights = graph->get(weightsName);
|
||||
|
||||
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
|
||||
|
||||
int nBits = weights->shape()[-1];
|
||||
int nRows = weights->shape().elements() / nBits;
|
||||
|
||||
Expr rotation;
|
||||
if(nBits != nBitsRot) {
|
||||
LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot}));
|
||||
rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32);
|
||||
nBits = nBitsRot;
|
||||
}
|
||||
|
||||
int bytesPerVector = lsh::bytesPerVector(nBits);
|
||||
LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector}));
|
||||
auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
|
||||
}
|
||||
|
||||
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName) {
|
||||
Expr weights = graph->get(weightsName);
|
||||
Expr codes = graph->get("lsh_output_codes");
|
||||
Expr rotation = graph->get("lsh_output_rotation");
|
||||
|
||||
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
|
||||
ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??");
|
||||
|
||||
if(rotation) {
|
||||
fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator());
|
||||
encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator());
|
||||
} else {
|
||||
encode(codes->val(), weights->val());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
49
src/layers/lsh.h
Normal file
49
src/layers/lsh.h
Normal file
@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include "graph/expression_operators.h"
|
||||
#include "graph/node_initializers.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
/**
|
||||
* In this file we bascially take the faiss::IndexLSH and pick it apart so that the individual steps
|
||||
* can be implemented as Marian inference operators. We can encode the inputs and weights into their
|
||||
* bitwise equivalents, apply the hashing rotation (if required), and perform the actual search.
|
||||
*
|
||||
* This also allows to create parameters that get dumped into the model weight file. This is currently
|
||||
* a bit hacky (see marian-conv), but once this is done the model can memory-map the LSH with existing
|
||||
* mechanisms and no additional memory is consumed to build the index or rotation matrix.
|
||||
*/
|
||||
|
||||
namespace marian {
|
||||
namespace lsh {
|
||||
|
||||
// return the number of full bytes required to encoded that many bits
|
||||
int bytesPerVector(int nBits);
|
||||
|
||||
// encodes an input as a bit vector, with optional rotation
|
||||
Expr encode(Expr input, Expr rotator = nullptr);
|
||||
|
||||
// compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
|
||||
Expr rotator(Expr weights, int nbits);
|
||||
|
||||
// perform the LSH search on fully encoded input and weights, return k results (indices) per input row
|
||||
// @TODO: add a top-k like operator that also returns the bitwise computed distances
|
||||
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0);
|
||||
|
||||
// same as above, but performs encoding on the fly
|
||||
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0);
|
||||
|
||||
// These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
|
||||
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBits);
|
||||
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName);
|
||||
|
||||
/**
|
||||
* Computes a random rotation matrix for LSH hashing.
|
||||
* This is part of a hash function. The values are orthonormal and computed via
|
||||
* QR decomposition.
|
||||
*/
|
||||
Ptr<inits::NodeInitializer> randomRotation();
|
||||
}
|
||||
|
||||
}
|
@ -11,6 +11,7 @@
|
||||
#include "data/alignment.h"
|
||||
#include "data/vocab_base.h"
|
||||
#include "tensors/cpu/expression_graph_packable.h"
|
||||
#include "layers/lsh.h"
|
||||
|
||||
#if USE_FBGEMM
|
||||
#include "fbgemm/Utils.h"
|
||||
@ -248,7 +249,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) {
|
||||
// This function converts an fp32 model into an FBGEMM based packed model.
|
||||
// marian defined types are used for external project as well.
|
||||
// The targetPrec is passed as int32_t for the exported function definition.
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec) {
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) {
|
||||
std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl;
|
||||
|
||||
YAML::Node config;
|
||||
@ -260,7 +261,26 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
|
||||
graph->setDevice(CPU0);
|
||||
|
||||
graph->load(inputFile);
|
||||
graph->forward();
|
||||
|
||||
// MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students.
|
||||
// The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config.
|
||||
int lshNBits = 1024;
|
||||
std::string lshOutputWeights = "Wemb";
|
||||
if(addLsh) {
|
||||
// Add dummy parameters for the LSH before the model gets actually initialized.
|
||||
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
|
||||
graph->setReloaded(false);
|
||||
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
|
||||
graph->setReloaded(true);
|
||||
}
|
||||
|
||||
graph->forward(); // run the initializers
|
||||
|
||||
if(addLsh) {
|
||||
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
|
||||
// Once this is done we can just pack and save as normal.
|
||||
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
|
||||
}
|
||||
|
||||
Type targetPrecType = (Type) targetPrec;
|
||||
if (targetPrecType == Type::packed16
|
||||
|
@ -76,7 +76,10 @@ std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocab
|
||||
DecoderCpuAvxVersion getCpuAvxVersion();
|
||||
DecoderCpuAvxVersion parseCpuAvxVersion(std::string name);
|
||||
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec);
|
||||
// MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose.
|
||||
// The calling code should be adapted, not this interface. If you need to fix things in QS because of this
|
||||
// talk to me first!
|
||||
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh);
|
||||
|
||||
} // namespace quicksand
|
||||
} // namespace marian
|
||||
|
@ -27,14 +27,17 @@ public:
|
||||
virtual ~ExpressionGraphPackable() {}
|
||||
|
||||
// Convert model weights into packed format and save to IO items.
|
||||
// @TODO: review this
|
||||
void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
|
||||
std::vector<io::Item> pack(Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
|
||||
std::vector<io::Item> ioItems;
|
||||
|
||||
// handle packable parameters first (a float32 parameter is packable)
|
||||
auto packableParameters = paramsByElementType_[Type::float32];
|
||||
// sorted by name in std::map
|
||||
for (auto p : params()->getMap()) {
|
||||
for (auto p : packableParameters->getMap()) {
|
||||
std::string pName = p.first;
|
||||
|
||||
LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type());
|
||||
|
||||
if (!namespace_.empty()) {
|
||||
if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
|
||||
pName = pName.substr(namespace_.size() + 2);
|
||||
@ -257,6 +260,33 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Now handle all non-float32 parameters
|
||||
for(auto& iter : paramsByElementType_) {
|
||||
auto type = iter.first;
|
||||
if(type == Type::float32)
|
||||
continue;
|
||||
|
||||
for (auto p : iter.second->getMap()) {
|
||||
std::string pName = p.first;
|
||||
LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type());
|
||||
|
||||
if (!namespace_.empty()) {
|
||||
if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
|
||||
pName = pName.substr(namespace_.size() + 2);
|
||||
}
|
||||
|
||||
Tensor val = p.second->val();
|
||||
io::Item item;
|
||||
val->get(item, pName);
|
||||
ioItems.emplace_back(std::move(item));
|
||||
}
|
||||
}
|
||||
|
||||
return ioItems;
|
||||
}
|
||||
|
||||
void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
|
||||
auto ioItems = pack(gemmElementType, saveElementType);
|
||||
if (!meta.empty())
|
||||
io::addMetaToItems(meta, "special:model.yml", ioItems);
|
||||
io::saveItems(name, ioItems);
|
||||
|
@ -35,7 +35,8 @@ class TensorBase {
|
||||
|
||||
ENABLE_INTRUSIVE_PTR(TensorBase)
|
||||
|
||||
// Constructors are private, use TensorBase::New(...)
|
||||
protected:
|
||||
// Constructors are protected, use TensorBase::New(...)
|
||||
TensorBase(MemoryPiece::PtrType memory,
|
||||
Shape shape,
|
||||
Type type,
|
||||
@ -61,10 +62,10 @@ class TensorBase {
|
||||
shape_(shape), type_(type), backend_(backend) {}
|
||||
|
||||
public:
|
||||
// Use this whenever pointing to MemoryPiece
|
||||
// Use this whenever pointing to TensorBase
|
||||
typedef IPtr<TensorBase> PtrType;
|
||||
|
||||
// Use this whenever creating a pointer to MemoryPiece
|
||||
// Use this whenever creating a pointer to TensorBase
|
||||
template <class ...Args>
|
||||
static PtrType New(Args&& ...args) {
|
||||
return PtrType(new TensorBase(std::forward<Args>(args)...));
|
||||
|
@ -142,8 +142,9 @@ public:
|
||||
// for periods.
|
||||
bool enteredNewPeriodOf(std::string schedulingParam) const {
|
||||
auto period = SchedulingParameter::parse(schedulingParam);
|
||||
// @TODO: adapt to logical epochs
|
||||
ABORT_IF(period.unit == SchedulingUnit::epochs,
|
||||
"Unit {} is not supported for frequency parameters (the one(s) with value {})",
|
||||
"Unit {} is not supported for frequency parameters",
|
||||
schedulingParam);
|
||||
auto previousProgress = getPreviousProgressIn(period.unit);
|
||||
auto progress = getProgressIn(period.unit);
|
||||
|
Loading…
Reference in New Issue
Block a user