Merge branch 'master' into pmaster

This commit is contained in:
Marcin Junczys-Dowmunt 2023-02-20 12:15:33 -08:00
commit 4ffd292881
20 changed files with 166 additions and 55 deletions

View File

@ -12,8 +12,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fused inplace-dropout in FFN layer in Transformer
- `--force-decode` option for marian-decoder
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
- `--valid-reset-all` option
### Fixed
- Make concat factors not break old vector implementation
- Use allocator in hashing
- Read/restore checkpoints from main process only when training with MPI
- Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)

View File

@ -1 +1 @@
v1.11.13
v1.11.15

View File

@ -16,13 +16,30 @@ parameters:
# The pipeline CI trigger is set on the branch master only and PR trigger on a
# (non-draft) pull request to any branch
trigger:
- master
# This minimizes the number of parallel pipeline runs. When a pipeline is
# running, the CI waits until it is completed before starting another one.
batch: true
branches:
include:
- master
paths:
exclude:
- azure-regression-tests.yml
- contrib
- doc
- examples
- regression-tests
- scripts
- VERSION
- vs
- '**/*.md'
- '**/*.txt'
pool:
name: Azure Pipelines
variables:
- group: marian-prod-tests
- group: marian-regression-tests
- name: BOOST_ROOT_WINDOWS
value: "C:/hostedtoolcache/windows/Boost/1.72.0/x86_64"
- name: BOOST_URL
@ -32,7 +49,7 @@ variables:
- name: MKL_DIR
value: "$(Build.SourcesDirectory)/mkl"
- name: MKL_URL
value: "https://romang.blob.core.windows.net/mariandev/ci/mkl-2020.1-windows-static.zip"
value: "https://data.statmt.org/romang/marian-regression-tests/ci/mkl-2020.1-windows-static.zip"
- name: VCPKG_COMMIT
value: 2022.03.10
- name: VCPKG_DIR
@ -52,6 +69,7 @@ stages:
######################################################################
- job: BuildWindows
cancelTimeoutInMinutes: 1
condition: eq(${{ parameters.runBuilds }}, true)
displayName: Windows
@ -188,6 +206,7 @@ stages:
######################################################################
- job: BuildUbuntu
cancelTimeoutInMinutes: 1
condition: eq(${{ parameters.runBuilds }}, true)
displayName: Ubuntu
timeoutInMinutes: 120
@ -324,6 +343,7 @@ stages:
######################################################################
- job: BuildMacOS
cancelTimeoutInMinutes: 1
condition: eq(${{ parameters.runBuilds }}, true)
displayName: macOS CPU clang
@ -373,6 +393,7 @@ stages:
######################################################################
- job: BuildInstall
cancelTimeoutInMinutes: 1
condition: eq(${{ parameters.runBuilds }}, true)
displayName: Linux CPU library install
@ -435,6 +456,7 @@ stages:
######################################################################
- job: TestWindows
cancelTimeoutInMinutes: 1
displayName: Windows CPU+FBGEMM
pool:
@ -528,14 +550,14 @@ stages:
displayName: Machine statistics
workingDirectory: marian-prod-tests
# The current SAS token will expire on 8/30/2023 and a new one will need to be set in Marian > Pipelines > Library
# The current SAS token will expire on 12/31/2023 and a new one will need to be set in Marian > Pipelines > Library
- bash: |
cd models
bash download-models.sh
ls
displayName: Prepare tests
env:
AWS_SECRET_SAS_TOKEN: $(blob-sas-token)
AZURE_STORAGE_SAS_TOKEN: $(marian-prod-tests-blob-sas-token)
workingDirectory: marian-prod-tests
# Avoid using $(Build.SourcesDirectory) in bash tasks because on Windows pools it uses '\'
@ -560,6 +582,7 @@ stages:
######################################################################
- job: TestLinux
cancelTimeoutInMinutes: 1
displayName: Linux CPU+FBGEMM
pool:
@ -572,7 +595,10 @@ stages:
# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
- bash: sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
# Installing libunwind-dev fixes a bug in 2204 (the libunwind-14 and libunwind-dev conflict)
- bash: |
sudo apt-get install -y libunwind-dev
sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
displayName: Install packages
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
@ -629,14 +655,14 @@ stages:
displayName: Machine statistics
workingDirectory: marian-prod-tests
# The current SAS token will expire on 8/30/2023 and a new one will need to be set in Marian > Pipelines > Library
# The current SAS token will expire on 12/31/2023 and a new one will need to be set in Marian > Pipelines > Library
- bash: |
cd models
bash download-models.sh
ls
displayName: Prepare tests
env:
AWS_SECRET_SAS_TOKEN: $(blob-sas-token)
AZURE_STORAGE_SAS_TOKEN: $(marian-prod-tests-blob-sas-token)
workingDirectory: marian-prod-tests
- bash: MARIAN=../marian-dev/build bash ./run_mrt.sh '#cpu' '#basics' '#devops'

