This commit is contained in:
Frank Seide 2019-08-30 13:51:03 -07:00
commit 7be1855c7f
35 changed files with 1806 additions and 132 deletions

4
.gitmodules vendored
View File

@ -10,3 +10,7 @@
[submodule "src/3rd_party/nccl"]
path = src/3rd_party/nccl
url = https://github.com/marian-nmt/nccl
[submodule "src/3rd_party/fbgemm"]
path = src/3rd_party/fbgemm
url = https://github.com/marian-nmt/FBGEMM
branch = master

View File

@ -5,7 +5,6 @@ if (POLICY CMP0074)
cmake_policy(SET CMP0074 NEW) # CMake 3.12
endif ()
project(marian CXX C)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
@ -22,6 +21,7 @@ option(USE_MPI "Use MPI library" OFF)
option(COMPILE_EXAMPLES "Compile examples" OFF)
option(COMPILE_TESTS "Compile tests" OFF)
option(COMPILE_SERVER "Compile marian-server" ON)
option(USE_FBGEMM "Use FBGEMM" ON)
# Project versioning
find_package(Git QUIET)
@ -59,44 +59,54 @@ if(MSVC)
find_library(SHLWAPI Shlwapi.lib)
set(EXT_LIBS ${EXT_LIBS} SHLWAPI)
else()
else(MSVC)
# Detect support CPU instrinsics for the current platform. This will
# only by used with BUILD_ARCH=native. For overridden BUILD_ARCH we
# minimally use -msse4.1. This seems to work with MKL.
set(INTRINSICS "")
if(BUILD_ARCH STREQUAL "native")
message(STATUS "Checking support for CPU intrinsics")
include(FindSSE)
if(SSE2_FOUND)
message(STATUS "SSE2 support found")
set(INTRINSICS "${INTRINSICS} -msse2")
endif(SSE2_FOUND)
if(SSE3_FOUND)
message(STATUS "SSE3 support found")
set(INTRINSICS "${INTRINSICS} -msse3")
endif(SSE3_FOUND)
if(SSE4_1_FOUND)
message(STATUS "SSE4.1 support found")
set(INTRINSICS "${INTRINSICS} -msse4.1")
endif(SSE4_1_FOUND)
if(AVX_FOUND)
message(STATUS "AVX support found")
set(INTRINSICS "${INTRINSICS} -mavx")
endif(AVX_FOUND)
if(AVX2_FOUND)
message(STATUS "AVX2 support found")
set(INTRINSICS "${INTRINSICS} -mavx2")
endif(AVX2_FOUND)
else()
set(INTRINSICS "-msse4.1")
endif()
# Detect support CPU instrinsics for the current platform. This will
# only by used with BUILD_ARCH=native. For overridden BUILD_ARCH we
# minimally use -msse4.1. This seems to work with MKL.
set(INTRINSICS "")
if(BUILD_ARCH STREQUAL "native")
message(STATUS "Checking support for CPU intrinsics")
include(FindSSE)
if(SSE2_FOUND)
message(STATUS "SSE2 support found")
set(INTRINSICS "${INTRINSICS} -msse2")
endif(SSE2_FOUND)
if(SSE3_FOUND)
message(STATUS "SSE3 support found")
set(INTRINSICS "${INTRINSICS} -msse3")
endif(SSE3_FOUND)
if(SSE4_1_FOUND)
message(STATUS "SSE4.1 support found")
set(INTRINSICS "${INTRINSICS} -msse4.1")
endif(SSE4_1_FOUND)
if(AVX_FOUND)
message(STATUS "AVX support found")
set(INTRINSICS "${INTRINSICS} -mavx")
endif(AVX_FOUND)
if(AVX2_FOUND)
message(STATUS "AVX2 support found")
set(INTRINSICS "${INTRINSICS} -mavx2")
endif(AVX2_FOUND)
if(AVX512_FOUND)
message(STATUS "AVX512 support found")
set(INTRINSICS "${INTRINSICS} -mavx512f")
list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx512f)
endif(AVX512_FOUND)
else()
set(INTRINSICS "-msse4.1")
endif()
set(DISABLE_GLOBALLY "-Wno-unused-result")
if(USE_FBGEMM)
set(EXT_LIBS ${EXT_LIBS} fbgemm dl)
add_definitions(-DUSE_FBGEMM=1)
endif(USE_FBGEMM)
# These are used in src/CMakeLists.txt on a per-target basis
list(APPEND ALL_WARNINGS -Wall; -Werror; -Wno-unused-result; -Wno-deprecated; -Wno-pragmas; -Wno-unused-parameter; -Wextra; -Wno-unused-function;
-Wno-unused-value; -Wno-unknown-pragmas; -Wno-sign-compare; -Wno-missing-field-initializers;)
set(DISABLE_GLOBALLY "-Wno-unused-result")
# These are used in src/CMakeLists.txt on a per-target basis
list(APPEND ALL_WARNINGS -Wall; -Werror; -Wno-unused-result; -Wno-deprecated; -Wno-pragmas; -Wno-unused-parameter; -Wextra; -Wno-unused-function;
-Wno-unused-value; -Wno-unknown-pragmas; -Wno-sign-compare; -Wno-missing-field-initializers;)
# This warning does not exist prior to gcc 5.0
if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0)
@ -111,7 +121,7 @@ list(APPEND ALL_WARNINGS -Wall; -Werror; -Wno-unused-result; -Wno-deprecated; -W
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg -g -rdynamic")
set(CMAKE_CXX_FLAGS_PROFGEN "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
set(CMAKE_CXX_FLAGS_PROFUSE "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
endif()
endif(MSVC)
# Downloading SentencePiece if requested and set to compile with it.
# Requires all the dependencies imposed by SentencePiece

View File

@ -53,11 +53,11 @@ else()
set(COR_LIB "mkl_core")
endif()
if(MSVC)
set(ProgramFilesx86 "ProgramFiles(x86)")
set(INTEL_ROOT_DEFAULT $ENV{${ProgramFilesx86}}/IntelSWTools/compilers_and_libraries/windows)
else()
set(INTEL_ROOT_DEFAULT "/opt/intel")
if(MSVC)
set(ProgramFilesx86 "ProgramFiles(x86)")
set(INTEL_ROOT_DEFAULT $ENV{${ProgramFilesx86}}/IntelSWTools/compilers_and_libraries/windows)
else()
set(INTEL_ROOT_DEFAULT "/opt/intel")
endif()
set(INTEL_ROOT ${INTEL_ROOT_DEFAULT} CACHE PATH "Folder contains intel libs")
find_path(MKL_ROOT include/mkl.h PATHS $ENV{MKLROOT} ${INTEL_ROOT}/mkl
@ -89,7 +89,10 @@ find_library(MKL_CORE_LIBRARY
NO_DEFAULT_PATH)
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY})
# Added -Wl block to avoid circular dependencies.
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group)
# message("1 ${MKL_INCLUDE_DIR}")
# message("2 ${MKL_INTERFACE_LIBRARY}")

View File

@ -56,6 +56,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
ELSE (AVX2_TRUE)
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
ENDIF (AVX2_TRUE)
STRING(REGEX REPLACE "^.*(avx512).*$" "\\1" SSE_THERE ${CPUINFO})
STRING(COMPARE EQUAL "avx512" "${SSE_THERE}" AVX512_TRUE)
IF (AVX512_TRUE)
set(AVX512_FOUND true CACHE BOOL "AVX512 available on host")
ELSE (AVX512_TRUE)
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
ENDIF (AVX512_TRUE)
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE
@ -108,6 +116,14 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
ELSE (AVX2_TRUE)
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
ENDIF (AVX2_TRUE)
STRING(REGEX REPLACE "^.*(avx512).*$" "\\1" SSE_THERE ${CPUINFO})
STRING(COMPARE EQUAL "avx512" "${SSE_THERE}" AVX512_TRUE)
IF (AVX512_TRUE)
set(AVX512_FOUND true CACHE BOOL "AVX512 available on host")
ELSE (AVX512_TRUE)
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
ENDIF (AVX512_TRUE)
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
# TODO
@ -117,6 +133,7 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
set(AVX_FOUND false CACHE BOOL "AVX available on host")
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
@ -124,6 +141,7 @@ ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
set(AVX_FOUND false CACHE BOOL "AVX available on host")
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")
if(NOT SSE2_FOUND)
@ -144,5 +162,8 @@ endif(NOT AVX_FOUND)
if(NOT AVX2_FOUND)
MESSAGE(STATUS "Could not find hardware support for AVX2 on this machine.")
endif(NOT AVX2_FOUND)
if(NOT AVX512_FOUND)
MESSAGE(STATUS "Could not find hardware support for AVX512 on this machine.")
endif(NOT AVX512_FOUND)
mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND, AVX_FOUND, AVX2_FOUND)
mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND, AVX_FOUND, AVX2_FOUND, AVX512_FOUND)

