merge with internal master

This commit is contained in:
Marcin Junczys-Dowmunt 2022-02-11 06:03:16 -08:00
commit b0275e7754
23 changed files with 415 additions and 217 deletions

View File

@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Scripts using PyYAML now use `safe_load`; see https://msg.pyyaml.org/load
### Changed
- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce.
- Changed minimal C++ standard to C++-17
- Faster LSH top-k search on CPU
## [1.11.0] - 2022-02-08

View File

@ -6,7 +6,7 @@ if (POLICY CMP0074)
endif ()
project(marian CXX C)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.")
@ -91,10 +91,11 @@ if(MSVC)
# C4310: cast truncates constant value
# C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier
# C4702: unreachable code; note it is also disabled globally in the VS project file
# C4996: warning STL4015: The std::iterator class template (used as a base class to provide typedefs) is deprecated in C++17
if(USE_SENTENCEPIECE)
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\" /wd\"4100\"")
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\" /wd\"4996\" /wd\"4100\"")
else()
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\"")
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\" /wd\"4996\"")
endif()
# set(INTRINSICS "/arch:AVX")

View File

@ -1 +1 @@
v1.11.1
v1.11.3

View File

@ -192,6 +192,9 @@ stages:
displayName: Ubuntu
timeoutInMinutes: 90
# Minimal tested configurations for marian-dev v1.11 and C++17:
# * Ubuntu 16.04, GCC 7.5, CMake 3.10.2, CUDA 9.2 (probably GCC 6 would work too)
# * Ubuntu 18.04, GCC 7.5, CMake 3.12.2, CUDA 10.0
strategy:
matrix:
################################################################
@ -319,51 +322,6 @@ stages:
displayName: Print versions
workingDirectory: build
######################################################################
- job: BuildUbuntuMinimal
condition: eq(${{ parameters.runBuilds }}, true)
displayName: Ubuntu CPU+GPU gcc-7 cmake 3.5
pool:
vmImage: ubuntu-18.04
steps:
- checkout: self
submodules: true
# The script simplifies installation of different versions of CUDA.
- bash: ./scripts/ci/install_cuda_ubuntu.sh "10.0"
displayName: Install CUDA
# CMake 3.5.1 is the minimum version supported
- bash: |
wget -nv https://cmake.org/files/v3.5/cmake-3.5.1-Linux-x86_64.tar.gz
tar zxf cmake-3.5.1-Linux-x86_64.tar.gz
./cmake-3.5.1-Linux-x86_64/bin/cmake --version
displayName: Download CMake
# GCC 5 is the minimum version supported
- bash: |
/usr/bin/gcc-7 --version
mkdir -p build
cd build
CC=/usr/bin/gcc-7 CXX=/usr/bin/g++-7 CUDAHOSTCXX=/usr/bin/g++-7 \
../cmake-3.5.1-Linux-x86_64/bin/cmake .. \
-DCOMPILE_CPU=on \
-DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda-10.0
displayName: Configure CMake
- bash: make -j3
displayName: Compile
workingDirectory: build
- bash: |
./marian --version
./marian-decoder --version
./marian-scorer --version
displayName: Print versions
workingDirectory: build
######################################################################
- job: BuildMacOS
condition: eq(${{ parameters.runBuilds }}, true)

@ -1 +1 @@
Subproject commit 0ca966eadd2a4885a10d41e0f2f51445ab6fd038
Subproject commit 6d5921cc7de91f4e915b59e9c52c9a76c4e99b00

@ -1 +1 @@
Subproject commit f7971b790abac39e557346bd5907c693d4939778
Subproject commit d59f7ad85ecfdf4a788c095ac9fc1c447094e39e

View File

