Merge branch 'master' into hihoan/lsh7

This commit is contained in:
Marcin Junczys-Dowmunt 2021-06-29 10:42:13 -07:00
commit 64e787afce
14 changed files with 330 additions and 65 deletions

View File

@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Compute aligned memory sizes using exact sizing
### Fixed
- Added support to MPIWrappest::bcast (and similar) for count of type size_t
- Adding new validation metrics when training is restarted and --reset-valid-stalled is used
- Missing depth-scaling in transformer FFN
- Fixed an issue when loading intgemm16 models from unaligned memory.

View File

@ -1,2 +1,2 @@
v1.10.20
v1.10.21

@ -1 +1 @@
Subproject commit 6f24a6b52a521a3467e99a9c175ba9e136905217
Subproject commit 5bafa8e8c3391bbe9721a16e986408341f95774c

View File

@ -696,6 +696,15 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Use approximate knn search in output layer (currently only in transformer)")
->implicit_val("100 1024");
// parameters for on-line quantization
cli.add<bool>("--optimize",
"Optimize the graph on-the-fly", false);
cli.add<std::string>("--gemm-type,-g",
"GEMM Type to be used for on-line quantization/packing: float32, packed16, packed8", "float32");
cli.add<float>("--quantize-range",
"Range for the on-line quantiziation of weight matrix in multiple of this range and standard deviation, 0.0 means min/max quantization",
0.f);
#if 0 // @TODO: Ask Hany if there are any decoding-time options
// add ULR settings
addSuboptionsULR(cli);
@ -747,6 +756,15 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
"Mixed precision for inference, set parameter type in expression graph",
{"float32"});
// parameters for on-line quantization
cli.add<bool>("--optimize",
"Optimize the graph on-the-fly", false);
cli.add<std::string>("--gemm-type,-g",
"GEMM Type to be used for on-line quantization/packing: float32, packed16, packed8", "float32");
cli.add<float>("--quantize-range",
"Range for the on-line quantiziation of weight matrix in multiple of this range and standard deviation, 0.0 means min/max quantization",
0.f);
cli.switchGroup(previous_group);
// clang-format on
}

View File