View File

@ -6,6 +6,18 @@ add_subdirectory(./SQLiteCpp)
add_subdirectory(./pathie-cpp)
add_subdirectory(./zlib)
if(USE_FBGEMM)
# @TODO: find out if this is somehow harmful. This is supppressing CMake warnings for CMAKE_SUPPRESS_DEVELOPER_WARNINGS
# meant to silence CMakeFiles of 3rd_party tools.
if(NOT DEFINED CMAKE_SUPPRESS_DEVELOPER_WARNINGS)
set(CMAKE_SUPPRESS_DEVELOPER_WARNINGS 1 CACHE INTERNAL "No dev warnings")
endif()
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "Disable fbgemm benchmark")
add_subdirectory(./fbgemm)
endif(USE_FBGEMM)
if(USE_SENTENCEPIECE)
if(USE_STATIC_LIBS)
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})

1
src/3rd_party/fbgemm vendored Submodule

@ -0,0 +1 @@
Subproject commit 49e8018ab2397c175354317b35c6be6dd68f8932

View File

@ -4,6 +4,7 @@ include_directories(.)
include_directories(3rd_party)
include_directories(3rd_party/SQLiteCpp/include)
include_directories(3rd_party/sentencepiece)
include_directories(3rd_party/fbgemm/include)
add_library(marian STATIC
common/version.cpp
@ -40,6 +41,7 @@ add_library(marian STATIC
tensors/cpu/sharp/int_gemm.cpp
tensors/cpu/sharp/avx_gemm.cpp
tensors/cpu/sharp/sse_gemm.cpp
tensors/cpu/sharp/packed_gemm.cpp
graph/expression_graph.cpp
graph/expression_operators.cpp

View File

@ -31,8 +31,9 @@ int main(int argc, char** argv) {
marian::io::getYamlFromModel(config, "special:model.yml", modelFrom);
configStr << config;
auto graph = New<ExpressionGraph>(true, false);
auto graph = New<ExpressionGraph>(true);
graph->setDevice(CPU0);
graph->getBackend()->setOptimized(false);
graph->load(modelFrom);
graph->forward();

View File

@ -527,6 +527,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--skip-cost",
"Ignore model cost during translation, not recommended for beam-size > 1");
cli.add<std::string>("--gemm-type",
"Select GEMM options: auto, mklfp32, intrinint16, fp16packed, int8packed",
"auto");
cli.add<std::vector<std::string>>("--shortlist",
"Use softmax shortlist: path first best prune");

View File

@ -4,7 +4,7 @@ add_executable(mnist_example mnist/mnist_ffnn.cpp)
foreach(exec iris_example mnist_example)
target_link_libraries(${exec} marian ${EXT_LIBS})
if(CUDA_FOUND)
target_link_libraries(${exec} marian marian_cuda ${EXT_LIBS})
target_link_libraries(${exec} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
endif(CUDA_FOUND)
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(exec)

View File

@ -20,15 +20,26 @@ class AutoTuner : public AutoTunerRecorder {
private:
typedef std::function<Return(Args...)> Algorithm;
const size_t max = 100;
// When the autotuner decides the fastest algorithm for a specific tensor operation (e.g. GEMM),
// the autotuner runs each algorithm at least this 'collectStatMax' number of times and
// collects the statistics.
const size_t collectStatMax = 50;
UPtr<timer::CPUTimer> timer_;
// This structure holds a hash key an algorithm function (e.g. int16, packed gemm, mkl gemm)
// for a specific operation size
// hash: a unique hash key for each operation size
// (e.g. m, n, k, transpose A, transpose B, bias size for GEMM)
// algorithm: a function that holds an algorithm
struct HashedAlgorithm {
size_t hash;
Algorithm algorithm;
};
// This structure represents the collected statistics.
// time: total accumulated time of this operator execution with the given algorithm
// runs: total time this algorithm was executed
struct Stat {
double time;
size_t runs;
@ -53,7 +64,7 @@ private:
auto& stat = it->second;
// collect more stats
if(stat.runs < max)
if(stat.runs < collectStatMax)
return i;
if(stat.time < bestTime) {
@ -93,7 +104,7 @@ public:
auto it = stats_.find(hash);
if(it != stats_.end()) {
if(it->second.runs < max) {
if(it->second.runs < collectStatMax) {
it->second.time += seconds.count();
it->second.runs += 1;
}

View File

@ -5,8 +5,8 @@
namespace marian {
ExpressionGraph::ExpressionGraph(bool inference, bool optimized)
: inferenceOnly_(inference), optimized_(optimized), backend_(nullptr) {}
ExpressionGraph::ExpressionGraph(bool inference)
: inferenceOnly_(inference), backend_(nullptr) {}
void ExpressionGraph::setDevice(DeviceId deviceId, Ptr<Device> device) {
if(!backend_) {

View File

@ -1,7 +1,6 @@
#pragma once
#include "common/config.h"
#include "common/definitions.h"
#include "tensors/backend.h"
#include "tensors/tensor_allocator.h"
@ -130,7 +129,6 @@ private:
std::unordered_map<size_t, std::vector<Expr>> memoized_;
bool inferenceOnly_{false};
bool optimized_{false};
Ptr<Backend> backend_;
bool reloaded_{false};
@ -148,7 +146,7 @@ public:
*
* Constructor should be used as New<ExpressionGraph>()
*/
ExpressionGraph(bool inference = false, bool optimized = false);
ExpressionGraph(bool inference = false);
void setInference(bool inference) { inferenceOnly_ = inference; }
bool isInference() { return inferenceOnly_; }
@ -165,9 +163,6 @@ public:
Ptr<Backend> getBackend() { return backend_; }
void setOptimized(bool optimized) { optimized_ = optimized; }
bool isOptimized() { return (optimized_ && inferenceOnly_); }
void switchParams(const std::string& newNamespace) {
namespace_ = newNamespace;
}

View File

@ -7,6 +7,11 @@
#include "graph/auto_tuner.h"
#include "tensors/cpu/int16.h"
#include "tensors/cpu/expanded_gemm.h"
#if USE_FBGEMM
#include "fbgemm/Utils.h"
#endif
namespace marian {
@ -386,7 +391,8 @@ Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
// Currently only true when command line options
// --optimize --cpu-thread=N with N > 0 are set.
if(a->graph()->isOptimized() && device == DeviceType::cpu) {
if(device == DeviceType::cpu && a->graph()->getBackend()->isOptimized()
&& a->graph()->getBackend()->getGemmType() == GemmType::IntrinInt16) {
// dotInt16 computes A * B.T, hence the transpose for B to get A * B
// if transA = false and transB = false.
@ -409,9 +415,13 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
float clipValue = a->graph()->getBackend()->getClip();
if(a->graph()->isOptimized() && device == DeviceType::cpu) {
bool autotune = true;
if(autotune) {
if(device == DeviceType::cpu && a->graph()->getBackend()->isOptimized()) {
GemmType gemmType = a->graph()->getBackend()->getGemmType();
// When gemmType is set to 'auto', an autotuner decides the best algorithm available.
// A new autotuner is created, then different kinds of algorithms are added to the autotuner.
// For each GEMM size, there is a unique hash key.
// (e.g. m, n, k, transpose A, transpose B, bias size for GEMM)
if(gemmType == GemmType::Auto) {
thread_local Ptr<AutoTuner<Expr>> tuner = New<AutoTuner<Expr>>();
// start with new set of algorithms
@ -431,60 +441,154 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
util::hash_combine(hash, transA);
util::hash_combine(hash, transB);
// add first algorithm variant (Int16)
size_t hash1 = hash;
util::hash_combine(hash1, 1);
auto rec1 = [=](Expr e, bool stop = false) {
e->record(tuner, hash1, stop);
#if USE_FBGEMM
// Use Packed GEMM only if the node b in the graph is memoized.
// More specifically, packed GEMM is used only if the B matrix (weight) is constant.
// In general, 'memoized' means that the node is a constant variable or
// a combination of contant nodes which is also a constant variable
// when it's computed once.
// Those memoized nodes are cached to avoid duplicated computations.
// 07/10/2019 - Use packed GEMM only if the cpu architecture supports AVX2
// one of the fbgemm's sub modules, cpuinfo (https://github.com/pytorch/cpuinfo).
// It looks at the cpu register
// (https://github.com/pytorch/cpuinfo/blob/master/src/x86/isa.c#L391),
// and this cpu lookup is executed only once and the state is kept in FBGEMM.
if(fbgemm::fbgemmHasAvx2Support() && b->memoize()) {
// add packed GEMM algorithm variant (Packed GEMM) to the autotuner
// Once an algorithm is added to the autotuner,
// autotuner runs all the added algorithms for a designated times.
// One algorithm is run per one this operation call
// and the stat for that algorithm is collected.
// When all the algorithms reach the maximum stat collection count,
// the autotuner decide the best algorithm, and keep using it afterward.
size_t hashPack = hash;
util::hash_combine(hashPack, 1);
auto recPack = [=](Expr e, bool stop = false) {
e->record(tuner, hashPack, stop);
return e;
};
auto algPack = [=]() {
auto packed = cpu::variant::pack(b, cpu::variant::PackMatrix::B, transB, clipValue);
return recPack(
cpu::variant::affine(
clip(a, clipValue),
packed,
b->shape(),
bias,
transA,
transB,
scale),
true);
};
tuner->insert({hashPack, algPack});
}
#endif // USE_FBGEMM
// add second algorithm variant (Int16) to the autotuner
size_t hashInt16 = hash;
util::hash_combine(hashInt16, 2);
auto recInt16 = [=](Expr e, bool stop = false) {
e->record(tuner, hashInt16, stop);
return e;
};
auto alg1 = [=]() {
return rec1(
auto algInt16 = [=]() {
return recInt16(
cpu::int16::affine(
rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a,
clipValue)),
cpu::int16::quantize(transB ? b : transpose(b), clipValue),
recInt16(
cpu::int16::quantize(
transA ? recInt16(transpose(a)) : a,
clipValue)),
cpu::int16::quantize(
transB ? b : transpose(b),
clipValue),
bias,
scale),
true);
};
tuner->insert({hash1, alg1});
tuner->insert({hashInt16, algInt16});
// add second algorithm variant (CBlas)
size_t hash2 = hash;
util::hash_combine(hash2, 2);
auto rec2 = [=](Expr e, bool stop = false) {
e->record(tuner, hash2, stop);
// add third algorithm variant (CBlas) to the autotuner
size_t hashCblas = hash;
util::hash_combine(hashCblas, 3);
auto recCblas = [=](Expr e, bool stop = false) {
e->record(tuner, hashCblas, stop);
return e;
};
auto alg2 = [=]() {
auto algCblas = [=]() {
auto ac = clip(a, clipValue);
if(ac != a)
ac = rec2(ac);
ac = recCblas(ac);
auto bc = clip(b, clipValue);
if(bc != b)
bc = rec2(bc);
bc = recCblas(bc);
int rows = ac->shape().elements() / ac->shape()[-1];
Expr ones = ac->graph()->ones({rows, 1});
std::vector<Expr> nodes = {ac, bc, bias, ones};
return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale),
true);
return recCblas(Expression<AffineNodeOp>(nodes, transA, transB, scale),
true);
};
tuner->insert({hash2, alg2});
tuner->insert({hashCblas, algCblas});
// execute algorithm with autotuning
return tuner->run();
} else {
// cpu int16 version
return cpu::int16::affine(
cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale);
if(gemmType == GemmType::IntrinInt16) {
// cpu int16 version
return cpu::int16::affine(
cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale);
} else if(gemmType == GemmType::FbFp16Packed) {
#if USE_FBGEMM
// 07/10/2019 - Use packed GEMM only if the cpu architecture supports AVX2
// one of the fbgemm's sub modules, cpuinfo (https://github.com/pytorch/cpuinfo).
// It looks at the cpu register
// (https://github.com/pytorch/cpuinfo/blob/master/src/x86/isa.c#L391),
// and this cpu lookup is executed only once and the state is kept in FBGEMM.
if(fbgemm::fbgemmHasAvx2Support() && b->memoize()) {
auto packed = cpu::variant::pack(b, cpu::variant::PackMatrix::B, transB, clipValue);
return cpu::variant::affine(
clip(a, clipValue),
packed,
b->shape(),
bias,
transA,
transB,
scale);
} else {
int rows = a->shape().elements() / a->shape()[-1];
Expr ones = a->graph()->ones({rows, 1});
std::vector<Expr> nodes = {clip(a, clipValue), clip(b, clipValue), bias, ones};
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
}
#else
ABORT("Packed GEMM is not available in this build");
#endif // USE_FBGEMM
} else if(gemmType == GemmType::MklFp32) {
// general version, MKL, CBlas or CUDA
// if clipValue > 0, the inputs will be clipped to range [-clipValue,
// clipValue] This is meant to keep values at the same range as used during
// training when optimizing for 8-bit integer products. Likely to be removed
// in the future when we explore better ways to handle this.
int rows = a->shape().elements() / a->shape()[-1];
Expr ones = a->graph()->ones({rows, 1});
std::vector<Expr> nodes
= {clip(a, clipValue), clip(b, clipValue), bias, ones};
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
} else {
ABORT("GemmType..{} not available by affine()", gemmType);
}
}
} else {
// general version, MKL, CBlas or CUDA

View File

@ -1,7 +1,7 @@
#include "quicksand.h"
#include "marian.h"
#ifdef MKL_FOUND
#if MKL_FOUND
#include "mkl.h"
#endif
@ -63,13 +63,17 @@ public:
vocabs_.push_back(std::dynamic_pointer_cast<VocabWrapper>(vi)->getVocab());
// setting 16-bit optimization to false for now. Re-enable with better caching or pre-computation
graph_ = New<ExpressionGraph>(/*inference=*/true, /*optimize=*/false);
graph_ = New<ExpressionGraph>(/*inference=*/true);
DeviceId deviceId{0, DeviceType::cpu};
device_ = New<cpu::WrappedDevice>(deviceId);
graph_->setDevice(deviceId, device_);
#ifdef MKL_FOUND
// Use packed GEMM for the production
graph_->getBackend()->setOptimized(true);
graph_->getBackend()->setGemmType("fp16packed");
#if MKL_FOUND
mkl_set_num_threads(options->get<int>("mkl-threads", 1));
#endif

View File

@ -8,9 +8,9 @@
namespace marian {
class CharS2SEncoder : public EncoderS2S {
using EncoderS2S::EncoderS2S;
public:
CharS2SEncoder(Ptr<Options> options) : EncoderS2S(options) {}
virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) override {
graph_ = graph;

View File

@ -32,7 +32,7 @@ Ptr<EncoderBase> EncoderFactory::construct(Ptr<ExpressionGraph> graph) {
#ifdef CUDNN
if(options_->get<std::string>("type") == "char-s2s")
return New<CharS2SEncoder>(options_);
return New<CharS2SEncoder>(graph, options_);
#endif
if(options_->get<std::string>("type") == "transformer")

View File

@ -1,4 +1,4 @@
#include "models/transformer.h"
#include "models/transformer.h"
namespace marian {
// factory functions

View File

@ -67,8 +67,12 @@ public:
auto devices = Config::getDevices(options_);
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
auto graph = New<ExpressionGraph>(true);
graph->setDevice(device);
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
}

View File

@ -5,6 +5,14 @@
namespace marian {
// GEMM type enum
typedef enum { Auto = 0, // auto tuning between available GEMMs
MklFp32 = 1, // MKL based GEMM, fp32
IntrinInt16 = 2, // Intrinsic implementation of Int 16 GEMM
FbFp16Packed = 10, // FBGEMM based fp16 GEMM with packing
FbInt8Packed = 11 // FBGEMM based int8 GEMM with packing
} GemmType;
class Backend {
protected:
DeviceId deviceId_;
@ -16,9 +24,7 @@ protected:
public:
Backend(DeviceId deviceId, size_t seed)
: deviceId_(deviceId),
seed_(seed),
randomGenerator_(createRandomGenerator(seed, deviceId)) {}
: deviceId_(deviceId), seed_(seed), randomGenerator_(createRandomGenerator(seed, deviceId)) {}
virtual DeviceId getDeviceId() { return deviceId_; };
virtual Ptr<RandomGenerator> getRandomGenerator() { return randomGenerator_; }
@ -29,6 +35,15 @@ public:
virtual void setClip(float clipValue) { clipValue_ = clipValue; }
float getClip() { return clipValue_; }
// for CPU, sets to use optimized code for inference.
// for GPU, this is invalid. for gpu, isOptimized() function always returns false.
virtual void setOptimized(bool optimize) = 0;
virtual bool isOptimized() = 0;
// for CPU, selects different GEMM types for the inference.
// for GPU, there's no gemm type. so, it does nothing.
virtual void setGemmType(std::string gemmType) = 0;
virtual GemmType getGemmType() = 0;
};
Ptr<Backend> BackendByDeviceId(DeviceId deviceId, size_t seed);

View File

@ -10,10 +10,30 @@ namespace marian {
namespace cpu {
class Backend : public marian::Backend {
protected:
bool optimized_{false};
GemmType gemmType_{GemmType::Auto};
public:
Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) {}
void setDevice() override {}
void synchronize() override {}
// for CPU & inference only, sets to use optimized code for inference. Does nothing for GPU.
void setOptimized(bool optimize) override { optimized_ = optimize; }
bool isOptimized() override { return optimized_; }
// for CPU only, selects different GEMM types for the inference. Does nothing for GPU.
void setGemmType(std::string gemmType) override {
if (gemmType == "auto") gemmType_ = GemmType::Auto;
else if (gemmType == "mklfp32") gemmType_ = GemmType::MklFp32;
else if (gemmType == "intrinint16") gemmType_ = GemmType::IntrinInt16;
#if USE_FBGEMM
else if (gemmType == "fp16packed") gemmType_ = GemmType::FbFp16Packed;
else if (gemmType == "int8packed") gemmType_ = GemmType::FbInt8Packed;
#endif // USE_FBGEMM
else ABORT("Unknown GEMM type - '{}'", gemmType);
}
GemmType getGemmType() override { return gemmType_; }
};
} // namespace cpu
} // namespace marian

View File

@ -9,12 +9,6 @@
namespace marian {
namespace cpu {
Device::~Device() {
free(data_);
data_ = nullptr;
size_ = 0;
}
// allocate function for tensor reserve() below.
// Needed for AVX512, while not available on all compilers. It seems clang
// does not have aligned_alloc for all cstlib versions. If AVX512 is not used
@ -35,6 +29,10 @@ Device::~Device() {
#define FREE(ptr) free(ptr)
#endif
Device::~Device() {
FREE(data_);
}
void Device::reserve(size_t size) {
size = align(size);
ABORT_IF(size < size_ || size == 0,

View File

@ -0,0 +1,195 @@
#pragma once
#include "graph/node.h"
#include "tensors/cpu/sharp/packed_gemm.h"
#if USE_FBGEMM
#include "3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h"
using namespace fbgemm;
#endif // USE_FBGEMM
namespace marian {
namespace cpu {
namespace variant {
// Enumeration for the Matrix used in pack functions
// A matrix - 0, B matrix - 1
enum class PackMatrix : uint8_t {
A = 0x00,
B = 0x01
};
// Pack a matrix into cache utilization efficient way (block format)
// PackMatrix packMat_: the type of packed matrix - A or B matrix
// bool transpose_: transpose
// int nrow_: the number of rows
// int ncol_: the number of columns
// int kernel_ncol_blocks_: the number of column blocks
// int brow_: the number of rows in a block
// int bcol_: the number of columns in a block
// int last_brow_: the number of rows in the last block
// int nbrow_: row index in a block
// int nbcol_: column index in a block
// uint64_t packsize_: the size of the packed matrix
// (the number of fp16 elements + padding (1024) + extra temporary memory (256))
struct PackNodeOp : public UnaryNodeOp {
PackMatrix packMat_;
bool transpose_;
int nrow_;
int ncol_;
int kernel_ncol_blocks_;
int brow_;
int bcol_;
int last_brow_;
int nbrow_;
int nbcol_;
uint64_t packsize_;
PackNodeOp(Expr a, PackMatrix packMat, bool transpose, float clipValue)
: UnaryNodeOp(a, newShape(a, transpose), Type::uint8),
packMat_(packMat),
transpose_(transpose) {
if(packMat != PackMatrix::B)
ABORT("Only prepacking of B (weight matrix) is supported");
if(clipValue != 0)
ABORT("Clipping is not supported");
if(!memoize_)
ABORT("Only constant weight node can be packed");
}
NodeOps forwardOps() override {
return {NodeOp(PackFp32(val_,
child(0)->val(),
transpose_,
nrow_,
ncol_,
kernel_ncol_blocks_,
brow_,
bcol_,
last_brow_,
nbrow_,
nbcol_,
packsize_))
};
}
NodeOps backwardOps() override {
ABORT("PackNodeOp only available for inference");
return {NodeOp(0)};
}
const std::string type() override { return "packMat"; }
Shape newShape(Expr a, bool transpose) {
#if USE_FBGEMM
auto shapeMat = a->shape();
// Should be 2D - weight matrix
ABORT_IF(shapeMat.size() != 2,
"Weight Matrix should be 2D");
nrow_ = transpose ? shapeMat[1] : shapeMat[0];
ncol_ = transpose ? shapeMat[0] : shapeMat[1];
kernel_ncol_blocks_ = 2;
brow_ = 512;
bcol_ = 8 * kernel_ncol_blocks_;
last_brow_ = nrow_ % brow_ == 0 ? brow_ : nrow_ % brow_;
nbrow_ = nrow_ % brow_ == 0 ? nrow_ / brow_ : (nrow_ + brow_) / brow_;
nbcol_ = ncol_ % bcol_ == 0 ? ncol_ / bcol_ : (ncol_ + bcol_) / bcol_;
const int padding = 1024; // required by sw pipelined kernels
const int specialMem = 256;
packsize_ = ((nbrow_ * brow_) * (nbcol_ * bcol_)) * sizeof(fbgemm::float16) + padding + specialMem;
Shape outShape({(int)packsize_});
return outShape;
#else // USE_FBGEMM
ABORT("Packed GEMM requires a build with USE_FBGEMM enabled");
return Shape();
#endif // USE_FBGEMM
}
};
// Affine transform (matrix multiplication) using packed B matrix
// float scalar_: scalar multiplier
// size_t m_: the number of rows in A and C
// size_t n_: the number of columns in B and C
// size_t k_: the number of columns in A and the number of rows in C
// bool transA_: transpose A
// bool transB_: transpose B
class AffineNodeOp : public NaryNodeOp {
private:
float scalar_;
size_t m_;
size_t n_;
size_t k_;
bool transA_;
bool transB_;
public:
AffineNodeOp(const std::vector<Expr>& nodes, Shape bShape, bool transA, bool transB, float scalar)
: NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32),
scalar_(scalar) {
transA_ = transA;
transB_ = transB;
m_ = nodes[0]->shape().elements() / nodes[0]->shape()[-1];
k_ = nodes[0]->shape().back();
if(transA)
std::swap(m_, k_);
size_t l = bShape.elements() / bShape[-1];
n_ = bShape[-1];
if(transB)
std::swap(l, n_);
}
Shape newShape(Expr a, Shape bShape, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = bShape;
if(transB) {
shapeB.set(shapeB.size() - 2, bShape[shapeB.size() - 1]);
shapeB.set(shapeB.size() - 1, bShape[shapeB.size() - 2]);
}
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match");
return outShape;
}
NodeOps forwardOps() override {
return {
NodeOp(GemmPackFp32(val_,
child(0)->val(),
child(1)->val(),
child(2)->val(),
m_,
n_,
transA_))
};
}
NodeOps backwardOps() override {
ABORT("Only used for inference");
return {NodeOp(0)};
}
const std::string type() override { return "fp16packed"; }
};
static inline Expr affine(Expr a, Expr b, Shape bShape, Expr c, bool transA, bool transB, float scalar) {
std::vector<Expr> nodes = {a, b, c};
return Expression<cpu::variant::AffineNodeOp>(nodes, bShape, transA, transB, scalar);
}
static inline Expr pack(Expr a, PackMatrix packMat, bool transpose, float clipValue) {
return Expression<cpu::variant::PackNodeOp>(a, packMat, transpose, clipValue);
}
} // namespace variant
} // namespace cpu
} // namespace marian

View File

@ -89,17 +89,13 @@ void AddBias(marian::Tensor C, const marian::Tensor Bias) {
const float* x = C->data();
const float* bias = Bias->data();
int m = C->shape().elements() / C->shape()[-1];
int n = C->shape()[-1];
#ifdef __AVX512F__
int n16 = n & ~15;
#else
int n4 = (n / 4) * 4;
#endif
const int m = C->shape().elements() / C->shape()[-1];
const int n = C->shape()[-1];
for(int j = 0; j < m; ++j) {
int i = 0;
#ifdef __AVX512F__
int n16 = n & ~15;
for(; i < n16; i += 16) {
__m512 ai = _mm512_loadu_ps(x + j * n + i);
__m512 bi = _mm512_loadu_ps(bias + i);
@ -107,6 +103,7 @@ void AddBias(marian::Tensor C, const marian::Tensor Bias) {
_mm512_storeu_ps(y + j * n + i, yi);
}
#else
int n4 = (n / 4) * 4;
for(; i < n4; i += 4) {
__m128 ai = _mm_loadu_ps(x + j * n + i);
__m128 bi = _mm_loadu_ps(bias + i);

View File

@ -0,0 +1,239 @@
#include "packed_gemm.h"
#include "tensors/tensor_allocator.h"
#include "tensors/tensor_operators.h"
#include <emmintrin.h>
#include <immintrin.h>
#include <tmmintrin.h>
#include <xmmintrin.h>
#include <cassert>
#include <cstddef>
#include <unordered_map>
//#include <chrono>
#ifdef _MSC_VER
#pragma warning(disable: 4505) // warning C4505: 'fbgemmAlignedAlloc' in fbgemm.h: unreferenced local function has been removed (missing 'static inline')
#endif
#if USE_FBGEMM
#include "3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h"
#include "3rd_party/fbgemm/include/fbgemm/QuantUtils.h"
#include "3rd_party/fbgemm/include/fbgemm/Fbgemm.h"
#ifdef _OPENMP
#include <omp.h>
#endif
#if MKL_FOUND
#include <mkl.h>
#include <mkl_types.h>
#endif
using namespace fbgemm;
#endif // USE_FBGEMM
namespace marian {
namespace cpu {
namespace variant { // Variants of GEMM implementations
#if USE_FBGEMM
// initialize with a dummy
// When this class is instantiated,
// the actual packing operation is happening. If we create this instance every time we call GEMM,
// we are doing packing every time and very slow.
// In Caffe2, the operator is stateful and hold an instance of this.
// But, we don't have any logic for this in marian. We can only cache a tensor (which means a memory chunk).
// So, for now, we keep the packed memory on our own 1D tensor, then when we call GEMM,
// we just reuse this instance again and again by replacing the class members (including memory pointer). Eventually,
// I will add a new constructor to the class in FBGEMM which accepts
// pre - allocated and pre - packed memory as a parameter.After it's done,
// this temporary buffer will be removed.
// When constructing this dummy buffer, ones are used for all the parameters to allocate minimum amount of memory.
//
// In a multi marian instance setting (as a dynamic library),
// different marian instances should not share this variable.
static thread_local PackedGemmMatrixFP16 packedPlaceholder(1, 1, 1, 1, 1, 1, 1, 1);
// This is copied from FBGEMM code
// A better way?
// will be removed, when FBGEMM api is changed
// blocked row-major format address arithmetic
/**
* Returns the memory address in the packed (block formatted) matrix array of a specific element
* indexed by the original non-packed array.
*
* @param r_ row index in the original matrix
* @param c_ column index in the original matrix
* @param brow_ row wide block index
* @param bcol_ column wide block index
* @param nbrow_ number of blocks in row
* @param nbcol_ number of blocks in column
* @param last_brow_ row number of the last block
*/
inline uint64_t addr(const int r_,
const int c_,
const int brow_,
const int bcol_,
const int nbrow_,
const int nbcol_,
const int last_brow_) {
uint64_t r = (uint64_t)r_;
uint64_t c = (uint64_t)c_;
uint64_t block_row_id = r / brow_;
uint64_t brow_offset = (block_row_id * nbcol_) * (brow_ * bcol_);
uint64_t block_col_id = c / bcol_;
uint64_t bcol_offset
= block_col_id * ((block_row_id != nbrow_ - 1) ? (brow_ * bcol_) : (last_brow_ * bcol_));
uint64_t block_offset = brow_offset + bcol_offset;
uint64_t inblock_offset = r % brow_ * bcol_ + c % bcol_;
uint64_t index = block_offset + inblock_offset;
return index;
}
void PackFp32(marian::Tensor out,
const marian::Tensor in,
const bool transpose,
const int nrow,
const int ncol,
const int kernel_ncol_blocks,
const int brow,
const int bcol,
const int last_brow,
const int nbrow,
const int nbcol,
const uint64_t packsize) {
// initialize memory
uint8_t* outmemorg = out->data<uint8_t>();
for(auto i = 0; i < packsize; i++) {
outmemorg[i] = 0;
}
// save the other auxiliary variables
uint64_t* auxmemsize = (uint64_t*)outmemorg;
auxmemsize[0] = packsize;
// save FBGEMM related parameters into the header of the allocated memory by marian
int32_t header[8];
header[0] = nrow;
header[1] = ncol;
header[2] = kernel_ncol_blocks;
header[3] = brow;
header[4] = bcol;
header[5] = last_brow;
header[6] = nbrow;
header[7] = nbcol;
memcpy(auxmemsize + 1, header, sizeof(header));
// cast to float16
fbgemm::float16* outmem = (fbgemm::float16*)(outmemorg + 256);
fbgemm::float16* dummy = new fbgemm::float16;
// pack the matrix
float* inmem = in->data();
for(int i = 0; i < nrow; i++) {
for(int j = 0; j < ncol; j++) {
outmem[addr(i, j, brow, bcol, nbrow, nbcol, last_brow)]
= tconv(!transpose ? inmem[i * ncol + j] : inmem[i + nrow * j], *dummy);
}
}
delete dummy;
}
// GEMM operation on the packed B matrix
// C: output matrix
// A: A matrix
// B: B matrix (packed)
// m: the number of rows in A and C
// n: the number of columns in B and C
// transA: transpose of A matrix
// B is already packed. So, we don't need transB
void GemmPackFp32(marian::Tensor C,
const marian::Tensor A,
const marian::Tensor B,
const marian::Tensor bias,
const size_t m,
const size_t n,
const int transA) {
// row major
// keep the original mem
fbgemm::float16* pmat = packedPlaceholder.pmat_;
// retreive aux fields from the memory
uint64_t* packedmemSize = (uint64_t*)B->data();
packedPlaceholder.size_ = packedmemSize[0];
int32_t header[8];
memcpy(header, packedmemSize + 1, sizeof(header));
packedPlaceholder.nrow_ = header[0];
packedPlaceholder.ncol_ = header[1];
packedPlaceholder.kernel_ncol_blocks_ = header[2];
packedPlaceholder.brow_ = header[3];
packedPlaceholder.bcol_ = header[4];
packedPlaceholder.last_brow_ = header[5];
packedPlaceholder.nbrow_ = header[6];
packedPlaceholder.nbcol_ = header[7];
// packed matrix
packedPlaceholder.pmat_ = (fbgemm::float16*)(B->data<uint8_t>() + 256);
#if MKL_FOUND
for(int i = 0; i < m; ++i) {
mkl_somatcopy('R', 'N', 1, n, 1, bias->data(), n, C->data() + n * i, n);
}
#else
for(int i = 0; i < m; ++i) {
std::copy(bias->data(), bias->data() + n, C->data() + n * i);
}
#endif
#ifdef _OPENMP
#pragma omp parallel
#endif
{
#ifdef _OPENMP
int num_threads = omp_get_num_threads();
int tid = omp_get_thread_num();
#else
int num_threads = 1;
int tid = 0;
#endif
fbgemm::cblas_gemm_compute(transA ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
(int)m,
A->data(),
packedPlaceholder,
1,
C->data(),
tid,
num_threads);
}
// return back the original mem
packedPlaceholder.pmat_ = pmat;
}
#else // USE_FBGEMM
void PackFp32(marian::Tensor out,
const marian::Tensor in,
const bool transpose,
const int nrow,
const int ncol,
const int kernel_ncol_blocks,
const int brow,
const int bcol,
const int last_brow,
const int nbrow,
const int nbcol,
const uint64_t packsize) {
// does nothing. supports only FBGEMM based packed gemm at this moment.
ABORT("FBGEMM is needed to use packed GEMM.");
}
void GemmPackFp32(marian::Tensor C,
const marian::Tensor A,
const marian::Tensor B,
const marian::Tensor bias,
const size_t m,
const size_t n,
const int transA) {
// does nothing. supports only FBGEMM based packed gemm at this moment.
ABORT("FBGEMM is needed to use packed GEMM.");
}
#endif // USE_FBGEMM
} // namespace variant
} // namespace cpu
} // namespace marian

View File

@ -0,0 +1,54 @@
#pragma once
#include "tensors/tensor.h"
namespace marian {
namespace cpu {
namespace variant { // Variants of GEMM implementations
// Pack a matrix into cache utilization efficient way (block format)
// out: output tensor - packed format
// in: input tensor - normal format
// transpose: the matrix is transposed
// nrow: the number of rows
// ncol: the number of columns
// kernel_ncol_blocks: the number of column blocks
// brow: the number of rows in a block
// bcol: the number of columns in a block
// last_brow: the number of rows in the last block
// nbrow: row index in a block
// nbcol: column index in a block
// packsize: the size of the packed matrix
// (the number of fp16 elements + padding (1024) + extra temporary memory (256))
void PackFp32(marian::Tensor out,
const marian::Tensor in,
const bool transpose,
const int nrow,
const int ncol,
const int kernel_ncol_blocks,
const int brow,
const int bcol,
const int last_brow,
const int nbrow,
const int nbcol,
const uint64_t packsize);
// GEMM operation on the packed B matrix
// C: output matrix
// A: A matrix
// B: B matrix (packed)
// m: the number of rows in A and C
// n: the number of columns in B and C
// transA: transpose of A matrix
// B is already packed. So, we don't need transB
void GemmPackFp32(marian::Tensor C,
const marian::Tensor A,
const marian::Tensor B,
const marian::Tensor bias,
const size_t m,
const size_t n,
const int transA = 0);
} // namespace variant
} // namespace cpu
} // namespace marian

View File

@ -10,6 +10,10 @@
#include "functional/functional.h"
#include "functional/tensor.h"
#if MKL_FOUND
#include <mkl.h>
#endif
namespace marian {
namespace cpu {
@ -186,6 +190,73 @@ void Transpose0213(Tensor out, Tensor in) {
}
}
// Given a 4D array, transpose (swap) the initial 3 dimensions while keeping the last dimension.
// e.g. 1234 --> 2134, 1234 --> 3214 (4 is always kept).
// This is an optimized version for swapping first 3 dimensions
// assuming the last dimension is large enough to get benefits from vectorized copy.
//
// @param out output tensor
// @param in input tensor
// @param vAxis target (transposed) axes of each given axes
template <bool add>
void TransposeFirst3In4(Tensor out, Tensor in, const std::vector<int>& vAxis) {
ABORT_IF(vAxis.size() != 4, "This function handles only 4D arrays.");
#if MKL_FOUND
int innermost = in->shape()[-1];
int l1 = in->shape()[vAxis[0]];
int l2 = in->shape()[vAxis[1]];
int l3 = in->shape()[vAxis[2]];
// find the mapping between the transposed output dimensional indices (oi, oj, ok)
// and original input dimensional indices (i, j, k)
int oi, oj, ok;
#pragma omp parallel for
for(int k = 0; k < l1; ++k) {
int shift = k * l2 * l3;
for(int j = 0; j < l2; ++j) {
for(int i = 0; i < l3; ++i) {
if(vAxis[0] == 0) {
if(vAxis[1] == 1) {
oi = i; oj = j; ok = k;
} else {
oi = j; oj = i; ok = k;
}
} else if(vAxis[0] == 1) {
if(vAxis[1] == 0) {
oi = i; oj = k; ok = j;
} else {
oi = j; oj = k; ok = i;
}
} else {
if(vAxis[1] == 0) {
oi = k; oj = i; ok = j;
} else {
oi = k; oj = j; ok = i;
}
}
int src = ok * in->shape()[1] * in->shape()[2] + oj * in->shape()[2] + oi;
int dst = l3 * j + shift + i;
const float* inRow = in->data() + src * innermost;
float* outRow = out->data() + dst * innermost;
if(!add) {
mkl_somatcopy('R', 'N', 1, innermost, 1.0f, inRow, innermost, outRow, innermost);
} else {
for(int ii = 0; ii < innermost; ++ii) {
outRow[ii] += inRow[ii];
}
}
}
}
}
#else
// it shouldn't come into here. This function is called only when MKL is available.
ABORT("Should not get here");
#endif // MKL_FOUND
}
inline void transpose4x4_SSE(const float* A,
float* B,
const int lda,
@ -262,6 +333,10 @@ void TransposeGeneric(Tensor out, Tensor in, const std::vector<int>& vAxis) {
void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
if(vAxis == std::vector<int>({0, 2, 1, 3}))
Transpose0213<false>(out, in);
#if MKL_FOUND
else if(vAxis.size() == 4 && vAxis[3] == 3)
TransposeFirst3In4<false>(out, in, vAxis);
#endif // MKL_FOUND
else if(vAxis == std::vector<int>({1, 0}) && in->shape()[-1] % 16 == 0
&& in->shape()[-2] % 16 == 0)
Transpose10(out, in);

View File

@ -21,7 +21,7 @@ protected:
public:
Device(DeviceId deviceId, size_t alignment = 256)
: deviceId_(deviceId), data_(0), size_(0), alignment_(alignment) {}
: deviceId_(deviceId), alignment_(alignment) {}
virtual ~Device(){};

View File

@ -1,13 +1,14 @@
#pragma once
#include "common/config.h"
#include "tensors/backend.h" // note: this is one folder up
#include "tensors/backend.h" // note: this is one folder up
#include "tensors/gpu/cuda_helpers.h"
#include "common/logging.h"
#include <cuda.h>
#include <cublas_v2.h>
#include <cusparse.h>
#include <cuda.h>
#include <curand.h>
#include <cusparse.h>
namespace marian {
namespace gpu {
@ -33,6 +34,24 @@ public:
cublasHandle_t getCublasHandle() { return cublasHandle_; }
cusparseHandle_t getCusparseHandle() { return cusparseHandle_; }
// for CPU, sets to use optimized code for inference.
// for GPU, this is invalid. for gpu, isOptimized() function always returns false.
void setOptimized(bool optimize) override {
LOG_ONCE(info, "setOptimized() not supported for GPU_{}", optimize);
}
bool isOptimized() override {
return false;
}
// for CPU, selects different GEMM types for the inference.
// for GPU, there's no gemm type. so, it does nothing.
void setGemmType(std::string gemmType) override {
LOG_ONCE(info, "setGemmType() not supported for GPU_{}", gemmType);
}
GemmType getGemmType() override {
LOG_ONCE(info, "getGemmType() not supported for GPU");
return GemmType::Auto;
}
private:
cublasHandle_t cublasHandle_;
cusparseHandle_t cusparseHandle_;

View File

@ -16,7 +16,7 @@ foreach(test ${APP_TESTS})
add_executable("test_${test}" "${test}.cpp")
if(CUDA_FOUND)
target_link_libraries("test_${test}" marian marian_cuda ${EXT_LIBS})
target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
else(CUDA_FOUND)
target_link_libraries("test_${test}" marian ${EXT_LIBS})
endif(CUDA_FOUND)

View File

@ -5,8 +5,9 @@ int main(int argc, char** argv) {
using namespace marian;
{
auto g = New<ExpressionGraph>(true, false);
auto g = New<ExpressionGraph>(true);
g->setDevice({0, DeviceType::cpu});
g->getBackend()->setOptimized(false);
g->reserveWorkspaceMB(2512);
timer::AutoTimer timer;
@ -37,8 +38,10 @@ int main(int argc, char** argv) {
}
{
auto g = New<ExpressionGraph>(true, true);
auto g = New<ExpressionGraph>(true);
g->setDevice({0, DeviceType::cpu});
g->getBackend()->setOptimized(true);
g->getBackend()->setGemmType("auto");
g->reserveWorkspaceMB(2512);
timer::AutoTimer timer;

View File

@ -10,7 +10,7 @@ foreach(test ${UNIT_TESTS})
add_executable("run_${test}" run_tests.cpp "${test}.cpp")
if(CUDA_FOUND)
target_link_libraries("run_${test}" marian marian_cuda ${EXT_LIBS} Catch)
target_link_libraries("run_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS} Catch)
else(CUDA_FOUND)
target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
endif(CUDA_FOUND)

View File

@ -55,9 +55,13 @@ public:
size_t id = 0;
for(auto device : devices) {
auto task = [&](DeviceId device, size_t id) {
auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
auto graph = New<ExpressionGraph>(true);
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
@ -167,9 +171,13 @@ public:
// initialize scorers
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
auto graph = New<ExpressionGraph>(true);
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);

View File

@ -49,7 +49,7 @@
<LinkIncremental>false</LinkIncremental>
<ExecutablePath>$(ExecutablePath)</ExecutablePath>
<IntDir>$(SolutionDir)$(Platform)\$(Configuration)\Marian\</IntDir>
<IncludePath>%CUDA_PATH%\include;..\src;..\src\3rd_party;%BOOST_INCLUDE_PATH%;%ZLIB_PATH%\include;%MKL_PATH%\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);</IncludePath>
<IncludePath>..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include;..\src\3rd_party\fbgemm\third_party\cpuinfo\src;..\src\3rd_party\fbgemm\third_party\cpuinfo\include;..\src\3rd_party\fbgemm\third_party\asmjit\src;%MKL_PATH%\include;..\src\3rd_party\fbgemm\include;%CUDA_PATH%\include;..\src;..\src\3rd_party;%BOOST_INCLUDE_PATH%;%ZLIB_PATH%\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>%CUDA_PATH%\lib\x64;%BOOST_LIB_PATH%;%ZLIB_PATH%\lib;%MKL_PATH%\lib\intel64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64</LibraryPath>
</PropertyGroup>
<ItemDefinitionGroup>
@ -69,21 +69,22 @@
</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>FBGEMM_EXPORTS;USE_FBGEMM=1;ASMJIT_VARAPI;CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>false</SDLCheck>
<TreatWarningAsError>true</TreatWarningAsError>
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
<AdditionalOptions>/bigobj %(AdditionalOptions) /arch:AVX2</AdditionalOptions>
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">MultiThreadedDebugDLL</RuntimeLibrary>
<DisableSpecificWarnings>4996; 4702</DisableSpecificWarnings>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<MinimalRebuild>false</MinimalRebuild>
<ObjectFileName>$(IntDir)%(RelativeDir)</ObjectFileName>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>cudart_static.lib;cublas.lib;cusparse.lib;curand.lib;zlib.lib;msmpi.lib;mkl_intel_ilp64.lib;mkl_sequential.lib;mkl_core.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;shlwapi.lib;%(AdditionalDependencies)</AdditionalDependencies>
<StackReserveSize>100000000</StackReserveSize>
<TreatLinkerWarningAsErrors>true</TreatLinkerWarningAsErrors>
<TreatLinkerWarningAsErrors>false</TreatLinkerWarningAsErrors>
</Link>
<CudaCompile>
<Include>$(SolutionDir)..\src\;$(SolutionDir)..\src\3rd_party</Include>
@ -104,10 +105,10 @@
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>FBGEMM_EXPORTS;USE_FBGEMM=1;CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>false</SDLCheck>
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions) /arch:AVX2</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Release|x64'">MultiThreadedDLL</RuntimeLibrary>
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">MultiThreaded</RuntimeLibrary>
@ -115,6 +116,7 @@
<OmitFramePointers>true</OmitFramePointers>
<DisableSpecificWarnings>4996; 4702</DisableSpecificWarnings>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<ObjectFileName>$(IntDir)%(RelativeDir)</ObjectFileName>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
@ -123,7 +125,7 @@
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>cudart_static.lib;cublas.lib;cusparse.lib;curand.lib;zlib.lib;msmpi.lib;mkl_intel_ilp64.lib;mkl_sequential.lib;mkl_core.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;shlwapi.lib;%(AdditionalDependencies)</AdditionalDependencies>
<StackReserveSize>100000000</StackReserveSize>
<TreatLinkerWarningAsErrors>true</TreatLinkerWarningAsErrors>
<TreatLinkerWarningAsErrors>false</TreatLinkerWarningAsErrors>
</Link>
<CudaCompile>
<Include>$(SolutionDir)..\src\;$(SolutionDir)..\src\3rd_party</Include>
@ -136,6 +138,294 @@
</ItemDefinitionGroup>
<ItemGroup>
<ClCompile Include="..\src\3rd_party\ExceptionWithCallStack.cpp" />
<ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\Fbgemm.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmConv.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8DepthwiseAvx2.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8Spmdm.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16Avx512.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32Avx512.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GroupwiseConvAcc32Avx2.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAMatrix.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithIm2Col.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithQuantRowOffset.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithRowOffset.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackBMatrix.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackMatrix.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightMatrixForGConv.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightsForConv.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtils.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtilsAvx2.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\RefImplementations.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\Utils.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx2.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx512.cc">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\arch.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\assembler.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codebuilder.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codecompiler.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeemitter.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeholder.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\constpool.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\cpuinfo.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\func.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\globals.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\inst.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\logging.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\operand.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\osutils.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\regalloc.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\runtime.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\string.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\utils.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\vmem.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\zone.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86inst.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instimpl.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand_regs.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86regalloc.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src\clog.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\api.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\init.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\descriptor.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\deterministic.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\init.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\info.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\init.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\isa.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\name.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\topology.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\uarch.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\vendor.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\init.c">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\3rd_party\pathie-cpp\src\entry_iterator.cpp" />
<ClCompile Include="..\src\3rd_party\pathie-cpp\src\errors.cpp" />
<ClCompile Include="..\src\3rd_party\pathie-cpp\src\path.cpp" />
@ -346,6 +636,79 @@
<ClCompile Include="..\src\3rd_party\yaml-cpp\binary_renamed.cpp" />
<ClCompile Include="..\src\3rd_party\yaml-cpp\yaml-node.cpp" />
<ClInclude Include="..\src\3rd_party\ExceptionWithCallStack.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\ConvUtils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Fbgemm.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmBuild.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmFP16.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8DepthwiseAvx2.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8Spmdm.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\OutputProcessing-inl.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\PackingTraits-inl.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtilsAvx2.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Types.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Utils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\UtilsAvx2.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelGeneric.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\GenerateKernel.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\GroupwiseConv.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\RefImplementations.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\src\TransposeUtils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\arm.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit_apibegin.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit_apiend.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit_build.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\arch.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\assembler.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codebuilder.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codecompiler.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeemitter.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeholder.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\constpool.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\cpuinfo.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\func.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\globals.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\inst.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\logging.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\misc_p.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\operand.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\osutils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\regalloc_p.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\runtime.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\simdtypes.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\string.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\utils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\vmem.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\zone.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86emitter.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86globals.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86inst.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instimpl_p.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal_p.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging_p.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86misc.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86regalloc_p.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include\clog.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo-mock.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\common.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\internal-api.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\log.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\utils.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\api.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cpuid.h" />
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\api.h" />
<ClInclude Include="..\src\3rd_party\nccl\src\collectives\collectives.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClInclude>
@ -612,6 +975,10 @@
<ClCompile Include="..\src\tensors\cpu\prod.cpp" />
<ClCompile Include="..\src\tensors\cpu\sharp\avx_gemm.cpp" />
<ClCompile Include="..\src\tensors\cpu\sharp\int_gemm.cpp" />
<ClCompile Include="..\src\tensors\cpu\sharp\packed_gemm.cpp">
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
<TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
</ClCompile>
<ClCompile Include="..\src\tensors\cpu\sharp\sse_gemm.cpp" />
<ClCompile Include="..\src\tensors\cpu\tensor_operators.cpp" />
<ClCompile Include="..\src\graph\expression_graph.cpp" />
@ -951,7 +1318,9 @@
<ClInclude Include="..\src\rnn\types.h" />
<ClInclude Include="..\src\tensors\allocator.h" />
<ClInclude Include="..\src\tensors\backend.h" />
<ClInclude Include="..\src\tensors\cpu\expanded_gemm.h" />
<ClInclude Include="..\src\tensors\cpu\sharp\int_gemm.h" />
<ClInclude Include="..\src\tensors\cpu\sharp\packed_gemm.h" />
<ClInclude Include="..\src\tensors\device.h" />
<ClInclude Include="..\src\tensors\dispatch.h" />
<ClInclude Include="..\src\tensors\gpu\add.h" />

View File

@ -490,6 +490,225 @@
<ClCompile Include="..\src\tensors\gpu\prod.cpp">
<Filter>tensors\gpu</Filter>
</ClCompile>
<ClCompile Include="..\src\tensors\cpu\sharp\packed_gemm.cpp">
<Filter>tensors\cpu\sharp</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\Fbgemm.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmConv.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8DepthwiseAvx2.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8Spmdm.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16Avx512.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32Avx512.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\GroupwiseConvAcc32Avx2.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAMatrix.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithIm2Col.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithQuantRowOffset.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithRowOffset.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackBMatrix.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackMatrix.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightMatrixForGConv.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightsForConv.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtils.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtilsAvx2.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\RefImplementations.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\Utils.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx2.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx512.cc">
<Filter>3rd_party\fbgemm\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\api.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\init.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\info.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\init.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\isa.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\name.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\topology.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\uarch.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\vendor.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\init.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\descriptor.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\deterministic.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\init.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86inst.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instimpl.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand_regs.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86regalloc.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\arch.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\assembler.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codebuilder.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codecompiler.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeemitter.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeholder.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\constpool.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\cpuinfo.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\func.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\globals.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\inst.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\logging.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\operand.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\osutils.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\regalloc.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\runtime.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\string.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\utils.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\vmem.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\zone.cpp">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClCompile>
<ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src\clog.c">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\src\marian.h" />
@ -1555,6 +1774,231 @@
<ClInclude Include="..\src\tensors\gpu\add.inc">
<Filter>tensors\gpu</Filter>
</ClInclude>
<ClInclude Include="..\src\tensors\cpu\expanded_gemm.h">
<Filter>tensors\cpu</Filter>
</ClInclude>
<ClInclude Include="..\src\tensors\cpu\sharp\packed_gemm.h">
<Filter>tensors\cpu\sharp</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\ConvUtils.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Fbgemm.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmBuild.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmFP16.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8DepthwiseAvx2.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8Spmdm.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\OutputProcessing-inl.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\PackingTraits-inl.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtils.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtilsAvx2.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Types.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Utils.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\UtilsAvx2.h">
<Filter>3rd_party\fbgemm\include\fbgemm</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelGeneric.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\GenerateKernel.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\GroupwiseConv.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\RefImplementations.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\src\TransposeUtils.h">
<Filter>3rd_party\fbgemm\src</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\include</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo-mock.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\include</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\api.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cpuid.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\api.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\common.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\internal-api.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\log.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\utils.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\arm.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit_apibegin.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit_apiend.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit_build.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86emitter.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86globals.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86inst.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instimpl_p.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal_p.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging_p.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86misc.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86regalloc_p.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\arch.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\assembler.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codebuilder.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codecompiler.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeemitter.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\codeholder.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\constpool.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\cpuinfo.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\func.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\globals.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\inst.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\logging.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\misc_p.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\operand.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\osutils.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\regalloc_p.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\runtime.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\simdtypes.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\string.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\utils.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\vmem.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\base\zone.h">
<Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\base</Filter>
</ClInclude>
<ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include\clog.h">
<Filter>3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">
@ -1722,6 +2166,69 @@
<Filter Include="tests">
<UniqueIdentifier>{a86d650a-2268-43d9-9d74-cb17cd6b534b}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm">
<UniqueIdentifier>{4bb88f6d-7ddf-41e0-91be-a43dbcd0e9b0}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\include">
<UniqueIdentifier>{6c2bef00-97a0-4881-a6f0-ded54b8520bf}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\include\fbgemm">
<UniqueIdentifier>{95f7ce7c-c649-4d57-8d2a-d724bd75fe84}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\src">
<UniqueIdentifier>{41f7fbeb-2a73-4747-800c-46307cd0b52b}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party">
<UniqueIdentifier>{dc2722bc-af78-4923-82cb-9a09cb290fbf}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\asmjit">
<UniqueIdentifier>{577ae810-9593-423d-a398-0787252022b4}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo">
<UniqueIdentifier>{f97ae984-fe9a-45f6-a3f4-af90875209ba}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\include">
<UniqueIdentifier>{4e7efd32-ec9d-4a1f-b454-656ba5c03275}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src">
<UniqueIdentifier>{15b8bcc0-2a07-4d39-8e03-18daa0c33d09}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\x86">
<UniqueIdentifier>{ffd4cf44-177f-47a2-870a-438df9ca3be4}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe">
<UniqueIdentifier>{b600923b-21c1-492a-bfd9-0aa1082ebcd7}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows">
<UniqueIdentifier>{79535a0d-1cdc-45a9-89fb-e9c5794ddff5}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo">
<UniqueIdentifier>{5709c1ff-41f9-4f83-badb-a7a7c98c1fae}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\asmjit\src">
<UniqueIdentifier>{a35aa317-6132-4c31-8f9a-8ec68a4b1c39}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\asmjit\src\asmjit">
<UniqueIdentifier>{fc12d7c4-41df-48c0-9017-e8f4d7538cf8}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86">
<UniqueIdentifier>{5818c959-7963-4d8e-9e87-b61f340476c2}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\asmjit\src\asmjit\base">
<UniqueIdentifier>{15414ec0-8761-4068-afef-822b7bed88df}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps">
<UniqueIdentifier>{d4505c8d-5e6e-4baf-8525-dc59ae8b6415}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps\clog">
<UniqueIdentifier>{fb9777f1-6887-4286-a58c-0956b356a815}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include">
<UniqueIdentifier>{17125bd0-f21b-4e95-a922-690f5665e9b6}</UniqueIdentifier>
</Filter>
<Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src">
<UniqueIdentifier>{8fd74b1e-d3c1-4158-ad46-4a447222934e}</UniqueIdentifier>
</Filter>
</ItemGroup>
<ItemGroup>
<None Include="..\src\3rd_party\nccl\src\bootstrap.cu">