Merged PR 19685: Marianize LSH as operators for mmapping and use in Quicksand

This PR turns the LSH index and search into a set of operators that live in the expression graph. This makes creation etc. thread-safe (one index per graph) and allows to later implement GPU versions.

This allows to mmap the LSH as a Marian parameter since now we only need to turn the index into something that can be saved to disk using the existing tensors. This happens in marian_conv or the equivalent interface function in the Quicksand interface.
This commit is contained in:
Martin Junczys-Dowmunt 2021-07-09 20:35:09 +00:00
parent d6c09b24de
commit 35c822eb4e
24 changed files with 502 additions and 740 deletions

View File

@ -1,119 +0,0 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include "Index.h"
#include "common/logging.h"
#include <cstring>
namespace faiss {
Index::~Index ()
{
}
void Index::train(idx_t /*n*/, const float* /*x*/) {
// does nothing by default
}
void Index::range_search (idx_t , const float *, float,
RangeSearchResult *) const
{
ABORT ("range search not implemented");
}
void Index::assign (idx_t n, const float * x, idx_t * labels, idx_t k)
{
float * distances = new float[n * k];
ScopeDeleter<float> del(distances);
search (n, x, k, distances, labels);
}
void Index::add_with_ids(
idx_t /*n*/,
const float* /*x*/,
const idx_t* /*xids*/) {
ABORT ("add_with_ids not implemented for this type of index");
}
size_t Index::remove_ids(const IDSelector& /*sel*/) {
ABORT ("remove_ids not implemented for this type of index");
return -1;
}
void Index::reconstruct (idx_t, float * ) const {
ABORT ("reconstruct not implemented for this type of index");
}
void Index::reconstruct_n (idx_t i0, idx_t ni, float *recons) const {
for (idx_t i = 0; i < ni; i++) {
reconstruct (i0 + i, recons + i * d);
}
}
void Index::search_and_reconstruct (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels,
float *recons) const {
search (n, x, k, distances, labels);
for (idx_t i = 0; i < n; ++i) {
for (idx_t j = 0; j < k; ++j) {
idx_t ij = i * k + j;
idx_t key = labels[ij];
float* reconstructed = recons + ij * d;
if (key < 0) {
// Fill with NaNs
memset(reconstructed, -1, sizeof(*reconstructed) * d);
} else {
reconstruct (key, reconstructed);
}
}
}
}
void Index::compute_residual (const float * x,
float * residual, idx_t key) const {
reconstruct (key, residual);
for (size_t i = 0; i < d; i++) {
residual[i] = x[i] - residual[i];
}
}
void Index::compute_residual_n (idx_t n, const float* xs,
float* residuals,
const idx_t* keys) const {
//#pragma omp parallel for
for (idx_t i = 0; i < n; ++i) {
compute_residual(&xs[i * d], &residuals[i * d], keys[i]);
}
}
size_t Index::sa_code_size () const
{
ABORT ("standalone codec not implemented for this type of index");
}
void Index::sa_encode (idx_t, const float *,
uint8_t *) const
{
ABORT ("standalone codec not implemented for this type of index");
}
void Index::sa_decode (idx_t, const uint8_t *,
float *) const
{
ABORT ("standalone codec not implemented for this type of index");
}
}

View File