View File

@ -14,12 +14,16 @@ trigger: none
# Hosted Azure DevOps Pool determining OS, CUDA version and available GPUs
pool: mariandevops-pool-m60-eus
variables:
- group: marian-regression-tests
stages:
- stage: TestsGPU
jobs:
######################################################################
- job: TestsGPULinux
cancelTimeoutInMinutes: 1
displayName: Linux GPU tests
timeoutInMinutes: 120
@ -103,11 +107,14 @@ stages:
workingDirectory: build
# Always run regression tests from the master branch
# The current SAS token will expire on 12/31/2023 and a new one will need to be set in Marian > Pipelines > Library
- bash: |
git checkout master
git pull origin master
make install
displayName: Prepare regression tests
env:
AZURE_STORAGE_SAS_TOKEN: $(marian-pub-tests-blob-sas-token)
workingDirectory: regression-tests
# Continue on error to be able to collect outputs and publish them as an artifact

@ -1 +1 @@
Subproject commit 488d454a0177ef300eab91ab813e485d420dc38d
Subproject commit 2a8bed3f0e937a9de2d6fa92dee3bcf482d3d47b

View File

@ -42,11 +42,15 @@ def main():
else:
print(model[args.key])
else:
total_nb_of_parameters = 0
for key in model:
if args.matrix_shapes:
print(key, model[key].shape)
if not key == S2S_SPECIAL_NODE:
total_nb_of_parameters += np.prod(model[key].shape)
if args.matrix_info:
print(key, model[key].shape, model[key].dtype)
else:
print(key)
print('Total number of parameters:', total_nb_of_parameters)
def parse_args():
@ -57,8 +61,8 @@ def parse_args():
help="print values from special:model.yml node")
parser.add_argument("-f", "--full-matrix", action="store_true",
help="force numpy to print full arrays for single key")
parser.add_argument("-ms", "--matrix-shapes", action="store_true",
help="print shapes of all arrays in the model")
parser.add_argument("-mi", "--matrix-info", action="store_true",
help="print full matrix info for all keys. Includes shape and dtype")
return parser.parse_args()

View File

@ -269,7 +269,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn",
"Size of position-wise feed-forward network (transformer)",
2048);
2048);
cli.add<int>("--transformer-decoder-dim-ffn",
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
0);
@ -591,7 +591,9 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Multiple metrics can be specified",
{"cross-entropy"});
cli.add<bool>("--valid-reset-stalled",
"Reset all stalled validation metrics when the training is restarted");
"Reset stalled validation metrics when the training is restarted");
cli.add<bool>("--valid-reset-all",
"Reset all validation metrics when the training is restarted");
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);

View File

@ -128,7 +128,7 @@ SentenceTuple Corpus::next() {
size_t vocabId = i - shift;
bool altered;
preprocessLine(fields[i], vocabId, curId, /*out=*/altered);
if (altered)
if(altered)
tup.markAltered();
addWordsToSentenceTuple(fields[i], vocabId, tup);
}

View File

@ -476,7 +476,10 @@ void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
// If the batch vector is altered within marian by, for example, case augmentation,
// the guided alignments we received for this tuple cease to be valid.
// Hence skip setting alignments for that sentence tuple..
if (!batchVector[b].isAltered()) {
if (batchVector[b].isAltered()) {
LOG_ONCE(info, "Using guided-alignment with case-augmentation is not recommended and can result in unexpected behavior");
aligns.push_back(WordAlignment());
} else {
aligns.push_back(std::move(batchVector[b].getAlignment()));
}
}

View File