@ -483,7 +483,45 @@ Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
// --optimize --cpu-thread=N with N > 0 are set.
if(device == DeviceType::cpu) {
if(isFloat(aElementType) && isFloat(bElementType)) {
return Expression<DotNodeOp>(a, b, transA, transB, scale);
if(b->memoize() && (a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed ||
a->graph()->getBackend()->getGemmType() == GemmType::FbInt8Packed)) {
#if USE_FBGEMM
if(a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed) {
auto packedB = cpu::variant::pack(
marian::Type::packed16, b, cpu::variant::PackMatrix::B, transB);
return cpu::variant::dot(marian::Type::packed16,
a, packedB, b->shape(), transA, transB, scale);
} else {
float quantizeRange = b->graph()->getBackend()->getQuantizeRange();
if(fbgemm::fbgemmHasAvx512Support()) {
auto packedB = cpu::variant::pack(marian::Type::packed8avx512,
b,
cpu::variant::PackMatrix::B,
transB,
quantizeRange);
return cpu::variant::dot(marian::Type::packed8avx512,
a, packedB, b->shape(), transA, transB, scale);
} else if(fbgemm::fbgemmHasAvx2Support()) {
auto packedB = cpu::variant::pack(marian::Type::packed8avx2,
b,
cpu::variant::PackMatrix::B,
transB,
quantizeRange);
return cpu::variant::dot(marian::Type::packed8avx2,
a, packedB, b->shape(), transA, transB, scale);
} else {
ABORT(
"AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed "
"GEMM");
}
}
#else
ABORT("Packed GEMM is not available in this build");
#endif // USE_FBGEMM
} else {
return Expression<DotNodeOp>(
a, b, transA, transB, scale);
}
} else if(isFloat(aElementType) && isIntgemm(bElementType)) {
return cpu::integer::affineOrDot(a, b, nullptr, transA, transB, scale);
} else if(isFloat(aElementType) && isPacked(bElementType)) {
@ -495,7 +533,8 @@ Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
// and this cpu lookup is executed only once and the state is kept in FBGEMM.
if(fbgemm::fbgemmHasAvx2Support()) {
// This variant of dot product can handle matrix multiplications with packed8 and packed16 weight matrix (B).
return cpu::variant::dot(a,
return cpu::variant::dot(b->value_type(),
a,
b,
b->shape(),
transA,
@ -545,7 +584,48 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
if(device == DeviceType::cpu) {
if(isFloat(aElementType) && isFloat(bElementType)) {
return affineDefault(a, b, bias, transA, transB, scale);
if(a->graph()->getBackend()->isOptimized()) {
if(b->memoize() && (a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed ||
a->graph()->getBackend()->getGemmType() == GemmType::FbInt8Packed)) {
#if USE_FBGEMM
if(a->graph()->getBackend()->getGemmType() == GemmType::FbFp16Packed) {
auto packedB = cpu::variant::pack(
marian::Type::packed16, b, cpu::variant::PackMatrix::B, transB);
return cpu::variant::affine(marian::Type::packed16,
a, packedB, b->shape(), bias, transA, transB, scale);
} else {
float quantizeRange = b->graph()->getBackend()->getQuantizeRange();
if(fbgemm::fbgemmHasAvx512Support()) {
auto packedB = cpu::variant::pack(marian::Type::packed8avx512,
b,
cpu::variant::PackMatrix::B,
transB,
quantizeRange);
return cpu::variant::affine(marian::Type::packed8avx512,
a, packedB, b->shape(), bias, transA, transB, scale);
} else if(fbgemm::fbgemmHasAvx2Support()) {
auto packedB = cpu::variant::pack(marian::Type::packed8avx2,
b,
cpu::variant::PackMatrix::B,
transB,
quantizeRange);
return cpu::variant::affine(marian::Type::packed8avx2,
a, packedB, b->shape(), bias, transA, transB, scale);
} else {
ABORT(
"AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed "
"GEMM");
}
}
#else
ABORT("Packed GEMM is not available in this build");
#endif // USE_FBGEMM
} else {
return affineDefault(a, b, bias, transA, transB, scale);
}
} else {
return affineDefault(a, b, bias, transA, transB, scale);
}
} else if(isFloat(aElementType) && isIntgemm(bElementType)) {
return cpu::integer::affineOrDot(a, b, bias, transA, transB, scale);
} else if(isFloat(aElementType) && isPacked(bElementType)) {
@ -557,7 +637,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
// and this cpu lookup is executed only once and the state is kept in FBGEMM.
if(fbgemm::fbgemmHasAvx2Support()) {
// This variant of affine product can handle matrix multiplications with packed8 and packed16 weight matrix (B).
return cpu::variant::affine(a,
return cpu::variant::affine(b->value_type(),
a,
b,
b->shape(),
bias,

View File

@ -177,6 +177,8 @@ static inline std::function<Expr(Expr)> activationByName(const std::string& actN
return (ActivationFunction*)swish;
else if (actName == "gelu")
return (ActivationFunction*)gelu;
else if (actName == "sigmoid")
return (ActivationFunction*)sigmoid;
else if (actName == "") // return identity function if activation name is empty
return [](Expr x) { return x; };
ABORT("Invalid activation name '{}'", actName);

View File

@ -5,6 +5,14 @@
namespace marian {
// GEMM type enum
typedef enum {
Auto = 0, // auto tuning between available GEMMs
Float32 = 1, // MKL based GEMM, fp32
FbFp16Packed = 10, // FBGEMM based fp16 GEMM with packing
FbInt8Packed = 11 // FBGEMM based int8 GEMM with packing
} GemmType;
class Backend {
protected:
DeviceId deviceId_;
@ -21,6 +29,19 @@ public:
// for GPU only, calls cudaSetDevice, does nothing on CPU. Maybe change name.
virtual void setDevice() = 0;
virtual void synchronize() = 0;
// for CPU, sets to use optimized code for inference.
// for GPU, this is invalid. for gpu, isOptimized() function always returns false.
virtual void setOptimized(bool optimize) = 0;
virtual bool isOptimized() = 0;
// for CPU, selects different GEMM types for the inference.
// for GPU, there's no gemm type. so, it does nothing.
virtual void setGemmType(std::string gemmType) = 0;
virtual GemmType getGemmType() = 0;
// for CPU, sets quantization range of weight matrices for the inference.
// for GPU, there's no quantization. so, it does nothing.
virtual void setQuantizeRange(float range) = 0;
virtual float getQuantizeRange() = 0;
};
Ptr<Backend> BackendByDeviceId(DeviceId deviceId, size_t seed);

View File

@ -10,10 +10,34 @@ namespace marian {
namespace cpu {
class Backend : public marian::Backend {
protected:
bool optimized_{false};
GemmType gemmType_{GemmType::Float32};
float quantizeRange_{0.f};
public:
Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) {}
void setDevice() override {}
void synchronize() override {}
// for CPU & inference only, sets to use optimized code for inference. Does nothing for GPU.
void setOptimized(bool optimize) override { optimized_ = optimize; }
bool isOptimized() override { return optimized_; }
// for CPU only, selects different GEMM types for the inference. Does nothing for GPU.
void setGemmType(std::string gemmType) override {
if (gemmType == "auto") gemmType_ = GemmType::Auto;
else if (gemmType == "float32") gemmType_ = GemmType::Float32;
#if USE_FBGEMM
else if (gemmType == "packed16") gemmType_ = GemmType::FbFp16Packed;
else if (gemmType.find("packed8") == 0) gemmType_ = GemmType::FbInt8Packed;
#endif // USE_FBGEMM
else ABORT("Unknown GEMM type - '{}'", gemmType);
}
GemmType getGemmType() override { return gemmType_; }
// for CPU, sets quantization range of weight matrices for the inference.
// for GPU, there's no quantization. so, it does nothing.
void setQuantizeRange(float range) override { quantizeRange_ = range; }
float getQuantizeRange() override { return quantizeRange_; }
};
} // namespace cpu

View File

@ -138,15 +138,18 @@ struct FbgemmPacked8PackNodeOp : public UnaryNodeOp {
int nrow_;
int ncol_;
uint64_t packsize_;
float quantizeRange_;
FbgemmPacked8PackNodeOp(Expr a,
PackMatrix packMat,
marian::Type packType,
bool transpose)
: UnaryNodeOp(a, newShape(a, transpose), Type::uint8),
bool transpose,
float quantizeRange)
: UnaryNodeOp(a, newShape(a, packType, transpose), Type::uint8),
packMat_(packMat),
packType_(packType),
transpose_(transpose) {
transpose_(transpose),
quantizeRange_(quantizeRange){
if(packMat != PackMatrix::B)
ABORT("Only prepacking of B (weight matrix) is supported");
if(!memoize_)
@ -161,7 +164,8 @@ struct FbgemmPacked8PackNodeOp : public UnaryNodeOp {
transpose_,
nrow_,
ncol_,
packsize_))
packsize_,
quantizeRange_))
};
#else // USE_FBGEMM
ABORT("FbgemmPacked8PackNodeOp can only be used with FBGEMM enabled.");
@ -177,13 +181,19 @@ struct FbgemmPacked8PackNodeOp : public UnaryNodeOp {
const std::string type() override { return "packMatInt8"; }
#if USE_FBGEMM
Shape newShape(Expr a, bool transpose) {
fbgemmPacked8PackInfo(a->shape(), packType_, transpose, nrow_, ncol_, packsize_);
Shape newShape(Expr a, marian::Type packType, bool transpose) {
fbgemmPacked8PackInfo(
a->shape(),
packType,
transpose,
nrow_,
ncol_,
packsize_);
Shape outShape({(int)packsize_});
return outShape;
}
#else
Shape newShape(Expr /*a*/, bool /*transpose*/) {
Shape newShape(Expr /*a*/, marian::Type /*packType*/, bool /*transpose*/) {
ABORT("Packed GEMM requires a build with USE_FBGEMM enabled");
return Shape();
}
@ -282,10 +292,17 @@ private:
size_t k_;
bool transA_;
bool transB_;
Type elementType_;
public:
FbgemmPacked8AffineNodeOp(const std::vector<Expr>& nodes, Shape bShape, bool transA, bool transB, float /*scalar*/)
: NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32)/*, scalar_(scalar) */ {
FbgemmPacked8AffineNodeOp(Type elementType,
const std::vector<Expr>& nodes,
Shape bShape,
bool transA,
bool transB,
float /*scalar*/)
: NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32),
elementType_(elementType) {
transA_ = transA;
transB_ = transB;
m_ = nodes[0]->shape().elements() / nodes[0]->shape()[-1];
@ -324,7 +341,8 @@ public:
#if USE_FBGEMM
// Do addBias only if it has a bias term
if (children().size() > 2) {
nodeOps = { NodeOp(fbgemmPacked8Gemm(val_,
nodeOps = { NodeOp(fbgemmPacked8Gemm(elementType_,
val_,
child(0)->val(),
child(1)->val(),
m_,
@ -334,7 +352,8 @@ public:
transB_);
marian::cpu::integer::AddBias(val_, child(2)->val())) };
} else {
nodeOps = { NodeOp(fbgemmPacked8Gemm(val_,
nodeOps = { NodeOp(fbgemmPacked8Gemm(elementType_,
val_,
child(0)->val(),
child(1)->val(),
m_,
@ -358,39 +377,46 @@ public:
const std::string type() override { return "gemmPacked8"; }
};
static inline Expr affine(Expr a, Expr b, Shape bShape, Expr c, bool transA, bool transB, float scalar) {
static inline Expr affine(Type elementType,
Expr a,
Expr b,
Shape bShape,
Expr c,
bool transA,
bool transB,
float scalar) {
std::vector<Expr> nodes = {a, b, c};
Type elementType = b->value_type();
if (elementType == Type::packed16)
return Expression<FbgemmPacked16AffineNodeOp>(nodes, bShape, transA, transB, scalar);
else if (isPacked(elementType) && sizeOf(elementType) == 1)
return Expression<FbgemmPacked8AffineNodeOp>(nodes, bShape, transA, transB, scalar);
return Expression<cpu::variant::FbgemmPacked8AffineNodeOp>(
elementType, nodes, bShape, transA, transB, scalar);
else {
ABORT("Only int8 and fp16 are available. {}", elementType);
return nullptr;
}
}
static inline Expr pack(Type elementType, Expr a, PackMatrix packMat, bool transpose) {
static inline Expr pack(Type elementType, Expr a, PackMatrix packMat, bool transpose, float quantizeRange = 0.f) {
if (elementType == Type::packed16)
return Expression<FbgemmPacked16PackNodeOp>(a, packMat, transpose);
else if (isPacked(elementType) && sizeOf(elementType) == 1)
return Expression<FbgemmPacked8PackNodeOp>(a, packMat, elementType, transpose);
return Expression<cpu::variant::FbgemmPacked8PackNodeOp>(a, packMat, elementType, transpose, quantizeRange);
else {
ABORT("Only int8 and fp16 are available. {}", elementType);
return nullptr;
}
}
static inline Expr dot(Expr a, Expr b, Shape bShape, bool transA, bool transB, float scalar) {
static inline Expr dot(Type elementType, Expr a, Expr b, Shape bShape, bool transA, bool transB, float scalar) {
std::vector<Expr> nodes = {a, b};
Type elementType = b->value_type();
if (elementType == Type::packed16)
return Expression<FbgemmPacked16AffineNodeOp>(nodes, bShape, transA, transB, scalar);
else if (isPacked(elementType) && sizeOf(elementType) == 1)
return Expression<FbgemmPacked8AffineNodeOp>(nodes, bShape, transA, transB, scalar);
return Expression<cpu::variant::FbgemmPacked8AffineNodeOp>(
elementType, nodes, bShape, transA, transB, scalar);
else {
ABORT("Only int8 and fp16 are available. {}", elementType);
return nullptr;

View File

@ -360,10 +360,10 @@ void fbgemmPacked8Pack(marian::Tensor out,
const float* data = inData;
float val = 0;
// Use half of the quantization range to prevent overflow of VPMADDUBSW
constexpr static int quantizedRange = 127;
constexpr static int quantizedMax = 63;
// Use half of the quantization range to prevent overflow of VPMADDUBSW
constexpr static int quantizedRange = 127;
constexpr static int quantizedMax = 63;
// This routine compute the quantization range for each column - either one of min/max range or quantRangeStdDevs sigma range.
for (size_t jj = 0; jj < n; jj++) { // for each column, collect stats (min/max or mean/std.dev.)
@ -371,32 +371,32 @@ void fbgemmPacked8Pack(marian::Tensor out,
double mean = 0, sqrSum = 0;
for (size_t ii = 0; ii < k; ii++) { // in a column, go throuhg all the rows and collect stats
val = getVal2dArr(data, ii, jj, k, n, transpose);
// If quantRangeStdDevs is 0.f, min/max values of the columns is used as a quantization range
if(quantRangeStdDevs == 0.f) {
if(min > val)
min = val;
if(max < val)
max = val;
} else {
// Quantize by std.dev. range
mean += val;
sqrSum += val * val;
}
}
// If a quantization range (in multiples of std. dev.) is given with a non-zero value,
// it calculate the range for this column (different quantization scale/offset are used for each column)
if(quantRangeStdDevs != 0.f) {
mean /= k;
sqrSum /= k;
sqrSum -= mean * mean;
sqrSum = sqrt(sqrSum);
min = (float)(mean - quantRangeStdDevs * sqrSum);
max = (float)(mean + quantRangeStdDevs * sqrSum);
}
// based on the quantization range, this computes the scale and offset for the quantization
quantScaleB[jj] = (max - min) / quantizedRange;
quantZeropointB[jj] = (int32_t)(quantizedMax - max / quantScaleB[jj]);
}
// If quantRangeStdDevs is 0.f, min/max values of the columns is used as a quantization range
if(quantRangeStdDevs == 0.f) {
if(min > val)
min = val;
if(max < val)
max = val;
} else {
// Quantize by std.dev. range
mean += val;
sqrSum += val * val;
}
}
// If a quantization range (in multiples of std. dev.) is given with a non-zero value,
// it calculate the range for this column (different quantization scale/offset are used for each column)
if(quantRangeStdDevs != 0.f) {
mean /= k;
sqrSum /= k;
sqrSum -= mean * mean;
sqrSum = sqrt(sqrSum);
min = (float)(mean - quantRangeStdDevs * sqrSum);
max = (float)(mean + quantRangeStdDevs * sqrSum);
}
// based on the quantization range, this computes the scale and offset for the quantization
quantScaleB[jj] = (max - min) / quantizedRange;
quantZeropointB[jj] = (int32_t)(quantizedMax - max / quantScaleB[jj]);
}
// 2. quantize
int8_t* quantized = 0;
@ -410,7 +410,7 @@ void fbgemmPacked8Pack(marian::Tensor out,
TensorQuantizationParams bQuantParam;
bQuantParam.scale = quantScaleB[jj];
bQuantParam.zero_point = quantZeropointB[jj];
bQuantParam.precision = 7; // Use half of the quantization range to prevent overflow of VPMADDUBSW
bQuantParam.precision = 7; // Use half of the quantization range to prevent overflow of VPMADDUBSW
if (transpose)
fbgemm::Quantize<int8_t>(data + jj * k, quantized + jj * k, k, bQuantParam);
@ -536,7 +536,8 @@ void fbgemmPacked16Gemm(marian::Tensor C,
// k: the number of columns in A and the number of rows in B
// transA: whether A matrix is transposed or not
// transB: whether B matrix is transposed or not
void fbgemmPacked8Gemm(marian::Tensor C,
void fbgemmPacked8Gemm(Type packType,
marian::Tensor C,
const marian::Tensor A,
const marian::Tensor B,
const size_t m,
@ -544,9 +545,6 @@ void fbgemmPacked8Gemm(marian::Tensor C,
const size_t k,
const int transA,
const int transB) {
// pack type
marian::Type packType = B->type();
const fbgemm::BlockingFactors* params = getBlockingFactors(packType);
// Check if the packed format matches with the available AVX instruction set in the machine

View File

@ -135,7 +135,8 @@ void fbgemmPacked16Gemm(marian::Tensor C,
// k: the number of columns in A and rows in B
// transA: transpose of A matrix
// transB: transpose of B matrix
void fbgemmPacked8Gemm(marian::Tensor C,
void fbgemmPacked8Gemm(Type packType,
marian::Tensor C,
const marian::Tensor A,
const marian::Tensor B,
const size_t m,

View File

@ -64,6 +64,36 @@ public:
return cusparseHandle_;
}
// for CPU, sets to use optimized code for inference.
// for GPU, this is invalid. for gpu, isOptimized() function always returns false.
void setOptimized(bool optimize) override {
LOG_ONCE(info, "setOptimized() not supported for GPU_{}", optimize);
}
bool isOptimized() override {
LOG_ONCE(info, "isOptimized() not supported for GPU");
return false;
};
// for CPU, selects different GEMM types for the inference.
// for GPU, there's no gemm type. so, it does nothing.
void setGemmType(std::string gemmType) override {
LOG_ONCE(info, "setGemmType() not supported for GPU_{}", gemmType);
}
GemmType getGemmType() override {
LOG_ONCE(info, "getGemmType() not supported for GPU");
return GemmType::Float32;
}
// for CPU, sets quantization range of weight matrices for the inference.
// for GPU, there's no quantization. so, it does nothing.
void setQuantizeRange(float range) override {
LOG_ONCE(info, "setQuantizeRange() not supported for GPU_{}", range);
}
float getQuantizeRange() override {
LOG_ONCE(info, "getQuantizeRange() not supported for GPU");
return 0.f;
}
CudaCompute getCudaComputeCapability() { return compute_; }
private:

View File

@ -123,20 +123,73 @@ public:
virtual void barrier(MPI_Comm comm = MPI_COMM_WORLD) const override {
HANDLE_MPI_ERROR(MPI_Barrier(comm));
}
virtual void bCast(void* buf, size_t count, MPI_Datatype datatype, size_t rootRank, MPI_Comm comm = MPI_COMM_WORLD) const override {
HANDLE_MPI_ERROR(MPI_Bcast(buf, (int)count, datatype, (int)rootRank, comm));
// MPI_Bcast only supports MAX_INT count, here and in the functions below, we need to cycle through the counts until we have sent
// all elemements if count is larger MAX_INT.
// get the data type size in bytes
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
// get the limit for int count
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
// while there are elements that we have not sent yet, loop until all has been sent in chunks of at most `limit`.
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Bcast((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)rootRank, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void sSend(void* buf, size_t count, MPI_Datatype datatype, size_t destRank, int tag, MPI_Comm comm) const override {
HANDLE_MPI_ERROR(MPI_Ssend(buf, (int)count, datatype, (int)destRank, tag, comm));
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Ssend((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)destRank, tag, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void recv(void* buf, size_t count, MPI_Datatype datatype, size_t sourceRank, int tag, MPI_Comm comm, MPI_Status* status) const override {
HANDLE_MPI_ERROR(MPI_Recv(buf, (int)count, datatype, (int)sourceRank, tag, comm, status));
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Recv((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)sourceRank, tag, comm, status));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
if (sendbuf == recvbuf)
sendbuf = MPI_IN_PLACE; // MSMPI requires this
HANDLE_MPI_ERROR(MPI_Allreduce(sendbuf, recvbuf, (int)count, datatype, op, comm));
int datatypeSize;
HANDLE_MPI_ERROR(MPI_Type_size(datatype, &datatypeSize));
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Allreduce((char*)sendbuf + offset * (size_t)datatypeSize, (char*)recvbuf + offset * (size_t)datatypeSize, intCount, datatype, op, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
}
}
virtual void finalize() override {
HANDLE_MPI_ERROR(MPI_Finalize());
}

View File

@ -91,6 +91,11 @@ public:
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(prec[0]));
graph->setDevice(device);
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
graph->getBackend()->setQuantizeRange(options_->get<float>("quantize-range"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
@ -284,6 +289,11 @@ public:
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
graph->getBackend()->setQuantizeRange(options_->get<float>("quantize-range"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);