@ -39,11 +39,6 @@
namespace faiss {
/// Forward declarations see AuxIndexStructures.h
struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;
/** Abstract structure for an index, supports adding vectors and searching them.
*
* All vectors provided at add or search time are 32-bit float arrays,
@ -53,178 +48,6 @@ struct Index {
using idx_t = int64_t; ///< all indices are this type
using component_t = float;
using distance_t = float;
int d; ///< vector dimension
idx_t ntotal; ///< total nb of indexed vectors
bool verbose; ///< verbosity level
/// set if the Index does not require training, or if training is
/// done already
bool is_trained;
/// type of metric this index uses for search
MetricType metric_type;
float metric_arg; ///< argument of the metric type
explicit Index (idx_t d = 0, MetricType metric = METRIC_L2):
d((int)d),
ntotal(0),
verbose(false),
is_trained(true),
metric_type (metric),
metric_arg(0) {}
virtual ~Index ();
/** Perform training on a representative set of vectors
*
* @param n nb of training vectors
* @param x training vecors, size n * d
*/
virtual void train(idx_t n, const float* x);
/** Add n vectors of dimension d to the index.
*
* Vectors are implicitly assigned labels ntotal .. ntotal + n - 1
* This function slices the input vectors in chuncks smaller than
* blocksize_add and calls add_core.
* @param x input matrix, size n * d
*/
virtual void add (idx_t n, const float *x) = 0;
/** Same as add, but stores xids instead of sequential ids.
*
* The default implementation fails with an assertion, as it is
* not supported by all indexes.
*
* @param xids if non-null, ids to store for the vectors (size n)
*/
virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids);
/** query n vectors of dimension d to the index.
*
* return at most k vectors. If there are not enough results for a
* query, the result array is padded with -1s.
*
* @param x input vectors to search, size n * d
* @param labels output labels of the NNs, size n*k
* @param distances output pairwise distances, size n*k
*/
virtual void search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels) const = 0;
/** query n vectors of dimension d to the index.
*
* return all vectors with distance < radius. Note that many
* indexes do not implement the range_search (only the k-NN search
* is mandatory).
*
* @param x input vectors to search, size n * d
* @param radius search radius
* @param result result table
*/
virtual void range_search (idx_t n, const float *x, float radius,
RangeSearchResult *result) const;
/** return the indexes of the k vectors closest to the query x.
*
* This function is identical as search but only return labels of neighbors.
* @param x input vectors to search, size n * d
* @param labels output labels of the NNs, size n*k
*/
void assign (idx_t n, const float * x, idx_t * labels, idx_t k = 1);
/// removes all elements from the database.
virtual void reset() = 0;
/** removes IDs from the index. Not supported by all
* indexes. Returns the number of elements removed.
*/
virtual size_t remove_ids (const IDSelector & sel);
/** Reconstruct a stored vector (or an approximation if lossy coding)
*
* this function may not be defined for some indexes
* @param key id of the vector to reconstruct
* @param recons reconstucted vector (size d)
*/
virtual void reconstruct (idx_t key, float * recons) const;
/** Reconstruct vectors i0 to i0 + ni - 1
*
* this function may not be defined for some indexes
* @param recons reconstucted vector (size ni * d)
*/
virtual void reconstruct_n (idx_t i0, idx_t ni, float *recons) const;
/** Similar to search, but also reconstructs the stored vectors (or an
* approximation in the case of lossy coding) for the search results.
*
* If there are not enough results for a query, the resulting arrays
* is padded with -1s.
*
* @param recons reconstructed vectors size (n, k, d)
**/
virtual void search_and_reconstruct (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels,
float *recons) const;
/** Computes a residual vector after indexing encoding.
*
* The residual vector is the difference between a vector and the
* reconstruction that can be decoded from its representation in
* the index. The residual can be used for multiple-stage indexing
* methods, like IndexIVF's methods.
*
* @param x input vector, size d
* @param residual output residual vector, size d
* @param key encoded index, as returned by search and assign
*/
virtual void compute_residual (const float * x,
float * residual, idx_t key) const;
/** Computes a residual vector after indexing encoding (batch form).
* Equivalent to calling compute_residual for each vector.
*
* The residual vector is the difference between a vector and the
* reconstruction that can be decoded from its representation in
* the index. The residual can be used for multiple-stage indexing
* methods, like IndexIVF's methods.
*
* @param n number of vectors
* @param xs input vectors, size (n x d)
* @param residuals output residual vectors, size (n x d)
* @param keys encoded index, as returned by search and assign
*/
virtual void compute_residual_n (idx_t n, const float* xs,
float* residuals,
const idx_t* keys) const;
/* The standalone codec interface */
/** size of the produced codes in bytes */
virtual size_t sa_code_size () const;
/** encode a set of vectors
*
* @param n number of vectors
* @param x input vectors, size n * d
* @param bytes output encoded vectors, size n * sa_code_size()
*/
virtual void sa_encode (idx_t n, const float *x,
uint8_t *bytes) const;
/** encode a set of vectors
*
* @param n number of vectors
* @param bytes input encoded vectors, size n * sa_code_size()
* @param x output vectors, size n * d
*/
virtual void sa_decode (idx_t n, const uint8_t *bytes,
float *x) const;
};
}

View File