@ -56,12 +56,16 @@ public:
* @brief Returns whether this Tuple was altered or augmented from what
* was provided to Marian in input.
*/
bool isAltered() const { return altered_; }
bool isAltered() const {
return altered_;
}
/**
* @brief Mark that this Tuple was internally altered or augmented by Marian
*/
void markAltered() { altered_ = true; }
void markAltered() {
altered_ = true;
}
/**
* @brief Adds a new sentence at the end of the tuple.

View File

@ -21,8 +21,7 @@ namespace marian {
maxSizeUnused;
// If model has already been loaded, then assume this is a shared object, and skip loading it again.
// This can be multi-threaded, so must run under lock.
static std::mutex s_mtx;
std::lock_guard<std::mutex> criticalSection(s_mtx);
std::lock_guard<std::mutex> criticalSection(loadMtx_);
if (size() != 0) {
//LOG(info, "[vocab] Attempting to load model a second time; skipping (assuming shared vocab)");
return size();

View File

@ -110,6 +110,7 @@ private:
Word unkId_{};
WordLUT vocab_;
size_t lemmaSize_;
std::mutex loadMtx_;
// factors
char factorSeparator_ = '|'; // separator symbol for parsing factored words

View File

@ -64,6 +64,14 @@ Expr ExpressionGraph::add(Expr node) {
}
}
/**
* Removes the node from the set of roots (will not be initialized during back propagation)
* @param node a pointer to a expression node
*/
void ExpressionGraph::removeAsRoot(Expr node) {
topNodes_.erase(node);
}
// Call on every checkpoint in backwards order
void createSubtape(Expr node) {
auto subtape = New<std::list<Expr>>();

View File

@ -676,6 +676,12 @@ public:
* @param node a pointer to a expression node
*/
Expr add(Expr node);
/**
* Removes the node from the set of roots (will not be initialized during back propagation)
* @param node a pointer to a expression node
*/
void removeAsRoot(Expr node);
/**
* Allocate memory for the forward pass of the given node.

View File

@ -27,6 +27,11 @@ Expr checkpoint(Expr a) {
return a;
}
Expr removeAsRoot(Expr a) {
a->graph()->removeAsRoot(a); // ugly, hence why hidden here
return a;
}
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
LambdaNodeFunctor fwd, size_t hash) {
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);

View File

@ -16,6 +16,11 @@ Expr debug(Expr a, const std::string& message = "");
*/
Expr checkpoint(Expr a);
/**
* Removes the node from the set of root nodes, no-op otherwise
*/
Expr removeAsRoot(Expr node);
typedef Expr(ActivationFunction)(Expr); ///< ActivationFunction has signature Expr(Expr)
/**

View File

@ -6,10 +6,8 @@ namespace marian {
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
std::string name = opt<std::string>("prefix");
int dimVoc = opt<int>("dimVocab");
int dimEmb = opt<int>("dimEmb");
int dimFactorEmb = opt<int>("dimFactorEmb");
int dimVoc = opt<int>("dimVocab");
int dimEmb = opt<int>("dimEmb");
bool fixed = opt<bool>("fixed", false);
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
@ -21,6 +19,7 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
dimVoc = (int)factoredVocab_->factorVocabSize();
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
if(opt<std::string>("factorsCombine") == "concat") {
int dimFactorEmb = opt<int>("dimFactorEmb", 0);
ABORT_IF(dimFactorEmb == 0,
"Embedding: If concatenation is chosen to combine the factor embeddings, a factor "
"embedding size must be specified.");
@ -179,8 +178,8 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape&
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
: prefix_ + "_Wemb",
"fixed", embeddingFix_,
"dimFactorEmb", opt<int>("factors-dim-emb"), // for factored embeddings
"factorsCombine", opt<std::string>("factors-combine"), // for factored embeddings
"dimFactorEmb", opt<int>("factors-dim-emb", 0), // for factored embeddings
"factorsCombine", opt<std::string>("factors-combine", ""), // for factored embeddings
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
// clang-format on
if(options_->hasAndNotEmpty("embedding-vectors")) {

View File

@ -26,7 +26,8 @@ guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
std::vector<IndexType> indices; std::vector<float> valuesFwd;
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
indices.reserve(byIndex.size());
valuesFwd.reserve(byIndex.size());
for(auto& p : byIndex) {
indices.push_back((IndexType)std::get<0>(p));
valuesFwd.push_back(std::get<1>(p));
@ -40,28 +41,33 @@ static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
Ptr<Options> options,
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
// a substantial rewrite of the multi-loss architecture, which is planned anyways.
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
// We dropped support for other losses which are not possible to implement with sparse labels.
// They were most likely not used anyway.
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
float guidedLossWeight = options->get<float>("guided-alignment-weight");
auto [indices, values] = guidedAlignmentToSparse(batch);
auto alignmentIndices = graph->indices(indices);
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
float epsilon = 1e-6f;
Expr alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
size_t numLabels = alignmentIndices->shape().elements();
const auto& [indices, values] = guidedAlignmentToSparse(batch);
Expr alignmentLoss;
size_t numLabels = indices.size(); // can be zero
if(indices.empty()) {
removeAsRoot(stopGradient(attention)); // unused, hence make sure we don't polute the backwards operations
alignmentLoss = graph->zeros({1});
numLabels = multiLossType == "sum" ? 0 : 1;
} else {
float epsilon = 1e-6f;
auto alignmentIndices = graph->indices(indices);
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
}
// Create label node, also weigh by scalar so labels and cost are in the same domain.
// Fractional label counts are OK. But only if combined as "sum".
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
// a substantial rewrite of the multi-loss architecture, which is planned anyways.
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
if (multiLossType == "sum") // sum of sums
if (multiLossType == "sum") // sum of sums
return RationalLoss(guidedLossWeight * alignmentLoss, guidedLossWeight * numLabels);
else
return RationalLoss(guidedLossWeight * alignmentLoss, (float)numLabels);

View File

@ -1,6 +1,12 @@
#include "layers/lsh.h"
#include "tensors/tensor_operators.h"
#include "common/timer.h"
#include "common/utils.h"
#include "layers/lsh.h"
#include "layers/lsh_impl.h"
#include "tensors/tensor_operators.h"
#if _MSC_VER
#include "3rd_party/faiss/Index.h"
#endif
#include "3rd_party/faiss/utils/hamming.h"
@ -8,10 +14,6 @@
#include "3rd_party/faiss/VectorTransform.h"
#endif
#include "common/timer.h"
#include "layers/lsh_impl.h"
namespace marian {
namespace lsh {
@ -116,7 +118,7 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNR
int currBeamSize = encodedQuery->shape()[0];
int batchSize = encodedQuery->shape()[2];
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
Expr encodedQuery = inputs[0];
Expr encodedWeights = inputs[1];
@ -130,6 +132,32 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNR
ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently?
#if _MSC_VER // unfortunately MSVC is horrible at loop unrolling, so we fall back to the old code (hrmph!) @TODO: figure this out one day
int qRows = encodedQuery->shape().elements() / bytesPerVector;
uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
// use actual faiss code for performing the hamming search.
std::vector<int> distances(qRows * dimK);
std::vector<faiss::Index::idx_t> ids(qRows * dimK);
faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)dimK, ids.data(), distances.data()};
faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
// Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
// The sorting is required as we later do a binary search on those values for reverse look-up.
uint32_t* outData = out->val()->data<uint32_t>();
int numHypos = out->shape().elements() / dimK;
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
size_t startIdx = dimK * hypoIdx;
size_t endIdx = startIdx + dimK;
for(size_t i = startIdx; i < endIdx; ++i)
outData[i] = (uint32_t)ids[i];
if(!noSort)
std::sort(outData + startIdx, outData + endIdx);
}
#else // this is using the new code for search, other parts of the code, like conversion are fine.
IndexType* outData = out->val()->data<IndexType>();
auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) {
outData[rowId * dimK + k] = kthColId;
@ -144,6 +172,7 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNR
params.bytesPerVector = bytesPerVector;
hammingTopK(params, gather);
#endif
};
Shape kShape({currBeamSize, batchSize, dimK});

View File

@ -494,12 +494,17 @@ public:
state_->wordsDisp = 0;
}
if(options_->get<bool>("valid-reset-stalled")) {
if(options_->get<bool>("valid-reset-stalled") || options_->get<bool>("valid-reset-all")) {
state_->stalled = 0;
state_->maxStalled = 0;
for(const auto& validator : validators_) {
if(state_->validators[validator->type()])
if(state_->validators[validator->type()]) {
// reset the number of stalled validations, e.g. when the validation set is the same
state_->validators[validator->type()]["stalled"] = 0;
// reset last best results as well, e.g. when the validation set changes
if(options_->get<bool>("valid-reset-all"))
state_->validators[validator->type()]["last-best"] = validator->initScore();
}
}
}
@ -512,10 +517,10 @@ public:
if(mpi_->isMainProcess())
if(filesystem::exists(nameYaml))
yamlStr = io::InputFileStream(nameYaml).readToString();
if(mpi_)
mpi_->bCast(yamlStr);
loadFromString(yamlStr);
}