Merged PR 13049: Remove repeated memory allocation for fbgemm temp scratch space

Repeated memory allocation/deallocation in a multi-threaded environment caused heap contention problem which all the threads are waiting each other to synchronously allocate/deallocate memory. To resolve this, several unnecessary memory allocations are removed.

1. Remove repeated memory allocation and deallocation for the fbgemm's scratch space.
2. Remove unnecessary memory allocation in fbgemm submodule (https://github.com/marian-nmt/FBGEMM/pull/3/files)
3. Make "USE_ONNX" false as a default option in vsproj windows build
4. Improve the variable naming in int8 fbgemm code for a better readability
This commit is contained in:
Young Jin Kim 2020-05-26 20:22:38 +00:00
parent c8a62dd2c8
commit ae7cae2760
3 changed files with 85 additions and 79 deletions

@ -1 +1 @@
Subproject commit 84e66a976046180187724aff60a236c5378fde7c
Subproject commit 5454259a1890888daba43a55c90e455fc120466e

View File

@ -251,7 +251,7 @@ void fbgemmPacked8PackInfo(const marian::Shape& shape,
// This function computes the offset values for each column which are used for compensating the remainders of quantized values
// More detailed math is avilable in the FBGEMM's blog - https://engineering.fb.com/ml-applications/fbgemm/
inline void col_offsets_with_zero_pt_s8acc32(
inline void colOffsetsWithZeroPtS8acc32(
bool transpose,
int K,
int N,
@ -355,8 +355,8 @@ void fbgemmPacked8Pack(marian::Tensor out,
int len = k * n;
// 1. collect stats for each column
float* bqScale = new float[n];
int32_t* bqZeropoint = new int32_t[n];
float* quantScaleB = new float[n];
int32_t* quantZeropointB = new int32_t[n];
const float* data = inData;
float val = 0;
@ -367,8 +367,8 @@ void fbgemmPacked8Pack(marian::Tensor out,
// This routine compute the quantization range for each column - either one of min/max range or quantRangeStdDevs sigma range.
for (size_t jj = 0; jj < n; jj++) { // for each column, collect stats (min/max or mean/std.dev.)
float min = std::numeric_limits<float>::max(), max = std::numeric_limits<float>::min();
double mean = 0, sqrsum = 0;
float min = std::numeric_limits<float>::max(), max = std::numeric_limits<float>::lowest();
double mean = 0, sqrSum = 0;
for (size_t ii = 0; ii < k; ii++) { // in a column, go throuhg all the rows and collect stats
val = getVal2dArr(data, ii, jj, k, n, transpose);
// If quantRangeStdDevs is 0.f, min/max values of the columns is used as a quantization range
@ -380,22 +380,22 @@ void fbgemmPacked8Pack(marian::Tensor out,
} else {
// Quantize by std.dev. range
mean += val;
sqrsum += val * val;
sqrSum += val * val;
}
}
// If a quantization range (in multiples of std. dev.) is given with a non-zero value,
// it calculate the range for this column (different quantization scale/offset are used for each column)
if(quantRangeStdDevs != 0.f) {
mean /= k;
sqrsum /= k;
sqrsum -= mean * mean;
sqrsum = sqrt(sqrsum);
min = (float)(mean - quantRangeStdDevs * sqrsum);
max = (float)(mean + quantRangeStdDevs * sqrsum);
sqrSum /= k;
sqrSum -= mean * mean;
sqrSum = sqrt(sqrSum);
min = (float)(mean - quantRangeStdDevs * sqrSum);
max = (float)(mean + quantRangeStdDevs * sqrSum);
}
// based on the quantization range, this computes the scale and offset for the quantization
bqScale[jj] = (max - min) / quantizedRange;
bqZeropoint[jj] = (int32_t)(quantizedMax - max / bqScale[jj]);
quantScaleB[jj] = (max - min) / quantizedRange;
quantZeropointB[jj] = (int32_t)(quantizedMax - max / quantScaleB[jj]);
}
// 2. quantize
@ -408,8 +408,8 @@ void fbgemmPacked8Pack(marian::Tensor out,
#endif
for (int jj = 0; jj < n; jj++) {
TensorQuantizationParams bQuantParam;
bQuantParam.scale = bqScale[jj];
bQuantParam.zero_point = bqZeropoint[jj];
bQuantParam.scale = quantScaleB[jj];
bQuantParam.zero_point = quantZeropointB[jj];
bQuantParam.precision = 7; // Use half of the quantization range to prevent overflow of VPMADDUBSW
if (transpose)
@ -422,13 +422,13 @@ void fbgemmPacked8Pack(marian::Tensor out,
}
// 3. compute column offsets
int32_t* col_offsets = new int32_t[n];
col_offsets_with_zero_pt_s8acc32(transpose, k, n, quantized, bqZeropoint, col_offsets, 1);
int32_t* colOffsets = new int32_t[n];
colOffsetsWithZeroPtS8acc32(transpose, k, n, quantized, quantZeropointB, colOffsets, 1);
int8_t* packedbuf = out->data<int8_t>();
int8_t* packedBuf = out->data<int8_t>();
for(auto i = 0; i < packsize; i++) {
packedbuf[i] = 0;
packedBuf[i] = 0;
}
// 4. packing
@ -436,23 +436,23 @@ void fbgemmPacked8Pack(marian::Tensor out,
PackBMatrix<int8_t> packedBN(
transpose ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
nrow, ncol, quantized, transpose ? nrow : ncol, packedbuf, 1, params);
nrow, ncol, quantized, transpose ? nrow : ncol, packedBuf, 1, params);
// copy quantization scale
memcpy(packedbuf + (packsize - n * (sizeof(float) + sizeof(int32_t) + sizeof(int32_t))), bqScale, n * sizeof(float));
memcpy(packedBuf + (packsize - n * (sizeof(float) + sizeof(int32_t) + sizeof(int32_t))), quantScaleB, n * sizeof(float));
// copy quantization offset
memcpy(packedbuf + (packsize - n * (sizeof(int32_t) + sizeof(int32_t))), bqZeropoint, n * sizeof(int32_t));
memcpy(packedBuf + (packsize - n * (sizeof(int32_t) + sizeof(int32_t))), quantZeropointB, n * sizeof(int32_t));
// copy column offsets to the memory
memcpy(packedbuf + (packsize - n * sizeof(int32_t)), col_offsets, n * sizeof(int32_t));
memcpy(packedBuf + (packsize - n * sizeof(int32_t)), colOffsets, n * sizeof(int32_t));
#ifdef _MSC_VER
_aligned_free(quantized);
#else
free(quantized);
#endif
delete[] col_offsets;
delete[] bqScale;
delete[] bqZeropoint;
delete[] colOffsets;
delete[] quantScaleB;
delete[] quantZeropointB;
}
// GEMM operation on the packed B matrix
@ -549,73 +549,93 @@ void fbgemmPacked8Gemm(marian::Tensor C,
const fbgemm::BlockingFactors* params = getBlockingFactors(packType);
if((packType == Type::packed8avx2 && fbgemmHasAvx512Support())
|| (packType == Type::packed8avx512 && !fbgemmHasAvx512Support())) {
// Check if the packed format matches with the available AVX instruction set in the machine
const bool avx512Support = fbgemmHasAvx512Support();
if((packType == Type::packed8avx2 && avx512Support)
|| (packType == Type::packed8avx512 && !avx512Support)) {
ABORT("FBGEMM doesn't allow to use {} packing order on {} CPUs",
packType == Type::packed8avx2 ? "AVX2" : "AVX512",
fbgemmHasAvx512Support() ? "AVX512" : "AVX2");
avx512Support ? "AVX512" : "AVX2");
}
// compute range to quantize A (activations) - (min/max quantization)
float min_est = std::numeric_limits<float>::max(), max_est = std::numeric_limits<float>::min();
float minA = std::numeric_limits<float>::max(), maxA = std::numeric_limits<float>::lowest();
int elem = A->shape().elements();
float* data = A->data();
int elemA = A->shape().elements();
float* dataA = A->data();
// AVX based find min/max
FindMinMax(data, &min_est, &max_est, elem);
FindMinMax(dataA, &minA, &maxA, elemA);
float ascale = (max_est - min_est) / 255;
int32_t azeropoint = (int32_t)(255 - max_est / ascale);
float quantScaleA = (maxA - minA) / 255;
int32_t quantZeropointA = (int32_t)(255 - maxA / quantScaleA);
std::vector<int32_t> row_offset_buf(PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize());
PackAWithQuantRowOffset<uint8_t> packAN(
// To avoid any repeated memory allocation and deallocation, make the scratch buffer variables static thread_local
// In a multi-threaded situation, heap access lock for the memory allocation/free could
// makes all the threads are blocked by each other. (heap contention)
const size_t sizeBufA = params->KCB * params->MCB;
static thread_local std::vector<uint8_t> packedBufA;
if (packedBufA.size() < sizeBufA)
packedBufA.resize(sizeBufA);
const size_t sizeRowOffsetBufA = PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize();
static thread_local std::vector<int32_t> rowOffsetBufA;
if (rowOffsetBufA.size() < sizeRowOffsetBufA)
rowOffsetBufA.resize(sizeRowOffsetBufA);
PackAWithQuantRowOffset<uint8_t> packA(
transA ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
(int32_t)(transA ? k : m),
(int32_t)(transA ? m : k),
A->data(),
(int32_t)(transA ? m : k),
nullptr, /*buffer for packed matrix*/
ascale,
azeropoint,
// buffer for packed matrix, pass a pre-allocated memory to avoid additional allocation/deallocation inside fbgemm
packedBufA.data(),
quantScaleA,
quantZeropointA,
1, /*groups*/
row_offset_buf.data(),
rowOffsetBufA.data(),
params);
// packed matrix size of B
int bPackSize = PackMatrix<PackBMatrix<int8_t>, int8_t>::packedBufferSize((int32_t)k, (int32_t)n);
int packSizeB = PackMatrix<PackBMatrix<int8_t>, int8_t>::packedBufferSize((int32_t)k, (int32_t)n);
// retrieve B matrix
int8_t* bdata = B->data<int8_t>();
float* bqScale = new float[n];
memcpy(bqScale, bdata + bPackSize, n * sizeof(float));
int8_t* dataB = B->data<int8_t>();
int32_t* bqZeropoint = new int32_t[n];
memcpy(bqZeropoint, bdata + bPackSize + n * sizeof(float), n * sizeof(int32_t));
// To avoid any repeated memory allocation and deallocation, make the scratch buffer variables static thread_local
// In a multi-threaded situation, heap access lock for the memory allocation/free could
// makes all the threads are blocked by each other. (heap contention)
static thread_local std::vector<float> quantScaleB;
if (quantScaleB.size() < n)
quantScaleB.resize(n);
memcpy(quantScaleB.data(), dataB + packSizeB, n * sizeof(float));
int32_t* col_offsets = new int32_t[n];
memcpy(col_offsets, bdata + bPackSize + n * (sizeof(float) + sizeof(int32_t)), n * sizeof(int32_t));
static thread_local std::vector<int32_t> quantZeropointB;
if (quantZeropointB.size() < n)
quantZeropointB.resize(n);
memcpy(quantZeropointB.data(), dataB + packSizeB + n * sizeof(float), n * sizeof(int32_t));
static thread_local std::vector<int32_t> colOffsetsB;
if (colOffsetsB.size() < n)
colOffsetsB.resize(n);
memcpy(colOffsetsB.data(), dataB + packSizeB + n * (sizeof(float) + sizeof(int32_t)), n * sizeof(int32_t));
DoNothing<float, float> doNothingObj{};
ReQuantizeForFloat<false, QuantizationGranularity::OUT_CHANNEL> outputProcObj(
doNothingObj,
ascale,
bqScale,
azeropoint,
bqZeropoint,
packAN.getRowOffsetBuffer(),
col_offsets,
quantScaleA,
quantScaleB.data(),
quantZeropointA,
quantZeropointB.data(),
packA.getRowOffsetBuffer(),
colOffsetsB.data(),
nullptr,
(std::uint32_t) n);
PackBMatrix<int8_t> repackedBN(
transB ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, (int32_t) k, (int32_t) n, bdata, (int32_t) (transB ? k : n), 1, params);
PackBMatrix<int8_t> repackedB(
transB ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, (int32_t) k, (int32_t) n, dataB, (int32_t) (transB ? k : n), 1, params);
// gemm computation
fbgemmPacked(packAN, repackedBN, C->data(), (int32_t*)C->data(), (int32_t) n, outputProcObj, 0, 1, params);
delete[] col_offsets;
delete[] bqZeropoint;
delete[] bqScale;
fbgemmPacked(packA, repackedB, C->data(), (int32_t*)C->data(), (int32_t) n, outputProcObj, 0, 1, params);
}
#endif // USE_FBGEMM

View File

@ -70,7 +70,7 @@
</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>false</SDLCheck>
<TreatWarningAsError>true</TreatWarningAsError>
<AdditionalOptions>/bigobj /arch:AVX %(AdditionalOptions)</AdditionalOptions>
@ -107,7 +107,7 @@
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>false</SDLCheck>
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
<AdditionalOptions>/d2Zi+ /bigobj /arch:AVX %(AdditionalOptions)</AdditionalOptions>
@ -144,44 +144,30 @@
<ClCompile Include="..\src\3rd_party\faiss\Index.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\faiss\IndexLSH.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\faiss\utils\hamming.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\faiss\utils\Heap.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\faiss\utils\misc.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\faiss\utils\random.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\faiss\VectorTransform.cpp">
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
<WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(Configuration)|$(Platform)'=='Release|x64'">USE_ONNX=1;USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\bench\BenchUtils.cc">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>