@ -1,224 +0,0 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexLSH.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <faiss/utils/hamming.h>
#include "common/logging.h"
namespace faiss {
/***************************************************************
* IndexLSH
***************************************************************/
IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
Index(d), nbits(nbits), rotate_data(rotate_data),
train_thresholds (train_thresholds), rrot(d, nbits)
{
is_trained = !train_thresholds;
bytes_per_vec = (nbits + 7) / 8;
if (rotate_data) {
rrot.init(5);
} else {
ABORT_UNLESS(d >= nbits, "d >= nbits");
}
}
IndexLSH::IndexLSH ():
nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
{
}
const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
{
float *xt = nullptr;
if (rotate_data) {
// also applies bias if exists
xt = rrot.apply (n, x);
} else if (d != nbits) {
assert (nbits < d);
xt = new float [nbits * n];
float *xp = xt;
for (idx_t i = 0; i < n; i++) {
const float *xl = x + i * d;
for (int j = 0; j < nbits; j++)
*xp++ = xl [j];
}
}
if (train_thresholds) {
if (xt == NULL) {
xt = new float [nbits * n];
memcpy (xt, x, sizeof(*x) * n * nbits);
}
float *xp = xt;
for (idx_t i = 0; i < n; i++)
for (int j = 0; j < nbits; j++)
*xp++ -= thresholds [j];
}
return xt ? xt : x;
}
void IndexLSH::train (idx_t n, const float *x)
{
if (train_thresholds) {
thresholds.resize (nbits);
train_thresholds = false;
const float *xt = apply_preprocess (n, x);
ScopeDeleter<float> del (xt == x ? nullptr : xt);
train_thresholds = true;
float * transposed_x = new float [n * nbits];
ScopeDeleter<float> del2 (transposed_x);
for (idx_t i = 0; i < n; i++)
for (idx_t j = 0; j < nbits; j++)
transposed_x [j * n + i] = xt [i * nbits + j];
for (idx_t i = 0; i < nbits; i++) {
float *xi = transposed_x + i * n;
// std::nth_element
std::sort (xi, xi + n);
if (n % 2 == 1)
thresholds [i] = xi [n / 2];
else
thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
}
}
is_trained = true;
}
void IndexLSH::add (idx_t n, const float *x)
{
ABORT_UNLESS (is_trained, "is_trained");
codes.resize ((ntotal + n) * bytes_per_vec);
sa_encode (n, x, &codes[ntotal * bytes_per_vec]);
ntotal += n;
}
void IndexLSH::search (
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels) const
{
ABORT_UNLESS (is_trained, "is_trained");
const float *xt = apply_preprocess (n, x);
ScopeDeleter<float> del (xt == x ? nullptr : xt);
uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
ScopeDeleter<uint8_t> del2 (qcodes);
fvecs2bitvecs (xt, qcodes, nbits, n);
int * idistances = new int [n * k];
ScopeDeleter<int> del3 (idistances);
int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
hammings_knn_hc (&res, qcodes, codes.data(),
ntotal, bytes_per_vec, true);
// convert distances to floats
for (int i = 0; i < k * n; i++)
distances[i] = idistances[i];
}
void IndexLSH::transfer_thresholds (LinearTransform *vt) {
if (!train_thresholds) return;
ABORT_UNLESS (nbits == vt->d_out, "nbits == vt->d_out");
if (!vt->have_bias) {
vt->b.resize (nbits, 0);
vt->have_bias = true;
}
for (int i = 0; i < nbits; i++)
vt->b[i] -= thresholds[i];
train_thresholds = false;
thresholds.clear();
}
void IndexLSH::reset() {
codes.clear();
ntotal = 0;
}
size_t IndexLSH::sa_code_size () const
{
return bytes_per_vec;
}
void IndexLSH::sa_encode (idx_t n, const float *x,
uint8_t *bytes) const
{
ABORT_UNLESS (is_trained, "is_trained");
const float *xt = apply_preprocess (n, x);
ScopeDeleter<float> del (xt == x ? nullptr : xt);
fvecs2bitvecs (xt, bytes, nbits, n);
}
void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes,
float *x) const
{
float *xt = x;
ScopeDeleter<float> del;
if (rotate_data || nbits != d) {
xt = new float [n * nbits];
del.set(xt);
}
bitvecs2fvecs (bytes, xt, nbits, n);
if (train_thresholds) {
float *xp = xt;
for (idx_t i = 0; i < n; i++) {
for (int j = 0; j < nbits; j++) {
*xp++ += thresholds [j];
}
}
}
if (rotate_data) {
rrot.reverse_transform (n, xt, x);
} else if (nbits != d) {
for (idx_t i = 0; i < n; i++) {
memcpy (x + i * d, xt + i * nbits,
nbits * sizeof(xt[0]));
}
}
}
} // namespace faiss

View File

@ -1,90 +0,0 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#ifndef INDEX_LSH_H
#define INDEX_LSH_H
#include <vector>
#include <faiss/Index.h>
#include <faiss/VectorTransform.h>
namespace faiss {
/** The sign of each vector component is put in a binary signature */
struct IndexLSH:Index {
typedef unsigned char uint8_t;
int nbits; ///< nb of bits per vector
int bytes_per_vec; ///< nb of 8-bits per encoded vector
bool rotate_data; ///< whether to apply a random rotation to input
bool train_thresholds; ///< whether we train thresholds or use 0
RandomRotationMatrix rrot; ///< optional random rotation
std::vector <float> thresholds; ///< thresholds to compare with
/// encoded dataset
std::vector<uint8_t> codes;
IndexLSH (
idx_t d, int nbits,
bool rotate_data = true,
bool train_thresholds = false);
/** Preprocesses and resizes the input to the size required to
* binarize the data
*
* @param x input vectors, size n * d
* @return output vectors, size n * bits. May be the same pointer
* as x, otherwise it should be deleted by the caller
*/
const float *apply_preprocess (idx_t n, const float *x) const;
void train(idx_t n, const float* x) override;
void add(idx_t n, const float* x) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
void reset() override;
/// transfer the thresholds to a pre-processing stage (and unset
/// train_thresholds)
void transfer_thresholds (LinearTransform * vt);
~IndexLSH() override {}
IndexLSH ();
/* standalone codec interface.
*
* The vectors are decoded to +/- 1 (not 0, 1) */
size_t sa_code_size () const override;
void sa_encode (idx_t n, const float *x,
uint8_t *bytes) const override;
void sa_decode (idx_t n, const uint8_t *bytes,
float *x) const override;
};
}
#endif

View File