@ -60,6 +60,13 @@ CUDA_PACKAGES_IN=(
CUDA_PACKAGES=""
for package in "${CUDA_PACKAGES_IN[@]}"; do
# @todo This is not perfect. Should probably provide a separate list for diff versions
# cuda-compiler-X-Y if CUDA >= 9.1 else cuda-nvcc-X-Y
if [[ "${package}" == "nvcc" ]] && version_ge "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then
package="compiler"
elif [[ "${package}" == "compiler" ]] && version_lt "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then
package="nvcc"
fi
# Build the full package name and append to the string.
CUDA_PACKAGES+=" cuda-${package}-${CUDA_MAJOR}-${CUDA_MINOR}"
done
@ -72,8 +79,8 @@ echo "CUDA_PACKAGES ${CUDA_PACKAGES}"
PIN_FILENAME="cuda-ubuntu${UBUNTU_VERSION}.pin"
PIN_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/${PIN_FILENAME}"
APT_KEY_URL="http://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/7fa2af80.pub"
REPO_URL="http://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/"
APT_KEY_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/7fa2af80.pub"
REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/"
echo "PIN_FILENAME ${PIN_FILENAME}"
echo "PIN_URL ${PIN_URL}"

View File

@ -344,7 +344,7 @@ inline HalfFloat operator+ (HalfFloat one, HalfFloat two)
// compute the difference between the two exponents. shifts with negative
// numbers are undefined, thus we need two code paths
register int expDiff = one.IEEE.Exp - two.IEEE.Exp;
/*register*/ int expDiff = one.IEEE.Exp - two.IEEE.Exp;
if (0 == expDiff)
{

View File

@ -86,11 +86,17 @@ int main(int argc, char** argv) {
graph->setDevice(CPU0);
graph->load(modelFrom);
std::vector<lsh::ParamConvInfo> toBeLSHed;
if(addLsh) {
// Add dummy parameters for the LSH before the model gets actually initialized.
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
toBeLSHed = {
{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits}
};
graph->setReloaded(false);
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
for(auto p : toBeLSHed)
lsh::addDummyParameters(graph, /*paramInfo=*/p);
graph->setReloaded(true);
}
@ -99,7 +105,8 @@ int main(int argc, char** argv) {
if(addLsh) {
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
// Once this is done we can just pack and save as normal.
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
for(auto p : toBeLSHed)
lsh::overwriteDummyParameters(graph, /*paramInfo=*/p);
}
// added a flag if the weights needs to be packed or not

View File

@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"none");
cli.add<std::string>("--guided-alignment-cost",
"Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)",
"mse");
"ce");
cli.add<double>("--guided-alignment-weight",
"Weight for guided alignment cost",
0.1);

View File

@ -2,6 +2,8 @@
#include "common/utils.h"
#include <algorithm>
#include <cmath>
#include <set>
namespace marian {
namespace data {
@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {}
WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}
WordAlignment::WordAlignment(const std::string& line) {
WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) {
std::vector<std::string> atok = utils::splitAny(line, " -");
for(size_t i = 0; i < atok.size(); i += 2)
data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols
}
void WordAlignment::sort() {
@ -22,6 +25,35 @@ void WordAlignment::sort() {
});
}
void WordAlignment::normalize(bool reverse/*=false*/) {
std::vector<size_t> counts;
counts.reserve(data_.size());
// reverse==false : normalize target word prob by number of source words
// reverse==true : normalize source word prob by number of target words
auto srcOrTgt = [](const Point& p, bool reverse) {
return reverse ? p.srcPos : p.tgtPos;
};
for(const auto& a : data_) {
size_t pos = srcOrTgt(a, reverse);
if(counts.size() <= pos)
counts.resize(pos + 1, 0);
counts[pos]++;
}
// a.prob at this point is either 1 or normalized to a different value,
// but we just set it to 1 / count, so multiple calls result in re-normalization
// regardless of forward or reverse direction. We also set the remaining values to 1.
for(auto& a : data_) {
size_t pos = srcOrTgt(a, reverse);
if(counts[pos] > 1)
a.prob = 1.f / counts[pos];
else
a.prob = 1.f;
}
}
std::string WordAlignment::toString() const {
std::stringstream str;
for(auto p = begin(); p != end(); ++p) {
@ -32,7 +64,7 @@ std::string WordAlignment::toString() const {
return str.str();
}
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
float threshold /*= 1.f*/) {
WordAlignment align;
// Alignments by maximum value
@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
}
}
}
// Sort alignment pairs in ascending order
align.sort();

View File

@ -1,20 +1,22 @@
#pragma once
#include <sstream>
#include <tuple>
#include <vector>
namespace marian {
namespace data {
class WordAlignment {
struct Point
{
public:
struct Point {
size_t srcPos;
size_t tgtPos;
float prob;
};
private:
std::vector<Point> data_;
public:
WordAlignment();
@ -28,11 +30,14 @@ private:
public:
/**
* @brief Constructs word alignments from textual representation.
* @brief Constructs word alignments from textual representation. Adds alignment point for externally
* supplied EOS positions in source and target string.
*
* @param line String in the form of "0-0 1-1 1-2", etc.
*/
WordAlignment(const std::string& line);
WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos);
Point& operator[](size_t i) { return data_[i]; }
auto begin() const -> decltype(data_.begin()) { return data_.begin(); }
auto end() const -> decltype(data_.end()) { return data_.end(); }
@ -46,6 +51,12 @@ public:
*/
void sort();
/**
* @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments.
* This is needed for correct cost computation for guided alignment training with CE cost criterion.
*/
void normalize(bool reverse=false);
/**
* @brief Returns textual representation.
*/
@ -56,7 +67,7 @@ public:
// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
float threshold = 1.f);
std::string SoftAlignToString(SoftAlignment align);

View File

@ -24,7 +24,7 @@ public:
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }
virtual void setGuidedAlignment(std::vector<float>&&) = 0;
virtual void setGuidedAlignment(std::vector<WordAlignment>&&) = 0;
virtual void setDataWeights(const std::vector<float>&) = 0;
virtual ~Batch() {};
protected:

View File

@ -132,14 +132,13 @@ SentenceTuple Corpus::next() {
tup.markAltered();
addWordsToSentenceTuple(fields[i], vocabId, tup);
}
// weights are added last to the sentence tuple, because this runs a validation that needs
// length of the target sequence
if(alignFileIdx_ > -1)
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
if(weightFileIdx_ > -1)
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
}
// weights are added last to the sentence tuple, because this runs a validation that needs
// length of the target sequence
if(alignFileIdx_ > -1)
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
if(weightFileIdx_ > -1)
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {

View File

@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
SentenceTupleImpl& tup) const {
ABORT_IF(rightLeft_,
"Guided alignment and right-left model cannot be used "
"together at the moment");
ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment");
ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size());
auto align = WordAlignment(line);
size_t srcEosPos = tup[0].size() - 1;
size_t tgtEosPos = tup[1].size() - 1;
auto align = WordAlignment(line, srcEosPos, tgtEosPos);
tup.setAlignment(align);
}
@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl
void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
const std::vector<Sample>& batchVector) {
int srcWords = (int)batch->front()->batchWidth();
int trgWords = (int)batch->back()->batchWidth();
std::vector<WordAlignment> aligns;
int dimBatch = (int)batch->getSentenceIds().size();
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
aligns.reserve(dimBatch);
for(int b = 0; b < dimBatch; ++b) {
// If the batch vector is altered within marian by, for example, case augmentation,
// the guided alignments we received for this tuple cease to be valid.
// Hence skip setting alignments for that sentence tuple..
if (!batchVector[b].isAltered()) {
for(auto p : batchVector[b].getAlignment()) {
size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos;
aligns[idx] = 1.f;
}
aligns.push_back(std::move(batchVector[b].getAlignment()));
}
}
batch->setGuidedAlignment(std::move(aligns));

View File

@ -338,7 +338,7 @@ public:
class CorpusBatch : public Batch {
protected:
std::vector<Ptr<SubBatch>> subBatches_;
std::vector<float> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<WordAlignment> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<float> dataWeights_;
public:
@ -444,8 +444,17 @@ public:
if(options->get("guided-alignment", std::string("none")) != "none") {
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
std::vector<float> alignment(batchSize * lengths.front() * lengths.back(),
0.f);
std::vector<data::WordAlignment> alignment;
for(size_t k = 0; k < batchSize; ++k) {
data::WordAlignment perSentence;
// fill with random alignment points, add more twice the number of words to be safe.
for(size_t j = 0; j < lengths.back() * 2; ++j) {
size_t i = rand() % lengths.back();
perSentence.push_back(i, j, 1.0f);
}
alignment.push_back(std::move(perSentence));
}
batch->setGuidedAlignment(std::move(alignment));
}
@ -501,29 +510,14 @@ public:
}
if(!guidedAlignment_.empty()) {
size_t oldTrgWords = back()->batchWidth();
size_t oldSize = size();
pos = 0;
for(auto split : splits) {
auto cb = std::static_pointer_cast<CorpusBatch>(split);
size_t srcWords = cb->front()->batchWidth();
size_t trgWords = cb->back()->batchWidth();
size_t dimBatch = cb->size();
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
for(size_t i = 0; i < dimBatch; ++i) {
size_t bi = i + pos;
for(size_t sid = 0; sid < srcWords; ++sid) {
for(size_t tid = 0; tid < trgWords; ++tid) {
size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid]
size_t idx = sid * dimBatch * trgWords + i * trgWords + tid;
aligns[idx] = guidedAlignment_[bidx];
}
}
}
cb->setGuidedAlignment(std::move(aligns));
std::vector<WordAlignment> batchAlignment;
for(size_t i = 0; i < dimBatch; ++i)
batchAlignment.push_back(std::move(guidedAlignment_[i + pos]));
cb->setGuidedAlignment(std::move(batchAlignment));
pos += dimBatch;
}
}
@ -556,15 +550,11 @@ public:
return splits;
}
const std::vector<float>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<float>&& aln) override {
const std::vector<WordAlignment>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<WordAlignment>&& aln) override {
guidedAlignment_ = std::move(aln);
}
size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) {
return ((s * size()) + b) * widthTrg() + t;
}
std::vector<float>& getDataWeights() { return dataWeights_; }
void setDataWeights(const std::vector<float>& weights) override {
dataWeights_ = weights;

View File

@ -77,7 +77,7 @@ public:
size_t size() const override { return inputs_.front().shape()[0]; }
void setGuidedAlignment(std::vector<float>&&) override {
void setGuidedAlignment(std::vector<WordAlignment>&&) override {
ABORT("Guided alignment in DataBatch is not implemented");
}
void setDataWeights(const std::vector<float>&) override {

View File

@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) {
/*********************************************************/
Expr concatenate(const std::vector<Expr>& concats, int ax) {
if(concats.size() == 1)
return concats[0];
return Expression<ConcatenateNodeOp>(concats, ax);
}

View File

@ -5,62 +5,57 @@
namespace marian {
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> /*graph*/,
static inline const std::tuple<std::vector<IndexType>, std::vector<float>>
guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
int trgWords = (int)batch->back()->batchWidth();
int dimBatch = (int)batch->size();
typedef std::tuple<size_t, float> BiPoint;
std::vector<BiPoint> byIndex;
byIndex.reserve(dimBatch * trgWords);
for(size_t b = 0; b < dimBatch; ++b) {
auto guidedAlignmentFwd = batch->getGuidedAlignment()[b]; // this copies
guidedAlignmentFwd.normalize(/*reverse=*/false); // normalize forward
for(size_t i = 0; i < guidedAlignmentFwd.size(); ++i) {
auto pFwd = guidedAlignmentFwd[i];
IndexType idx = (IndexType)(pFwd.srcPos * dimBatch * trgWords + b * trgWords + pFwd.tgtPos);
byIndex.push_back({idx, pFwd.prob});
}
}
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
std::vector<IndexType> indices; std::vector<float> valuesFwd;
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
for(auto& p : byIndex) {
indices.push_back((IndexType)std::get<0>(p));
valuesFwd.push_back(std::get<1>(p));
}
return {indices, valuesFwd};
}
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
Ptr<Options> options,
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
// We dropped support for other losses which are not possible to implement with sparse labels.
// They were most likely not used anyway.
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
float guidedLossWeight = options->get<float>("guided-alignment-weight");
const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length]
float epsilon = 1e-6f;
Expr alignmentLoss; // sum up loss over all attention/alignment positions
size_t numLabels;
if(guidedLossType == "ce") {
// normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t)
auto dimBatch = shape[-2];
auto dimTrgWords = shape[-1];
auto dimSrcWords = shape[-3];
ABORT_IF(shape[-4] != 1, "Guided alignments with beam??");
auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention'
auto srcBatch = batch->front();
const auto& srcMask = srcBatch->mask();
ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??");
ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??");
auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; };
for (size_t b = 0; b < dimBatch; b++) {
for (size_t t = 0; t < dimTrgWords; t++) {
for (size_t s = 0; s < dimSrcWords; s++)
ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??");
// renormalize the alignment such that it sums up to 1
float sum = 0;
for (size_t s = 0; s < dimSrcWords; s++)
sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1
if (sum != 0 && sum != 1)
for (size_t s = 0; s < dimSrcWords; s++)
normalizedAlignment[locate(s, b, t)] /= sum;
}
}
auto alignment = constant_like(attention, std::move(normalizedAlignment));
alignmentLoss = -sum(flatten(alignment * log(attention + epsilon)));
numLabels = batch->back()->batchWords();
ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??");
} else {
auto alignment = constant_like(attention, batch->getGuidedAlignment());
if(guidedLossType == "mse")
alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f;
else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it?
alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon);
else
ABORT("Unknown alignment cost type: {}", guidedLossType);
// every position is a label as they should all agree
// @TODO: there should be positional masking here ... on the other hand, positions that are not
// in a sentence should always agree (both being 0). Lack of masking affects label count only which is
// probably negligible?
numLabels = shape.elements();
}
auto [indices, values] = guidedAlignmentToSparse(batch);
auto alignmentIndices = graph->indices(indices);
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
float epsilon = 1e-6f;
Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon));
size_t numLabels = alignmentIndices->shape().elements();
// Create label node, also weigh by scalar so labels and cost are in the same domain.
// Fractional label counts are OK. But only if combined as "sum".
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires

View File

@ -3,12 +3,14 @@
#include "common/utils.h"
#include "3rd_party/faiss/utils/hamming.h"
#include "3rd_party/faiss/Index.h"
#if BLAS_FOUND
#include "3rd_party/faiss/VectorTransform.h"
#endif
#include "common/timer.h"
#include "layers/lsh_impl.h"
namespace marian {
namespace lsh {
@ -98,24 +100,22 @@ Expr encode(Expr input, Expr rotation) {
return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash);
}
Expr rotator(Expr weights, int nBits) {
Expr rotator(Expr weights, int inDim, int nBits) {
auto rotator = [](Expr out, const std::vector<Expr>& inputs) {
inputs;
fillRandomRotationMatrix(out->val(), out->graph()->allocator());
};
static const size_t rotatorHash = (size_t)&rotator;
int dim = weights->shape()[-1];
return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash);
return lambda({weights}, {inDim, nBits}, Type::float32, rotator, rotatorHash);
}
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) {
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNRows, bool noSort/*= false*/) {
ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1],
"Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]);
int currBeamSize = encodedQuery->shape()[0];
int batchSize = encodedQuery->shape()[2];
int numHypos = currBeamSize * batchSize;
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
Expr encodedQuery = inputs[0];
@ -128,30 +128,25 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows
if(firstNRows != 0)
wRows = firstNRows;
int qRows = encodedQuery->shape().elements() / bytesPerVector;
ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently?
uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
IndexType* outData = out->val()->data<IndexType>();
auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) {
outData[rowId * dimK + k] = kthColId;
};
// use actual faiss code for performing the hamming search.
std::vector<int> distances(qRows * k);
std::vector<faiss::Index::idx_t> ids(qRows * k);
faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()};
faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
Parameters params;
params.k = dimK;
params.queryRows = encodedQuery->val()->data<uint8_t>();
params.numQueryRows = encodedQuery->shape().elements() / bytesPerVector;
params.codeRows = encodedWeights->val()->data<uint8_t>();
params.numCodeRows = wRows;
params.bytesPerVector = bytesPerVector;
// Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
// The sorting is required as we later do a binary search on those values for reverse look-up.
uint32_t* outData = out->val()->data<uint32_t>();
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
size_t startIdx = k * hypoIdx;
size_t endIdx = startIdx + k;
for(size_t i = startIdx; i < endIdx; ++i)
outData[i] = (uint32_t)ids[i];
std::sort(outData + startIdx, outData + endIdx);
}
hammingTopK(params, gather);
};
Shape kShape({currBeamSize, batchSize, k});
Shape kShape({currBeamSize, batchSize, dimK});
return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search);
}
@ -166,7 +161,7 @@ Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abo
} else {
ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited");
LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits}));
rotMat = rotator(weights, nBits);
rotMat = rotator(weights, dim, nBits);
}
}
@ -195,34 +190,43 @@ Ptr<inits::NodeInitializer> randomRotation() {
return New<RandomRotation>();
}
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBitsRot) {
auto weights = graph->get(weightsName);
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo) {
auto weights = graph->get(paramInfo.name);
int nBitsRot = paramInfo.nBits;
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name);
int nBits = weights->shape()[-1];
if(paramInfo.transpose)
nBits = weights->shape()[-2];
int nRows = weights->shape().elements() / nBits;
Expr rotation;
if(nBits != nBitsRot) {
LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot}));
rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32);
LOG(info, "Adding LSH rotation parameter {} with shape {}", paramInfo.rotationName, Shape({nBits, nBitsRot}));
rotation = graph->param(paramInfo.rotationName, {nBits, nBitsRot}, inits::dummy(), Type::float32);
nBits = nBitsRot;
}
int bytesPerVector = lsh::bytesPerVector(nBits);
LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector}));
auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
LOG(info, "Adding LSH encoded weights {} with shape {}", paramInfo.codesName, Shape({nRows, bytesPerVector}));
auto codes = graph->param(paramInfo.codesName, {nRows, bytesPerVector}, inits::dummy(), Type::uint8);
}
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName) {
Expr weights = graph->get(weightsName);
Expr codes = graph->get("lsh_output_codes");
Expr rotation = graph->get("lsh_output_rotation");
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo) {
Expr weights = graph->get(paramInfo.name);
Expr codes = graph->get(paramInfo.codesName);
Expr rotation = graph->get(paramInfo.rotationName);
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName);
ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name);
ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??");
if(paramInfo.transpose) {
weights = transpose(weights);
graph->forward();
}
if(rotation) {
fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator());
encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator());

