mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 14402: Sync with public marian-dev master 1.9.31
This simply pulls the recent updates from the public repo.
This commit is contained in:
parent
9001477147
commit
080d75ad59
51
.github/workflows/build-macos-10.15-cpu.yml
vendored
Normal file
51
.github/workflows/build-macos-10.15-cpu.yml
vendored
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
name: macos-10.5-cpu
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: macos-10.15
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: brew install openblas protobuf
|
||||||
|
|
||||||
|
# Openblas location is exported explicitly because openblas is keg-only,
|
||||||
|
# which means it was not symlinked into /usr/local/.
|
||||||
|
# CMake cannot find BLAS on GitHub runners if Marian is being compiled
|
||||||
|
# statically, hence USE_STATIC_LIBS=off
|
||||||
|
- name: Configure CMake
|
||||||
|
run: |
|
||||||
|
export LDFLAGS="-L/usr/local/opt/openblas/lib"
|
||||||
|
export CPPFLAGS="-I/usr/local/opt/openblas/include"
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \
|
||||||
|
-DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on -DUSE_STATIC_LIBS=off
|
||||||
|
|
||||||
|
- name: Compile
|
||||||
|
working-directory: build
|
||||||
|
run: make -j2
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
|
working-directory: build
|
||||||
|
run: make test
|
||||||
|
|
||||||
|
- name: Print versions
|
||||||
|
working-directory: build
|
||||||
|
run: |
|
||||||
|
./marian --version
|
||||||
|
./marian-decoder --version
|
||||||
|
./marian-scorer --version
|
||||||
|
./spm_encode --version
|
||||||
|
|
64
.github/workflows/build-ubuntu-18.04-cpu.yml
vendored
Normal file
64
.github/workflows/build-ubuntu-18.04-cpu.yml
vendored
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
name: ubuntu-18.04-cpu
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
# The following packages are already installed on GitHub-hosted runners: build-essential openssl libssl-dev
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt-get install --no-install-recommends libgoogle-perftools-dev libprotobuf10 libprotobuf-dev protobuf-compiler
|
||||||
|
|
||||||
|
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
|
||||||
|
- name: Install MKL
|
||||||
|
run: |
|
||||||
|
wget -qO- "https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB" | sudo apt-key add -
|
||||||
|
sudo sh -c "echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list"
|
||||||
|
sudo apt-get update -o Dir::Etc::sourcelist="/etc/apt/sources.list.d/intel-mkl.list"
|
||||||
|
sudo apt-get install --no-install-recommends intel-mkl-64bit-2020.0-088
|
||||||
|
|
||||||
|
- name: Print Boost paths
|
||||||
|
run: |
|
||||||
|
ls $BOOST_ROOT_1_69_0
|
||||||
|
ls $BOOST_ROOT_1_69_0/include
|
||||||
|
ls $BOOST_ROOT_1_69_0/lib
|
||||||
|
|
||||||
|
# Boost is already installed on GitHub-hosted runners in a non-standard location
|
||||||
|
# https://github.com/actions/virtual-environments/issues/687#issuecomment-610471671
|
||||||
|
- name: Configure CMake
|
||||||
|
run: |
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \
|
||||||
|
-DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on \
|
||||||
|
-DBOOST_ROOT=$BOOST_ROOT_1_69_0 -DBOOST_INCLUDEDIR=$BOOST_ROOT_1_69_0/include -DBOOST_LIBRARYDIR=$BOOST_ROOT_1_69_0/lib \
|
||||||
|
-DBoost_ARCHITECTURE=-x64
|
||||||
|
|
||||||
|
- name: Compile
|
||||||
|
working-directory: build
|
||||||
|
run: make -j2
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
|
working-directory: build
|
||||||
|
run: make test
|
||||||
|
|
||||||
|
- name: Print versions
|
||||||
|
working-directory: build
|
||||||
|
run: |
|
||||||
|
./marian --version
|
||||||
|
./marian-decoder --version
|
||||||
|
./marian-scorer --version
|
||||||
|
./spm_encode --version
|
||||||
|
|
49
.github/workflows/build-windows-2019-cpu.yml
vendored
Normal file
49
.github/workflows/build-windows-2019-cpu.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
name: windows-2019-cpu
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: windows-2019
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Prepare vcpkg
|
||||||
|
uses: lukka/run-vcpkg@v3
|
||||||
|
with:
|
||||||
|
vcpkgArguments: protobuf
|
||||||
|
vcpkgGitCommitId: 6185aa76504a5025f36754324abf307cc776f3da
|
||||||
|
vcpkgDirectory: ${{ github.workspace }}/vcpkg/
|
||||||
|
vcpkgTriplet: x64-windows-static
|
||||||
|
|
||||||
|
# Note that we build with a simplified CMake settings JSON file
|
||||||
|
- name: Run CMake
|
||||||
|
uses: lukka/run-cmake@v2
|
||||||
|
with:
|
||||||
|
buildDirectory: ${{ github.workspace }}/build/
|
||||||
|
cmakeAppendedArgs: -G Ninja
|
||||||
|
cmakeListsOrSettingsJson: CMakeSettingsJson
|
||||||
|
cmakeSettingsJsonPath: ${{ github.workspace }}/CMakeSettingsCI.json
|
||||||
|
useVcpkgToolchainFile: true
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
|
working-directory: build/Debug/
|
||||||
|
run: ctest
|
||||||
|
|
||||||
|
- name: Print versions
|
||||||
|
working-directory: build/Debug/
|
||||||
|
run: |
|
||||||
|
.\marian.exe --version
|
||||||
|
.\marian-decoder.exe --version
|
||||||
|
.\marian-scorer.exe --version
|
||||||
|
.\spm_encode.exe --version
|
||||||
|
|
10
CHANGELOG.md
10
CHANGELOG.md
@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
- Decoding multi-source models in marian-server with --tsv
|
||||||
|
- GitHub workflows on Ubuntu, Windows, and MacOS
|
||||||
- LSH indexing to replace short list
|
- LSH indexing to replace short list
|
||||||
- ONNX support for transformer models
|
- ONNX support for transformer models
|
||||||
- Add topk operator like PyTorch's topk
|
- Add topk operator like PyTorch's topk
|
||||||
@ -19,11 +21,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||||||
and translation with options --tsv and --tsv-fields n.
|
and translation with options --tsv and --tsv-fields n.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
- Fix compilation without BLAS installed
|
||||||
|
- Providing a single value to vector-like options using the equals sign, e.g. --models=model.npz
|
||||||
|
- Fix quiet-translation in marian-server
|
||||||
|
- CMake-based compilation on Windows
|
||||||
|
- Fix minor issues with compilation on MacOS
|
||||||
|
- Fix warnings in Windows MSVC builds using CMake
|
||||||
- Fix building server with Boost 1.72
|
- Fix building server with Boost 1.72
|
||||||
- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref
|
- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref
|
||||||
- In concatenation make sure that we do not multiply 0 with nan (which results in nan)
|
- In concatenation make sure that we do not multiply 0 with nan (which results in nan)
|
||||||
- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now
|
- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now
|
||||||
absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance.
|
absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance.
|
||||||
|
- Fixed bug in finding .git/logs/HEAD when Marian is a submodule in another project.
|
||||||
|
- Properly record cmake variables in the cmake build directory instead of the source tree.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Move Simple-WebSocket-Server to submodule
|
- Move Simple-WebSocket-Server to submodule
|
||||||
|
@ -51,6 +51,8 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
|
|||||||
execute_process(COMMAND git submodule update --init --recursive --no-fetch
|
execute_process(COMMAND git submodule update --init --recursive --no-fetch
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
|
|
||||||
|
# Note that with CMake MSVC build, the option CMAKE_BUILD_TYPE is automatically derived from the key
|
||||||
|
# 'configurationType' in CMakeSettings.json configurations
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
|
message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
|
||||||
set(CMAKE_BUILD_TYPE "Release")
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
@ -62,10 +64,11 @@ if(MSVC)
|
|||||||
# These are used in src/CMakeLists.txt on a per-target basis
|
# These are used in src/CMakeLists.txt on a per-target basis
|
||||||
list(APPEND ALL_WARNINGS /WX; /W4;)
|
list(APPEND ALL_WARNINGS /WX; /W4;)
|
||||||
|
|
||||||
# Disabled bogus warnings for CPU intrincics:
|
# Disabled bogus warnings for CPU intrinsics:
|
||||||
# C4310: cast truncates constant value
|
# C4310: cast truncates constant value
|
||||||
# C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier
|
# C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier
|
||||||
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"")
|
# C4702: unreachable code; note it is also disabled globally in the VS project file
|
||||||
|
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\"")
|
||||||
|
|
||||||
# set(INTRINSICS "/arch:AVX")
|
# set(INTRINSICS "/arch:AVX")
|
||||||
add_definitions(-DUSE_SSE2=1)
|
add_definitions(-DUSE_SSE2=1)
|
||||||
@ -79,7 +82,9 @@ if(MSVC)
|
|||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
|
||||||
|
|
||||||
# ignores warning LNK4049: locally defined symbol free imported - this comes from zlib
|
# ignores warning LNK4049: locally defined symbol free imported - this comes from zlib
|
||||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049")
|
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /ignore:4049")
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRT")
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRTD")
|
||||||
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental")
|
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental")
|
||||||
|
|
||||||
find_library(SHLWAPI Shlwapi.lib)
|
find_library(SHLWAPI Shlwapi.lib)
|
||||||
@ -218,7 +223,9 @@ if(COMPILE_CUDA)
|
|||||||
|
|
||||||
if(USE_STATIC_LIBS)
|
if(USE_STATIC_LIBS)
|
||||||
# link statically to stdlib libraries
|
# link statically to stdlib libraries
|
||||||
set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++")
|
if(NOT MSVC)
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++")
|
||||||
|
endif()
|
||||||
|
|
||||||
# look for libraries that have .a suffix
|
# look for libraries that have .a suffix
|
||||||
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||||
@ -250,12 +257,22 @@ if(CUDA_FOUND)
|
|||||||
endif(COMPILE_CUDA_SM70)
|
endif(COMPILE_CUDA_SM70)
|
||||||
|
|
||||||
if(USE_STATIC_LIBS)
|
if(USE_STATIC_LIBS)
|
||||||
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||||
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
|
||||||
|
# The cuLIBOS library does not seem to exist in Windows CUDA toolkit installs
|
||||||
|
if(CUDA_culibos_LIBRARY)
|
||||||
|
set(EXT_LIBS ${EXT_LIBS} ${CUDA_culibos_LIBRARY})
|
||||||
|
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_culibos_LIBRARY})
|
||||||
|
elseif(NOT WIN32)
|
||||||
|
message(FATAL_ERROR "cuLIBOS library not found")
|
||||||
|
endif()
|
||||||
# CUDA 10.1 introduces cublasLt library that is required on static build
|
# CUDA 10.1 introduces cublasLt library that is required on static build
|
||||||
if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1"))
|
if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1"))
|
||||||
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
|
||||||
|
if(NOT CUDA_cublasLt_LIBRARY)
|
||||||
|
message(FATAL_ERROR "cuBLASLt library not found")
|
||||||
|
endif()
|
||||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
|
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||||
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
|
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||||
endif()
|
endif()
|
||||||
@ -319,7 +336,7 @@ if(NOT MSVC)
|
|||||||
list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
|
list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
|
||||||
list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC})
|
list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC})
|
||||||
else()
|
else()
|
||||||
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; )
|
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; -Xcompiler\ /MT$<$<CONFIG:Debug>:d>; )
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(REMOVE_DUPLICATES CUDA_NVCC_FLAGS)
|
list(REMOVE_DUPLICATES CUDA_NVCC_FLAGS)
|
||||||
@ -443,8 +460,11 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h.in
|
|||||||
# Generate build_info.cpp with CMake cache variables
|
# Generate build_info.cpp with CMake cache variables
|
||||||
include(GetCacheVariables)
|
include(GetCacheVariables)
|
||||||
|
|
||||||
|
# make sure src/common/build_info.cpp has been removed
|
||||||
|
execute_process(COMMAND rm ${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp
|
||||||
|
OUTPUT_QUIET ERROR_QUIET)
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in
|
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp @ONLY)
|
${CMAKE_CURRENT_BINARY_DIR}/src/common/build_info.cpp @ONLY)
|
||||||
|
|
||||||
# Compile source files
|
# Compile source files
|
||||||
include_directories(${marian_SOURCE_DIR}/src)
|
include_directories(${marian_SOURCE_DIR}/src)
|
||||||
|
52
CMakeSettingsCI.json
Normal file
52
CMakeSettingsCI.json
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
{
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Release",
|
||||||
|
"generator": "Ninja",
|
||||||
|
"configurationType": "Release",
|
||||||
|
"inheritEnvironments": [ "msvc_x64" ],
|
||||||
|
"cmakeCommandArgs": "",
|
||||||
|
"buildCommandArgs": "-v",
|
||||||
|
"ctestCommandArgs": "",
|
||||||
|
"variables": [
|
||||||
|
{ "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" },
|
||||||
|
|
||||||
|
{ "name": "COMPILE_CUDA:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "COMPILE_CPU:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "COMPILE_SERVER:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "COMPILE_TESTS:BOOL", "value": "TRUE" },
|
||||||
|
|
||||||
|
{ "name": "USE_FBGEMM:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "USE_MPI:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" }
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug",
|
||||||
|
"generator": "Ninja",
|
||||||
|
"configurationType": "Debug",
|
||||||
|
"inheritEnvironments": [ "msvc_x64" ],
|
||||||
|
"cmakeCommandArgs": "",
|
||||||
|
"buildCommandArgs": "-v",
|
||||||
|
"ctestCommandArgs": "",
|
||||||
|
"variables": [
|
||||||
|
{ "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" },
|
||||||
|
|
||||||
|
{ "name": "COMPILE_CUDA:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "COMPILE_CPU:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "COMPILE_SERVER:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "COMPILE_TESTS:BOOL", "value": "TRUE" },
|
||||||
|
|
||||||
|
{ "name": "USE_FBGEMM:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "USE_MPI:BOOL", "value": "FALSE" },
|
||||||
|
{ "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" },
|
||||||
|
{ "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -89,10 +89,17 @@ find_library(MKL_CORE_LIBRARY
|
|||||||
NO_DEFAULT_PATH)
|
NO_DEFAULT_PATH)
|
||||||
|
|
||||||
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
|
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
|
||||||
# Added -Wl block to avoid circular dependencies.
|
set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY})
|
||||||
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
|
|
||||||
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
|
if(NOT WIN32 AND NOT APPLE)
|
||||||
set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group)
|
# Added -Wl block to avoid circular dependencies.
|
||||||
|
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
|
||||||
|
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
|
||||||
|
set(MKL_LIBRARIES -Wl,--start-group ${MKL_LIBRARIES} -Wl,--end-group)
|
||||||
|
elseif(APPLE)
|
||||||
|
# MacOS does not support --start-group and --end-group
|
||||||
|
set(MKL_LIBRARIES -Wl,${MKL_LIBRARIES} -Wl,)
|
||||||
|
endif()
|
||||||
|
|
||||||
# message("1 ${MKL_INCLUDE_DIR}")
|
# message("1 ${MKL_INCLUDE_DIR}")
|
||||||
# message("2 ${MKL_INTERFACE_LIBRARY}")
|
# message("2 ${MKL_INTERFACE_LIBRARY}")
|
||||||
@ -130,4 +137,4 @@ endif()
|
|||||||
INCLUDE(FindPackageHandleStandardArgs)
|
INCLUDE(FindPackageHandleStandardArgs)
|
||||||
FIND_PACKAGE_HANDLE_STANDARD_ARGS(MKL DEFAULT_MSG MKL_LIBRARIES MKL_INCLUDE_DIRS MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)
|
FIND_PACKAGE_HANDLE_STANDARD_ARGS(MKL DEFAULT_MSG MKL_LIBRARIES MKL_INCLUDE_DIRS MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)
|
||||||
|
|
||||||
MARK_AS_ADVANCED(MKL_INCLUDE_DIRS MKL_LIBRARIES MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)
|
MARK_AS_ADVANCED(MKL_INCLUDE_DIRS MKL_LIBRARIES MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)
|
||||||
|
@ -34,17 +34,18 @@ foreach(_variableName ${_variableNames})
|
|||||||
(NOT "${_variableType}" STREQUAL "INTERNAL") AND
|
(NOT "${_variableType}" STREQUAL "INTERNAL") AND
|
||||||
(NOT "${_variableValue}" STREQUAL "") )
|
(NOT "${_variableValue}" STREQUAL "") )
|
||||||
|
|
||||||
|
string(REPLACE "\"" " " _variableValueEscapedQuotes ${_variableValue})
|
||||||
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValue}\\n\"\n")
|
string(REPLACE "\\" "/" _variableValueEscaped ${_variableValueEscapedQuotes})
|
||||||
|
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValueEscaped}\\n\"\n")
|
||||||
|
|
||||||
# Get the variable's advanced flag
|
# Get the variable's advanced flag
|
||||||
get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET)
|
get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET)
|
||||||
if(NOT _isAdvanced)
|
if(NOT _isAdvanced)
|
||||||
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValue}\\n\"\n")
|
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValueEscaped}\\n\"\n")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Print variables for debugging
|
# Print variables for debugging
|
||||||
#message(STATUS "${_variableName}=${${_variableName}}")
|
#message(STATUS "${_variableName}=${_variableValueEscaped}")
|
||||||
#message(STATUS " Type=${_variableType}")
|
#message(STATUS " Type=${_variableType}")
|
||||||
#message(STATUS " Advanced=${_isAdvanced}")
|
#message(STATUS " Advanced=${_isAdvanced}")
|
||||||
endif()
|
endif()
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
if(PROJECT_VERSION_FILE)
|
if(PROJECT_VERSION_FILE)
|
||||||
file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING)
|
file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING)
|
||||||
else()
|
else()
|
||||||
file(STRINGS ${CMAKE_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING)
|
file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Get current commit SHA from git
|
# Get current commit SHA from git
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 0f8cabf13ec362d50544d33490024e00c3a763be
|
Subproject commit 7b8f6ee5b6ff7779fd993df7f77adf1e2d9adbe5
|
3
src/3rd_party/CLI/App.hpp
vendored
3
src/3rd_party/CLI/App.hpp
vendored
@ -1595,7 +1595,8 @@ class App {
|
|||||||
if(num < 0) {
|
if(num < 0) {
|
||||||
// RG: We need to keep track if the vector option is empty and handle this separately as
|
// RG: We need to keep track if the vector option is empty and handle this separately as
|
||||||
// otherwise the parser will mark the command-line option as not set
|
// otherwise the parser will mark the command-line option as not set
|
||||||
bool emptyVectorArgs = true;
|
// RG: An option value after '=' was already collected
|
||||||
|
bool emptyVectorArgs = (collected <= 0);
|
||||||
while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) {
|
while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) {
|
||||||
if(collected >= -num) {
|
if(collected >= -num) {
|
||||||
// We could break here for allow extras, but we don't
|
// We could break here for allow extras, but we don't
|
||||||
|
11
src/3rd_party/CMakeLists.txt
vendored
11
src/3rd_party/CMakeLists.txt
vendored
@ -19,6 +19,9 @@ if(USE_FBGEMM)
|
|||||||
# only locally disabled for the 3rd_party folder
|
# only locally disabled for the 3rd_party folder
|
||||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function")
|
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused")
|
||||||
|
else()
|
||||||
|
# Do not compile cpuinfo executables due to a linker error, and they are not needed
|
||||||
|
set(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "Build command-line tools")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
|
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
|
||||||
@ -45,7 +48,7 @@ if(USE_SENTENCEPIECE)
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available." FORCE)
|
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available.")
|
||||||
|
|
||||||
if(USE_STATIC_LIBS)
|
if(USE_STATIC_LIBS)
|
||||||
message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \
|
message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \
|
||||||
@ -55,8 +58,8 @@ if(USE_SENTENCEPIECE)
|
|||||||
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
||||||
set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
||||||
else(USE_STATIC_LIBS)
|
else(USE_STATIC_LIBS)
|
||||||
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries.")
|
||||||
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC.")
|
||||||
endif(USE_STATIC_LIBS)
|
endif(USE_STATIC_LIBS)
|
||||||
|
|
||||||
add_subdirectory(./sentencepiece)
|
add_subdirectory(./sentencepiece)
|
||||||
@ -66,7 +69,7 @@ if(USE_SENTENCEPIECE)
|
|||||||
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||||
|
|
||||||
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||||
foreach(t sentencepiece sentencepiece_train sentencepiece_train-static
|
foreach(t sentencepiece-static sentencepiece_train-static
|
||||||
spm_decode spm_encode spm_export_vocab spm_normalize spm_train)
|
spm_decode spm_encode spm_export_vocab spm_normalize spm_train)
|
||||||
set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused")
|
set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused")
|
||||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||||
|
4
src/3rd_party/cnpy/cnpy.h
vendored
4
src/3rd_party/cnpy/cnpy.h
vendored
@ -18,6 +18,10 @@
|
|||||||
#include<map>
|
#include<map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace cnpy {
|
namespace cnpy {
|
||||||
|
|
||||||
struct NpyArray {
|
struct NpyArray {
|
||||||
|
3
src/3rd_party/faiss/VectorTransform.h
vendored
3
src/3rd_party/faiss/VectorTransform.h
vendored
@ -18,6 +18,9 @@
|
|||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <faiss/Index.h>
|
#include <faiss/Index.h>
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#include <x86intrin.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
namespace faiss {
|
namespace faiss {
|
||||||
|
3
src/3rd_party/phf/phf.cc
vendored
3
src/3rd_party/phf/phf.cc
vendored
@ -679,7 +679,8 @@ PHF_PUBLIC void PHF::compact(struct phf *phf) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* simply keep old array if realloc fails */
|
/* simply keep old array if realloc fails */
|
||||||
if ((tmp = realloc(phf->g, phf->r * size)))
|
tmp = realloc(phf->g, phf->r * size);
|
||||||
|
if (tmp != 0)
|
||||||
phf->g = static_cast<uint32_t *>(tmp);
|
phf->g = static_cast<uint32_t *>(tmp);
|
||||||
} /* PHF::compact() */
|
} /* PHF::compact() */
|
||||||
|
|
||||||
|
4
src/3rd_party/zlib/CMakeLists.txt
vendored
4
src/3rd_party/zlib/CMakeLists.txt
vendored
@ -2,11 +2,11 @@
|
|||||||
file(GLOB ZLIB_SRC *.c)
|
file(GLOB ZLIB_SRC *.c)
|
||||||
file(GLOB ZLIB_INC *.h)
|
file(GLOB ZLIB_INC *.h)
|
||||||
|
|
||||||
# add sources of the wrapper as a "SQLiteCpp" static library
|
# add sources of the wrapper as a "zlib" static library
|
||||||
add_library(zlib OBJECT ${ZLIB_SRC} ${ZLIB_INC})
|
add_library(zlib OBJECT ${ZLIB_SRC} ${ZLIB_INC})
|
||||||
|
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
target_compile_options(zlib PUBLIC /wd"4996" /wd"4267")
|
target_compile_options(zlib PUBLIC /wd4996 /wd4267)
|
||||||
else()
|
else()
|
||||||
target_compile_options(zlib PUBLIC -Wno-implicit-function-declaration)
|
target_compile_options(zlib PUBLIC -Wno-implicit-function-declaration)
|
||||||
endif()
|
endif()
|
||||||
|
@ -20,7 +20,7 @@ add_library(marian STATIC
|
|||||||
common/config_validator.cpp
|
common/config_validator.cpp
|
||||||
common/options.cpp
|
common/options.cpp
|
||||||
common/binary.cpp
|
common/binary.cpp
|
||||||
common/build_info.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/common/build_info.cpp
|
||||||
common/io.cpp
|
common/io.cpp
|
||||||
common/filesystem.cpp
|
common/filesystem.cpp
|
||||||
common/file_stream.cpp
|
common/file_stream.cpp
|
||||||
@ -70,7 +70,7 @@ add_library(marian STATIC
|
|||||||
layers/generic.cpp
|
layers/generic.cpp
|
||||||
layers/loss.cpp
|
layers/loss.cpp
|
||||||
layers/weight.cpp
|
layers/weight.cpp
|
||||||
layers/lsh.cpp
|
layers/lsh.cpp
|
||||||
|
|
||||||
rnn/cells.cpp
|
rnn/cells.cpp
|
||||||
rnn/attention.cpp
|
rnn/attention.cpp
|
||||||
@ -117,21 +117,22 @@ target_compile_options(marian PUBLIC ${ALL_WARNINGS})
|
|||||||
# Git updates .git/logs/HEAD file whenever you pull or commit something.
|
# Git updates .git/logs/HEAD file whenever you pull or commit something.
|
||||||
|
|
||||||
# If Marian is checked out as a submodule in another repository,
|
# If Marian is checked out as a submodule in another repository,
|
||||||
# there's no .git directory in ${CMAKE_SOURCE_DIR}. Instead .git is a
|
# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is not a directory but a file
|
||||||
# file that specifies the relative path from ${CMAKE_SOURCE_DIR} to
|
# that specifies the relative path from ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
# ./git/modules/<MARIAN_ROOT_DIR> in the root of the repository that
|
# to ./git/modules/<MARIAN_ROOT_DIR> in the root of the check_out of
|
||||||
# contains Marian as a submodule. We set MARIAN_GIT_DIR to the appropriate
|
# the project that contains Marian as a submodule.
|
||||||
# path, depending on whether ${CMAKE_SOURCE_DIR}/.git is a directory or file.
|
#
|
||||||
if(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) # not a submodule
|
# We set MARIAN_GIT_DIR to the appropriate path, depending on whether
|
||||||
set(MARIAN_GIT_DIR ${CMAKE_SOURCE_DIR}/.git)
|
# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is a directory or file.
|
||||||
else(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
|
set(MARIAN_GIT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../.git)
|
||||||
file(READ ${CMAKE_SOURCE_DIR}/.git MARIAN_GIT_DIR)
|
if(NOT IS_DIRECTORY ${MARIAN_GIT_DIR}) # i.e., it's a submodule
|
||||||
|
file(READ ${MARIAN_GIT_DIR} MARIAN_GIT_DIR)
|
||||||
string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR})
|
string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR})
|
||||||
get_filename_component(MARIAN_GIT_DIR "${CMAKE_SOURCE_DIR}/${MARIAN_GIT_DIR}" ABSOLUTE)
|
get_filename_component(MARIAN_GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${MARIAN_GIT_DIR}" ABSOLUTE)
|
||||||
endif(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
|
endif(NOT IS_DIRECTORY ${MARIAN_GIT_DIR})
|
||||||
|
|
||||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
||||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
||||||
DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD
|
DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD
|
||||||
VERBATIM
|
VERBATIM
|
||||||
@ -220,7 +221,7 @@ if(COMPILE_SERVER)
|
|||||||
add_executable(marian_server command/marian_server.cpp)
|
add_executable(marian_server command/marian_server.cpp)
|
||||||
set_target_properties(marian_server PROPERTIES OUTPUT_NAME marian-server)
|
set_target_properties(marian_server PROPERTIES OUTPUT_NAME marian-server)
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
# Disable warnings from the SimpleWebSocketServer library
|
# Disable warnings from the SimpleWebSocketServer library needed for compilation of marian-server
|
||||||
target_compile_options(marian_server PUBLIC ${ALL_WARNINGS} /wd4267 /wd4244 /wd4456 /wd4458)
|
target_compile_options(marian_server PUBLIC ${ALL_WARNINGS} /wd4267 /wd4244 /wd4456 /wd4458)
|
||||||
else(MSVC)
|
else(MSVC)
|
||||||
# -Wno-suggest-override disables warnings from Boost 1.69+
|
# -Wno-suggest-override disables warnings from Boost 1.69+
|
||||||
|
@ -14,6 +14,7 @@ int main(int argc, char **argv) {
|
|||||||
// Initialize translation task
|
// Initialize translation task
|
||||||
auto options = parseOptions(argc, argv, cli::mode::server, true);
|
auto options = parseOptions(argc, argv, cli::mode::server, true);
|
||||||
auto task = New<TranslateService<BeamSearch>>(options);
|
auto task = New<TranslateService<BeamSearch>>(options);
|
||||||
|
auto quiet = options->get<bool>("quiet-translation");
|
||||||
|
|
||||||
// Initialize web server
|
// Initialize web server
|
||||||
WSServer server;
|
WSServer server;
|
||||||
@ -21,8 +22,8 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
auto &translate = server.endpoint["^/translate/?$"];
|
auto &translate = server.endpoint["^/translate/?$"];
|
||||||
|
|
||||||
translate.on_message = [&task](Ptr<WSServer::Connection> connection,
|
translate.on_message = [&task, quiet](Ptr<WSServer::Connection> connection,
|
||||||
Ptr<WSServer::InMessage> message) {
|
Ptr<WSServer::InMessage> message) {
|
||||||
// Get input text
|
// Get input text
|
||||||
auto inputText = message->string();
|
auto inputText = message->string();
|
||||||
auto sendStream = std::make_shared<WSServer::OutMessage>();
|
auto sendStream = std::make_shared<WSServer::OutMessage>();
|
||||||
@ -30,9 +31,9 @@ int main(int argc, char **argv) {
|
|||||||
// Translate
|
// Translate
|
||||||
timer::Timer timer;
|
timer::Timer timer;
|
||||||
auto outputText = task->run(inputText);
|
auto outputText = task->run(inputText);
|
||||||
LOG(info, "Best translation: {}", outputText);
|
|
||||||
*sendStream << outputText << std::endl;
|
*sendStream << outputText << std::endl;
|
||||||
LOG(info, "Translation took: {:.5f}s", timer.elapsed());
|
if(!quiet)
|
||||||
|
LOG(info, "Translation took: {:.5f}s", timer.elapsed());
|
||||||
|
|
||||||
// Send translation back
|
// Send translation back
|
||||||
connection->send(sendStream, [](const SimpleWeb::error_code &ec) {
|
connection->send(sendStream, [](const SimpleWeb::error_code &ec) {
|
||||||
|
@ -14,7 +14,9 @@ namespace filesystem {
|
|||||||
// Pretend that Windows knows no named pipes. It does, by the way, but
|
// Pretend that Windows knows no named pipes. It does, by the way, but
|
||||||
// they seem to be different from pipes on Unix / Linux. See
|
// they seem to be different from pipes on Unix / Linux. See
|
||||||
// https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
|
// https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
|
||||||
bool is_fifo(char const*) { return false; }
|
bool is_fifo(char const* /*path*/) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
bool is_fifo(char const* path) {
|
bool is_fifo(char const* path) {
|
||||||
struct stat buf;
|
struct stat buf;
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
Options::Options()
|
Options::Options()
|
||||||
#if FASTOPT
|
#if FASTOPT
|
||||||
: fastOptions_(options_)
|
: fastOptions_(options_)
|
||||||
#endif
|
#endif
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <ctime>
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
namespace timer {
|
namespace timer {
|
||||||
|
@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const {
|
|||||||
TextInput::TextInput(std::vector<std::string> inputs,
|
TextInput::TextInput(std::vector<std::string> inputs,
|
||||||
std::vector<Ptr<Vocab>> vocabs,
|
std::vector<Ptr<Vocab>> vocabs,
|
||||||
Ptr<Options> options)
|
Ptr<Options> options)
|
||||||
: DatasetBase(inputs, options), vocabs_(vocabs) {
|
: DatasetBase(inputs, options),
|
||||||
// note: inputs are automatically stored in the inherited variable named paths_, but these are
|
vocabs_(vocabs),
|
||||||
|
maxLength_(options_->get<size_t>("max-length")),
|
||||||
|
maxLengthCrop_(options_->get<bool>("max-length-crop")) {
|
||||||
|
// Note: inputs are automatically stored in the inherited variable named paths_, but these are
|
||||||
// texts not paths!
|
// texts not paths!
|
||||||
for(const auto& text : paths_)
|
for(const auto& text : paths_)
|
||||||
files_.emplace_back(new std::istringstream(text));
|
files_.emplace_back(new std::istringstream(text));
|
||||||
@ -42,6 +45,10 @@ SentenceTuple TextInput::next() {
|
|||||||
std::string line;
|
std::string line;
|
||||||
if(io::getline(*files_[i], line)) {
|
if(io::getline(*files_[i], line)) {
|
||||||
Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
|
Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
|
||||||
|
if(this->maxLengthCrop_ && words.size() > this->maxLength_) {
|
||||||
|
words.resize(maxLength_);
|
||||||
|
words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels
|
||||||
|
}
|
||||||
if(words.empty())
|
if(words.empty())
|
||||||
words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
|
words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
|
||||||
tup.push_back(words);
|
tup.push_back(words);
|
||||||
|
@ -33,6 +33,9 @@ private:
|
|||||||
|
|
||||||
size_t pos_{0};
|
size_t pos_{0};
|
||||||
|
|
||||||
|
size_t maxLength_{0};
|
||||||
|
bool maxLengthCrop_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
typedef SentenceTuple Sample;
|
typedef SentenceTuple Sample;
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ namespace marian {
|
|||||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||||
|
|
||||||
auto firstLogits = logits_.front()->loss();
|
auto firstLogits = logits_.front()->loss();
|
||||||
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
|
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
|
||||||
"Labels not matching logits shape ({} != {}, {})??",
|
"Labels not matching logits shape ({} != {}, {})??",
|
||||||
labels.size() * firstLogits->shape()[-1],
|
labels.size() * firstLogits->shape()[-1],
|
||||||
firstLogits->shape().elements(),
|
firstLogits->shape().elements(),
|
||||||
@ -267,8 +267,8 @@ namespace marian {
|
|||||||
Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||||
lazyConstruct(input->shape()[-1]);
|
lazyConstruct(input->shape()[-1]);
|
||||||
|
|
||||||
auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
|
||||||
#if BLAS_FOUND
|
#if BLAS_FOUND
|
||||||
|
auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||||
if(lsh_) {
|
if(lsh_) {
|
||||||
ABORT_IF( transA, "Transposed query not supported for LSH");
|
ABORT_IF( transA, "Transposed query not supported for LSH");
|
||||||
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
|
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
|
||||||
@ -276,10 +276,12 @@ namespace marian {
|
|||||||
} else {
|
} else {
|
||||||
return affine(x, W, b, transA, transB);
|
return affine(x, W, b, transA, transB);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
return affine(x, W, b, transA, transB);
|
|
||||||
#endif
|
|
||||||
};
|
};
|
||||||
|
#else
|
||||||
|
auto affineOrLSH = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||||
|
return affine(x, W, b, transA, transB);
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
|
if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
|
||||||
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
|
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
|
||||||
@ -362,7 +364,7 @@ namespace marian {
|
|||||||
factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||||
else
|
else
|
||||||
factorLogits = affine(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
factorLogits = affine(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||||
|
|
||||||
// optionally add lemma-dependent bias
|
// optionally add lemma-dependent bias
|
||||||
if (Plemma) { // [B... x U0]
|
if (Plemma) { // [B... x U0]
|
||||||
int lemmaVocabDim = Plemma->shape()[-1];
|
int lemmaVocabDim = Plemma->shape()[-1];
|
||||||
@ -448,7 +450,7 @@ namespace marian {
|
|||||||
|
|
||||||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
|
||||||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
|
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
|
||||||
|
|
||||||
if (options_->has("embFile")) {
|
if (options_->has("embFile")) {
|
||||||
std::string file = opt<std::string>("embFile");
|
std::string file = opt<std::string>("embFile");
|
||||||
if (!file.empty()) {
|
if (!file.empty()) {
|
||||||
|
@ -32,8 +32,6 @@ Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
|
|||||||
return New<MeanMultiRationalLoss>();
|
return New<MeanMultiRationalLoss>();
|
||||||
else
|
else
|
||||||
ABORT("Unknown multi-loss-type {}", multiLossType);
|
ABORT("Unknown multi-loss-type {}", multiLossType);
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
@ -74,7 +74,6 @@ std::string ScoreCollector::getAlignment(const data::SoftAlignment& align) {
|
|||||||
} else {
|
} else {
|
||||||
ABORT("Unrecognized word alignment type");
|
ABORT("Unrecognized word alignment type");
|
||||||
}
|
}
|
||||||
return "";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ScoreCollectorNBest::ScoreCollectorNBest(const Ptr<Options>& options)
|
ScoreCollectorNBest::ScoreCollectorNBest(const Ptr<Options>& options)
|
||||||
|
@ -19,7 +19,6 @@ struct QuantizeNodeOp : public UnaryNodeOp {
|
|||||||
|
|
||||||
NodeOps backwardOps() override {
|
NodeOps backwardOps() override {
|
||||||
ABORT("Only used for inference");
|
ABORT("Only used for inference");
|
||||||
return {NodeOp(0)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string type() override { return "quantizeInt16"; }
|
const std::string type() override { return "quantizeInt16"; }
|
||||||
@ -54,7 +53,6 @@ public:
|
|||||||
|
|
||||||
NodeOps backwardOps() override {
|
NodeOps backwardOps() override {
|
||||||
ABORT("Only used for inference");
|
ABORT("Only used for inference");
|
||||||
return {NodeOp(0)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string type() override { return "dotInt16"; }
|
const std::string type() override { return "dotInt16"; }
|
||||||
@ -92,7 +90,6 @@ public:
|
|||||||
|
|
||||||
NodeOps backwardOps() override {
|
NodeOps backwardOps() override {
|
||||||
ABORT("Only used for inference");
|
ABORT("Only used for inference");
|
||||||
return {NodeOp(0)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string type() override { return "affineInt16"; }
|
const std::string type() override { return "affineInt16"; }
|
||||||
|
@ -1,25 +1,26 @@
|
|||||||
# Unit tests
|
# Unit tests
|
||||||
add_subdirectory(units)
|
add_subdirectory(units)
|
||||||
|
|
||||||
|
if(NOT MSVC)
|
||||||
|
# Testing apps
|
||||||
|
set(APP_TESTS
|
||||||
|
logger
|
||||||
|
dropout
|
||||||
|
sqlite
|
||||||
|
prod
|
||||||
|
cli
|
||||||
|
pooling
|
||||||
|
)
|
||||||
|
|
||||||
# Testing apps
|
foreach(test ${APP_TESTS})
|
||||||
set(APP_TESTS
|
add_executable("test_${test}" "${test}.cpp")
|
||||||
logger
|
|
||||||
dropout
|
|
||||||
sqlite
|
|
||||||
prod
|
|
||||||
cli
|
|
||||||
pooling
|
|
||||||
)
|
|
||||||
|
|
||||||
foreach(test ${APP_TESTS})
|
if(CUDA_FOUND)
|
||||||
add_executable("test_${test}" "${test}.cpp")
|
target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
|
||||||
|
else(CUDA_FOUND)
|
||||||
|
target_link_libraries("test_${test}" marian ${EXT_LIBS})
|
||||||
|
endif(CUDA_FOUND)
|
||||||
|
|
||||||
if(CUDA_FOUND)
|
set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||||
target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
|
endforeach(test)
|
||||||
else(CUDA_FOUND)
|
endif(NOT MSVC)
|
||||||
target_link_libraries("test_${test}" marian ${EXT_LIBS})
|
|
||||||
endif(CUDA_FOUND)
|
|
||||||
|
|
||||||
set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
|
||||||
endforeach(test)
|
|
@ -17,5 +17,10 @@ foreach(test ${UNIT_TESTS})
|
|||||||
target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
|
target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
|
||||||
endif(CUDA_FOUND)
|
endif(CUDA_FOUND)
|
||||||
|
|
||||||
|
if(MSVC)
|
||||||
|
# Disable C4305: truncation from 'double' to '_Ty'
|
||||||
|
target_compile_options("run_${test}" PUBLIC /wd4305)
|
||||||
|
endif(MSVC)
|
||||||
|
|
||||||
add_test(NAME ${test} COMMAND "run_${test}")
|
add_test(NAME ${test} COMMAND "run_${test}")
|
||||||
endforeach(test)
|
endforeach(test)
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
using namespace marian;
|
using namespace marian;
|
||||||
|
|
||||||
TEST_CASE("FastOpt can be constructed from a YAML node", "[fastopt]") {
|
TEST_CASE("FastOpt can be constructed from a YAML node", "[fastopt]") {
|
||||||
YAML::Node node;
|
|
||||||
|
|
||||||
SECTION("from a simple node") {
|
SECTION("from a simple node") {
|
||||||
YAML::Node node = YAML::Load("{foo: bar}");
|
YAML::Node node = YAML::Load("{foo: bar}");
|
||||||
const FastOpt o(node);
|
const FastOpt o(node);
|
||||||
|
@ -40,12 +40,12 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
std::vector<T> vA({1, -2, 3, -4});
|
std::vector<T> vA({1, -2, 3, -4});
|
||||||
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
|
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
|
||||||
|
|
||||||
auto compare = [&](Expr res, std::function<float(float)> f, bool exactMatch) -> bool {
|
auto compare = [&](Expr res, std::function<float(float)> f) -> bool {
|
||||||
if (res->shape() != Shape({ 2, 2, 1 }))
|
if (res->shape() != Shape({ 2, 2, 1 }))
|
||||||
return false;
|
return false;
|
||||||
res->val()->get(values);
|
res->val()->get(values);
|
||||||
std::vector<float> ref{f(vA[0]), f(vA[1]), f(vA[2]), f(vA[3])};
|
std::vector<float> ref{f(vA[0]), f(vA[1]), f(vA[2]), f(vA[3])};
|
||||||
return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
|
return std::equal(values.begin(), values.end(), ref.begin(), floatEqual);
|
||||||
};
|
};
|
||||||
|
|
||||||
// @TODO: add all operators and scalar variants here for completeness
|
// @TODO: add all operators and scalar variants here for completeness
|
||||||
@ -55,15 +55,15 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
auto rmax2 = maximum(1, a);
|
auto rmax2 = maximum(1, a);
|
||||||
auto rmin1 = minimum(a, 1);
|
auto rmin1 = minimum(a, 1);
|
||||||
auto rmin2 = minimum(1, a);
|
auto rmin2 = minimum(1, a);
|
||||||
|
|
||||||
graph->forward();
|
graph->forward();
|
||||||
|
|
||||||
CHECK(compare(rsmult, [](float a) {return 2.f * a;}, true));
|
CHECK(compare(rsmult, [](float a) {return 2.f * a;}));
|
||||||
CHECK(compare(rabs, [](float a) {return std::abs(a);}, true));
|
CHECK(compare(rabs, [](float a) {return std::abs(a);}));
|
||||||
CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);}, true));
|
CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);}));
|
||||||
CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);}, true));
|
CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);}));
|
||||||
CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);}, true));
|
CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);}));
|
||||||
CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);}, true));
|
CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);}));
|
||||||
}
|
}
|
||||||
|
|
||||||
SECTION("elementwise binary operators with broadcasting") {
|
SECTION("elementwise binary operators with broadcasting") {
|
||||||
@ -76,12 +76,23 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
|
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
|
||||||
auto b = graph->constant({2, 1}, inits::fromVector(vB));
|
auto b = graph->constant({2, 1}, inits::fromVector(vB));
|
||||||
|
|
||||||
auto compare = [&](Expr res, std::function<float(float,float)> f, bool exactMatch) -> bool {
|
// Two lambdas below differ in the use of floatEqual or floatApprox and
|
||||||
|
// are not merged because MSVC compiler returns C2446: no conversion from
|
||||||
|
// lambda_x to lambda_y
|
||||||
|
auto compare = [&](Expr res, std::function<float(float,float)> f) -> bool {
|
||||||
if (res->shape() != Shape({ 2, 2, 1 }))
|
if (res->shape() != Shape({ 2, 2, 1 }))
|
||||||
return false;
|
return false;
|
||||||
res->val()->get(values);
|
res->val()->get(values);
|
||||||
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
|
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
|
||||||
return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
|
return std::equal(values.begin(), values.end(), ref.begin(), floatEqual);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto compareApprox = [&](Expr res, std::function<float(float, float)> f) -> bool {
|
||||||
|
if(res->shape() != Shape({2, 2, 1}))
|
||||||
|
return false;
|
||||||
|
res->val()->get(values);
|
||||||
|
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
|
||||||
|
return std::equal(values.begin(), values.end(), ref.begin(), floatApprox);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto rplus = a + b;
|
auto rplus = a + b;
|
||||||
@ -100,19 +111,19 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
|
|
||||||
graph->forward();
|
graph->forward();
|
||||||
|
|
||||||
CHECK(compare(rplus, [](float a, float b) {return a + b;}, true));
|
CHECK(compare(rplus, [](float a, float b) {return a + b;}));
|
||||||
CHECK(compare(rminus, [](float a, float b) {return a - b;}, true));
|
CHECK(compare(rminus, [](float a, float b) {return a - b;}));
|
||||||
CHECK(compare(rmult, [](float a, float b) {return a * b;}, true));
|
CHECK(compare(rmult, [](float a, float b) {return a * b;}));
|
||||||
CHECK(compare(rdiv, [](float a, float b) {return a / b;}, false));
|
CHECK(compareApprox(rdiv, [](float a, float b) {return a / b;}));
|
||||||
CHECK(compare(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}, false));
|
CHECK(compareApprox(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}));
|
||||||
CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}, true));
|
CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}));
|
||||||
CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}, true));
|
CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}));
|
||||||
CHECK(compare(rlt, [](float a, float b) {return a < b;}, true));
|
CHECK(compare(rlt, [](float a, float b) {return a < b;}));
|
||||||
CHECK(compare(req, [](float a, float b) {return a == b;}, true));
|
CHECK(compare(req, [](float a, float b) {return a == b;}));
|
||||||
CHECK(compare(rgt, [](float a, float b) {return a > b;}, true));
|
CHECK(compare(rgt, [](float a, float b) {return a > b;}));
|
||||||
CHECK(compare(rge, [](float a, float b) {return a >= b;}, true));
|
CHECK(compare(rge, [](float a, float b) {return a >= b;}));
|
||||||
CHECK(compare(rne, [](float a, float b) {return a != b;}, true));
|
CHECK(compare(rne, [](float a, float b) {return a != b;}));
|
||||||
CHECK(compare(rle, [](float a, float b) {return a <= b;}, true));
|
CHECK(compare(rle, [](float a, float b) {return a <= b;}));
|
||||||
}
|
}
|
||||||
|
|
||||||
SECTION("transposing and reshaping") {
|
SECTION("transposing and reshaping") {
|
||||||
@ -399,8 +410,8 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
std::vector<float> SV; // create CSR version of S
|
std::vector<float> SV; // create CSR version of S
|
||||||
std::vector<IndexType> SI, SO;
|
std::vector<IndexType> SI, SO;
|
||||||
SO.push_back((IndexType)SI.size());
|
SO.push_back((IndexType)SI.size());
|
||||||
for (IndexType i = 0; i < S->shape()[0]; i++) {
|
for (IndexType i = 0; i < (IndexType)S->shape()[0]; i++) {
|
||||||
for (IndexType j = 0; j < S->shape()[1]; j++) {
|
for (IndexType j = 0; j < (IndexType)S->shape()[1]; j++) {
|
||||||
auto k = 4 * i + j;
|
auto k = 4 * i + j;
|
||||||
if (vS[k] != 0) {
|
if (vS[k] != 0) {
|
||||||
SV.push_back(vS[k]);
|
SV.push_back(vS[k]);
|
||||||
@ -477,7 +488,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
aff1->val()->get(values);
|
aff1->val()->get(values);
|
||||||
CHECK(values == vAff);
|
CHECK(values == vAff);
|
||||||
|
|
||||||
std::vector<T> values2;
|
values2.clear();
|
||||||
CHECK(aff2->shape() == aff1->shape());
|
CHECK(aff2->shape() == aff1->shape());
|
||||||
aff2->val()->get(values2);
|
aff2->val()->get(values2);
|
||||||
CHECK(values2 == values);
|
CHECK(values2 == values);
|
||||||
@ -653,7 +664,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
SECTION("relation of rows and columns selection using transpose") {
|
SECTION("relation of rows and columns selection using transpose") {
|
||||||
graph->clear();
|
graph->clear();
|
||||||
values.clear();
|
values.clear();
|
||||||
std::vector<T> values2;
|
values2.clear();
|
||||||
|
|
||||||
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
|
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
|
||||||
std::vector<IndexType> idx({0, 1});
|
std::vector<IndexType> idx({0, 1});
|
||||||
@ -763,7 +774,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
SECTION("rows/cols as gather operations") {
|
SECTION("rows/cols as gather operations") {
|
||||||
graph->clear();
|
graph->clear();
|
||||||
values.clear();
|
values.clear();
|
||||||
std::vector<T> values2;
|
values2.clear();
|
||||||
|
|
||||||
|
|
||||||
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
|
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
|
||||||
@ -791,14 +802,14 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
SECTION("topk operations") {
|
SECTION("topk operations") {
|
||||||
graph->clear();
|
graph->clear();
|
||||||
values.clear();
|
values.clear();
|
||||||
|
|
||||||
std::vector<T> vA({ 0, .3333, -.2,
|
std::vector<T> vA({ 0, .3333, -.2,
|
||||||
-.3, 0, 4.5,
|
-.3, 0, 4.5,
|
||||||
5.2, -10, 101.45,
|
5.2, -10, 101.45,
|
||||||
-100.05, 0, 1.05e-5});
|
-100.05, 0, 1.05e-5});
|
||||||
|
|
||||||
auto a = graph->constant({2, 2, 3}, inits::fromVector(vA));
|
auto a = graph->constant({2, 2, 3}, inits::fromVector(vA));
|
||||||
|
|
||||||
// get top-k indices and values as a tuple
|
// get top-k indices and values as a tuple
|
||||||
auto rtopk1 = topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/true);
|
auto rtopk1 = topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/true);
|
||||||
auto rval1 = get<0>(rtopk1); // values from top-k
|
auto rval1 = get<0>(rtopk1); // values from top-k
|
||||||
@ -807,13 +818,13 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
|
|
||||||
auto ridx2 = get<1>(topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/false));
|
auto ridx2 = get<1>(topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/false));
|
||||||
auto gval2 = gather(a, -1, ridx2); // get the same values via gather and indices
|
auto gval2 = gather(a, -1, ridx2); // get the same values via gather and indices
|
||||||
|
|
||||||
auto ridx3 = get<1>(argmin(a, -1));
|
auto ridx3 = get<1>(argmin(a, -1));
|
||||||
auto ridx3_ = slice(ridx2, -1, 0); // slice and cast now support uint32_t/IndexType
|
auto ridx3_ = slice(ridx2, -1, 0); // slice and cast now support uint32_t/IndexType
|
||||||
|
|
||||||
// @TODO: add integer types to more operators
|
// @TODO: add integer types to more operators
|
||||||
auto eq3 = eq(cast(ridx3, floatType), cast(ridx3_, floatType));
|
auto eq3 = eq(cast(ridx3, floatType), cast(ridx3_, floatType));
|
||||||
|
|
||||||
auto rtopk4 = argmax(a, /*axis=*/-2); // axes other than -1 are currently implemented via inefficient transpose
|
auto rtopk4 = argmax(a, /*axis=*/-2); // axes other than -1 are currently implemented via inefficient transpose
|
||||||
auto rval4 = get<0>(rtopk4);
|
auto rval4 = get<0>(rtopk4);
|
||||||
auto ridx4 = get<1>(rtopk4);
|
auto ridx4 = get<1>(rtopk4);
|
||||||
@ -824,12 +835,12 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
CHECK(rval1 != gval1);
|
CHECK(rval1 != gval1);
|
||||||
CHECK(rval1->shape() == gval1->shape());
|
CHECK(rval1->shape() == gval1->shape());
|
||||||
CHECK(ridx1->shape() == gval1->shape());
|
CHECK(ridx1->shape() == gval1->shape());
|
||||||
|
|
||||||
std::vector<T> vval1 = { 0.3333, 0,
|
std::vector<T> vval1 = { 0.3333, 0,
|
||||||
4.5, 0,
|
4.5, 0,
|
||||||
101.45, 5.2,
|
101.45, 5.2,
|
||||||
1.05e-5, 0 };
|
1.05e-5, 0 };
|
||||||
|
|
||||||
std::vector<T> rvalues;
|
std::vector<T> rvalues;
|
||||||
std::vector<T> gvalues;
|
std::vector<T> gvalues;
|
||||||
rval1->val()->get(rvalues);
|
rval1->val()->get(rvalues);
|
||||||
@ -837,9 +848,9 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
CHECK( rvalues == gvalues );
|
CHECK( rvalues == gvalues );
|
||||||
CHECK( rvalues == vval1 );
|
CHECK( rvalues == vval1 );
|
||||||
|
|
||||||
std::vector<T> vval2 = { -0.2, 0,
|
std::vector<T> vval2 = { -0.2, 0,
|
||||||
-0.3, 0,
|
-0.3, 0,
|
||||||
-10.0, 5.2,
|
-10.0, 5.2,
|
||||||
-100.05, 0 };
|
-100.05, 0 };
|
||||||
gval2->val()->get(values);
|
gval2->val()->get(values);
|
||||||
CHECK( values == vval2 );
|
CHECK( values == vval2 );
|
||||||
@ -850,10 +861,10 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
|||||||
std::vector<IndexType> vidx4;
|
std::vector<IndexType> vidx4;
|
||||||
ridx4->val()->get(vidx4);
|
ridx4->val()->get(vidx4);
|
||||||
CHECK( ridx4->shape() == Shape({2, 1, 3}) );
|
CHECK( ridx4->shape() == Shape({2, 1, 3}) );
|
||||||
CHECK( vidx4 == std::vector<IndexType>({0, 0, 1,
|
CHECK( vidx4 == std::vector<IndexType>({0, 0, 1,
|
||||||
0, 1, 0}) );
|
0, 1, 0}) );
|
||||||
|
|
||||||
std::vector<T> vval4 = { 0, 0.3333, 4.5,
|
std::vector<T> vval4 = { 0, 0.3333, 4.5,
|
||||||
5.2, 0, 101.45 };
|
5.2, 0, 101.45 };
|
||||||
rval4->val()->get(values);
|
rval4->val()->get(values);
|
||||||
CHECK( values == vval4 );
|
CHECK( values == vval4 );
|
||||||
@ -886,7 +897,7 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]")
|
|||||||
|
|
||||||
TEST_CASE("Compare aggregate operator", "[graph]") {
|
TEST_CASE("Compare aggregate operator", "[graph]") {
|
||||||
auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); };
|
auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); };
|
||||||
|
|
||||||
Config::seed = 1234;
|
Config::seed = 1234;
|
||||||
|
|
||||||
std::vector<float> initc;
|
std::vector<float> initc;
|
||||||
@ -908,7 +919,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") {
|
|||||||
SECTION("initializing with zero (cpu)") {
|
SECTION("initializing with zero (cpu)") {
|
||||||
std::vector<float> values1;
|
std::vector<float> values1;
|
||||||
std::vector<float> values2;
|
std::vector<float> values2;
|
||||||
|
|
||||||
auto graph1 = New<ExpressionGraph>();
|
auto graph1 = New<ExpressionGraph>();
|
||||||
graph1->setDevice({0, DeviceType::cpu});
|
graph1->setDevice({0, DeviceType::cpu});
|
||||||
graph1->reserveWorkspaceMB(40);
|
graph1->reserveWorkspaceMB(40);
|
||||||
@ -916,7 +927,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") {
|
|||||||
auto graph2 = New<ExpressionGraph>();
|
auto graph2 = New<ExpressionGraph>();
|
||||||
graph2->setDevice({0, DeviceType::gpu});
|
graph2->setDevice({0, DeviceType::gpu});
|
||||||
graph2->reserveWorkspaceMB(40);
|
graph2->reserveWorkspaceMB(40);
|
||||||
|
|
||||||
auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
|
auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
|
||||||
auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
|
auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
|
||||||
auto prod1 = scalar_product(chl1, adj1, -1);
|
auto prod1 = scalar_product(chl1, adj1, -1);
|
||||||
@ -935,4 +946,4 @@ TEST_CASE("Compare aggregate operator", "[graph]") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -181,7 +181,7 @@ void tests(DeviceType type, Type floatType = Type::float32) {
|
|||||||
|
|
||||||
auto context = concatenate({rnnFw.construct(graph)->transduce(input, mask),
|
auto context = concatenate({rnnFw.construct(graph)->transduce(input, mask),
|
||||||
rnnBw.construct(graph)->transduce(input, mask)},
|
rnnBw.construct(graph)->transduce(input, mask)},
|
||||||
/*axis =*/ input->shape().size() - 1);
|
/*axis =*/ (int)input->shape().size() - 1);
|
||||||
|
|
||||||
if(second > 0) {
|
if(second > 0) {
|
||||||
// add more layers (unidirectional) by transducing the output of the
|
// add more layers (unidirectional) by transducing the output of the
|
||||||
|
@ -79,13 +79,14 @@ void OutputCollector::Write(long sourceId,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StringCollector::StringCollector() : maxId_(-1) {}
|
StringCollector::StringCollector(bool quiet /*=false*/) : maxId_(-1), quiet_(quiet) {}
|
||||||
|
|
||||||
void StringCollector::add(long sourceId,
|
void StringCollector::add(long sourceId,
|
||||||
const std::string& best1,
|
const std::string& best1,
|
||||||
const std::string& bestn) {
|
const std::string& bestn) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
LOG(info, "Best translation {} : {}", sourceId, best1);
|
if(!quiet_)
|
||||||
|
LOG(info, "Best translation {} : {}", sourceId, best1);
|
||||||
outputs_[sourceId] = std::make_pair(best1, bestn);
|
outputs_[sourceId] = std::make_pair(best1, bestn);
|
||||||
if(maxId_ <= sourceId)
|
if(maxId_ <= sourceId)
|
||||||
maxId_ = sourceId;
|
maxId_ = sourceId;
|
||||||
|
@ -74,14 +74,15 @@ protected:
|
|||||||
|
|
||||||
class StringCollector {
|
class StringCollector {
|
||||||
public:
|
public:
|
||||||
StringCollector();
|
StringCollector(bool quiet = false);
|
||||||
StringCollector(const StringCollector&) = delete;
|
StringCollector(const StringCollector&) = delete;
|
||||||
|
|
||||||
void add(long sourceId, const std::string& best1, const std::string& bestn);
|
void add(long sourceId, const std::string& best1, const std::string& bestn);
|
||||||
std::vector<std::string> collect(bool nbest);
|
std::vector<std::string> collect(bool nbest);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
long maxId_;
|
long maxId_; // the largest index of the translated source sentences
|
||||||
|
bool quiet_; // if true do not log best translations
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
|
|
||||||
typedef std::map<long, std::pair<std::string, std::string>> Outputs;
|
typedef std::map<long, std::pair<std::string, std::string>> Outputs;
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "data/batch_generator.h"
|
#include "data/batch_generator.h"
|
||||||
#include "data/corpus.h"
|
#include "data/corpus.h"
|
||||||
#include "data/shortlist.h"
|
#include "data/shortlist.h"
|
||||||
@ -245,10 +247,14 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string run(const std::string& input) override {
|
std::string run(const std::string& input) override {
|
||||||
auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_);
|
// split tab-separated input into fields if necessary
|
||||||
|
auto inputs = options_->get<bool>("tsv", false)
|
||||||
|
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
|
||||||
|
: std::vector<std::string>({input});
|
||||||
|
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
|
||||||
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
|
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
|
||||||
|
|
||||||
auto collector = New<StringCollector>();
|
auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
|
||||||
auto printer = New<OutputPrinter>(options_, trgVocab_);
|
auto printer = New<OutputPrinter>(options_, trgVocab_);
|
||||||
size_t batchId = 0;
|
size_t batchId = 0;
|
||||||
|
|
||||||
@ -258,7 +264,6 @@ public:
|
|||||||
ThreadPool threadPool_(numDevices_, numDevices_);
|
ThreadPool threadPool_(numDevices_, numDevices_);
|
||||||
|
|
||||||
for(auto batch : batchGenerator) {
|
for(auto batch : batchGenerator) {
|
||||||
|
|
||||||
auto task = [=](size_t id) {
|
auto task = [=](size_t id) {
|
||||||
thread_local Ptr<ExpressionGraph> graph;
|
thread_local Ptr<ExpressionGraph> graph;
|
||||||
thread_local std::vector<Ptr<Scorer>> scorers;
|
thread_local std::vector<Ptr<Scorer>> scorers;
|
||||||
@ -287,5 +292,30 @@ public:
|
|||||||
auto translations = collector->collect(options_->get<bool>("n-best"));
|
auto translations = collector->collect(options_->get<bool>("n-best"));
|
||||||
return utils::join(translations, "\n");
|
return utils::join(translations, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
|
||||||
|
// of sentences from source(s) and target sides, e.g.
|
||||||
|
// "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
|
||||||
|
std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
|
||||||
|
std::vector<std::string> outputFields(numFields);
|
||||||
|
|
||||||
|
std::string line;
|
||||||
|
std::vector<std::string> lineFields(numFields);
|
||||||
|
std::istringstream inputStream(inputText);
|
||||||
|
bool first = true;
|
||||||
|
while(std::getline(inputStream, line)) {
|
||||||
|
utils::splitTsv(line, lineFields, numFields);
|
||||||
|
for(size_t i = 0; i < numFields; ++i) {
|
||||||
|
if(!first)
|
||||||
|
outputFields[i] += "\n"; // join sentences with a new line sign
|
||||||
|
outputFields[i] += lineFields[i];
|
||||||
|
}
|
||||||
|
if(first)
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputFields;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
Loading…
Reference in New Issue
Block a user