diff --git a/src/tensors/gpu/backend.h b/src/tensors/gpu/backend.h index 044c344f..9bb97174 100644 --- a/src/tensors/gpu/backend.h +++ b/src/tensors/gpu/backend.h @@ -29,23 +29,40 @@ private: public: Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) { setDevice(); - cublasCreate(&cublasHandle_); - cusparseCreate(&cusparseHandle_); setCudaComputeCapability(); } ~Backend() { setDevice(); - cusparseDestroy(cusparseHandle_); - cublasDestroy(cublasHandle_); + if(cusparseHandle_) { + cusparseDestroy(cusparseHandle_); + cusparseHandle_ = 0; + } + if(cublasHandle_) { + cublasDestroy(cublasHandle_); + cublasHandle_ = 0; + } } void setDevice() override { CUDA_CHECK(cudaSetDevice((int)deviceId_.no)); } void synchronize() override { CUDA_CHECK(cudaStreamSynchronize(0)); } - cublasHandle_t getCublasHandle() { return cublasHandle_; } - cusparseHandle_t getCusparseHandle() { return cusparseHandle_; } + cublasHandle_t getCublasHandle() { + if(!cublasHandle_) { // lazy initialization here to avoid memory usage when unused + setDevice(); + cublasCreate(&cublasHandle_); + } + return cublasHandle_; + } + + cusparseHandle_t getCusparseHandle() { + if(!cusparseHandle_) { // lazy initialization here to avoid memory usage when unused + setDevice(); + cusparseCreate(&cusparseHandle_); + } + return cusparseHandle_; + } CudaCompute getCudaComputeCapability() { return compute_; } @@ -60,8 +77,8 @@ public: } private: - cublasHandle_t cublasHandle_; - cusparseHandle_t cusparseHandle_; + cublasHandle_t cublasHandle_{0}; // make sure it's 0, so it can be initalized lazily + cusparseHandle_t cusparseHandle_{0}; // as above CudaCompute compute_; }; } // namespace gpu