View File

@ -17,26 +17,34 @@
namespace marian {
namespace lsh {
// return the number of full bytes required to encoded that many bits
int bytesPerVector(int nBits);
// encodes an input as a bit vector, with optional rotation
Expr encode(Expr input, Expr rotator = nullptr);
// compute the rotation matrix (maps weights->shape()[-1] to nbits floats)
Expr rotator(Expr weights, int nbits);
Expr rotator(Expr weights, int inDim, int nbits);
// perform the LSH search on fully encoded input and weights, return k results (indices) per input row
// @TODO: add a top-k like operator that also returns the bitwise computed distances
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0);
Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0, bool noSort = false);
// same as above, but performs encoding on the fly
Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false);
// struct for parameter conversion used in marian-conv
struct ParamConvInfo {
std::string name;
std::string codesName;
std::string rotationName;
int nBits;
bool transpose;
ParamConvInfo(const std::string& name, const std::string& codesName, const std::string& rotationName, int nBits, bool transpose = false)
: name(name), codesName(codesName), rotationName(rotationName), nBits(nBits), transpose(transpose) {}
};
// These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv
void addDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName, int nBits);
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, std::string weightsName);
void addDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
void overwriteDummyParameters(Ptr<ExpressionGraph> graph, ParamConvInfo paramInfo);
/**
* Computes a random rotation matrix for LSH hashing.

186
src/layers/lsh_impl.h Normal file
View File

@ -0,0 +1,186 @@
#pragma once
#include <vector>
#ifdef _MSC_VER
#define __builtin_popcountl __popcnt64
#define __builtin_popcount __popcnt
#endif
namespace marian {
namespace lsh {
struct Parameters {
int k;
uint8_t* queryRows;
int numQueryRows;
uint8_t* codeRows;
int numCodeRows;
int bytesPerVector;
};
typedef uint32_t DistType;
typedef uint64_t ChunkType;
inline DistType popcount(const ChunkType& chunk) {
switch (sizeof(ChunkType)) {
case 8 : return (DistType)__builtin_popcountl((uint64_t)chunk);
case 4 : return (DistType)__builtin_popcount((uint32_t)chunk);
default: ABORT("Size {} not supported", sizeof(ChunkType));
}
}
// return the number of full bytes required to encoded that many bits
inline int bytesPerVector(int nBits);
// compute top-k hamming distances for given query and weight binary codes. Faster than FAISS version, especially for larger k nearly constant wrt. k.
template <int StaticValue = 0, bool Dynamic=true, typename T>
inline constexpr T getStaticOrDynamic(T dynamicValue) {
return Dynamic ? dynamicValue : StaticValue;
}
template <size_t StepsStatic, bool Dynamic=false>
inline DistType hamming(ChunkType* queryRow, ChunkType* codeRow, int stepsDynamic = 0) {
static_assert(Dynamic == true || StepsStatic != 0, "Either define dynamic use of steps or provide non-zero template argument");
DistType dist = 0;
for(int i = 0; i < getStaticOrDynamic<StepsStatic, Dynamic>(stepsDynamic); ++i)
dist += popcount(queryRow[i] ^ codeRow[i]);
return dist;
}
template <int warpSize, int NumCodeRows, int BytesPerVector, bool Dynamic, class Functor>
inline void hammingTopKUnrollWarp(int queryOffset, const Parameters& parameters, const Functor& gather) {
const int numBits = getStaticOrDynamic<BytesPerVector, Dynamic>(parameters.bytesPerVector) * 8;
ABORT_IF(numBits % 64 != 0, "LSH hash size has to be a multiple of 64");
// counter to keep track of seen hamming distances
std::vector<std::vector<DistType>> counter(warpSize, std::vector<DistType>(numBits, 0));
// buffer the distances for query vector warpRowId to all weight weight vectors codeRowId
std::vector<std::vector<DistType>> distBuffer(warpSize, std::vector<DistType>(getStaticOrDynamic<NumCodeRows, Dynamic>(parameters.numCodeRows), 0));
// minimal distances per query
std::vector<DistType> minDist(warpSize);
constexpr int StepStatic = BytesPerVector / sizeof(ChunkType);
int stepDynamic = parameters.bytesPerVector / sizeof(ChunkType);
ChunkType* codeRow = (ChunkType*)parameters.codeRows;
for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) {
std::memset(counter[warpRowId].data(), 0, numBits * sizeof(DistType)); // Reset the counter via memset to 0
minDist[warpRowId] = (DistType)numBits;
}
for(IndexType codeRowId = 0; codeRowId < (IndexType)getStaticOrDynamic<NumCodeRows, Dynamic>(parameters.numCodeRows); ++codeRowId, codeRow += getStaticOrDynamic<StepStatic, Dynamic>(stepDynamic)) {
ChunkType* queryRow = (ChunkType*)parameters.queryRows;
for(IndexType warpRowId = 0; warpRowId < warpSize; warpRowId++, queryRow += getStaticOrDynamic<StepStatic, Dynamic>(stepDynamic)) {
// Compute the bit-wise hamming distance
DistType dist = hamming<StepStatic, Dynamic>(queryRow, codeRow, stepDynamic);
// Record the minimal distance seen for this query vector wrt. all weight vectors
if(dist < minDist[warpRowId]) {
minDist[warpRowId] = dist;
}
// Record the number of weight vectors that have this distance from the query vector.
// Note, because there is at most numBits different distances this can be trivially done.
// Not the case for generic distances like float.
counter[warpRowId][dist]++;
// Record the distance for this weight vector
distBuffer[warpRowId][codeRowId] = dist;
}
}
// warp finished, harvest k top distances
for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) {
// Here we search for the distance at which we have seen equal or more than k elements with
// smaller distances. We start with the minimal distance from above which is its own address
// to the counter.
DistType maxDist = minDist[warpRowId];
size_t cummulativeDistances = 0;
// Accumulate number of elements until we reach k in growing distance order. Note that
// counter is indexed by hamming distance - from lowest to highest. Some slots will be 0.
// The cumulative sum from position a to b tells you how many elements have distances smaller
// than the distance at b.
while(cummulativeDistances < parameters.k)
cummulativeDistances += counter[warpRowId][maxDist++];
if(cummulativeDistances)
maxDist--; // fix overcounting
// Usually, we overshoot by a couple of elements and we need to take care of the distance at which the k-th
// element sits. This elements has more neighbors at the same distance, but we only care for them
// as long we have not reached k elements in total.
// By contrast, we trivially collect all elements below that distance -- these are always safe.
// This is the number of elements we need to collect at the last distance.
DistType maxDistLimit = /*number of elements at maxDist=*/counter[warpRowId][maxDist] - /*overflow=*/((DistType)cummulativeDistances - (DistType)parameters.k);
IndexType kSeen = 0;
IndexType kSeenAtKDist = 0;
for(IndexType codeRowId = 0; kSeen < (IndexType)parameters.k && codeRowId < (IndexType)getStaticOrDynamic<NumCodeRows, Dynamic>(parameters.numCodeRows); ++codeRowId) {
DistType dist = distBuffer[warpRowId][codeRowId];
// - if the current distance is smaller than the maxDist, just consume.
// - if the distance is equal to maxDist, make sure to only consume maxDistLimit elements at maxDist
// and ignore the rest (smaller indices make it in first).
// - after we finish this loop we have exactly k top values for every query row in original index order.
int queryRowId = queryOffset + warpRowId;
if(dist < maxDist) {
gather(queryRowId, (IndexType)kSeen, codeRowId, dist);
kSeen++;
} else if(dist == maxDist && kSeenAtKDist < (DistType)maxDistLimit) {
gather(queryRowId, (IndexType)kSeen, codeRowId, dist);
kSeen++;
kSeenAtKDist++;
}
}
}
}
// Faster top-k search for hamming distance. The idea here is that instead of sorting the elements we find a hamming distances at which it is safe
// to copy the given index. Copying only the indices below that distance is guaranteed to results in no more than k elements. For elements at that
// distance we need to correct for overshooting.
// Once we have that distance we only need to traverse the set of distances. In the end we get exactly k elements per queryRows vector.
template <int NumCodeRows, int BytesPerVector, bool Dynamic, class Functor>
inline void hammingTopKUnroll(const Parameters& parameters, const Functor& gather) {
static_assert(Dynamic == true || (NumCodeRows != 0 && BytesPerVector != 0), "Either define dynamic use of variables or provide non-zero template arguments");
int warpSize = 4; // starting warpSize of 4 seems optimal
auto warpParameters = parameters;
for(int queryOffset = 0; queryOffset < parameters.numQueryRows; queryOffset += warpSize) {
while(parameters.numQueryRows - queryOffset < warpSize)
warpSize /= 2;
int step = getStaticOrDynamic<BytesPerVector, Dynamic>(parameters.bytesPerVector);
warpParameters.queryRows = parameters.queryRows + queryOffset * step;
warpParameters.numQueryRows = warpSize;
switch(warpSize) {
case 8 : hammingTopKUnrollWarp<8, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
case 4 : hammingTopKUnrollWarp<4, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
case 2 : hammingTopKUnrollWarp<2, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
case 1 : hammingTopKUnrollWarp<1, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break;
default: ABORT("Unhandled warpSize = {}??", warpSize);
}
}
}
template <class Functor>
inline void hammingTopK(const Parameters& parameters, const Functor& gather) {
if(parameters.numCodeRows == 2048 && parameters.bytesPerVector == 64)
hammingTopKUnroll< 2048, 64, false>(parameters, gather);
else if(parameters.numCodeRows == 4096 && parameters.bytesPerVector == 64)
hammingTopKUnroll< 4096, 64, false>(parameters, gather);
else if(parameters.numCodeRows == 6144 && parameters.bytesPerVector == 64)
hammingTopKUnroll< 6144, 64, false>(parameters, gather);
else if(parameters.numCodeRows == 8192 && parameters.bytesPerVector == 64)
hammingTopKUnroll< 8192, 64, false>(parameters, gather);
else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 64)
hammingTopKUnroll<32000, 64, false>(parameters, gather);
else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 128)
hammingTopKUnroll<32000, 128, false>(parameters, gather);
else
hammingTopKUnroll< 0, 0, true>(parameters, gather);
}
} // namespace lsh
} // namespace marian

View File

@ -178,8 +178,7 @@ public:
auto score = std::get<2>(result);
// determine alignment if present
AlignmentSets alignmentSets;
if (options_->hasAndNotEmpty("alignment"))
{
if (options_->hasAndNotEmpty("alignment")) {
float alignmentThreshold;
auto alignment = options_->get<std::string>("alignment"); // @TODO: this logic now exists three times in Marian
if (alignment == "soft")
@ -287,7 +286,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
// Add dummy parameters for the LSH before the model gets actually initialized.
// This create the parameters with useless values in the tensors, but it gives us the memory we need.
graph->setReloaded(false);
lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits);
lsh::addDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits});
graph->setReloaded(true);
}
@ -296,7 +295,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP
if(addLsh) {
// After initialization, hijack the paramters for the LSH and force-overwrite with correct values.
// Once this is done we can just pack and save as normal.
lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights);
lsh::overwriteDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits});
}
Type targetPrecType = (Type) targetPrec;