@ -10,8 +10,8 @@
namespace faiss {
#ifdef _MSC_VER
#define bzero(p,n) (memset((p),0,(n)))
#ifdef _MSC_VER
#define bzero(p,n) (memset((p),0,(n)))
#endif
inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size):
code (code), code_size (code_size), i(0)
@ -29,7 +29,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
i += nbit;
return;
} else {
int j = i >> 3;
size_t j = i >> 3;
code[j++] |= x << (i & 7);
i += nbit;
x >>= na;
@ -57,7 +57,7 @@ inline uint64_t BitstringReader::read(int nbit) {
return res;
} else {
int ofs = na;
int j = (i >> 3) + 1;
size_t j = (i >> 3) + 1;
i += nbit;
nbit -= na;
while (nbit > 8) {
@ -160,7 +160,7 @@ struct HammingComputer20 {
void set (const uint8_t *a8, int code_size) {
assert (code_size == 20);
const uint64_t *a = (uint64_t *)a8;
a0 = a[0]; a1 = a[1]; a2 = a[2];
a0 = a[0]; a1 = a[1]; a2 = (uint32_t)a[2];
}
inline int hamming (const uint8_t *b8) const {

View File

@ -31,7 +31,7 @@
#ifdef _MSC_VER
#include <intrin.h> // needed for some intrinsics in <memory>
#define __builtin_popcountl __popcnt64
#define __builtin_popcountl __popcnt64
#endif
/* The Hamming distance type */
@ -116,7 +116,7 @@ struct BitstringReader {
extern size_t hamming_batch_size;
static inline int popcount64(uint64_t x) {
return __builtin_popcountl(x);
return (int)__builtin_popcountl(x);
}

View File

@ -75,6 +75,7 @@ set(MARIAN_SOURCES
layers/embedding.cpp
layers/output.cpp
layers/logits.cpp
layers/lsh.cpp
rnn/cells.cpp
rnn/attention.cpp

View File

@ -2,6 +2,7 @@
#include "common/cli_wrapper.h"
#include "tensors/cpu/expression_graph_packable.h"
#include "onnx/expression_graph_onnx_exporter.h"
#include "layers/lsh.h"
#include <sstream>
@ -25,6 +26,9 @@ int main(int argc, char** argv) {
cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512, "
"intgemm8, intgemm8ssse3, intgemm8avx2, intgemm8avx512, intgemm16, intgemm16sse2, intgemm16avx2, intgemm16avx512",
"float32");
cli->add<std::vector<std::string>>("--add-lsh",
"Encode output matrix and optional rotation matrix into model file. "
"arg1: number of bits in LSH encoding, arg2: name of output weights matrix")->implicit_val("1024 Wemb");
cli->add<std::vector<std::string>>("--vocabs,-V", "Vocabulary file, required for ONNX export");
cli->parse(argc, argv);
options->merge(config);
@ -34,6 +38,16 @@ int main(int argc, char** argv) {
auto exportAs = options->get<std::string>("export-as");
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");// , std::vector<std::string>());
bool addLsh = options->hasAndNotEmpty("add-lsh");
int lshNBits = 1024;
std::string lshOutputWeights = "Wemb";
if(addLsh) {
auto lshParams = options->get<std::vector<std::string>>("add-lsh");
lshNBits = std::stoi(lshParams[0]);
if(lshParams.size() > 1)
lshOutputWeights = lshParams[1];
}
// We accept any type here and will later croak during packAndSave if the type cannot be used for conversion
Type saveGemmType = typeFromString(options->get<std::string>("gemm-type", "float32"));
@ -45,23 +59,36 @@ int main(int argc, char** argv) {
marian::io::getYamlFromModel(config, "special:model.yml", modelFrom);
configStr << config;
auto load = [&](Ptr<ExpressionGraph> graph) {
graph->setDevice(CPU0);
graph->load(modelFrom);
graph->forward(); // run the initializers
};
if (exportAs == "marian-bin") {
auto graph = New<ExpressionGraphPackable>();
load(graph);
graph->setDevice(CPU0);
graph->load(modelFrom);
if(addLsh) {
// Add dummy parameters for the LSH before the model gets actually initialized.
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
graph->setReloaded(false);
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
graph->setReloaded(true);
}
graph->forward(); // run the initializers
if(addLsh) {
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
// Once this is done we can just pack and save as normal.
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
}
// added a flag if the weights needs to be packed or not
graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32);
}
else if (exportAs == "onnx-encode") {
#ifdef USE_ONNX
auto graph = New<ExpressionGraphONNXExporter>();
load(graph);
graph->setDevice(CPU0);
graph->load(modelFrom);
graph->forward(); // run the initializers
auto modelOptions = New<Options>(config)->with("vocabs", vocabPaths, "inference", true);
graph->exportToONNX(modelTo, modelOptions, vocabPaths);

View File

@ -1,10 +1,7 @@
#include "data/shortlist.h"
#include "microsoft/shortlist/utils/ParameterTree.h"
#include "marian.h"
#if BLAS_FOUND
#include "3rd_party/faiss/IndexLSH.h"
#endif
#include "layers/lsh.h"
namespace marian {
namespace data {
@ -47,7 +44,6 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
Shape kShape({k});
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
initialized_ = true;
}
@ -78,12 +74,10 @@ void Shortlist::createCachedTensors(Expr weights,
}
///////////////////////////////////////////////////////////////////////////////////
Ptr<faiss::IndexLSH> LSHShortlist::index_;
std::mutex LSHShortlist::mutex_;
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
: Shortlist(std::vector<WordIndex>())
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
: Shortlist(std::vector<WordIndex>()),
k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
}
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
@ -99,67 +93,23 @@ Expr LSHShortlist::getIndicesExpr() const {
}
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
#if BLAS_FOUND
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
"LSH index (--output-approx-knn) currently not implemented for GPU");
int currBeamSize = input->shape()[0];
int batchSize = input->shape()[2];
int numHypos = currBeamSize * batchSize;
auto forward = [this, numHypos](Expr out, const std::vector<Expr>& inputs) {
auto query = inputs[0];
auto values = inputs[1];
int dim = values->shape()[-1];
mutex_.lock();
if(!index_) {
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
index_.reset(new faiss::IndexLSH(dim, nbits_,
/*rotate=*/dim != nbits_,
/*train_thesholds*/false));
index_->train(lemmaSize_, values->val()->data<float>());
index_->add( lemmaSize_, values->val()->data<float>());
}
mutex_.unlock();
int qRows = query->shape().elements() / dim;
std::vector<float> distances(qRows * k_);
std::vector<faiss::Index::idx_t> ids(qRows * k_);
index_->search(qRows, query->val()->data<float>(), k_,
distances.data(), ids.data());
indices_.clear();
for(auto iter = ids.begin(); iter != ids.end(); ++iter) {
faiss::Index::idx_t id = *iter;
indices_.push_back((WordIndex)id);
}
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
size_t startIdx = k_ * hypoIdx;
size_t endIdx = startIdx + k_;
std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
}
out->val()->set(indices_);
};
Shape kShape({currBeamSize, batchSize, k_});
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
indicesExpr_ = callback(lsh::search(input, weights, k_, nbits_, (int)lemmaSize_),
[this](Expr node) {
node->val()->get(indices_); // set the value of the field indices_ whenever the graph traverses this node
});
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_);
#else
input; weights; isLegacyUntransposedW; b; lemmaEt;
ABORT("LSH output layer requires a CPU BLAS library");
#endif
}
void LSHShortlist::createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k) {
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k) {
int currBeamSize = indicesExpr_->shape()[0];
int batchSize = indicesExpr_->shape()[1];
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");

View File

@ -25,7 +25,8 @@ namespace data {
class Shortlist {
protected:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
Expr indicesExpr_;
Expr indicesExpr_; // cache an expression that contains the short list indices
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;

View File

@ -646,6 +646,16 @@ public:
return it->second;
}
/**
* Return the Parameters object related to the graph by elementType.
* The Parameters object holds the whole set of the parameter nodes of the given type.
*/
Ptr<Parameters>& params(Type elementType) {
auto it = paramsByElementType_.find(elementType);
ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
return it->second;
}
/**
* Set default element type for the graph.
* The default value is used if some node type is not specified.

View File

@ -28,13 +28,17 @@ Expr checkpoint(Expr a) {
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd);
LambdaNodeFunctor fwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd);
LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, bwd, hash);
}
Expr callback(Expr node, LambdaNodeCallback call) {
return Expression<CallbackNodeOp>(node, call);
}
// logistic function. Note: scipy name is expit()

View File

@ -26,12 +26,19 @@ typedef std::function<void(Expr out, const std::vector<Expr>& in)> LambdaNodeFun
/**
* Arbitrary node with forward operation only.
*/
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd);
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, size_t hash = 0);
/**
* Arbitrary node with forward and backward operation.
*/
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd);
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type, LambdaNodeFunctor fwd, LambdaNodeFunctor bwd, size_t hash = 0);
/**
* Convience typedef for graph @ref lambda expressions.
*/
typedef std::function<void(Expr)> LambdaNodeCallback;
Expr callback(Expr node, LambdaNodeCallback call);
/**
* @addtogroup graph_ops_activation Activation Functions

View File

@ -11,6 +11,15 @@ namespace marian {
namespace inits {
class DummyInit : public NodeInitializer {
public:
void apply(Tensor tensor) override {
tensor;
}
};
Ptr<NodeInitializer> dummy() { return New<DummyInit>(); }
class LambdaInit : public NodeInitializer {
private:
std::function<void(Tensor)> lambda_;
@ -237,24 +246,3 @@ template Ptr<NodeInitializer> range<IndexType>(IndexType begin, IndexType end, I
} // namespace inits
} // namespace marian
#if BLAS_FOUND
#include "faiss/VectorTransform.h"
namespace marian {
namespace inits {
Ptr<NodeInitializer> randomRotation(size_t seed) {
auto rot = [=](Tensor t) {
int rows = t->shape()[-2];
int cols = t->shape()[-1];
faiss::RandomRotationMatrix rrot(cols, rows); // transposed in faiss
rrot.init((int)seed);
t->set(rrot.A);
};
return fromLambda(rot, Type::float32);
}
} // namespace inits
} // namespace marian
#endif

View File

@ -35,6 +35,11 @@ public:
virtual ~NodeInitializer() {}
};
/**
* Dummy do-nothing initializer. Mostly for testing.
*/
Ptr<NodeInitializer> dummy();
/**
* Use a lambda function of form [](Tensor t) { do something with t } to initialize tensor.
* @param func functor
@ -263,13 +268,6 @@ Ptr<NodeInitializer> fromWord2vec(const std::string& file,
*/
Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start);
/**
* Computes a random rotation matrix for LSH hashing.
* This is part of a hash function. The values are orthonormal and computed via
* QR decomposition. Same seed results in same random rotation.
*/
Ptr<NodeInitializer> randomRotation(size_t seed = Config::seed);
/**
* Computes the equivalent of Python's range().
* Computes a range from begin to end-1, like Python's range().

View File

@ -21,20 +21,26 @@ private:
std::unique_ptr<LambdaNodeFunctor> forward_;
std::unique_ptr<LambdaNodeFunctor> backward_;
size_t externalHash_;
public:
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
LambdaNodeFunctor forward)
LambdaNodeFunctor forward,
size_t externalHash = 0)
: NaryNodeOp(inputs, shape, type),
forward_(new LambdaNodeFunctor(forward)) {
forward_(new LambdaNodeFunctor(forward)),
externalHash_(externalHash) {
Node::trainable_ = !!backward_;
}
LambdaNodeOp(Inputs inputs, Shape shape, Type type,
LambdaNodeFunctor forward,
LambdaNodeFunctor backward)
LambdaNodeFunctor backward,
size_t externalHash = 0)
: NaryNodeOp(inputs, shape, type),
forward_(new LambdaNodeFunctor(forward)),
backward_(new LambdaNodeFunctor(backward)) {
backward_(new LambdaNodeFunctor(backward)),
externalHash_(externalHash) {
}
void forward() override {
@ -50,8 +56,12 @@ public:
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, forward_.get());
util::hash_combine(seed, backward_.get());
if(externalHash_ != 0) {
util::hash_combine(seed, externalHash_);
} else {
util::hash_combine(seed, forward_.get());
util::hash_combine(seed, backward_.get());
}
return seed;
}

View File

@ -795,7 +795,7 @@ private:
};
class ReshapeNodeOp : public UnaryNodeOp {
private:
protected:
friend class SerializationHelpers;
Expr reshapee_;
@ -858,6 +858,45 @@ public:
}
};
// @TODO: add version with access to backward step
// This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise
// i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the
// compute in the lambda function. It gets called after the forward step of the argument node.
class CallbackNodeOp : public ReshapeNodeOp {
private:
typedef std::function<void(Expr)> LambdaNodeCallback;
std::unique_ptr<LambdaNodeCallback> callback_;
public:
CallbackNodeOp(Expr node, LambdaNodeCallback callback)
: ReshapeNodeOp(node, node->shape()),
callback_(new LambdaNodeCallback(callback)) {
}
void forward() override {
(*callback_)(ReshapeNodeOp::reshapee_);
}
const std::string type() override { return "callback"; }
virtual size_t hash() override {
size_t seed = ReshapeNodeOp::hash();
util::hash_combine(seed, callback_.get());
return seed;
}
virtual bool equal(Expr node) override {
if(!ReshapeNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<CallbackNodeOp>(node);
if(!cnode)
return false;
if(callback_ != cnode->callback_) // pointer compare on purpose
return false;
return true;
}
};
// @TODO: review if still required as this is an ugly hack anyway.
// Memory less operator that clips gradients during backward step
// Executes this as an additional operation on the gradient.

233
src/layers/lsh.cpp Normal file
View File

@ -0,0 +1,233 @@
#include "layers/lsh.h"
#include "tensors/tensor_operators.h"
#include "common/utils.h"
#include "3rd_party/faiss/utils/hamming.h"
#include "3rd_party/faiss/Index.h"
#if BLAS_FOUND
#include "3rd_party/faiss/VectorTransform.h"
#endif
namespace marian {
namespace lsh {
int bytesPerVector(int nBits) {
return (nBits + 7) / 8;
}
void fillRandomRotationMatrix(Tensor output, Ptr<Allocator> allocator) {
#if BLAS_FOUND
int nRows = output->shape()[-2];
int nBits = output->shape()[-1];
// @TODO re-implement using Marian code so it uses the correct random generator etc.
faiss::RandomRotationMatrix rrot(nRows, nBits);
// Then we do not need to use this seed at all
rrot.init(5); // currently set to 5 following the default from FAISS, this could be any number really.
// The faiss random rotation matrix is column major, hence we create a temporary tensor,
// copy the rotation matrix into it and transpose to output.
Shape tempShape = {nBits, nRows};
auto memory = allocator->alloc(requiredBytes(tempShape, output->type()));
auto temp = TensorBase::New(memory,
tempShape,
output->type(),
output->getBackend());
temp->set(rrot.A);
TransposeND(output, temp, {0, 1, 3, 2});
allocator->free(memory);
#else
output; allocator;
ABORT("LSH with rotation matrix requires Marian to be compiled with a BLAS library");
#endif
}
void encode(Tensor output, Tensor input) {
int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix
int nRows = input->shape().elements() / nBits;
faiss::fvecs2bitvecs(input->data<float>(), output->data<uint8_t>(), (size_t)nBits, (size_t)nRows);
}
void encodeWithRotation(Tensor output, Tensor input, Tensor rotation, Ptr<Allocator> allocator) {
int nBits = input->shape()[-1]; // number of bits is equal last dimension of float matrix unless we rotate
int nRows = input->shape().elements() / nBits;
Tensor tempInput = input;
MemoryPiece::PtrType memory;
if(rotation) {
int nBitsRot = rotation->shape()[-1];
Shape tempShape = {nRows, nBitsRot};
memory = allocator->alloc(requiredBytes(tempShape, rotation->type()));
tempInput = TensorBase::New(memory, tempShape, rotation->type(), rotation->getBackend());
Prod(tempInput, input, rotation, false, false, 0.f, 1.f);
}
encode(output, tempInput);
if(memory)
allocator->free(memory);
};
Expr encode(Expr input, Expr rotation) {
auto encodeFwd = [](Expr out, const std::vector<Expr>& inputs) {
if(inputs.size() == 1) {
encode(out->val(), inputs[0]->val());
} else if(inputs.size() == 2) {
encodeWithRotation(out->val(), inputs[0]->val(), inputs[1]->val(), out->graph()->allocator());
} else {
ABORT("Too many inputs to encode??");
}
};
// Use the address of the first lambda function as an immutable hash. Making it static and const makes sure
// that this hash value will not change. Next pass the hash into the lambda functor were it will be used
// to identify this unique operation. Marian's ExpressionGraph can automatically memoize and identify nodes
// that operate only on immutable nodes (parameters) and have the same hash. This way we make sure that the
// codes node won't actually get recomputed throughout ExpressionGraph lifetime. `codes` will be reused
// and the body of the lambda will not be called again. This does however build one index per graph.
static const size_t encodeHash = (size_t)&encodeFwd;
Shape encodedShape = input->shape();
int nBits = rotation ? rotation->shape()[-1] : input->shape()[-1];
encodedShape.set(-1, bytesPerVector(nBits));
std::vector<Expr> inputs = {input};
if(rotation)
inputs.push_back(rotation);
return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash);
}
Expr rotator(Expr weights, int nBits) {
auto rotator = [](Expr out, const std::vector<Expr>& inputs) {
inputs;
fillRandomRotationMatrix(out->val(), out->graph()->allocator());
};
static const size_t rotatorHash = (size_t)&rotator;
int dim = weights->shape()[-1];
return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash);
}
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) {
ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1],
"Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]);
int currBeamSize = encodedQuery->shape()[0];
int batchSize = encodedQuery->shape()[2];
int numHypos = currBeamSize * batchSize;
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
Expr encodedQuery = inputs[0];
Expr encodedWeights = inputs[1];
int bytesPerVector = encodedWeights->shape()[-1];
int wRows = encodedWeights->shape().elements() / bytesPerVector;
// we use this with Factored Segmenter to skip the factor embeddings at the end
if(firstNRows != 0)
wRows = firstNRows;
int qRows = encodedQuery->shape().elements() / bytesPerVector;
uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
// use actual faiss code for performing the hamming search.
std::vector<int> distances(qRows * k);
std::vector<faiss::Index::idx_t> ids(qRows * k);
faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()};
faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
// Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
// The sorting is required as we later do a binary search on those values for reverse look-up.
uint32_t* outData = out->val()->data<uint32_t>();
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
size_t startIdx = k * hypoIdx;
size_t endIdx = startIdx + k;
for(size_t i = startIdx; i < endIdx; ++i)
outData[i] = (uint32_t)ids[i];
std::sort(outData + startIdx, outData + endIdx);
}
};
Shape kShape({currBeamSize, batchSize, k});
return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
}
Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows) {
int dim = weights->shape()[-1];
Expr rotMat = nullptr;
if(dim != nBits) {
rotMat = weights->graph()->get("lsh_output_rotation");
if(rotMat) {
LOG_ONCE(info, "Reusing parameter LSH rotation matrix {} with shape {}", rotMat->name(), rotMat->shape());
} else {
LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
rotMat = rotator(weights, nBits);
}
}
Expr encodedWeights = weights->graph()->get("lsh_output_codes");
if(encodedWeights) {
LOG_ONCE(info, "Reusing parameter LSH code matrix {} with shape {}", encodedWeights->name(), encodedWeights->shape());
} else {
LOG_ONCE(info, "Creating ad-hoc code matrix with shape {}", Shape({weights->shape()[-2], lsh::bytesPerVector(nBits)}));
encodedWeights = encode(weights, rotMat);
}
return searchEncoded(encode(query, rotMat), encodedWeights, k, firstNRows);
}
class RandomRotation : public inits::NodeInitializer {
public:
void apply(Tensor tensor) override {
auto sharedAllocator = allocator_.lock();
ABORT_IF(!sharedAllocator, "Allocator in RandomRotation has not been set or expired");
fillRandomRotationMatrix(tensor, sharedAllocator);
}
};
Ptr<inits::NodeInitializer> randomRotation() {
return New<RandomRotation>();
}
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBitsRot) {
auto weights = graph->get(weightsName);
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
int nBits = weights->shape()[-1];
int nRows = weights->shape().elements() / nBits;
Expr rotation;
if(nBits != nBitsRot) {
LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot}));
rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32);
nBits = nBitsRot;
}
int bytesPerVector = lsh::bytesPerVector(nBits);
LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector}));
auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
}
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName) {
Expr weights = graph->get(weightsName);
Expr codes = graph->get("lsh_output_codes");
Expr rotation = graph->get("lsh_output_rotation");
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??");
if(rotation) {
fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator());
encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator());
} else {
encode(codes->val(), weights->val());
}
}
}
}

49
src/layers/lsh.h Normal file
View File

@ -0,0 +1,49 @@
#pragma once
#include "graph/expression_operators.h"
#include "graph/node_initializers.h"
#include <vector>
/**
* In this file we bascially take the faiss::IndexLSH and pick it apart so that the individual steps
* can be implemented as Marian inference operators. We can encode the inputs and weights into their
* bitwise equivalents, apply the hashing rotation (if required), and perform the actual search.
*
* This also allows to create parameters that get dumped into the model weight file. This is currently
* a bit hacky (see marian-conv), but once this is done the model can memory-map the LSH with existing
* mechanisms and no additional memory is consumed to build the index or rotation matrix.
*/
namespace marian {
namespace lsh {
// return the number of full bytes required to encoded that many bits
int bytesPerVector(int nBits);
// encodes an input as a bit vector, with optional rotation
Expr encode(Expr input, Expr rotator = nullptr);
// compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
Expr rotator(Expr weights, int nbits);
// perform the LSH search on fully encoded input and weights, return k results (indices) per input row
// @TODO: add a top-k like operator that also returns the bitwise computed distances
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0);
// same as above, but performs encoding on the fly
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0);
// These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBits);
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName);
/**
* Computes a random rotation matrix for LSH hashing.
* This is part of a hash function. The values are orthonormal and computed via
* QR decomposition.
*/
Ptr<inits::NodeInitializer> randomRotation();
}
}

