diff --git a/CHANGELOG.md b/CHANGELOG.md index 90f913c4..4e2a40d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Scripts using PyYAML now use `safe_load`; see https://msg.pyyaml.org/load ### Changed +- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce. +- Changed minimal C++ standard to C++-17 +- Faster LSH top-k search on CPU ## [1.11.0] - 2022-02-08 diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c41b365..dbad75cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ if (POLICY CMP0074) endif () project(marian CXX C) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.") @@ -91,10 +91,11 @@ if(MSVC) # C4310: cast truncates constant value # C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier # C4702: unreachable code; note it is also disabled globally in the VS project file + # C4996: warning STL4015: The std::iterator class template (used as a base class to provide typedefs) is deprecated in C++17 if(USE_SENTENCEPIECE) - set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\" /wd\"4100\"") + set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\" /wd\"4996\" /wd\"4100\"") else() - set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\"") + set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\" /wd\"4996\"") endif() # set(INTRINSICS "/arch:AVX") diff --git a/VERSION b/VERSION index 65b4811d..3d461ead 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.11.1 +v1.11.3 diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0348ebb4..f5e92400 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -192,6 +192,9 @@ stages: displayName: Ubuntu timeoutInMinutes: 90 + # Minimal tested configurations for marian-dev v1.11 and C++17: + # * Ubuntu 16.04, GCC 7.5, CMake 3.10.2, CUDA 9.2 (probably GCC 6 would work too) + # * Ubuntu 18.04, GCC 7.5, CMake 3.12.2, CUDA 10.0 strategy: matrix: ################################################################ @@ -319,51 +322,6 @@ stages: displayName: Print versions workingDirectory: build - ###################################################################### - - job: BuildUbuntuMinimal - condition: eq(${{ parameters.runBuilds }}, true) - displayName: Ubuntu CPU+GPU gcc-7 cmake 3.5 - - pool: - vmImage: ubuntu-18.04 - - steps: - - checkout: self - submodules: true - - # The script simplifies installation of different versions of CUDA. - - bash: ./scripts/ci/install_cuda_ubuntu.sh "10.0" - displayName: Install CUDA - - # CMake 3.5.1 is the minimum version supported - - bash: | - wget -nv https://cmake.org/files/v3.5/cmake-3.5.1-Linux-x86_64.tar.gz - tar zxf cmake-3.5.1-Linux-x86_64.tar.gz - ./cmake-3.5.1-Linux-x86_64/bin/cmake --version - displayName: Download CMake - - # GCC 5 is the minimum version supported - - bash: | - /usr/bin/gcc-7 --version - mkdir -p build - cd build - CC=/usr/bin/gcc-7 CXX=/usr/bin/g++-7 CUDAHOSTCXX=/usr/bin/g++-7 \ - ../cmake-3.5.1-Linux-x86_64/bin/cmake .. \ - -DCOMPILE_CPU=on \ - -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda-10.0 - displayName: Configure CMake - - - bash: make -j3 - displayName: Compile - workingDirectory: build - - - bash: | - ./marian --version - ./marian-decoder --version - ./marian-scorer --version - displayName: Print versions - workingDirectory: build - ###################################################################### - job: BuildMacOS condition: eq(${{ parameters.runBuilds }}, true) diff --git a/examples b/examples index 0ca966ea..6d5921cc 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit 0ca966eadd2a4885a10d41e0f2f51445ab6fd038 +Subproject commit 6d5921cc7de91f4e915b59e9c52c9a76c4e99b00 diff --git a/regression-tests b/regression-tests index f7971b79..d59f7ad8 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit f7971b790abac39e557346bd5907c693d4939778 +Subproject commit d59f7ad85ecfdf4a788c095ac9fc1c447094e39e diff --git a/scripts/ci/install_cuda_ubuntu.sh b/scripts/ci/install_cuda_ubuntu.sh index b058294a..de60a5b6 100755 --- a/scripts/ci/install_cuda_ubuntu.sh +++ b/scripts/ci/install_cuda_ubuntu.sh @@ -60,6 +60,13 @@ CUDA_PACKAGES_IN=( CUDA_PACKAGES="" for package in "${CUDA_PACKAGES_IN[@]}"; do + # @todo This is not perfect. Should probably provide a separate list for diff versions + # cuda-compiler-X-Y if CUDA >= 9.1 else cuda-nvcc-X-Y + if [[ "${package}" == "nvcc" ]] && version_ge "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then + package="compiler" + elif [[ "${package}" == "compiler" ]] && version_lt "$CUDA_VERSION_MAJOR_MINOR" "9.1" ; then + package="nvcc" + fi # Build the full package name and append to the string. CUDA_PACKAGES+=" cuda-${package}-${CUDA_MAJOR}-${CUDA_MINOR}" done @@ -72,8 +79,8 @@ echo "CUDA_PACKAGES ${CUDA_PACKAGES}" PIN_FILENAME="cuda-ubuntu${UBUNTU_VERSION}.pin" PIN_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/${PIN_FILENAME}" -APT_KEY_URL="http://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/7fa2af80.pub" -REPO_URL="http://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/" +APT_KEY_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/7fa2af80.pub" +REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/" echo "PIN_FILENAME ${PIN_FILENAME}" echo "PIN_URL ${PIN_URL}" diff --git a/src/3rd_party/half_float/umHalf.inl b/src/3rd_party/half_float/umHalf.inl index 3f5285a2..257ba1c2 100644 --- a/src/3rd_party/half_float/umHalf.inl +++ b/src/3rd_party/half_float/umHalf.inl @@ -344,7 +344,7 @@ inline HalfFloat operator+ (HalfFloat one, HalfFloat two) // compute the difference between the two exponents. shifts with negative // numbers are undefined, thus we need two code paths - register int expDiff = one.IEEE.Exp - two.IEEE.Exp; + /*register*/ int expDiff = one.IEEE.Exp - two.IEEE.Exp; if (0 == expDiff) { diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp index 943f61d4..b4a5f374 100644 --- a/src/command/marian_conv.cpp +++ b/src/command/marian_conv.cpp @@ -86,11 +86,17 @@ int main(int argc, char** argv) { graph->setDevice(CPU0); graph->load(modelFrom); + std::vector toBeLSHed; if(addLsh) { // Add dummy parameters for the LSH before the model gets actually initialized. // This create the parameters with useless values in the tensors, but it gives us the memory we need. + toBeLSHed = { + {lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits} + }; + graph->setReloaded(false); - lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits); + for(auto p : toBeLSHed) + lsh::addDummyParameters(graph, /*paramInfo=*/p); graph->setReloaded(true); } @@ -99,7 +105,8 @@ int main(int argc, char** argv) { if(addLsh) { // After initialization, hijack the paramters for the LSH and force-overwrite with correct values. // Once this is done we can just pack and save as normal. - lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights); + for(auto p : toBeLSHed) + lsh::overwriteDummyParameters(graph, /*paramInfo=*/p); } // added a flag if the weights needs to be packed or not diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 837bee53..0d956495 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "none"); cli.add("--guided-alignment-cost", "Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)", - "mse"); + "ce"); cli.add("--guided-alignment-weight", "Weight for guided alignment cost", 0.1); diff --git a/src/data/alignment.cpp b/src/data/alignment.cpp index 928beb21..3b7e0d66 100644 --- a/src/data/alignment.cpp +++ b/src/data/alignment.cpp @@ -2,6 +2,8 @@ #include "common/utils.h" #include +#include +#include namespace marian { namespace data { @@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {} WordAlignment::WordAlignment(const std::vector& align) : data_(align) {} -WordAlignment::WordAlignment(const std::string& line) { +WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) { std::vector atok = utils::splitAny(line, " -"); for(size_t i = 0; i < atok.size(); i += 2) - data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f }); + data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f }); + data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols } void WordAlignment::sort() { @@ -22,6 +25,35 @@ void WordAlignment::sort() { }); } +void WordAlignment::normalize(bool reverse/*=false*/) { + std::vector counts; + counts.reserve(data_.size()); + + // reverse==false : normalize target word prob by number of source words + // reverse==true : normalize source word prob by number of target words + auto srcOrTgt = [](const Point& p, bool reverse) { + return reverse ? p.srcPos : p.tgtPos; + }; + + for(const auto& a : data_) { + size_t pos = srcOrTgt(a, reverse); + if(counts.size() <= pos) + counts.resize(pos + 1, 0); + counts[pos]++; + } + + // a.prob at this point is either 1 or normalized to a different value, + // but we just set it to 1 / count, so multiple calls result in re-normalization + // regardless of forward or reverse direction. We also set the remaining values to 1. + for(auto& a : data_) { + size_t pos = srcOrTgt(a, reverse); + if(counts[pos] > 1) + a.prob = 1.f / counts[pos]; + else + a.prob = 1.f; + } +} + std::string WordAlignment::toString() const { std::stringstream str; for(auto p = begin(); p != end(); ++p) { @@ -32,7 +64,7 @@ std::string WordAlignment::toString() const { return str.str(); } -WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, +WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft, float threshold /*= 1.f*/) { WordAlignment align; // Alignments by maximum value @@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, } } } - // Sort alignment pairs in ascending order align.sort(); diff --git a/src/data/alignment.h b/src/data/alignment.h index 1c68bb39..f27bea38 100644 --- a/src/data/alignment.h +++ b/src/data/alignment.h @@ -1,20 +1,22 @@ #pragma once #include +#include #include namespace marian { namespace data { class WordAlignment { - struct Point - { +public: + struct Point { size_t srcPos; size_t tgtPos; float prob; }; private: std::vector data_; + public: WordAlignment(); @@ -28,11 +30,14 @@ private: public: /** - * @brief Constructs word alignments from textual representation. + * @brief Constructs word alignments from textual representation. Adds alignment point for externally + * supplied EOS positions in source and target string. * * @param line String in the form of "0-0 1-1 1-2", etc. */ - WordAlignment(const std::string& line); + WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos); + + Point& operator[](size_t i) { return data_[i]; } auto begin() const -> decltype(data_.begin()) { return data_.begin(); } auto end() const -> decltype(data_.end()) { return data_.end(); } @@ -46,6 +51,12 @@ public: */ void sort(); + /** + * @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments. + * This is needed for correct cost computation for guided alignment training with CE cost criterion. + */ + void normalize(bool reverse=false); + /** * @brief Returns textual representation. */ @@ -56,7 +67,7 @@ public: // Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t) typedef std::vector> SoftAlignment; // [trg pos][beam depth * max src length * batch size] -WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, +WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft, float threshold = 1.f); std::string SoftAlignToString(SoftAlignment align); diff --git a/src/data/batch.h b/src/data/batch.h index 3c592b31..761f46a4 100644 --- a/src/data/batch.h +++ b/src/data/batch.h @@ -24,7 +24,7 @@ public: const std::vector& getSentenceIds() const { return sentenceIds_; } void setSentenceIds(const std::vector& ids) { sentenceIds_ = ids; } - virtual void setGuidedAlignment(std::vector&&) = 0; + virtual void setGuidedAlignment(std::vector&&) = 0; virtual void setDataWeights(const std::vector&) = 0; virtual ~Batch() {}; protected: diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index 643a7de9..2fbe4982 100644 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp @@ -132,14 +132,13 @@ SentenceTuple Corpus::next() { tup.markAltered(); addWordsToSentenceTuple(fields[i], vocabId, tup); } - - // weights are added last to the sentence tuple, because this runs a validation that needs - // length of the target sequence - if(alignFileIdx_ > -1) - addAlignmentToSentenceTuple(fields[alignFileIdx_], tup); - if(weightFileIdx_ > -1) - addWeightsToSentenceTuple(fields[weightFileIdx_], tup); } + // weights are added last to the sentence tuple, because this runs a validation that needs + // length of the target sequence + if(alignFileIdx_ > -1) + addAlignmentToSentenceTuple(fields[alignFileIdx_], tup); + if(weightFileIdx_ > -1) + addWeightsToSentenceTuple(fields[weightFileIdx_], tup); // check if all streams are valid, that is, non-empty and no longer than maximum allowed length if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) { diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 636752c9..71c9f990 100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line, void CorpusBase::addAlignmentToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const { - ABORT_IF(rightLeft_, - "Guided alignment and right-left model cannot be used " - "together at the moment"); + ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment"); + ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size()); - auto align = WordAlignment(line); + size_t srcEosPos = tup[0].size() - 1; + size_t tgtEosPos = tup[1].size() - 1; + + auto align = WordAlignment(line, srcEosPos, tgtEosPos); tup.setAlignment(align); } @@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl void CorpusBase::addAlignmentsToBatch(Ptr batch, const std::vector& batchVector) { - int srcWords = (int)batch->front()->batchWidth(); - int trgWords = (int)batch->back()->batchWidth(); + std::vector aligns; + int dimBatch = (int)batch->getSentenceIds().size(); - - std::vector aligns(srcWords * dimBatch * trgWords, 0.f); - + aligns.reserve(dimBatch); + for(int b = 0; b < dimBatch; ++b) { - // If the batch vector is altered within marian by, for example, case augmentation, // the guided alignments we received for this tuple cease to be valid. // Hence skip setting alignments for that sentence tuple.. if (!batchVector[b].isAltered()) { - for(auto p : batchVector[b].getAlignment()) { - size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos; - aligns[idx] = 1.f; - } + aligns.push_back(std::move(batchVector[b].getAlignment())); } } batch->setGuidedAlignment(std::move(aligns)); diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h index a54c20f8..4e6d923e 100644 --- a/src/data/corpus_base.h +++ b/src/data/corpus_base.h @@ -338,7 +338,7 @@ public: class CorpusBatch : public Batch { protected: std::vector> subBatches_; - std::vector guidedAlignment_; // [max source len, batch size, max target len] flattened + std::vector guidedAlignment_; // [max source len, batch size, max target len] flattened std::vector dataWeights_; public: @@ -444,8 +444,17 @@ public: if(options->get("guided-alignment", std::string("none")) != "none") { // @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths - std::vector alignment(batchSize * lengths.front() * lengths.back(), - 0.f); + + std::vector alignment; + for(size_t k = 0; k < batchSize; ++k) { + data::WordAlignment perSentence; + // fill with random alignment points, add more twice the number of words to be safe. + for(size_t j = 0; j < lengths.back() * 2; ++j) { + size_t i = rand() % lengths.back(); + perSentence.push_back(i, j, 1.0f); + } + alignment.push_back(std::move(perSentence)); + } batch->setGuidedAlignment(std::move(alignment)); } @@ -501,29 +510,14 @@ public: } if(!guidedAlignment_.empty()) { - size_t oldTrgWords = back()->batchWidth(); - size_t oldSize = size(); - pos = 0; for(auto split : splits) { auto cb = std::static_pointer_cast(split); - size_t srcWords = cb->front()->batchWidth(); - size_t trgWords = cb->back()->batchWidth(); size_t dimBatch = cb->size(); - - std::vector aligns(srcWords * dimBatch * trgWords, 0.f); - - for(size_t i = 0; i < dimBatch; ++i) { - size_t bi = i + pos; - for(size_t sid = 0; sid < srcWords; ++sid) { - for(size_t tid = 0; tid < trgWords; ++tid) { - size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid] - size_t idx = sid * dimBatch * trgWords + i * trgWords + tid; - aligns[idx] = guidedAlignment_[bidx]; - } - } - } - cb->setGuidedAlignment(std::move(aligns)); + std::vector batchAlignment; + for(size_t i = 0; i < dimBatch; ++i) + batchAlignment.push_back(std::move(guidedAlignment_[i + pos])); + cb->setGuidedAlignment(std::move(batchAlignment)); pos += dimBatch; } } @@ -556,15 +550,11 @@ public: return splits; } - const std::vector& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened - void setGuidedAlignment(std::vector&& aln) override { + const std::vector& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened + void setGuidedAlignment(std::vector&& aln) override { guidedAlignment_ = std::move(aln); } - size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) { - return ((s * size()) + b) * widthTrg() + t; - } - std::vector& getDataWeights() { return dataWeights_; } void setDataWeights(const std::vector& weights) override { dataWeights_ = weights; diff --git a/src/examples/mnist/dataset.h b/src/examples/mnist/dataset.h index b0492b85..c665fa65 100644 --- a/src/examples/mnist/dataset.h +++ b/src/examples/mnist/dataset.h @@ -77,7 +77,7 @@ public: size_t size() const override { return inputs_.front().shape()[0]; } - void setGuidedAlignment(std::vector&&) override { + void setGuidedAlignment(std::vector&&) override { ABORT("Guided alignment in DataBatch is not implemented"); } void setDataWeights(const std::vector&) override { diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 5294fca3..ca5e6805 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) { /*********************************************************/ Expr concatenate(const std::vector& concats, int ax) { + if(concats.size() == 1) + return concats[0]; return Expression(concats, ax); } diff --git a/src/layers/guided_alignment.h b/src/layers/guided_alignment.h index f08d3f09..d2171c50 100644 --- a/src/layers/guided_alignment.h +++ b/src/layers/guided_alignment.h @@ -5,62 +5,57 @@ namespace marian { -static inline RationalLoss guidedAlignmentCost(Ptr /*graph*/, +static inline const std::tuple, std::vector> +guidedAlignmentToSparse(Ptr batch) { + int trgWords = (int)batch->back()->batchWidth(); + int dimBatch = (int)batch->size(); + + typedef std::tuple BiPoint; + std::vector byIndex; + byIndex.reserve(dimBatch * trgWords); + + for(size_t b = 0; b < dimBatch; ++b) { + auto guidedAlignmentFwd = batch->getGuidedAlignment()[b]; // this copies + guidedAlignmentFwd.normalize(/*reverse=*/false); // normalize forward + for(size_t i = 0; i < guidedAlignmentFwd.size(); ++i) { + auto pFwd = guidedAlignmentFwd[i]; + IndexType idx = (IndexType)(pFwd.srcPos * dimBatch * trgWords + b * trgWords + pFwd.tgtPos); + byIndex.push_back({idx, pFwd.prob}); + } + } + + std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); }); + std::vector indices; std::vector valuesFwd; + indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size()); + for(auto& p : byIndex) { + indices.push_back((IndexType)std::get<0>(p)); + valuesFwd.push_back(std::get<1>(p)); + } + + return {indices, valuesFwd}; +} + +static inline RationalLoss guidedAlignmentCost(Ptr graph, Ptr batch, Ptr options, Expr attention) { // [beam depth=1, max src length, batch size, tgt length] - std::string guidedLossType = options->get("guided-alignment-cost"); // @TODO: change "cost" to "loss" + + // We dropped support for other losses which are not possible to implement with sparse labels. + // They were most likely not used anyway. + ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported"); + float guidedLossWeight = options->get("guided-alignment-weight"); - const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length] - float epsilon = 1e-6f; - Expr alignmentLoss; // sum up loss over all attention/alignment positions - size_t numLabels; - if(guidedLossType == "ce") { - // normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t) - auto dimBatch = shape[-2]; - auto dimTrgWords = shape[-1]; - auto dimSrcWords = shape[-3]; - ABORT_IF(shape[-4] != 1, "Guided alignments with beam??"); - auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention' - auto srcBatch = batch->front(); - const auto& srcMask = srcBatch->mask(); - ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??"); - ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??"); - auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; }; - for (size_t b = 0; b < dimBatch; b++) { - for (size_t t = 0; t < dimTrgWords; t++) { - for (size_t s = 0; s < dimSrcWords; s++) - ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??"); - // renormalize the alignment such that it sums up to 1 - float sum = 0; - for (size_t s = 0; s < dimSrcWords; s++) - sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1 - if (sum != 0 && sum != 1) - for (size_t s = 0; s < dimSrcWords; s++) - normalizedAlignment[locate(s, b, t)] /= sum; - } - } - auto alignment = constant_like(attention, std::move(normalizedAlignment)); - alignmentLoss = -sum(flatten(alignment * log(attention + epsilon))); - numLabels = batch->back()->batchWords(); - ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??"); - } else { - auto alignment = constant_like(attention, batch->getGuidedAlignment()); - if(guidedLossType == "mse") - alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f; - else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it? - alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon); - else - ABORT("Unknown alignment cost type: {}", guidedLossType); - // every position is a label as they should all agree - // @TODO: there should be positional masking here ... on the other hand, positions that are not - // in a sentence should always agree (both being 0). Lack of masking affects label count only which is - // probably negligible? - numLabels = shape.elements(); - } + auto [indices, values] = guidedAlignmentToSparse(batch); + auto alignmentIndices = graph->indices(indices); + auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values)); + auto attentionAtAligned = cols(flatten(attention), alignmentIndices); + float epsilon = 1e-6f; + Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon)); + size_t numLabels = alignmentIndices->shape().elements(); + // Create label node, also weigh by scalar so labels and cost are in the same domain. // Fractional label counts are OK. But only if combined as "sum". // @TODO: It is ugly to check the multi-loss type here, but doing this right requires diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp index 8a9c924e..73d45fc7 100644 --- a/src/layers/lsh.cpp +++ b/src/layers/lsh.cpp @@ -3,12 +3,14 @@ #include "common/utils.h" #include "3rd_party/faiss/utils/hamming.h" -#include "3rd_party/faiss/Index.h" #if BLAS_FOUND #include "3rd_party/faiss/VectorTransform.h" #endif +#include "common/timer.h" + +#include "layers/lsh_impl.h" namespace marian { namespace lsh { @@ -98,24 +100,22 @@ Expr encode(Expr input, Expr rotation) { return lambda(inputs, encodedShape, Type::uint8, encodeFwd, encodeHash); } -Expr rotator(Expr weights, int nBits) { +Expr rotator(Expr weights, int inDim, int nBits) { auto rotator = [](Expr out, const std::vector& inputs) { inputs; fillRandomRotationMatrix(out->val(), out->graph()->allocator()); }; static const size_t rotatorHash = (size_t)&rotator; - int dim = weights->shape()[-1]; - return lambda({weights}, {dim, nBits}, Type::float32, rotator, rotatorHash); + return lambda({weights}, {inDim, nBits}, Type::float32, rotator, rotatorHash); } -Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows) { +Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNRows, bool noSort/*= false*/) { ABORT_IF(encodedQuery->shape()[-1] != encodedWeights->shape()[-1], "Query and index bit vectors need to be of same size ({} != {})", encodedQuery->shape()[-1], encodedWeights->shape()[-1]); int currBeamSize = encodedQuery->shape()[0]; int batchSize = encodedQuery->shape()[2]; - int numHypos = currBeamSize * batchSize; auto search = [=](Expr out, const std::vector& inputs) { Expr encodedQuery = inputs[0]; @@ -128,30 +128,25 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows if(firstNRows != 0) wRows = firstNRows; - int qRows = encodedQuery->shape().elements() / bytesPerVector; + ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently? - uint8_t* qCodes = encodedQuery->val()->data(); - uint8_t* wCodes = encodedWeights->val()->data(); + IndexType* outData = out->val()->data(); + auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) { + outData[rowId * dimK + k] = kthColId; + }; - // use actual faiss code for performing the hamming search. - std::vector distances(qRows * k); - std::vector ids(qRows * k); - faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)k, ids.data(), distances.data()}; - faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0); + Parameters params; + params.k = dimK; + params.queryRows = encodedQuery->val()->data(); + params.numQueryRows = encodedQuery->shape().elements() / bytesPerVector; + params.codeRows = encodedWeights->val()->data(); + params.numCodeRows = wRows; + params.bytesPerVector = bytesPerVector; - // Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis. - // The sorting is required as we later do a binary search on those values for reverse look-up. - uint32_t* outData = out->val()->data(); - for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) { - size_t startIdx = k * hypoIdx; - size_t endIdx = startIdx + k; - for(size_t i = startIdx; i < endIdx; ++i) - outData[i] = (uint32_t)ids[i]; - std::sort(outData + startIdx, outData + endIdx); - } + hammingTopK(params, gather); }; - Shape kShape({currBeamSize, batchSize, k}); + Shape kShape({currBeamSize, batchSize, dimK}); return lambda({encodedQuery, encodedWeights}, kShape, Type::uint32, search); } @@ -166,7 +161,7 @@ Expr search(Expr query, Expr weights, int k, int nBits, int firstNRows, bool abo } else { ABORT_IF(abortIfDynamic, "Dynamic creation of LSH rotation matrix prohibited"); LOG_ONCE(info, "Creating ad-hoc rotation matrix with shape {}", Shape({dim, nBits})); - rotMat = rotator(weights, nBits); + rotMat = rotator(weights, dim, nBits); } } @@ -195,34 +190,43 @@ Ptr randomRotation() { return New(); } -void addDummyParameters(Ptr graph, std::string weightsName, int nBitsRot) { - auto weights = graph->get(weightsName); - - ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName); +void addDummyParameters(Ptr graph, ParamConvInfo paramInfo) { + auto weights = graph->get(paramInfo.name); + int nBitsRot = paramInfo.nBits; + + ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name); int nBits = weights->shape()[-1]; + if(paramInfo.transpose) + nBits = weights->shape()[-2]; + int nRows = weights->shape().elements() / nBits; Expr rotation; if(nBits != nBitsRot) { - LOG(info, "Adding LSH rotation parameter lsh_output_rotation with shape {}", Shape({nBits, nBitsRot})); - rotation = graph->param("lsh_output_rotation", {nBits, nBitsRot}, inits::dummy(), Type::float32); + LOG(info, "Adding LSH rotation parameter {} with shape {}", paramInfo.rotationName, Shape({nBits, nBitsRot})); + rotation = graph->param(paramInfo.rotationName, {nBits, nBitsRot}, inits::dummy(), Type::float32); nBits = nBitsRot; } int bytesPerVector = lsh::bytesPerVector(nBits); - LOG(info, "Adding LSH encoded weights lsh_output_codes with shape {}", Shape({nRows, bytesPerVector})); - auto codes = graph->param("lsh_output_codes", {nRows, bytesPerVector}, inits::dummy(), Type::uint8); + LOG(info, "Adding LSH encoded weights {} with shape {}", paramInfo.codesName, Shape({nRows, bytesPerVector})); + auto codes = graph->param(paramInfo.codesName, {nRows, bytesPerVector}, inits::dummy(), Type::uint8); } -void overwriteDummyParameters(Ptr graph, std::string weightsName) { - Expr weights = graph->get(weightsName); - Expr codes = graph->get("lsh_output_codes"); - Expr rotation = graph->get("lsh_output_rotation"); +void overwriteDummyParameters(Ptr graph, ParamConvInfo paramInfo) { + Expr weights = graph->get(paramInfo.name); + Expr codes = graph->get(paramInfo.codesName); + Expr rotation = graph->get(paramInfo.rotationName); - ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", weightsName); + ABORT_IF(!weights, "Trying to encode non-existing weights matrix {}??", paramInfo.name); ABORT_IF(!codes, "Trying to overwrite non-existing LSH parameters lsh_output_codes??"); + if(paramInfo.transpose) { + weights = transpose(weights); + graph->forward(); + } + if(rotation) { fillRandomRotationMatrix(rotation->val(), weights->graph()->allocator()); encodeWithRotation(codes->val(), weights->val(), rotation->val(), weights->graph()->allocator()); diff --git a/src/layers/lsh.h b/src/layers/lsh.h index 7a585891..5065ffcf 100644 --- a/src/layers/lsh.h +++ b/src/layers/lsh.h @@ -17,26 +17,34 @@ namespace marian { namespace lsh { - - // return the number of full bytes required to encoded that many bits - int bytesPerVector(int nBits); - // encodes an input as a bit vector, with optional rotation Expr encode(Expr input, Expr rotator = nullptr); // compute the rotation matrix (maps weights->shape()[-1] to nbits floats) - Expr rotator(Expr weights, int nbits); + Expr rotator(Expr weights, int inDim, int nbits); // perform the LSH search on fully encoded input and weights, return k results (indices) per input row // @TODO: add a top-k like operator that also returns the bitwise computed distances - Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0); + Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int k, int firstNRows = 0, bool noSort = false); // same as above, but performs encoding on the fly Expr search(Expr query, Expr weights, int k, int nbits, int firstNRows = 0, bool abortIfDynamic = false); + // struct for parameter conversion used in marian-conv + struct ParamConvInfo { + std::string name; + std::string codesName; + std::string rotationName; + int nBits; + bool transpose; + + ParamConvInfo(const std::string& name, const std::string& codesName, const std::string& rotationName, int nBits, bool transpose = false) + : name(name), codesName(codesName), rotationName(rotationName), nBits(nBits), transpose(transpose) {} + }; + // These are helper functions for encoding the LSH into the binary Marian model, used by marian-conv - void addDummyParameters(Ptr graph, std::string weightsName, int nBits); - void overwriteDummyParameters(Ptr graph, std::string weightsName); + void addDummyParameters(Ptr graph, ParamConvInfo paramInfo); + void overwriteDummyParameters(Ptr graph, ParamConvInfo paramInfo); /** * Computes a random rotation matrix for LSH hashing. diff --git a/src/layers/lsh_impl.h b/src/layers/lsh_impl.h new file mode 100644 index 00000000..d87d23e0 --- /dev/null +++ b/src/layers/lsh_impl.h @@ -0,0 +1,186 @@ +#pragma once + +#include + +#ifdef _MSC_VER +#define __builtin_popcountl __popcnt64 +#define __builtin_popcount __popcnt +#endif + +namespace marian { +namespace lsh { + + struct Parameters { + int k; + uint8_t* queryRows; + int numQueryRows; + uint8_t* codeRows; + int numCodeRows; + int bytesPerVector; + }; + + typedef uint32_t DistType; + typedef uint64_t ChunkType; + + inline DistType popcount(const ChunkType& chunk) { + switch (sizeof(ChunkType)) { + case 8 : return (DistType)__builtin_popcountl((uint64_t)chunk); + case 4 : return (DistType)__builtin_popcount((uint32_t)chunk); + default: ABORT("Size {} not supported", sizeof(ChunkType)); + } + } + + // return the number of full bytes required to encoded that many bits + inline int bytesPerVector(int nBits); + + // compute top-k hamming distances for given query and weight binary codes. Faster than FAISS version, especially for larger k nearly constant wrt. k. + template + inline constexpr T getStaticOrDynamic(T dynamicValue) { + return Dynamic ? dynamicValue : StaticValue; + } + + template + inline DistType hamming(ChunkType* queryRow, ChunkType* codeRow, int stepsDynamic = 0) { + static_assert(Dynamic == true || StepsStatic != 0, "Either define dynamic use of steps or provide non-zero template argument"); + DistType dist = 0; + for(int i = 0; i < getStaticOrDynamic(stepsDynamic); ++i) + dist += popcount(queryRow[i] ^ codeRow[i]); + return dist; + } + + template + inline void hammingTopKUnrollWarp(int queryOffset, const Parameters& parameters, const Functor& gather) { + const int numBits = getStaticOrDynamic(parameters.bytesPerVector) * 8; + ABORT_IF(numBits % 64 != 0, "LSH hash size has to be a multiple of 64"); + + // counter to keep track of seen hamming distances + std::vector> counter(warpSize, std::vector(numBits, 0)); + // buffer the distances for query vector warpRowId to all weight weight vectors codeRowId + std::vector> distBuffer(warpSize, std::vector(getStaticOrDynamic(parameters.numCodeRows), 0)); + // minimal distances per query + std::vector minDist(warpSize); + + constexpr int StepStatic = BytesPerVector / sizeof(ChunkType); + int stepDynamic = parameters.bytesPerVector / sizeof(ChunkType); + + ChunkType* codeRow = (ChunkType*)parameters.codeRows; + + for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) { + std::memset(counter[warpRowId].data(), 0, numBits * sizeof(DistType)); // Reset the counter via memset to 0 + minDist[warpRowId] = (DistType)numBits; + } + + for(IndexType codeRowId = 0; codeRowId < (IndexType)getStaticOrDynamic(parameters.numCodeRows); ++codeRowId, codeRow += getStaticOrDynamic(stepDynamic)) { + ChunkType* queryRow = (ChunkType*)parameters.queryRows; + for(IndexType warpRowId = 0; warpRowId < warpSize; warpRowId++, queryRow += getStaticOrDynamic(stepDynamic)) { + // Compute the bit-wise hamming distance + DistType dist = hamming(queryRow, codeRow, stepDynamic); + + // Record the minimal distance seen for this query vector wrt. all weight vectors + if(dist < minDist[warpRowId]) { + minDist[warpRowId] = dist; + } + + // Record the number of weight vectors that have this distance from the query vector. + // Note, because there is at most numBits different distances this can be trivially done. + // Not the case for generic distances like float. + counter[warpRowId][dist]++; + + // Record the distance for this weight vector + distBuffer[warpRowId][codeRowId] = dist; + } + } + // warp finished, harvest k top distances + + for(int warpRowId = 0; warpRowId < warpSize; warpRowId++) { + // Here we search for the distance at which we have seen equal or more than k elements with + // smaller distances. We start with the minimal distance from above which is its own address + // to the counter. + DistType maxDist = minDist[warpRowId]; + size_t cummulativeDistances = 0; + + // Accumulate number of elements until we reach k in growing distance order. Note that + // counter is indexed by hamming distance - from lowest to highest. Some slots will be 0. + // The cumulative sum from position a to b tells you how many elements have distances smaller + // than the distance at b. + while(cummulativeDistances < parameters.k) + cummulativeDistances += counter[warpRowId][maxDist++]; + if(cummulativeDistances) + maxDist--; // fix overcounting + + // Usually, we overshoot by a couple of elements and we need to take care of the distance at which the k-th + // element sits. This elements has more neighbors at the same distance, but we only care for them + // as long we have not reached k elements in total. + // By contrast, we trivially collect all elements below that distance -- these are always safe. + + // This is the number of elements we need to collect at the last distance. + DistType maxDistLimit = /*number of elements at maxDist=*/counter[warpRowId][maxDist] - /*overflow=*/((DistType)cummulativeDistances - (DistType)parameters.k); + IndexType kSeen = 0; + IndexType kSeenAtKDist = 0; + + for(IndexType codeRowId = 0; kSeen < (IndexType)parameters.k && codeRowId < (IndexType)getStaticOrDynamic(parameters.numCodeRows); ++codeRowId) { + DistType dist = distBuffer[warpRowId][codeRowId]; + // - if the current distance is smaller than the maxDist, just consume. + // - if the distance is equal to maxDist, make sure to only consume maxDistLimit elements at maxDist + // and ignore the rest (smaller indices make it in first). + // - after we finish this loop we have exactly k top values for every query row in original index order. + int queryRowId = queryOffset + warpRowId; + if(dist < maxDist) { + gather(queryRowId, (IndexType)kSeen, codeRowId, dist); + kSeen++; + } else if(dist == maxDist && kSeenAtKDist < (DistType)maxDistLimit) { + gather(queryRowId, (IndexType)kSeen, codeRowId, dist); + kSeen++; + kSeenAtKDist++; + } + } + } + } + + // Faster top-k search for hamming distance. The idea here is that instead of sorting the elements we find a hamming distances at which it is safe + // to copy the given index. Copying only the indices below that distance is guaranteed to results in no more than k elements. For elements at that + // distance we need to correct for overshooting. + // Once we have that distance we only need to traverse the set of distances. In the end we get exactly k elements per queryRows vector. + template + inline void hammingTopKUnroll(const Parameters& parameters, const Functor& gather) { + static_assert(Dynamic == true || (NumCodeRows != 0 && BytesPerVector != 0), "Either define dynamic use of variables or provide non-zero template arguments"); + + int warpSize = 4; // starting warpSize of 4 seems optimal + auto warpParameters = parameters; + for(int queryOffset = 0; queryOffset < parameters.numQueryRows; queryOffset += warpSize) { + while(parameters.numQueryRows - queryOffset < warpSize) + warpSize /= 2; + + int step = getStaticOrDynamic(parameters.bytesPerVector); + warpParameters.queryRows = parameters.queryRows + queryOffset * step; + warpParameters.numQueryRows = warpSize; + switch(warpSize) { + case 8 : hammingTopKUnrollWarp<8, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + case 4 : hammingTopKUnrollWarp<4, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + case 2 : hammingTopKUnrollWarp<2, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + case 1 : hammingTopKUnrollWarp<1, NumCodeRows, BytesPerVector, Dynamic>(queryOffset, warpParameters, gather); break; + default: ABORT("Unhandled warpSize = {}??", warpSize); + } + } + } + + template + inline void hammingTopK(const Parameters& parameters, const Functor& gather) { + if(parameters.numCodeRows == 2048 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 2048, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 4096 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 4096, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 6144 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 6144, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 8192 && parameters.bytesPerVector == 64) + hammingTopKUnroll< 8192, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 64) + hammingTopKUnroll<32000, 64, false>(parameters, gather); + else if(parameters.numCodeRows == 32000 && parameters.bytesPerVector == 128) + hammingTopKUnroll<32000, 128, false>(parameters, gather); + else + hammingTopKUnroll< 0, 0, true>(parameters, gather); + } + +} // namespace lsh +} // namespace marian \ No newline at end of file diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index a439197b..316c66d1 100644 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -178,8 +178,7 @@ public: auto score = std::get<2>(result); // determine alignment if present AlignmentSets alignmentSets; - if (options_->hasAndNotEmpty("alignment")) - { + if (options_->hasAndNotEmpty("alignment")) { float alignmentThreshold; auto alignment = options_->get("alignment"); // @TODO: this logic now exists three times in Marian if (alignment == "soft") @@ -287,7 +286,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP // Add dummy parameters for the LSH before the model gets actually initialized. // This create the parameters with useless values in the tensors, but it gives us the memory we need. graph->setReloaded(false); - lsh::addDummyParameters(graph, /*weights=*/lshOutputWeights, /*nBits=*/lshNBits); + lsh::addDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits}); graph->setReloaded(true); } @@ -296,7 +295,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP if(addLsh) { // After initialization, hijack the paramters for the LSH and force-overwrite with correct values. // Once this is done we can just pack and save as normal. - lsh::overwriteDummyParameters(graph, /*weights=*/lshOutputWeights); + lsh::overwriteDummyParameters(graph, /*paramInfo=*/{lshOutputWeights, "lsh_output_codes", "lsh_output_rotation", lshNBits}); } Type targetPrecType = (Type) targetPrec;