Merged PR 19597: Enable mpi wrapper to use size larger than MAX_INT

Enable mpi wrapper to use size larger than MAX_INT.
This commit is contained in:
Martin Junczys-Dowmunt 2021-06-28 23:15:23 +00:00
parent 85eb6adce0
commit fc0f41f24a
3 changed files with 59 additions and 5 deletions

View File

@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Compute aligned memory sizes using exact sizing
### Fixed
- Added support to MPIWrappest::bcast (and similar) for count of type size_t
- Adding new validation metrics when training is restarted and --reset-valid-stalled is used
- Missing depth-scaling in transformer FFN
- Fixed an issue when loading intgemm16 models from unaligned memory.

View File

@ -1,2 +1,2 @@
v1.10.20
v1.10.21

View File

@ -123,20 +123,73 @@ public:
virtual void barrier(MPI_Comm comm = MPI_COMM_WORLD) const override {
HANDLE_MPI_ERROR(MPI_Barrier(comm));
}
virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm = MPI_COMM_WORLD) const override {
HANDLE_MPI_ERROR(MPI_Bcast(buf, (int)count, datatype, (int)rootRank, comm));
// MPI_Bcast only supports MAX_INT count, here and in the functions below, we need to cycle through the counts until we have sent
// all elemements if count is larger MAX_INT.
// get the data type size in bytes
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
// get the limit for int count
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
// while there are elements that we have not sent yet, loop until all has been sent in chunks of at most `limit`.
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Bcast((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)rootRank, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
HANDLE_MPI_ERROR(MPI_Ssend(buf, (int)count, datatype, (int)destRank, tag, comm));
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Ssend((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)destRank, tag, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
HANDLE_MPI_ERROR(MPI_Recv(buf, (int)count, datatype, (int)sourceRank, tag, comm, status));
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Recv((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)sourceRank, tag, comm, status));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
if (sendbuf == recvbuf)
sendbuf = MPI_IN_PLACE; // MSMPI requires this
HANDLE_MPI_ERROR(MPI_Allreduce(sendbuf, recvbuf, (int)count, datatype, op, comm));
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Allreduce((char*)sendbuf + offset * (size_t)datatypeSize, (char*)recvbuf + offset * (size_t)datatypeSize, intCount, datatype, op, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void finalize() override {
HANDLE_MPI_ERROR(MPI_Finalize());
}