View File

@ -11,6 +11,7 @@
#include "data/alignment.h"
#include "data/vocab_base.h"
#include "tensors/cpu/expression_graph_packable.h"
#include "layers/lsh.h"
#if USE_FBGEMM
#include "fbgemm/Utils.h"
@ -248,7 +249,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) {
// This function converts an fp32 model into an FBGEMM based packed model.
// marian defined types are used for external project as well.
// The targetPrec is passed as int32_t for the exported function definition.
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec) {
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh) {
std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl;
YAML::Node config;
@ -260,7 +261,26 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
graph->setDevice(CPU0);
graph->load(inputFile);
graph->forward();
// MJD: Note, this is a default settings which we might want to change or expose. Use this only with Polonium students.
// The LSH will not be used by default even if it exists in the model. That has to be enabled in the decoder config.
int lshNBits = 1024;
std::string lshOutputWeights = "Wemb";
if(addLsh) {
// Add dummy parameters for the LSH before the model gets actually initialized.
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
graph->setReloaded(false);
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
graph->setReloaded(true);
}
graph->forward(); // run the initializers
if(addLsh) {
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
// Once this is done we can just pack and save as normal.
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
}
Type targetPrecType = (Type) targetPrec;
if (targetPrecType == Type::packed16

View File

@ -76,7 +76,10 @@ std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocab
DecoderCpuAvxVersion getCpuAvxVersion();
DecoderCpuAvxVersion parseCpuAvxVersion(std::string name);
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec);
// MJD: added "addLsh" which will now break whatever compilation after update. That's on purpose.
// The calling code should be adapted, not this interface. If you need to fix things in QS because of this
// talk to me first!
bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec, bool addLsh);
} // namespace quicksand
} // namespace marian

