mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Merged PR 19597: Enable mpi wrapper to use size larger than MAX_INT
Enable mpi wrapper to use size larger than MAX_INT.
This commit is contained in:
parent
85eb6adce0
commit
fc0f41f24a
@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Compute aligned memory sizes using exact sizing
|
||||
|
||||
### Fixed
|
||||
- Added support to MPIWrappest::bcast (and similar) for count of type size_t
|
||||
- Adding new validation metrics when training is restarted and --reset-valid-stalled is used
|
||||
- Missing depth-scaling in transformer FFN
|
||||
- Fixed an issue when loading intgemm16 models from unaligned memory.
|
||||
|
@ -123,20 +123,73 @@ public:
|
||||
virtual void barrier(MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
HANDLE_MPI_ERROR(MPI_Barrier(comm));
|
||||
}
|
||||
|
||||
virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
HANDLE_MPI_ERROR(MPI_Bcast(buf, (int)count, datatype, (int)rootRank, comm));
|
||||
// MPI_Bcast only supports MAX_INT count, here and in the functions below, we need to cycle through the counts until we have sent
|
||||
// all elemements if count is larger MAX_INT.
|
||||
|
||||
// get the data type size in bytes
|
||||
int datatypeSize;
|
||||
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
|
||||
|
||||
// get the limit for int count
|
||||
size_t limit = (size_t)std::numeric_limits<int>::max();
|
||||
size_t remaining = count, offset = 0;
|
||||
|
||||
// while there are elements that we have not sent yet, loop until all has been sent in chunks of at most `limit`.
|
||||
while(remaining > 0) {
|
||||
int intCount = (int)std::min(remaining, limit);
|
||||
HANDLE_MPI_ERROR(MPI_Bcast((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)rootRank, comm));
|
||||
offset += (size_t)intCount;
|
||||
remaining -= (size_t)intCount;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
|
||||
HANDLE_MPI_ERROR(MPI_Ssend(buf, (int)count, datatype, (int)destRank, tag, comm));
|
||||
int datatypeSize;
|
||||
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
|
||||
|
||||
size_t limit = (size_t)std::numeric_limits<int>::max();
|
||||
size_t remaining = count, offset = 0;
|
||||
while(remaining > 0) {
|
||||
int intCount = (int)std::min(remaining, limit);
|
||||
HANDLE_MPI_ERROR(MPI_Ssend((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)destRank, tag, comm));
|
||||
offset += (size_t)intCount;
|
||||
remaining -= (size_t)intCount;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
|
||||
HANDLE_MPI_ERROR(MPI_Recv(buf, (int)count, datatype, (int)sourceRank, tag, comm, status));
|
||||
int datatypeSize;
|
||||
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
|
||||
|
||||
size_t limit = (size_t)std::numeric_limits<int>::max();
|
||||
size_t remaining = count, offset = 0;
|
||||
while(remaining > 0) {
|
||||
int intCount = (int)std::min(remaining, limit);
|
||||
HANDLE_MPI_ERROR(MPI_Recv((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)sourceRank, tag, comm, status));
|
||||
offset += (size_t)intCount;
|
||||
remaining -= (size_t)intCount;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
|
||||
if (sendbuf == recvbuf)
|
||||
sendbuf = MPI_IN_PLACE; // MSMPI requires this
|
||||
HANDLE_MPI_ERROR(MPI_Allreduce(sendbuf, recvbuf, (int)count, datatype, op, comm));
|
||||
|
||||
int datatypeSize;
|
||||
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
|
||||
|
||||
size_t limit = (size_t)std::numeric_limits<int>::max();
|
||||
size_t remaining = count, offset = 0;
|
||||
while(remaining > 0) {
|
||||
int intCount = (int)std::min(remaining, limit);
|
||||
HANDLE_MPI_ERROR(MPI_Allreduce((char*)sendbuf + offset * (size_t)datatypeSize, (char*)recvbuf + offset * (size_t)datatypeSize, intCount, datatype, op, comm));
|
||||
offset += (size_t)intCount;
|
||||
remaining -= (size_t)intCount;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void finalize() override {
|
||||
HANDLE_MPI_ERROR(MPI_Finalize());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user