mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 14402: Sync with public marian-dev master 1.9.31
This simply pulls the recent updates from the public repo.
This commit is contained in:
parent
9001477147
commit
080d75ad59
51
.github/workflows/build-macos-10.15-cpu.yml
vendored
Normal file
51
.github/workflows/build-macos-10.15-cpu.yml
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
name: macos-10.5-cpu
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: macos-10.15
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install dependencies
|
||||
run: brew install openblas protobuf
|
||||
|
||||
# Openblas location is exported explicitly because openblas is keg-only,
|
||||
# which means it was not symlinked into /usr/local/.
|
||||
# CMake cannot find BLAS on GitHub runners if Marian is being compiled
|
||||
# statically, hence USE_STATIC_LIBS=off
|
||||
- name: Configure CMake
|
||||
run: |
|
||||
export LDFLAGS="-L/usr/local/opt/openblas/lib"
|
||||
export CPPFLAGS="-I/usr/local/opt/openblas/include"
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \
|
||||
-DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on -DUSE_STATIC_LIBS=off
|
||||
|
||||
- name: Compile
|
||||
working-directory: build
|
||||
run: make -j2
|
||||
|
||||
- name: Run unit tests
|
||||
working-directory: build
|
||||
run: make test
|
||||
|
||||
- name: Print versions
|
||||
working-directory: build
|
||||
run: |
|
||||
./marian --version
|
||||
./marian-decoder --version
|
||||
./marian-scorer --version
|
||||
./spm_encode --version
|
||||
|
64
.github/workflows/build-ubuntu-18.04-cpu.yml
vendored
Normal file
64
.github/workflows/build-ubuntu-18.04-cpu.yml
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
name: ubuntu-18.04-cpu
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-18.04
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
# The following packages are already installed on GitHub-hosted runners: build-essential openssl libssl-dev
|
||||
- name: Install dependencies
|
||||
run: sudo apt-get install --no-install-recommends libgoogle-perftools-dev libprotobuf10 libprotobuf-dev protobuf-compiler
|
||||
|
||||
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
|
||||
- name: Install MKL
|
||||
run: |
|
||||
wget -qO- "https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB" | sudo apt-key add -
|
||||
sudo sh -c "echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list"
|
||||
sudo apt-get update -o Dir::Etc::sourcelist="/etc/apt/sources.list.d/intel-mkl.list"
|
||||
sudo apt-get install --no-install-recommends intel-mkl-64bit-2020.0-088
|
||||
|
||||
- name: Print Boost paths
|
||||
run: |
|
||||
ls $BOOST_ROOT_1_69_0
|
||||
ls $BOOST_ROOT_1_69_0/include
|
||||
ls $BOOST_ROOT_1_69_0/lib
|
||||
|
||||
# Boost is already installed on GitHub-hosted runners in a non-standard location
|
||||
# https://github.com/actions/virtual-environments/issues/687#issuecomment-610471671
|
||||
- name: Configure CMake
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \
|
||||
-DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on \
|
||||
-DBOOST_ROOT=$BOOST_ROOT_1_69_0 -DBOOST_INCLUDEDIR=$BOOST_ROOT_1_69_0/include -DBOOST_LIBRARYDIR=$BOOST_ROOT_1_69_0/lib \
|
||||
-DBoost_ARCHITECTURE=-x64
|
||||
|
||||
- name: Compile
|
||||
working-directory: build
|
||||
run: make -j2
|
||||
|
||||
- name: Run unit tests
|
||||
working-directory: build
|
||||
run: make test
|
||||
|
||||
- name: Print versions
|
||||
working-directory: build
|
||||
run: |
|
||||
./marian --version
|
||||
./marian-decoder --version
|
||||
./marian-scorer --version
|
||||
./spm_encode --version
|
||||
|
49
.github/workflows/build-windows-2019-cpu.yml
vendored
Normal file
49
.github/workflows/build-windows-2019-cpu.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
name: windows-2019-cpu
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: windows-2019
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Prepare vcpkg
|
||||
uses: lukka/run-vcpkg@v3
|
||||
with:
|
||||
vcpkgArguments: protobuf
|
||||
vcpkgGitCommitId: 6185aa76504a5025f36754324abf307cc776f3da
|
||||
vcpkgDirectory: ${{ github.workspace }}/vcpkg/
|
||||
vcpkgTriplet: x64-windows-static
|
||||
|
||||
# Note that we build with a simplified CMake settings JSON file
|
||||
- name: Run CMake
|
||||
uses: lukka/run-cmake@v2
|
||||
with:
|
||||
buildDirectory: ${{ github.workspace }}/build/
|
||||
cmakeAppendedArgs: -G Ninja
|
||||
cmakeListsOrSettingsJson: CMakeSettingsJson
|
||||
cmakeSettingsJsonPath: ${{ github.workspace }}/CMakeSettingsCI.json
|
||||
useVcpkgToolchainFile: true
|
||||
|
||||
- name: Run unit tests
|
||||
working-directory: build/Debug/
|
||||
run: ctest
|
||||
|
||||
- name: Print versions
|
||||
working-directory: build/Debug/
|
||||
run: |
|
||||
.\marian.exe --version
|
||||
.\marian-decoder.exe --version
|
||||
.\marian-scorer.exe --version
|
||||
.\spm_encode.exe --version
|
||||
|
10
CHANGELOG.md
10
CHANGELOG.md
@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Decoding multi-source models in marian-server with --tsv
|
||||
- GitHub workflows on Ubuntu, Windows, and MacOS
|
||||
- LSH indexing to replace short list
|
||||
- ONNX support for transformer models
|
||||
- Add topk operator like PyTorch's topk
|
||||
@ -19,11 +21,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
and translation with options --tsv and --tsv-fields n.
|
||||
|
||||
### Fixed
|
||||
- Fix compilation without BLAS installed
|
||||
- Providing a single value to vector-like options using the equals sign, e.g. --models=model.npz
|
||||
- Fix quiet-translation in marian-server
|
||||
- CMake-based compilation on Windows
|
||||
- Fix minor issues with compilation on MacOS
|
||||
- Fix warnings in Windows MSVC builds using CMake
|
||||
- Fix building server with Boost 1.72
|
||||
- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref
|
||||
- In concatenation make sure that we do not multiply 0 with nan (which results in nan)
|
||||
- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now
|
||||
absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance.
|
||||
- Fixed bug in finding .git/logs/HEAD when Marian is a submodule in another project.
|
||||
- Properly record cmake variables in the cmake build directory instead of the source tree.
|
||||
|
||||
### Changed
|
||||
- Move Simple-WebSocket-Server to submodule
|
||||
|
@ -51,6 +51,8 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
|
||||
execute_process(COMMAND git submodule update --init --recursive --no-fetch
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# Note that with CMake MSVC build, the option CMAKE_BUILD_TYPE is automatically derived from the key
|
||||
# 'configurationType' in CMakeSettings.json configurations
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
@ -62,10 +64,11 @@ if(MSVC)
|
||||
# These are used in src/CMakeLists.txt on a per-target basis
|
||||
list(APPEND ALL_WARNINGS /WX; /W4;)
|
||||
|
||||
# Disabled bogus warnings for CPU intrincics:
|
||||
# Disabled bogus warnings for CPU intrinsics:
|
||||
# C4310: cast truncates constant value
|
||||
# C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier
|
||||
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"")
|
||||
# C4702: unreachable code; note it is also disabled globally in the VS project file
|
||||
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\"")
|
||||
|
||||
# set(INTRINSICS "/arch:AVX")
|
||||
add_definitions(-DUSE_SSE2=1)
|
||||
@ -79,7 +82,9 @@ if(MSVC)
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
|
||||
|
||||
# ignores warning LNK4049: locally defined symbol free imported - this comes from zlib
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /ignore:4049")
|
||||
set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRT")
|
||||
set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRTD")
|
||||
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental")
|
||||
|
||||
find_library(SHLWAPI Shlwapi.lib)
|
||||
@ -218,7 +223,9 @@ if(COMPILE_CUDA)
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
# link statically to stdlib libraries
|
||||
if(NOT MSVC)
|
||||
set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++")
|
||||
endif()
|
||||
|
||||
# look for libraries that have .a suffix
|
||||
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})
|
||||
@ -250,12 +257,22 @@ if(CUDA_FOUND)
|
||||
endif(COMPILE_CUDA_SM70)
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
|
||||
# The cuLIBOS library does not seem to exist in Windows CUDA toolkit installs
|
||||
if(CUDA_culibos_LIBRARY)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_culibos_LIBRARY})
|
||||
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_culibos_LIBRARY})
|
||||
elseif(NOT WIN32)
|
||||
message(FATAL_ERROR "cuLIBOS library not found")
|
||||
endif()
|
||||
# CUDA 10.1 introduces cublasLt library that is required on static build
|
||||
if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1"))
|
||||
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
|
||||
if(NOT CUDA_cublasLt_LIBRARY)
|
||||
message(FATAL_ERROR "cuBLASLt library not found")
|
||||
endif()
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||
endif()
|
||||
@ -319,7 +336,7 @@ if(NOT MSVC)
|
||||
list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
|
||||
list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC})
|
||||
else()
|
||||
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; )
|
||||
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; -Xcompiler\ /MT$<$<CONFIG:Debug>:d>; )
|
||||
endif()
|
||||
|
||||
list(REMOVE_DUPLICATES CUDA_NVCC_FLAGS)
|
||||
@ -443,8 +460,11 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h.in
|
||||
# Generate build_info.cpp with CMake cache variables
|
||||
include(GetCacheVariables)
|
||||
|
||||
# make sure src/common/build_info.cpp has been removed
|
||||
execute_process(COMMAND rm ${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp
|
||||
OUTPUT_QUIET ERROR_QUIET)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp @ONLY)
|
||||
${CMAKE_CURRENT_BINARY_DIR}/src/common/build_info.cpp @ONLY)
|
||||
|
||||
# Compile source files
|
||||
include_directories(${marian_SOURCE_DIR}/src)
|
||||
|
52
CMakeSettingsCI.json
Normal file
52
CMakeSettingsCI.json
Normal file
@ -0,0 +1,52 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "Release",
|
||||
"inheritEnvironments": [ "msvc_x64" ],
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "-v",
|
||||
"ctestCommandArgs": "",
|
||||
"variables": [
|
||||
{ "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" },
|
||||
{ "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" },
|
||||
|
||||
{ "name": "COMPILE_CUDA:BOOL", "value": "FALSE" },
|
||||
{ "name": "COMPILE_CPU:BOOL", "value": "TRUE" },
|
||||
{ "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" },
|
||||
{ "name": "COMPILE_SERVER:BOOL", "value": "FALSE" },
|
||||
{ "name": "COMPILE_TESTS:BOOL", "value": "TRUE" },
|
||||
|
||||
{ "name": "USE_FBGEMM:BOOL", "value": "TRUE" },
|
||||
{ "name": "USE_MPI:BOOL", "value": "FALSE" },
|
||||
{ "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" },
|
||||
{ "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Debug",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "Debug",
|
||||
"inheritEnvironments": [ "msvc_x64" ],
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "-v",
|
||||
"ctestCommandArgs": "",
|
||||
"variables": [
|
||||
{ "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" },
|
||||
{ "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" },
|
||||
|
||||
{ "name": "COMPILE_CUDA:BOOL", "value": "FALSE" },
|
||||
{ "name": "COMPILE_CPU:BOOL", "value": "TRUE" },
|
||||
{ "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" },
|
||||
{ "name": "COMPILE_SERVER:BOOL", "value": "FALSE" },
|
||||
{ "name": "COMPILE_TESTS:BOOL", "value": "TRUE" },
|
||||
|
||||
{ "name": "USE_FBGEMM:BOOL", "value": "TRUE" },
|
||||
{ "name": "USE_MPI:BOOL", "value": "FALSE" },
|
||||
{ "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" },
|
||||
{ "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" }
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -89,10 +89,17 @@ find_library(MKL_CORE_LIBRARY
|
||||
NO_DEFAULT_PATH)
|
||||
|
||||
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
|
||||
set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY})
|
||||
|
||||
if(NOT WIN32 AND NOT APPLE)
|
||||
# Added -Wl block to avoid circular dependencies.
|
||||
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
|
||||
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
|
||||
set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group)
|
||||
set(MKL_LIBRARIES -Wl,--start-group ${MKL_LIBRARIES} -Wl,--end-group)
|
||||
elseif(APPLE)
|
||||
# MacOS does not support --start-group and --end-group
|
||||
set(MKL_LIBRARIES -Wl,${MKL_LIBRARIES} -Wl,)
|
||||
endif()
|
||||
|
||||
# message("1 ${MKL_INCLUDE_DIR}")
|
||||
# message("2 ${MKL_INTERFACE_LIBRARY}")
|
||||
|
@ -34,17 +34,18 @@ foreach(_variableName ${_variableNames})
|
||||
(NOT "${_variableType}" STREQUAL "INTERNAL") AND
|
||||
(NOT "${_variableValue}" STREQUAL "") )
|
||||
|
||||
|
||||
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValue}\\n\"\n")
|
||||
string(REPLACE "\"" " " _variableValueEscapedQuotes ${_variableValue})
|
||||
string(REPLACE "\\" "/" _variableValueEscaped ${_variableValueEscapedQuotes})
|
||||
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValueEscaped}\\n\"\n")
|
||||
|
||||
# Get the variable's advanced flag
|
||||
get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET)
|
||||
if(NOT _isAdvanced)
|
||||
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValue}\\n\"\n")
|
||||
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValueEscaped}\\n\"\n")
|
||||
endif()
|
||||
|
||||
# Print variables for debugging
|
||||
#message(STATUS "${_variableName}=${${_variableName}}")
|
||||
#message(STATUS "${_variableName}=${_variableValueEscaped}")
|
||||
#message(STATUS " Type=${_variableType}")
|
||||
#message(STATUS " Advanced=${_isAdvanced}")
|
||||
endif()
|
||||
|
@ -18,7 +18,7 @@
|
||||
if(PROJECT_VERSION_FILE)
|
||||
file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING)
|
||||
else()
|
||||
file(STRINGS ${CMAKE_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING)
|
||||
file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING)
|
||||
endif()
|
||||
|
||||
# Get current commit SHA from git
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 0f8cabf13ec362d50544d33490024e00c3a763be
|
||||
Subproject commit 7b8f6ee5b6ff7779fd993df7f77adf1e2d9adbe5
|
3
src/3rd_party/CLI/App.hpp
vendored
3
src/3rd_party/CLI/App.hpp
vendored
@ -1595,7 +1595,8 @@ class App {
|
||||
if(num < 0) {
|
||||
// RG: We need to keep track if the vector option is empty and handle this separately as
|
||||
// otherwise the parser will mark the command-line option as not set
|
||||
bool emptyVectorArgs = true;
|
||||
// RG: An option value after '=' was already collected
|
||||
bool emptyVectorArgs = (collected <= 0);
|
||||
while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) {
|
||||
if(collected >= -num) {
|
||||
// We could break here for allow extras, but we don't
|
||||
|
11
src/3rd_party/CMakeLists.txt
vendored
11
src/3rd_party/CMakeLists.txt
vendored
@ -19,6 +19,9 @@ if(USE_FBGEMM)
|
||||
# only locally disabled for the 3rd_party folder
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused")
|
||||
else()
|
||||
# Do not compile cpuinfo executables due to a linker error, and they are not needed
|
||||
set(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "Build command-line tools")
|
||||
endif()
|
||||
|
||||
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
|
||||
@ -45,7 +48,7 @@ if(USE_SENTENCEPIECE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available." FORCE)
|
||||
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available.")
|
||||
|
||||
if(USE_STATIC_LIBS)
|
||||
message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \
|
||||
@ -55,8 +58,8 @@ if(USE_SENTENCEPIECE)
|
||||
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
||||
set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
||||
else(USE_STATIC_LIBS)
|
||||
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
|
||||
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC." FORCE)
|
||||
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries.")
|
||||
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC.")
|
||||
endif(USE_STATIC_LIBS)
|
||||
|
||||
add_subdirectory(./sentencepiece)
|
||||
@ -66,7 +69,7 @@ if(USE_SENTENCEPIECE)
|
||||
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
foreach(t sentencepiece sentencepiece_train sentencepiece_train-static
|
||||
foreach(t sentencepiece-static sentencepiece_train-static
|
||||
spm_decode spm_encode spm_export_vocab spm_normalize spm_train)
|
||||
set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused")
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
|
4
src/3rd_party/cnpy/cnpy.h
vendored
4
src/3rd_party/cnpy/cnpy.h
vendored
@ -18,6 +18,10 @@
|
||||
#include<map>
|
||||
#include <memory>
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace cnpy {
|
||||
|
||||
struct NpyArray {
|
||||
|
3
src/3rd_party/faiss/VectorTransform.h
vendored
3
src/3rd_party/faiss/VectorTransform.h
vendored
@ -18,6 +18,9 @@
|
||||
#include <stdint.h>
|
||||
|
||||
#include <faiss/Index.h>
|
||||
#ifdef __APPLE__
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace faiss {
|
||||
|
3
src/3rd_party/phf/phf.cc
vendored
3
src/3rd_party/phf/phf.cc
vendored
@ -679,7 +679,8 @@ PHF_PUBLIC void PHF::compact(struct phf *phf) {
|
||||
}
|
||||
|
||||
/* simply keep old array if realloc fails */
|
||||
if ((tmp = realloc(phf->g, phf->r * size)))
|
||||
tmp = realloc(phf->g, phf->r * size);
|
||||
if (tmp != 0)
|
||||
phf->g = static_cast<uint32_t *>(tmp);
|
||||
} /* PHF::compact() */
|
||||
|
||||
|
4
src/3rd_party/zlib/CMakeLists.txt
vendored
4
src/3rd_party/zlib/CMakeLists.txt
vendored
@ -2,11 +2,11 @@
|
||||
file(GLOB ZLIB_SRC *.c)
|
||||
file(GLOB ZLIB_INC *.h)
|
||||
|
||||
# add sources of the wrapper as a "SQLiteCpp" static library
|
||||
# add sources of the wrapper as a "zlib" static library
|
||||
add_library(zlib OBJECT ${ZLIB_SRC} ${ZLIB_INC})
|
||||
|
||||
if(MSVC)
|
||||
target_compile_options(zlib PUBLIC /wd"4996" /wd"4267")
|
||||
target_compile_options(zlib PUBLIC /wd4996 /wd4267)
|
||||
else()
|
||||
target_compile_options(zlib PUBLIC -Wno-implicit-function-declaration)
|
||||
endif()
|
||||
|
@ -20,7 +20,7 @@ add_library(marian STATIC
|
||||
common/config_validator.cpp
|
||||
common/options.cpp
|
||||
common/binary.cpp
|
||||
common/build_info.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/common/build_info.cpp
|
||||
common/io.cpp
|
||||
common/filesystem.cpp
|
||||
common/file_stream.cpp
|
||||
@ -117,21 +117,22 @@ target_compile_options(marian PUBLIC ${ALL_WARNINGS})
|
||||
# Git updates .git/logs/HEAD file whenever you pull or commit something.
|
||||
|
||||
# If Marian is checked out as a submodule in another repository,
|
||||
# there's no .git directory in ${CMAKE_SOURCE_DIR}. Instead .git is a
|
||||
# file that specifies the relative path from ${CMAKE_SOURCE_DIR} to
|
||||
# ./git/modules/<MARIAN_ROOT_DIR> in the root of the repository that
|
||||
# contains Marian as a submodule. We set MARIAN_GIT_DIR to the appropriate
|
||||
# path, depending on whether ${CMAKE_SOURCE_DIR}/.git is a directory or file.
|
||||
if(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) # not a submodule
|
||||
set(MARIAN_GIT_DIR ${CMAKE_SOURCE_DIR}/.git)
|
||||
else(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
|
||||
file(READ ${CMAKE_SOURCE_DIR}/.git MARIAN_GIT_DIR)
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is not a directory but a file
|
||||
# that specifies the relative path from ${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
# to ./git/modules/<MARIAN_ROOT_DIR> in the root of the check_out of
|
||||
# the project that contains Marian as a submodule.
|
||||
#
|
||||
# We set MARIAN_GIT_DIR to the appropriate path, depending on whether
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is a directory or file.
|
||||
set(MARIAN_GIT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../.git)
|
||||
if(NOT IS_DIRECTORY ${MARIAN_GIT_DIR}) # i.e., it's a submodule
|
||||
file(READ ${MARIAN_GIT_DIR} MARIAN_GIT_DIR)
|
||||
string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR})
|
||||
get_filename_component(MARIAN_GIT_DIR "${CMAKE_SOURCE_DIR}/${MARIAN_GIT_DIR}" ABSOLUTE)
|
||||
endif(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
|
||||
get_filename_component(MARIAN_GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${MARIAN_GIT_DIR}" ABSOLUTE)
|
||||
endif(NOT IS_DIRECTORY ${MARIAN_GIT_DIR})
|
||||
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
|
||||
DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD
|
||||
VERBATIM
|
||||
@ -220,7 +221,7 @@ if(COMPILE_SERVER)
|
||||
add_executable(marian_server command/marian_server.cpp)
|
||||
set_target_properties(marian_server PROPERTIES OUTPUT_NAME marian-server)
|
||||
if(MSVC)
|
||||
# Disable warnings from the SimpleWebSocketServer library
|
||||
# Disable warnings from the SimpleWebSocketServer library needed for compilation of marian-server
|
||||
target_compile_options(marian_server PUBLIC ${ALL_WARNINGS} /wd4267 /wd4244 /wd4456 /wd4458)
|
||||
else(MSVC)
|
||||
# -Wno-suggest-override disables warnings from Boost 1.69+
|
||||
|
@ -14,6 +14,7 @@ int main(int argc, char **argv) {
|
||||
// Initialize translation task
|
||||
auto options = parseOptions(argc, argv, cli::mode::server, true);
|
||||
auto task = New<TranslateService<BeamSearch>>(options);
|
||||
auto quiet = options->get<bool>("quiet-translation");
|
||||
|
||||
// Initialize web server
|
||||
WSServer server;
|
||||
@ -21,7 +22,7 @@ int main(int argc, char **argv) {
|
||||
|
||||
auto &translate = server.endpoint["^/translate/?$"];
|
||||
|
||||
translate.on_message = [&task](Ptr<WSServer::Connection> connection,
|
||||
translate.on_message = [&task, quiet](Ptr<WSServer::Connection> connection,
|
||||
Ptr<WSServer::InMessage> message) {
|
||||
// Get input text
|
||||
auto inputText = message->string();
|
||||
@ -30,8 +31,8 @@ int main(int argc, char **argv) {
|
||||
// Translate
|
||||
timer::Timer timer;
|
||||
auto outputText = task->run(inputText);
|
||||
LOG(info, "Best translation: {}", outputText);
|
||||
*sendStream << outputText << std::endl;
|
||||
if(!quiet)
|
||||
LOG(info, "Translation took: {:.5f}s", timer.elapsed());
|
||||
|
||||
// Send translation back
|
||||
|
@ -14,7 +14,9 @@ namespace filesystem {
|
||||
// Pretend that Windows knows no named pipes. It does, by the way, but
|
||||
// they seem to be different from pipes on Unix / Linux. See
|
||||
// https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
|
||||
bool is_fifo(char const*) { return false; }
|
||||
bool is_fifo(char const* /*path*/) {
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
bool is_fifo(char const* path) {
|
||||
struct stat buf;
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
|
||||
namespace marian {
|
||||
namespace timer {
|
||||
|
@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const {
|
||||
TextInput::TextInput(std::vector<std::string> inputs,
|
||||
std::vector<Ptr<Vocab>> vocabs,
|
||||
Ptr<Options> options)
|
||||
: DatasetBase(inputs, options), vocabs_(vocabs) {
|
||||
// note: inputs are automatically stored in the inherited variable named paths_, but these are
|
||||
: DatasetBase(inputs, options),
|
||||
vocabs_(vocabs),
|
||||
maxLength_(options_->get<size_t>("max-length")),
|
||||
maxLengthCrop_(options_->get<bool>("max-length-crop")) {
|
||||
// Note: inputs are automatically stored in the inherited variable named paths_, but these are
|
||||
// texts not paths!
|
||||
for(const auto& text : paths_)
|
||||
files_.emplace_back(new std::istringstream(text));
|
||||
@ -42,6 +45,10 @@ SentenceTuple TextInput::next() {
|
||||
std::string line;
|
||||
if(io::getline(*files_[i], line)) {
|
||||
Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
|
||||
if(this->maxLengthCrop_ && words.size() > this->maxLength_) {
|
||||
words.resize(maxLength_);
|
||||
words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels
|
||||
}
|
||||
if(words.empty())
|
||||
words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
|
||||
tup.push_back(words);
|
||||
|
@ -33,6 +33,9 @@ private:
|
||||
|
||||
size_t pos_{0};
|
||||
|
||||
size_t maxLength_{0};
|
||||
bool maxLengthCrop_{false};
|
||||
|
||||
public:
|
||||
typedef SentenceTuple Sample;
|
||||
|
||||
|
@ -267,8 +267,8 @@ namespace marian {
|
||||
Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
lazyConstruct(input->shape()[-1]);
|
||||
|
||||
auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
#if BLAS_FOUND
|
||||
auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
if(lsh_) {
|
||||
ABORT_IF( transA, "Transposed query not supported for LSH");
|
||||
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
|
||||
@ -276,10 +276,12 @@ namespace marian {
|
||||
} else {
|
||||
return affine(x, W, b, transA, transB);
|
||||
}
|
||||
#else
|
||||
return affine(x, W, b, transA, transB);
|
||||
#endif
|
||||
};
|
||||
#else
|
||||
auto affineOrLSH = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
return affine(x, W, b, transA, transB);
|
||||
};
|
||||
#endif
|
||||
|
||||
if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
|
||||
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
|
||||
|
@ -32,8 +32,6 @@ Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
|
||||
return New<MeanMultiRationalLoss>();
|
||||
else
|
||||
ABORT("Unknown multi-loss-type {}", multiLossType);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -74,7 +74,6 @@ std::string ScoreCollector::getAlignment(const data::SoftAlignment& align) {
|
||||
} else {
|
||||
ABORT("Unrecognized word alignment type");
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
ScoreCollectorNBest::ScoreCollectorNBest(const Ptr<Options>& options)
|
||||
|
@ -19,7 +19,6 @@ struct QuantizeNodeOp : public UnaryNodeOp {
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
ABORT("Only used for inference");
|
||||
return {NodeOp(0)};
|
||||
}
|
||||
|
||||
const std::string type() override { return "quantizeInt16"; }
|
||||
@ -54,7 +53,6 @@ public:
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
ABORT("Only used for inference");
|
||||
return {NodeOp(0)};
|
||||
}
|
||||
|
||||
const std::string type() override { return "dotInt16"; }
|
||||
@ -92,7 +90,6 @@ public:
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
ABORT("Only used for inference");
|
||||
return {NodeOp(0)};
|
||||
}
|
||||
|
||||
const std::string type() override { return "affineInt16"; }
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Unit tests
|
||||
add_subdirectory(units)
|
||||
|
||||
|
||||
if(NOT MSVC)
|
||||
# Testing apps
|
||||
set(APP_TESTS
|
||||
logger
|
||||
@ -23,3 +23,4 @@ foreach(test ${APP_TESTS})
|
||||
|
||||
set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
endforeach(test)
|
||||
endif(NOT MSVC)
|
@ -17,5 +17,10 @@ foreach(test ${UNIT_TESTS})
|
||||
target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
|
||||
endif(CUDA_FOUND)
|
||||
|
||||
if(MSVC)
|
||||
# Disable C4305: truncation from 'double' to '_Ty'
|
||||
target_compile_options("run_${test}" PUBLIC /wd4305)
|
||||
endif(MSVC)
|
||||
|
||||
add_test(NAME ${test} COMMAND "run_${test}")
|
||||
endforeach(test)
|
||||
|
@ -5,8 +5,6 @@
|
||||
using namespace marian;
|
||||
|
||||
TEST_CASE("FastOpt can be constructed from a YAML node", "[fastopt]") {
|
||||
YAML::Node node;
|
||||
|
||||
SECTION("from a simple node") {
|
||||
YAML::Node node = YAML::Load("{foo: bar}");
|
||||
const FastOpt o(node);
|
||||
|
@ -40,12 +40,12 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
std::vector<T> vA({1, -2, 3, -4});
|
||||
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
|
||||
|
||||
auto compare = [&](Expr res, std::function<float(float)> f, bool exactMatch) -> bool {
|
||||
auto compare = [&](Expr res, std::function<float(float)> f) -> bool {
|
||||
if (res->shape() != Shape({ 2, 2, 1 }))
|
||||
return false;
|
||||
res->val()->get(values);
|
||||
std::vector<float> ref{f(vA[0]), f(vA[1]), f(vA[2]), f(vA[3])};
|
||||
return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
|
||||
return std::equal(values.begin(), values.end(), ref.begin(), floatEqual);
|
||||
};
|
||||
|
||||
// @TODO: add all operators and scalar variants here for completeness
|
||||
@ -58,12 +58,12 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
|
||||
graph->forward();
|
||||
|
||||
CHECK(compare(rsmult, [](float a) {return 2.f * a;}, true));
|
||||
CHECK(compare(rabs, [](float a) {return std::abs(a);}, true));
|
||||
CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);}, true));
|
||||
CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);}, true));
|
||||
CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);}, true));
|
||||
CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);}, true));
|
||||
CHECK(compare(rsmult, [](float a) {return 2.f * a;}));
|
||||
CHECK(compare(rabs, [](float a) {return std::abs(a);}));
|
||||
CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);}));
|
||||
CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);}));
|
||||
CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);}));
|
||||
CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);}));
|
||||
}
|
||||
|
||||
SECTION("elementwise binary operators with broadcasting") {
|
||||
@ -76,12 +76,23 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
|
||||
auto b = graph->constant({2, 1}, inits::fromVector(vB));
|
||||
|
||||
auto compare = [&](Expr res, std::function<float(float,float)> f, bool exactMatch) -> bool {
|
||||
// Two lambdas below differ in the use of floatEqual or floatApprox and
|
||||
// are not merged because MSVC compiler returns C2446: no conversion from
|
||||
// lambda_x to lambda_y
|
||||
auto compare = [&](Expr res, std::function<float(float,float)> f) -> bool {
|
||||
if (res->shape() != Shape({ 2, 2, 1 }))
|
||||
return false;
|
||||
res->val()->get(values);
|
||||
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
|
||||
return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
|
||||
return std::equal(values.begin(), values.end(), ref.begin(), floatEqual);
|
||||
};
|
||||
|
||||
auto compareApprox = [&](Expr res, std::function<float(float, float)> f) -> bool {
|
||||
if(res->shape() != Shape({2, 2, 1}))
|
||||
return false;
|
||||
res->val()->get(values);
|
||||
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
|
||||
return std::equal(values.begin(), values.end(), ref.begin(), floatApprox);
|
||||
};
|
||||
|
||||
auto rplus = a + b;
|
||||
@ -100,19 +111,19 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
|
||||
graph->forward();
|
||||
|
||||
CHECK(compare(rplus, [](float a, float b) {return a + b;}, true));
|
||||
CHECK(compare(rminus, [](float a, float b) {return a - b;}, true));
|
||||
CHECK(compare(rmult, [](float a, float b) {return a * b;}, true));
|
||||
CHECK(compare(rdiv, [](float a, float b) {return a / b;}, false));
|
||||
CHECK(compare(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}, false));
|
||||
CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}, true));
|
||||
CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}, true));
|
||||
CHECK(compare(rlt, [](float a, float b) {return a < b;}, true));
|
||||
CHECK(compare(req, [](float a, float b) {return a == b;}, true));
|
||||
CHECK(compare(rgt, [](float a, float b) {return a > b;}, true));
|
||||
CHECK(compare(rge, [](float a, float b) {return a >= b;}, true));
|
||||
CHECK(compare(rne, [](float a, float b) {return a != b;}, true));
|
||||
CHECK(compare(rle, [](float a, float b) {return a <= b;}, true));
|
||||
CHECK(compare(rplus, [](float a, float b) {return a + b;}));
|
||||
CHECK(compare(rminus, [](float a, float b) {return a - b;}));
|
||||
CHECK(compare(rmult, [](float a, float b) {return a * b;}));
|
||||
CHECK(compareApprox(rdiv, [](float a, float b) {return a / b;}));
|
||||
CHECK(compareApprox(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}));
|
||||
CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}));
|
||||
CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}));
|
||||
CHECK(compare(rlt, [](float a, float b) {return a < b;}));
|
||||
CHECK(compare(req, [](float a, float b) {return a == b;}));
|
||||
CHECK(compare(rgt, [](float a, float b) {return a > b;}));
|
||||
CHECK(compare(rge, [](float a, float b) {return a >= b;}));
|
||||
CHECK(compare(rne, [](float a, float b) {return a != b;}));
|
||||
CHECK(compare(rle, [](float a, float b) {return a <= b;}));
|
||||
}
|
||||
|
||||
SECTION("transposing and reshaping") {
|
||||
@ -399,8 +410,8 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
std::vector<float> SV; // create CSR version of S
|
||||
std::vector<IndexType> SI, SO;
|
||||
SO.push_back((IndexType)SI.size());
|
||||
for (IndexType i = 0; i < S->shape()[0]; i++) {
|
||||
for (IndexType j = 0; j < S->shape()[1]; j++) {
|
||||
for (IndexType i = 0; i < (IndexType)S->shape()[0]; i++) {
|
||||
for (IndexType j = 0; j < (IndexType)S->shape()[1]; j++) {
|
||||
auto k = 4 * i + j;
|
||||
if (vS[k] != 0) {
|
||||
SV.push_back(vS[k]);
|
||||
@ -477,7 +488,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
aff1->val()->get(values);
|
||||
CHECK(values == vAff);
|
||||
|
||||
std::vector<T> values2;
|
||||
values2.clear();
|
||||
CHECK(aff2->shape() == aff1->shape());
|
||||
aff2->val()->get(values2);
|
||||
CHECK(values2 == values);
|
||||
@ -653,7 +664,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
SECTION("relation of rows and columns selection using transpose") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
std::vector<T> values2;
|
||||
values2.clear();
|
||||
|
||||
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
|
||||
std::vector<IndexType> idx({0, 1});
|
||||
@ -763,7 +774,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
SECTION("rows/cols as gather operations") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
std::vector<T> values2;
|
||||
values2.clear();
|
||||
|
||||
|
||||
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
|
||||
|
@ -181,7 +181,7 @@ void tests(DeviceType type, Type floatType = Type::float32) {
|
||||
|
||||
auto context = concatenate({rnnFw.construct(graph)->transduce(input, mask),
|
||||
rnnBw.construct(graph)->transduce(input, mask)},
|
||||
/*axis =*/ input->shape().size() - 1);
|
||||
/*axis =*/ (int)input->shape().size() - 1);
|
||||
|
||||
if(second > 0) {
|
||||
// add more layers (unidirectional) by transducing the output of the
|
||||
|
@ -79,12 +79,13 @@ void OutputCollector::Write(long sourceId,
|
||||
}
|
||||
}
|
||||
|
||||
StringCollector::StringCollector() : maxId_(-1) {}
|
||||
StringCollector::StringCollector(bool quiet /*=false*/) : maxId_(-1), quiet_(quiet) {}
|
||||
|
||||
void StringCollector::add(long sourceId,
|
||||
const std::string& best1,
|
||||
const std::string& bestn) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if(!quiet_)
|
||||
LOG(info, "Best translation {} : {}", sourceId, best1);
|
||||
outputs_[sourceId] = std::make_pair(best1, bestn);
|
||||
if(maxId_ <= sourceId)
|
||||
|
@ -74,14 +74,15 @@ protected:
|
||||
|
||||
class StringCollector {
|
||||
public:
|
||||
StringCollector();
|
||||
StringCollector(bool quiet = false);
|
||||
StringCollector(const StringCollector&) = delete;
|
||||
|
||||
void add(long sourceId, const std::string& best1, const std::string& bestn);
|
||||
std::vector<std::string> collect(bool nbest);
|
||||
|
||||
protected:
|
||||
long maxId_;
|
||||
long maxId_; // the largest index of the translated source sentences
|
||||
bool quiet_; // if true do not log best translations
|
||||
std::mutex mutex_;
|
||||
|
||||
typedef std::map<long, std::pair<std::string, std::string>> Outputs;
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "data/batch_generator.h"
|
||||
#include "data/corpus.h"
|
||||
#include "data/shortlist.h"
|
||||
@ -245,10 +247,14 @@ public:
|
||||
}
|
||||
|
||||
std::string run(const std::string& input) override {
|
||||
auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_);
|
||||
// split tab-separated input into fields if necessary
|
||||
auto inputs = options_->get<bool>("tsv", false)
|
||||
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
|
||||
: std::vector<std::string>({input});
|
||||
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
|
||||
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
|
||||
|
||||
auto collector = New<StringCollector>();
|
||||
auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
|
||||
auto printer = New<OutputPrinter>(options_, trgVocab_);
|
||||
size_t batchId = 0;
|
||||
|
||||
@ -258,7 +264,6 @@ public:
|
||||
ThreadPool threadPool_(numDevices_, numDevices_);
|
||||
|
||||
for(auto batch : batchGenerator) {
|
||||
|
||||
auto task = [=](size_t id) {
|
||||
thread_local Ptr<ExpressionGraph> graph;
|
||||
thread_local std::vector<Ptr<Scorer>> scorers;
|
||||
@ -287,5 +292,30 @@ public:
|
||||
auto translations = collector->collect(options_->get<bool>("n-best"));
|
||||
return utils::join(translations, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
// Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
|
||||
// of sentences from source(s) and target sides, e.g.
|
||||
// "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
|
||||
std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
|
||||
std::vector<std::string> outputFields(numFields);
|
||||
|
||||
std::string line;
|
||||
std::vector<std::string> lineFields(numFields);
|
||||
std::istringstream inputStream(inputText);
|
||||
bool first = true;
|
||||
while(std::getline(inputStream, line)) {
|
||||
utils::splitTsv(line, lineFields, numFields);
|
||||
for(size_t i = 0; i < numFields; ++i) {
|
||||
if(!first)
|
||||
outputFields[i] += "\n"; // join sentences with a new line sign
|
||||
outputFields[i] += lineFields[i];
|
||||
}
|
||||
if(first)
|
||||
first = false;
|
||||
}
|
||||
|
||||
return outputFields;
|
||||
}
|
||||
};
|
||||
} // namespace marian
|
||||
|
Loading…
Reference in New Issue
Block a user