View File

@ -27,14 +27,17 @@ public:
virtual ~ExpressionGraphPackable() {}
// Convert model weights into packed format and save to IO items.
// @TODO: review this
void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
std::vector<io::Item> pack(Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
std::vector<io::Item> ioItems;
// handle packable parameters first (a float32 parameter is packable)
auto packableParameters = paramsByElementType_[Type::float32];
// sorted by name in std::map
for (auto p : params()->getMap()) {
for (auto p : packableParameters->getMap()) {
std::string pName = p.first;
LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type());
if (!namespace_.empty()) {
if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
pName = pName.substr(namespace_.size() + 2);
@ -257,6 +260,33 @@ public:
}
}
// Now handle all non-float32 parameters
for(auto& iter : paramsByElementType_) {
auto type = iter.first;
if(type == Type::float32)
continue;
for (auto p : iter.second->getMap()) {
std::string pName = p.first;
LOG(info, "Processing parameter {} with shape {} and type {}", pName, p.second->shape(), p.second->value_type());
if (!namespace_.empty()) {
if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
pName = pName.substr(namespace_.size() + 2);
}
Tensor val = p.second->val();
io::Item item;
val->get(item, pName);
ioItems.emplace_back(std::move(item));
}
}
return ioItems;
}
void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
auto ioItems = pack(gemmElementType, saveElementType);
if (!meta.empty())
io::addMetaToItems(meta, "special:model.yml", ioItems);
io::saveItems(name, ioItems);

