mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 09:40:15 +03:00
Merged PR 11998: Lazy init for cuda handles (cusparse and cublas)
This does a lazy init of the two cuda handles that we are using on the GPU. When initialized every eagerly cusparse will consume about 250MB of CPU RAM and about 75MB of GPU RAM. Should only be used when actually needed.
This commit is contained in:
parent
3c7a88f4e9
commit
69d6f02711
@ -29,23 +29,40 @@ private:
|
||||
public:
|
||||
Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) {
|
||||
setDevice();
|
||||
cublasCreate(&cublasHandle_);
|
||||
cusparseCreate(&cusparseHandle_);
|
||||
setCudaComputeCapability();
|
||||
}
|
||||
|
||||
~Backend() {
|
||||
setDevice();
|
||||
cusparseDestroy(cusparseHandle_);
|
||||
cublasDestroy(cublasHandle_);
|
||||
if(cusparseHandle_) {
|
||||
cusparseDestroy(cusparseHandle_);
|
||||
cusparseHandle_ = 0;
|
||||
}
|
||||
if(cublasHandle_) {
|
||||
cublasDestroy(cublasHandle_);
|
||||
cublasHandle_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void setDevice() override { CUDA_CHECK(cudaSetDevice((int)deviceId_.no)); }
|
||||
|
||||
void synchronize() override { CUDA_CHECK(cudaStreamSynchronize(0)); }
|
||||
|
||||
cublasHandle_t getCublasHandle() { return cublasHandle_; }
|
||||
cusparseHandle_t getCusparseHandle() { return cusparseHandle_; }
|
||||
cublasHandle_t getCublasHandle() {
|
||||
if(!cublasHandle_) { // lazy initialization here to avoid memory usage when unused
|
||||
setDevice();
|
||||
cublasCreate(&cublasHandle_);
|
||||
}
|
||||
return cublasHandle_;
|
||||
}
|
||||
|
||||
cusparseHandle_t getCusparseHandle() {
|
||||
if(!cusparseHandle_) { // lazy initialization here to avoid memory usage when unused
|
||||
setDevice();
|
||||
cusparseCreate(&cusparseHandle_);
|
||||
}
|
||||
return cusparseHandle_;
|
||||
}
|
||||
|
||||
CudaCompute getCudaComputeCapability() { return compute_; }
|
||||
|
||||
@ -60,8 +77,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
cublasHandle_t cublasHandle_;
|
||||
cusparseHandle_t cusparseHandle_;
|
||||
cublasHandle_t cublasHandle_{0}; // make sure it's 0, so it can be initalized lazily
|
||||
cusparseHandle_t cusparseHandle_{0}; // as above
|
||||
CudaCompute compute_;
|
||||
};
|
||||
} // namespace gpu
|
||||
|
Loading…
Reference in New Issue
Block a user