mirror of
https://github.com/marian-nmt/marian.git
synced 2024-07-14 17:40:36 +03:00
Merge branch 'master' into pmaster
This commit is contained in:
commit
4ffd292881
@ -12,8 +12,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||||||
- Fused inplace-dropout in FFN layer in Transformer
|
- Fused inplace-dropout in FFN layer in Transformer
|
||||||
- `--force-decode` option for marian-decoder
|
- `--force-decode` option for marian-decoder
|
||||||
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
|
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
|
||||||
|
- `--valid-reset-all` option
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
- Make concat factors not break old vector implementation
|
||||||
- Use allocator in hashing
|
- Use allocator in hashing
|
||||||
- Read/restore checkpoints from main process only when training with MPI
|
- Read/restore checkpoints from main process only when training with MPI
|
||||||
- Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)
|
- Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)
|
||||||
|
@ -16,13 +16,30 @@ parameters:
|
|||||||
# The pipeline CI trigger is set on the branch master only and PR trigger on a
|
# The pipeline CI trigger is set on the branch master only and PR trigger on a
|
||||||
# (non-draft) pull request to any branch
|
# (non-draft) pull request to any branch
|
||||||
trigger:
|
trigger:
|
||||||
- master
|
# This minimizes the number of parallel pipeline runs. When a pipeline is
|
||||||
|
# running, the CI waits until it is completed before starting another one.
|
||||||
|
batch: true
|
||||||
|
branches:
|
||||||
|
include:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
exclude:
|
||||||
|
- azure-regression-tests.yml
|
||||||
|
- contrib
|
||||||
|
- doc
|
||||||
|
- examples
|
||||||
|
- regression-tests
|
||||||
|
- scripts
|
||||||
|
- VERSION
|
||||||
|
- vs
|
||||||
|
- '**/*.md'
|
||||||
|
- '**/*.txt'
|
||||||
|
|
||||||
pool:
|
pool:
|
||||||
name: Azure Pipelines
|
name: Azure Pipelines
|
||||||
|
|
||||||
variables:
|
variables:
|
||||||
- group: marian-prod-tests
|
- group: marian-regression-tests
|
||||||
- name: BOOST_ROOT_WINDOWS
|
- name: BOOST_ROOT_WINDOWS
|
||||||
value: "C:/hostedtoolcache/windows/Boost/1.72.0/x86_64"
|
value: "C:/hostedtoolcache/windows/Boost/1.72.0/x86_64"
|
||||||
- name: BOOST_URL
|
- name: BOOST_URL
|
||||||
@ -32,7 +49,7 @@ variables:
|
|||||||
- name: MKL_DIR
|
- name: MKL_DIR
|
||||||
value: "$(Build.SourcesDirectory)/mkl"
|
value: "$(Build.SourcesDirectory)/mkl"
|
||||||
- name: MKL_URL
|
- name: MKL_URL
|
||||||
value: "https://romang.blob.core.windows.net/mariandev/ci/mkl-2020.1-windows-static.zip"
|
value: "https://data.statmt.org/romang/marian-regression-tests/ci/mkl-2020.1-windows-static.zip"
|
||||||
- name: VCPKG_COMMIT
|
- name: VCPKG_COMMIT
|
||||||
value: 2022.03.10
|
value: 2022.03.10
|
||||||
- name: VCPKG_DIR
|
- name: VCPKG_DIR
|
||||||
@ -52,6 +69,7 @@ stages:
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: BuildWindows
|
- job: BuildWindows
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
condition: eq(${{ parameters.runBuilds }}, true)
|
condition: eq(${{ parameters.runBuilds }}, true)
|
||||||
displayName: Windows
|
displayName: Windows
|
||||||
|
|
||||||
@ -188,6 +206,7 @@ stages:
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: BuildUbuntu
|
- job: BuildUbuntu
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
condition: eq(${{ parameters.runBuilds }}, true)
|
condition: eq(${{ parameters.runBuilds }}, true)
|
||||||
displayName: Ubuntu
|
displayName: Ubuntu
|
||||||
timeoutInMinutes: 120
|
timeoutInMinutes: 120
|
||||||
@ -324,6 +343,7 @@ stages:
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: BuildMacOS
|
- job: BuildMacOS
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
condition: eq(${{ parameters.runBuilds }}, true)
|
condition: eq(${{ parameters.runBuilds }}, true)
|
||||||
displayName: macOS CPU clang
|
displayName: macOS CPU clang
|
||||||
|
|
||||||
@ -373,6 +393,7 @@ stages:
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: BuildInstall
|
- job: BuildInstall
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
condition: eq(${{ parameters.runBuilds }}, true)
|
condition: eq(${{ parameters.runBuilds }}, true)
|
||||||
displayName: Linux CPU library install
|
displayName: Linux CPU library install
|
||||||
|
|
||||||
@ -435,6 +456,7 @@ stages:
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: TestWindows
|
- job: TestWindows
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
displayName: Windows CPU+FBGEMM
|
displayName: Windows CPU+FBGEMM
|
||||||
|
|
||||||
pool:
|
pool:
|
||||||
@ -528,14 +550,14 @@ stages:
|
|||||||
displayName: Machine statistics
|
displayName: Machine statistics
|
||||||
workingDirectory: marian-prod-tests
|
workingDirectory: marian-prod-tests
|
||||||
|
|
||||||
# The current SAS token will expire on 8/30/2023 and a new one will need to be set in Marian > Pipelines > Library
|
# The current SAS token will expire on 12/31/2023 and a new one will need to be set in Marian > Pipelines > Library
|
||||||
- bash: |
|
- bash: |
|
||||||
cd models
|
cd models
|
||||||
bash download-models.sh
|
bash download-models.sh
|
||||||
ls
|
ls
|
||||||
displayName: Prepare tests
|
displayName: Prepare tests
|
||||||
env:
|
env:
|
||||||
AWS_SECRET_SAS_TOKEN: $(blob-sas-token)
|
AZURE_STORAGE_SAS_TOKEN: $(marian-prod-tests-blob-sas-token)
|
||||||
workingDirectory: marian-prod-tests
|
workingDirectory: marian-prod-tests
|
||||||
|
|
||||||
# Avoid using $(Build.SourcesDirectory) in bash tasks because on Windows pools it uses '\'
|
# Avoid using $(Build.SourcesDirectory) in bash tasks because on Windows pools it uses '\'
|
||||||
@ -560,6 +582,7 @@ stages:
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: TestLinux
|
- job: TestLinux
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
displayName: Linux CPU+FBGEMM
|
displayName: Linux CPU+FBGEMM
|
||||||
|
|
||||||
pool:
|
pool:
|
||||||
@ -572,7 +595,10 @@ stages:
|
|||||||
|
|
||||||
# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
|
# The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev
|
||||||
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
|
# No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev
|
||||||
- bash: sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
|
# Installing libunwind-dev fixes a bug in 2204 (the libunwind-14 and libunwind-dev conflict)
|
||||||
|
- bash: |
|
||||||
|
sudo apt-get install -y libunwind-dev
|
||||||
|
sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler gcc-9 g++-9
|
||||||
displayName: Install packages
|
displayName: Install packages
|
||||||
|
|
||||||
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
|
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
|
||||||
@ -629,14 +655,14 @@ stages:
|
|||||||
displayName: Machine statistics
|
displayName: Machine statistics
|
||||||
workingDirectory: marian-prod-tests
|
workingDirectory: marian-prod-tests
|
||||||
|
|
||||||
# The current SAS token will expire on 8/30/2023 and a new one will need to be set in Marian > Pipelines > Library
|
# The current SAS token will expire on 12/31/2023 and a new one will need to be set in Marian > Pipelines > Library
|
||||||
- bash: |
|
- bash: |
|
||||||
cd models
|
cd models
|
||||||
bash download-models.sh
|
bash download-models.sh
|
||||||
ls
|
ls
|
||||||
displayName: Prepare tests
|
displayName: Prepare tests
|
||||||
env:
|
env:
|
||||||
AWS_SECRET_SAS_TOKEN: $(blob-sas-token)
|
AZURE_STORAGE_SAS_TOKEN: $(marian-prod-tests-blob-sas-token)
|
||||||
workingDirectory: marian-prod-tests
|
workingDirectory: marian-prod-tests
|
||||||
|
|
||||||
- bash: MARIAN=../marian-dev/build bash ./run_mrt.sh '#cpu' '#basics' '#devops'
|
- bash: MARIAN=../marian-dev/build bash ./run_mrt.sh '#cpu' '#basics' '#devops'
|
||||||
|
@ -14,12 +14,16 @@ trigger: none
|
|||||||
# Hosted Azure DevOps Pool determining OS, CUDA version and available GPUs
|
# Hosted Azure DevOps Pool determining OS, CUDA version and available GPUs
|
||||||
pool: mariandevops-pool-m60-eus
|
pool: mariandevops-pool-m60-eus
|
||||||
|
|
||||||
|
variables:
|
||||||
|
- group: marian-regression-tests
|
||||||
|
|
||||||
stages:
|
stages:
|
||||||
- stage: TestsGPU
|
- stage: TestsGPU
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
- job: TestsGPULinux
|
- job: TestsGPULinux
|
||||||
|
cancelTimeoutInMinutes: 1
|
||||||
displayName: Linux GPU tests
|
displayName: Linux GPU tests
|
||||||
timeoutInMinutes: 120
|
timeoutInMinutes: 120
|
||||||
|
|
||||||
@ -103,11 +107,14 @@ stages:
|
|||||||
workingDirectory: build
|
workingDirectory: build
|
||||||
|
|
||||||
# Always run regression tests from the master branch
|
# Always run regression tests from the master branch
|
||||||
|
# The current SAS token will expire on 12/31/2023 and a new one will need to be set in Marian > Pipelines > Library
|
||||||
- bash: |
|
- bash: |
|
||||||
git checkout master
|
git checkout master
|
||||||
git pull origin master
|
git pull origin master
|
||||||
make install
|
make install
|
||||||
displayName: Prepare regression tests
|
displayName: Prepare regression tests
|
||||||
|
env:
|
||||||
|
AZURE_STORAGE_SAS_TOKEN: $(marian-pub-tests-blob-sas-token)
|
||||||
workingDirectory: regression-tests
|
workingDirectory: regression-tests
|
||||||
|
|
||||||
# Continue on error to be able to collect outputs and publish them as an artifact
|
# Continue on error to be able to collect outputs and publish them as an artifact
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 488d454a0177ef300eab91ab813e485d420dc38d
|
Subproject commit 2a8bed3f0e937a9de2d6fa92dee3bcf482d3d47b
|
@ -42,11 +42,15 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print(model[args.key])
|
print(model[args.key])
|
||||||
else:
|
else:
|
||||||
|
total_nb_of_parameters = 0
|
||||||
for key in model:
|
for key in model:
|
||||||
if args.matrix_shapes:
|
if not key == S2S_SPECIAL_NODE:
|
||||||
print(key, model[key].shape)
|
total_nb_of_parameters += np.prod(model[key].shape)
|
||||||
|
if args.matrix_info:
|
||||||
|
print(key, model[key].shape, model[key].dtype)
|
||||||
else:
|
else:
|
||||||
print(key)
|
print(key)
|
||||||
|
print('Total number of parameters:', total_nb_of_parameters)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -57,8 +61,8 @@ def parse_args():
|
|||||||
help="print values from special:model.yml node")
|
help="print values from special:model.yml node")
|
||||||
parser.add_argument("-f", "--full-matrix", action="store_true",
|
parser.add_argument("-f", "--full-matrix", action="store_true",
|
||||||
help="force numpy to print full arrays for single key")
|
help="force numpy to print full arrays for single key")
|
||||||
parser.add_argument("-ms", "--matrix-shapes", action="store_true",
|
parser.add_argument("-mi", "--matrix-info", action="store_true",
|
||||||
help="print shapes of all arrays in the model")
|
help="print full matrix info for all keys. Includes shape and dtype")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
|||||||
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
|
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
|
||||||
cli.add<int>("--transformer-dim-ffn",
|
cli.add<int>("--transformer-dim-ffn",
|
||||||
"Size of position-wise feed-forward network (transformer)",
|
"Size of position-wise feed-forward network (transformer)",
|
||||||
2048);
|
2048);
|
||||||
cli.add<int>("--transformer-decoder-dim-ffn",
|
cli.add<int>("--transformer-decoder-dim-ffn",
|
||||||
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
|
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
|
||||||
0);
|
0);
|
||||||
@ -591,7 +591,9 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
|
|||||||
"Multiple metrics can be specified",
|
"Multiple metrics can be specified",
|
||||||
{"cross-entropy"});
|
{"cross-entropy"});
|
||||||
cli.add<bool>("--valid-reset-stalled",
|
cli.add<bool>("--valid-reset-stalled",
|
||||||
"Reset all stalled validation metrics when the training is restarted");
|
"Reset stalled validation metrics when the training is restarted");
|
||||||
|
cli.add<bool>("--valid-reset-all",
|
||||||
|
"Reset all validation metrics when the training is restarted");
|
||||||
cli.add<size_t>("--early-stopping",
|
cli.add<size_t>("--early-stopping",
|
||||||
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
"Stop if the first validation metric does not improve for arg consecutive validation steps",
|
||||||
10);
|
10);
|
||||||
|
@ -128,7 +128,7 @@ SentenceTuple Corpus::next() {
|
|||||||
size_t vocabId = i - shift;
|
size_t vocabId = i - shift;
|
||||||
bool altered;
|
bool altered;
|
||||||
preprocessLine(fields[i], vocabId, curId, /*out=*/altered);
|
preprocessLine(fields[i], vocabId, curId, /*out=*/altered);
|
||||||
if (altered)
|
if(altered)
|
||||||
tup.markAltered();
|
tup.markAltered();
|
||||||
addWordsToSentenceTuple(fields[i], vocabId, tup);
|
addWordsToSentenceTuple(fields[i], vocabId, tup);
|
||||||
}
|
}
|
||||||
|
@ -476,7 +476,10 @@ void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
|
|||||||
// If the batch vector is altered within marian by, for example, case augmentation,
|
// If the batch vector is altered within marian by, for example, case augmentation,
|
||||||
// the guided alignments we received for this tuple cease to be valid.
|
// the guided alignments we received for this tuple cease to be valid.
|
||||||
// Hence skip setting alignments for that sentence tuple..
|
// Hence skip setting alignments for that sentence tuple..
|
||||||
if (!batchVector[b].isAltered()) {
|
if (batchVector[b].isAltered()) {
|
||||||
|
LOG_ONCE(info, "Using guided-alignment with case-augmentation is not recommended and can result in unexpected behavior");
|
||||||
|
aligns.push_back(WordAlignment());
|
||||||
|
} else {
|
||||||
aligns.push_back(std::move(batchVector[b].getAlignment()));
|
aligns.push_back(std::move(batchVector[b].getAlignment()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,12 +56,16 @@ public:
|
|||||||
* @brief Returns whether this Tuple was altered or augmented from what
|
* @brief Returns whether this Tuple was altered or augmented from what
|
||||||
* was provided to Marian in input.
|
* was provided to Marian in input.
|
||||||
*/
|
*/
|
||||||
bool isAltered() const { return altered_; }
|
bool isAltered() const {
|
||||||
|
return altered_;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Mark that this Tuple was internally altered or augmented by Marian
|
* @brief Mark that this Tuple was internally altered or augmented by Marian
|
||||||
*/
|
*/
|
||||||
void markAltered() { altered_ = true; }
|
void markAltered() {
|
||||||
|
altered_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Adds a new sentence at the end of the tuple.
|
* @brief Adds a new sentence at the end of the tuple.
|
||||||
|
@ -21,8 +21,7 @@ namespace marian {
|
|||||||
maxSizeUnused;
|
maxSizeUnused;
|
||||||
// If model has already been loaded, then assume this is a shared object, and skip loading it again.
|
// If model has already been loaded, then assume this is a shared object, and skip loading it again.
|
||||||
// This can be multi-threaded, so must run under lock.
|
// This can be multi-threaded, so must run under lock.
|
||||||
static std::mutex s_mtx;
|
std::lock_guard<std::mutex> criticalSection(loadMtx_);
|
||||||
std::lock_guard<std::mutex> criticalSection(s_mtx);
|
|
||||||
if (size() != 0) {
|
if (size() != 0) {
|
||||||
//LOG(info, "[vocab] Attempting to load model a second time; skipping (assuming shared vocab)");
|
//LOG(info, "[vocab] Attempting to load model a second time; skipping (assuming shared vocab)");
|
||||||
return size();
|
return size();
|
||||||
|
@ -110,6 +110,7 @@ private:
|
|||||||
Word unkId_{};
|
Word unkId_{};
|
||||||
WordLUT vocab_;
|
WordLUT vocab_;
|
||||||
size_t lemmaSize_;
|
size_t lemmaSize_;
|
||||||
|
std::mutex loadMtx_;
|
||||||
|
|
||||||
// factors
|
// factors
|
||||||
char factorSeparator_ = '|'; // separator symbol for parsing factored words
|
char factorSeparator_ = '|'; // separator symbol for parsing factored words
|
||||||
|
@ -64,6 +64,14 @@ Expr ExpressionGraph::add(Expr node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes the node from the set of roots (will not be initialized during back propagation)
|
||||||
|
* @param node a pointer to a expression node
|
||||||
|
*/
|
||||||
|
void ExpressionGraph::removeAsRoot(Expr node) {
|
||||||
|
topNodes_.erase(node);
|
||||||
|
}
|
||||||
|
|
||||||
// Call on every checkpoint in backwards order
|
// Call on every checkpoint in backwards order
|
||||||
void createSubtape(Expr node) {
|
void createSubtape(Expr node) {
|
||||||
auto subtape = New<std::list<Expr>>();
|
auto subtape = New<std::list<Expr>>();
|
||||||
|
@ -676,6 +676,12 @@ public:
|
|||||||
* @param node a pointer to a expression node
|
* @param node a pointer to a expression node
|
||||||
*/
|
*/
|
||||||
Expr add(Expr node);
|
Expr add(Expr node);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes the node from the set of roots (will not be initialized during back propagation)
|
||||||
|
* @param node a pointer to a expression node
|
||||||
|
*/
|
||||||
|
void removeAsRoot(Expr node);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Allocate memory for the forward pass of the given node.
|
* Allocate memory for the forward pass of the given node.
|
||||||
|
@ -27,6 +27,11 @@ Expr checkpoint(Expr a) {
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Expr removeAsRoot(Expr a) {
|
||||||
|
a->graph()->removeAsRoot(a); // ugly, hence why hidden here
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
||||||
LambdaNodeFunctor fwd, size_t hash) {
|
LambdaNodeFunctor fwd, size_t hash) {
|
||||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
|
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
|
||||||
|
@ -16,6 +16,11 @@ Expr debug(Expr a, const std::string& message = "");
|
|||||||
*/
|
*/
|
||||||
Expr checkpoint(Expr a);
|
Expr checkpoint(Expr a);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes the node from the set of root nodes, no-op otherwise
|
||||||
|
*/
|
||||||
|
Expr removeAsRoot(Expr node);
|
||||||
|
|
||||||
typedef Expr(ActivationFunction)(Expr); ///< ActivationFunction has signature Expr(Expr)
|
typedef Expr(ActivationFunction)(Expr); ///< ActivationFunction has signature Expr(Expr)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -6,10 +6,8 @@ namespace marian {
|
|||||||
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||||
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
: LayerBase(graph, options), inference_(opt<bool>("inference")) {
|
||||||
std::string name = opt<std::string>("prefix");
|
std::string name = opt<std::string>("prefix");
|
||||||
int dimVoc = opt<int>("dimVocab");
|
int dimVoc = opt<int>("dimVocab");
|
||||||
int dimEmb = opt<int>("dimEmb");
|
int dimEmb = opt<int>("dimEmb");
|
||||||
int dimFactorEmb = opt<int>("dimFactorEmb");
|
|
||||||
|
|
||||||
bool fixed = opt<bool>("fixed", false);
|
bool fixed = opt<bool>("fixed", false);
|
||||||
|
|
||||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||||
@ -21,6 +19,7 @@ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
|||||||
dimVoc = (int)factoredVocab_->factorVocabSize();
|
dimVoc = (int)factoredVocab_->factorVocabSize();
|
||||||
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
|
LOG_ONCE(info, "[embedding] Factored embeddings enabled");
|
||||||
if(opt<std::string>("factorsCombine") == "concat") {
|
if(opt<std::string>("factorsCombine") == "concat") {
|
||||||
|
int dimFactorEmb = opt<int>("dimFactorEmb", 0);
|
||||||
ABORT_IF(dimFactorEmb == 0,
|
ABORT_IF(dimFactorEmb == 0,
|
||||||
"Embedding: If concatenation is chosen to combine the factor embeddings, a factor "
|
"Embedding: If concatenation is chosen to combine the factor embeddings, a factor "
|
||||||
"embedding size must be specified.");
|
"embedding size must be specified.");
|
||||||
@ -179,8 +178,8 @@ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape&
|
|||||||
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
|
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
|
||||||
: prefix_ + "_Wemb",
|
: prefix_ + "_Wemb",
|
||||||
"fixed", embeddingFix_,
|
"fixed", embeddingFix_,
|
||||||
"dimFactorEmb", opt<int>("factors-dim-emb"), // for factored embeddings
|
"dimFactorEmb", opt<int>("factors-dim-emb", 0), // for factored embeddings
|
||||||
"factorsCombine", opt<std::string>("factors-combine"), // for factored embeddings
|
"factorsCombine", opt<std::string>("factors-combine", ""), // for factored embeddings
|
||||||
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
|
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
|
||||||
// clang-format on
|
// clang-format on
|
||||||
if(options_->hasAndNotEmpty("embedding-vectors")) {
|
if(options_->hasAndNotEmpty("embedding-vectors")) {
|
||||||
|
@ -26,7 +26,8 @@ guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
|
|||||||
|
|
||||||
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
|
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
|
||||||
std::vector<IndexType> indices; std::vector<float> valuesFwd;
|
std::vector<IndexType> indices; std::vector<float> valuesFwd;
|
||||||
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
|
indices.reserve(byIndex.size());
|
||||||
|
valuesFwd.reserve(byIndex.size());
|
||||||
for(auto& p : byIndex) {
|
for(auto& p : byIndex) {
|
||||||
indices.push_back((IndexType)std::get<0>(p));
|
indices.push_back((IndexType)std::get<0>(p));
|
||||||
valuesFwd.push_back(std::get<1>(p));
|
valuesFwd.push_back(std::get<1>(p));
|
||||||
@ -40,28 +41,33 @@ static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
|
|||||||
Ptr<Options> options,
|
Ptr<Options> options,
|
||||||
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
|
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
|
||||||
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
|
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
|
||||||
|
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
|
||||||
|
// a substantial rewrite of the multi-loss architecture, which is planned anyways.
|
||||||
|
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
|
||||||
|
|
||||||
// We dropped support for other losses which are not possible to implement with sparse labels.
|
// We dropped support for other losses which are not possible to implement with sparse labels.
|
||||||
// They were most likely not used anyway.
|
// They were most likely not used anyway.
|
||||||
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
|
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
|
||||||
|
|
||||||
float guidedLossWeight = options->get<float>("guided-alignment-weight");
|
float guidedLossWeight = options->get<float>("guided-alignment-weight");
|
||||||
|
const auto& [indices, values] = guidedAlignmentToSparse(batch);
|
||||||
auto [indices, values] = guidedAlignmentToSparse(batch);
|
|
||||||
auto alignmentIndices = graph->indices(indices);
|
Expr alignmentLoss;
|
||||||
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
|
size_t numLabels = indices.size(); // can be zero
|
||||||
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
|
if(indices.empty()) {
|
||||||
|
removeAsRoot(stopGradient(attention)); // unused, hence make sure we don't polute the backwards operations
|
||||||
float epsilon = 1e-6f;
|
alignmentLoss = graph->zeros({1});
|
||||||
Expr alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
|
numLabels = multiLossType == "sum" ? 0 : 1;
|
||||||
size_t numLabels = alignmentIndices->shape().elements();
|
} else {
|
||||||
|
float epsilon = 1e-6f;
|
||||||
|
auto alignmentIndices = graph->indices(indices);
|
||||||
|
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
|
||||||
|
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
|
||||||
|
alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
|
||||||
|
}
|
||||||
// Create label node, also weigh by scalar so labels and cost are in the same domain.
|
// Create label node, also weigh by scalar so labels and cost are in the same domain.
|
||||||
// Fractional label counts are OK. But only if combined as "sum".
|
// Fractional label counts are OK. But only if combined as "sum".
|
||||||
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
|
if (multiLossType == "sum") // sum of sums
|
||||||
// a substantial rewrite of the multi-loss architecture, which is planned anyways.
|
|
||||||
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
|
|
||||||
if (multiLossType == "sum") // sum of sums
|
|
||||||
return RationalLoss(guidedLossWeight * alignmentLoss, guidedLossWeight * numLabels);
|
return RationalLoss(guidedLossWeight * alignmentLoss, guidedLossWeight * numLabels);
|
||||||
else
|
else
|
||||||
return RationalLoss(guidedLossWeight * alignmentLoss, (float)numLabels);
|
return RationalLoss(guidedLossWeight * alignmentLoss, (float)numLabels);
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
#include "layers/lsh.h"
|
#include "common/timer.h"
|
||||||
#include "tensors/tensor_operators.h"
|
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
|
#include "layers/lsh.h"
|
||||||
|
#include "layers/lsh_impl.h"
|
||||||
|
#include "tensors/tensor_operators.h"
|
||||||
|
|
||||||
|
#if _MSC_VER
|
||||||
|
#include "3rd_party/faiss/Index.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "3rd_party/faiss/utils/hamming.h"
|
#include "3rd_party/faiss/utils/hamming.h"
|
||||||
|
|
||||||
@ -8,10 +14,6 @@
|
|||||||
#include "3rd_party/faiss/VectorTransform.h"
|
#include "3rd_party/faiss/VectorTransform.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "common/timer.h"
|
|
||||||
|
|
||||||
#include "layers/lsh_impl.h"
|
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
namespace lsh {
|
namespace lsh {
|
||||||
|
|
||||||
@ -116,7 +118,7 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNR
|
|||||||
|
|
||||||
int currBeamSize = encodedQuery->shape()[0];
|
int currBeamSize = encodedQuery->shape()[0];
|
||||||
int batchSize = encodedQuery->shape()[2];
|
int batchSize = encodedQuery->shape()[2];
|
||||||
|
|
||||||
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
|
auto search = [=](Expr out, const std::vector<Expr>& inputs) {
|
||||||
Expr encodedQuery = inputs[0];
|
Expr encodedQuery = inputs[0];
|
||||||
Expr encodedWeights = inputs[1];
|
Expr encodedWeights = inputs[1];
|
||||||
@ -130,6 +132,32 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNR
|
|||||||
|
|
||||||
ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently?
|
ABORT_IF(dimK > wRows, "k is larger than number of candidate values?"); // @TODO: use min(k, wRows) silently?
|
||||||
|
|
||||||
|
#if _MSC_VER // unfortunately MSVC is horrible at loop unrolling, so we fall back to the old code (hrmph!) @TODO: figure this out one day
|
||||||
|
int qRows = encodedQuery->shape().elements() / bytesPerVector;
|
||||||
|
|
||||||
|
uint8_t* qCodes = encodedQuery->val()->data<uint8_t>();
|
||||||
|
uint8_t* wCodes = encodedWeights->val()->data<uint8_t>();
|
||||||
|
|
||||||
|
// use actual faiss code for performing the hamming search.
|
||||||
|
std::vector<int> distances(qRows * dimK);
|
||||||
|
std::vector<faiss::Index::idx_t> ids(qRows * dimK);
|
||||||
|
faiss::int_maxheap_array_t res = {(size_t)qRows, (size_t)dimK, ids.data(), distances.data()};
|
||||||
|
faiss::hammings_knn_hc(&res, qCodes, wCodes, (size_t)wRows, (size_t)bytesPerVector, 0);
|
||||||
|
|
||||||
|
// Copy int64_t indices to Marian index type and sort by increasing index value per hypothesis.
|
||||||
|
// The sorting is required as we later do a binary search on those values for reverse look-up.
|
||||||
|
uint32_t* outData = out->val()->data<uint32_t>();
|
||||||
|
|
||||||
|
int numHypos = out->shape().elements() / dimK;
|
||||||
|
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
|
||||||
|
size_t startIdx = dimK * hypoIdx;
|
||||||
|
size_t endIdx = startIdx + dimK;
|
||||||
|
for(size_t i = startIdx; i < endIdx; ++i)
|
||||||
|
outData[i] = (uint32_t)ids[i];
|
||||||
|
if(!noSort)
|
||||||
|
std::sort(outData + startIdx, outData + endIdx);
|
||||||
|
}
|
||||||
|
#else // this is using the new code for search, other parts of the code, like conversion are fine.
|
||||||
IndexType* outData = out->val()->data<IndexType>();
|
IndexType* outData = out->val()->data<IndexType>();
|
||||||
auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) {
|
auto gather = [outData, dimK](IndexType rowId, IndexType k, IndexType kthColId, DistType /*dist*/) {
|
||||||
outData[rowId * dimK + k] = kthColId;
|
outData[rowId * dimK + k] = kthColId;
|
||||||
@ -144,6 +172,7 @@ Expr searchEncoded(Expr encodedQuery, Expr encodedWeights, int dimK, int firstNR
|
|||||||
params.bytesPerVector = bytesPerVector;
|
params.bytesPerVector = bytesPerVector;
|
||||||
|
|
||||||
hammingTopK(params, gather);
|
hammingTopK(params, gather);
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
Shape kShape({currBeamSize, batchSize, dimK});
|
Shape kShape({currBeamSize, batchSize, dimK});
|
||||||
|
@ -494,12 +494,17 @@ public:
|
|||||||
state_->wordsDisp = 0;
|
state_->wordsDisp = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(options_->get<bool>("valid-reset-stalled")) {
|
if(options_->get<bool>("valid-reset-stalled") || options_->get<bool>("valid-reset-all")) {
|
||||||
state_->stalled = 0;
|
state_->stalled = 0;
|
||||||
state_->maxStalled = 0;
|
state_->maxStalled = 0;
|
||||||
for(const auto& validator : validators_) {
|
for(const auto& validator : validators_) {
|
||||||
if(state_->validators[validator->type()])
|
if(state_->validators[validator->type()]) {
|
||||||
|
// reset the number of stalled validations, e.g. when the validation set is the same
|
||||||
state_->validators[validator->type()]["stalled"] = 0;
|
state_->validators[validator->type()]["stalled"] = 0;
|
||||||
|
// reset last best results as well, e.g. when the validation set changes
|
||||||
|
if(options_->get<bool>("valid-reset-all"))
|
||||||
|
state_->validators[validator->type()]["last-best"] = validator->initScore();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -512,10 +517,10 @@ public:
|
|||||||
if(mpi_->isMainProcess())
|
if(mpi_->isMainProcess())
|
||||||
if(filesystem::exists(nameYaml))
|
if(filesystem::exists(nameYaml))
|
||||||
yamlStr = io::InputFileStream(nameYaml).readToString();
|
yamlStr = io::InputFileStream(nameYaml).readToString();
|
||||||
|
|
||||||
if(mpi_)
|
if(mpi_)
|
||||||
mpi_->bCast(yamlStr);
|
mpi_->bCast(yamlStr);
|
||||||
|
|
||||||
loadFromString(yamlStr);
|
loadFromString(yamlStr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user