View File

@ -35,7 +35,8 @@ class TensorBase {
ENABLE_INTRUSIVE_PTR(TensorBase)
// Constructors are private, use TensorBase::New(...)
protected:
// Constructors are protected, use TensorBase::New(...)
TensorBase(MemoryPiece::PtrType memory,
Shape shape,
Type type,
@ -61,10 +62,10 @@ class TensorBase {
shape_(shape), type_(type), backend_(backend) {}
public:
// Use this whenever pointing to MemoryPiece
// Use this whenever pointing to TensorBase
typedef IPtr<TensorBase> PtrType;
// Use this whenever creating a pointer to MemoryPiece
// Use this whenever creating a pointer to TensorBase
template <class ...Args>
static PtrType New(Args&& ...args) {
return PtrType(new TensorBase(std::forward<Args>(args)...));

View File

@ -142,8 +142,9 @@ public:
// for periods.
bool enteredNewPeriodOf(std::string schedulingParam) const {
auto period = SchedulingParameter::parse(schedulingParam);
// @TODO: adapt to logical epochs
ABORT_IF(period.unit == SchedulingUnit::epochs,
"Unit {} is not supported for frequency parameters (the one(s) with value {})",
"Unit {} is not supported for frequency parameters",
schedulingParam);
auto previousProgress = getPreviousProgressIn(period.unit);
auto progress = getProgressIn(period.unit);