Merged PR 14402: Sync with public marian-dev master 1.9.31

This simply pulls the recent updates from the public repo.
This commit is contained in:
Roman Grundkiewicz 2020-07-28 22:19:40 +00:00 committed by Martin Junczys-Dowmunt
parent 9001477147
commit 080d75ad59
36 changed files with 471 additions and 147 deletions

View File

@ -0,0 +1,51 @@
name: macos-10.5-cpu
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: macos-10.15
steps:
- name: Checkout
uses: actions/checkout@v2
with:
submodules: recursive
- name: Install dependencies
run: brew install openblas protobuf
# Openblas location is exported explicitly because openblas is keg-only,
# which means it was not symlinked into /usr/local/.
# CMake cannot find BLAS on GitHub runners if Marian is being compiled
# statically, hence USE_STATIC_LIBS=off
- name: Configure CMake
run: |
export LDFLAGS="-L/usr/local/opt/openblas/lib"
export CPPFLAGS="-I/usr/local/opt/openblas/include"
mkdir -p build
cd build
cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \
-DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on -DUSE_STATIC_LIBS=off
- name: Compile
working-directory: build
run: make -j2
- name: Run unit tests
working-directory: build
run: make test
- name: Print versions
working-directory: build
run: |
./marian --version
./marian-decoder --version
./marian-scorer --version
./spm_encode --version

View File

@ -0,0 +1,64 @@
name: ubuntu-18.04-cpu
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: ubuntu-18.04
steps:
- name: Checkout
uses: actions/checkout@v2
with:
submodules: recursive
# The following packages are already installed on GitHub-hosted runners: build-essential openssl libssl-dev
- name: Install dependencies
run: sudo apt-get install --no-install-recommends libgoogle-perftools-dev libprotobuf10 libprotobuf-dev protobuf-compiler
# https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html
- name: Install MKL
run: |
wget -qO- "https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB" | sudo apt-key add -
sudo sh -c "echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list"
sudo apt-get update -o Dir::Etc::sourcelist="/etc/apt/sources.list.d/intel-mkl.list"
sudo apt-get install --no-install-recommends intel-mkl-64bit-2020.0-088
- name: Print Boost paths
run: |
ls $BOOST_ROOT_1_69_0
ls $BOOST_ROOT_1_69_0/include
ls $BOOST_ROOT_1_69_0/lib
# Boost is already installed on GitHub-hosted runners in a non-standard location
# https://github.com/actions/virtual-environments/issues/687#issuecomment-610471671
- name: Configure CMake
run: |
mkdir -p build
cd build
cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \
-DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on \
-DBOOST_ROOT=$BOOST_ROOT_1_69_0 -DBOOST_INCLUDEDIR=$BOOST_ROOT_1_69_0/include -DBOOST_LIBRARYDIR=$BOOST_ROOT_1_69_0/lib \
-DBoost_ARCHITECTURE=-x64
- name: Compile
working-directory: build
run: make -j2
- name: Run unit tests
working-directory: build
run: make test
- name: Print versions
working-directory: build
run: |
./marian --version
./marian-decoder --version
./marian-scorer --version
./spm_encode --version

View File

@ -0,0 +1,49 @@
name: windows-2019-cpu
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
runs-on: windows-2019
steps:
- name: Checkout
uses: actions/checkout@v2
with:
submodules: recursive
- name: Prepare vcpkg
uses: lukka/run-vcpkg@v3
with:
vcpkgArguments: protobuf
vcpkgGitCommitId: 6185aa76504a5025f36754324abf307cc776f3da
vcpkgDirectory: ${{ github.workspace }}/vcpkg/
vcpkgTriplet: x64-windows-static
# Note that we build with a simplified CMake settings JSON file
- name: Run CMake
uses: lukka/run-cmake@v2
with:
buildDirectory: ${{ github.workspace }}/build/
cmakeAppendedArgs: -G Ninja
cmakeListsOrSettingsJson: CMakeSettingsJson
cmakeSettingsJsonPath: ${{ github.workspace }}/CMakeSettingsCI.json
useVcpkgToolchainFile: true
- name: Run unit tests
working-directory: build/Debug/
run: ctest
- name: Print versions
working-directory: build/Debug/
run: |
.\marian.exe --version
.\marian-decoder.exe --version
.\marian-scorer.exe --version
.\spm_encode.exe --version

View File

@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Decoding multi-source models in marian-server with --tsv
- GitHub workflows on Ubuntu, Windows, and MacOS
- LSH indexing to replace short list
- ONNX support for transformer models
- Add topk operator like PyTorch's topk
@ -19,11 +21,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
and translation with options --tsv and --tsv-fields n.
### Fixed
- Fix compilation without BLAS installed
- Providing a single value to vector-like options using the equals sign, e.g. --models=model.npz
- Fix quiet-translation in marian-server
- CMake-based compilation on Windows
- Fix minor issues with compilation on MacOS
- Fix warnings in Windows MSVC builds using CMake
- Fix building server with Boost 1.72
- Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref
- In concatenation make sure that we do not multiply 0 with nan (which results in nan)
- Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now
absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance.
- Fixed bug in finding .git/logs/HEAD when Marian is a submodule in another project.
- Properly record cmake variables in the cmake build directory instead of the source tree.
### Changed
- Move Simple-WebSocket-Server to submodule

View File

@ -51,6 +51,8 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
execute_process(COMMAND git submodule update --init --recursive --no-fetch
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
# Note that with CMake MSVC build, the option CMAKE_BUILD_TYPE is automatically derived from the key
# 'configurationType' in CMakeSettings.json configurations
if(NOT CMAKE_BUILD_TYPE)
message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
set(CMAKE_BUILD_TYPE "Release")
@ -62,10 +64,11 @@ if(MSVC)
# These are used in src/CMakeLists.txt on a per-target basis
list(APPEND ALL_WARNINGS /WX; /W4;)
# Disabled bogus warnings for CPU intrincics:
# Disabled bogus warnings for CPU intrinsics:
# C4310: cast truncates constant value
# C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"")
# C4702: unreachable code; note it is also disabled globally in the VS project file
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\"")
# set(INTRINSICS "/arch:AVX")
add_definitions(-DUSE_SSE2=1)
@ -79,7 +82,9 @@ if(MSVC)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
# ignores warning LNK4049: locally defined symbol free imported - this comes from zlib
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /ignore:4049")
set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRT")
set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRTD")
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental")
find_library(SHLWAPI Shlwapi.lib)
@ -218,7 +223,9 @@ if(COMPILE_CUDA)
if(USE_STATIC_LIBS)
# link statically to stdlib libraries
set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++")
if(NOT MSVC)
set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++")
endif()
# look for libraries that have .a suffix
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})
@ -250,12 +257,22 @@ if(CUDA_FOUND)
endif(COMPILE_CUDA_SM70)
if(USE_STATIC_LIBS)
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
# The cuLIBOS library does not seem to exist in Windows CUDA toolkit installs
if(CUDA_culibos_LIBRARY)
set(EXT_LIBS ${EXT_LIBS} ${CUDA_culibos_LIBRARY})
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_culibos_LIBRARY})
elseif(NOT WIN32)
message(FATAL_ERROR "cuLIBOS library not found")
endif()
# CUDA 10.1 introduces cublasLt library that is required on static build
if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1"))
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
if(NOT CUDA_cublasLt_LIBRARY)
message(FATAL_ERROR "cuBLASLt library not found")
endif()
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
endif()
@ -319,7 +336,7 @@ if(NOT MSVC)
list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC})
else()
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; )
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; -Xcompiler\ /MT$<$<CONFIG:Debug>:d>; )
endif()
list(REMOVE_DUPLICATES CUDA_NVCC_FLAGS)
@ -443,8 +460,11 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h.in
# Generate build_info.cpp with CMake cache variables
include(GetCacheVariables)
# make sure src/common/build_info.cpp has been removed
execute_process(COMMAND rm ${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp
OUTPUT_QUIET ERROR_QUIET)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in
${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp @ONLY)
${CMAKE_CURRENT_BINARY_DIR}/src/common/build_info.cpp @ONLY)
# Compile source files
include_directories(${marian_SOURCE_DIR}/src)

52
CMakeSettingsCI.json Normal file
View File

@ -0,0 +1,52 @@
{
"configurations": [
{
"name": "Release",
"generator": "Ninja",
"configurationType": "Release",
"inheritEnvironments": [ "msvc_x64" ],
"cmakeCommandArgs": "",
"buildCommandArgs": "-v",
"ctestCommandArgs": "",
"variables": [
{ "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" },
{ "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" },
{ "name": "COMPILE_CUDA:BOOL", "value": "FALSE" },
{ "name": "COMPILE_CPU:BOOL", "value": "TRUE" },
{ "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" },
{ "name": "COMPILE_SERVER:BOOL", "value": "FALSE" },
{ "name": "COMPILE_TESTS:BOOL", "value": "TRUE" },
{ "name": "USE_FBGEMM:BOOL", "value": "TRUE" },
{ "name": "USE_MPI:BOOL", "value": "FALSE" },
{ "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" },
{ "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" }
]
},
{
"name": "Debug",
"generator": "Ninja",
"configurationType": "Debug",
"inheritEnvironments": [ "msvc_x64" ],
"cmakeCommandArgs": "",
"buildCommandArgs": "-v",
"ctestCommandArgs": "",
"variables": [
{ "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" },
{ "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" },
{ "name": "COMPILE_CUDA:BOOL", "value": "FALSE" },
{ "name": "COMPILE_CPU:BOOL", "value": "TRUE" },
{ "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" },
{ "name": "COMPILE_SERVER:BOOL", "value": "FALSE" },
{ "name": "COMPILE_TESTS:BOOL", "value": "TRUE" },
{ "name": "USE_FBGEMM:BOOL", "value": "TRUE" },
{ "name": "USE_MPI:BOOL", "value": "FALSE" },
{ "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" },
{ "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" }
]
}
]
}

View File

@ -1,2 +1,2 @@
v1.9.26
v1.9.32

View File

@ -89,10 +89,17 @@ find_library(MKL_CORE_LIBRARY
NO_DEFAULT_PATH)
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
# Added -Wl block to avoid circular dependencies.
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group)
set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY})
if(NOT WIN32 AND NOT APPLE)
# Added -Wl block to avoid circular dependencies.
# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
set(MKL_LIBRARIES -Wl,--start-group ${MKL_LIBRARIES} -Wl,--end-group)
elseif(APPLE)
# MacOS does not support --start-group and --end-group
set(MKL_LIBRARIES -Wl,${MKL_LIBRARIES} -Wl,)
endif()
# message("1 ${MKL_INCLUDE_DIR}")
# message("2 ${MKL_INTERFACE_LIBRARY}")
@ -130,4 +137,4 @@ endif()
INCLUDE(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(MKL DEFAULT_MSG MKL_LIBRARIES MKL_INCLUDE_DIRS MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)
MARK_AS_ADVANCED(MKL_INCLUDE_DIRS MKL_LIBRARIES MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)
MARK_AS_ADVANCED(MKL_INCLUDE_DIRS MKL_LIBRARIES MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY)

View File

@ -34,17 +34,18 @@ foreach(_variableName ${_variableNames})
(NOT "${_variableType}" STREQUAL "INTERNAL") AND
(NOT "${_variableValue}" STREQUAL "") )
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValue}\\n\"\n")
string(REPLACE "\"" " " _variableValueEscapedQuotes ${_variableValue})
string(REPLACE "\\" "/" _variableValueEscaped ${_variableValueEscapedQuotes})
set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValueEscaped}\\n\"\n")
# Get the variable's advanced flag
get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET)
if(NOT _isAdvanced)
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValue}\\n\"\n")
set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValueEscaped}\\n\"\n")
endif()
# Print variables for debugging
#message(STATUS "${_variableName}=${${_variableName}}")
#message(STATUS "${_variableName}=${_variableValueEscaped}")
#message(STATUS " Type=${_variableType}")
#message(STATUS " Advanced=${_isAdvanced}")
endif()

View File

@ -18,7 +18,7 @@
if(PROJECT_VERSION_FILE)
file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING)
else()
file(STRINGS ${CMAKE_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING)
file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING)
endif()
# Get current commit SHA from git

@ -1 +1 @@
Subproject commit 0f8cabf13ec362d50544d33490024e00c3a763be
Subproject commit 7b8f6ee5b6ff7779fd993df7f77adf1e2d9adbe5

View File

@ -1595,7 +1595,8 @@ class App {
if(num < 0) {
// RG: We need to keep track if the vector option is empty and handle this separately as
// otherwise the parser will mark the command-line option as not set
bool emptyVectorArgs = true;
// RG: An option value after '=' was already collected
bool emptyVectorArgs = (collected <= 0);
while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) {
if(collected >= -num) {
// We could break here for allow extras, but we don't

View File

@ -19,6 +19,9 @@ if(USE_FBGEMM)
# only locally disabled for the 3rd_party folder
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused")
else()
# Do not compile cpuinfo executables due to a linker error, and they are not needed
set(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "Build command-line tools")
endif()
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
@ -45,7 +48,7 @@ if(USE_SENTENCEPIECE)
endif()
endif()
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available." FORCE)
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available.")
if(USE_STATIC_LIBS)
message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \
@ -55,8 +58,8 @@ if(USE_SENTENCEPIECE)
set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
else(USE_STATIC_LIBS)
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC." FORCE)
set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries.")
set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC.")
endif(USE_STATIC_LIBS)
add_subdirectory(./sentencepiece)
@ -66,7 +69,7 @@ if(USE_SENTENCEPIECE)
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
foreach(t sentencepiece sentencepiece_train sentencepiece_train-static
foreach(t sentencepiece-static sentencepiece_train-static
spm_decode spm_encode spm_export_vocab spm_normalize spm_train)
set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)

View File

@ -18,6 +18,10 @@
#include<map>
#include <memory>
#ifdef __APPLE__
#include <unistd.h>
#endif
namespace cnpy {
struct NpyArray {

View File

@ -18,6 +18,9 @@
#include <stdint.h>
#include <faiss/Index.h>
#ifdef __APPLE__
#include <x86intrin.h>
#endif
namespace faiss {

View File

@ -679,7 +679,8 @@ PHF_PUBLIC void PHF::compact(struct phf *phf) {
}
/* simply keep old array if realloc fails */
if ((tmp = realloc(phf->g, phf->r * size)))
tmp = realloc(phf->g, phf->r * size);
if (tmp != 0)
phf->g = static_cast<uint32_t *>(tmp);
} /* PHF::compact() */

View File

@ -2,11 +2,11 @@
file(GLOB ZLIB_SRC *.c)
file(GLOB ZLIB_INC *.h)
# add sources of the wrapper as a "SQLiteCpp" static library
# add sources of the wrapper as a "zlib" static library
add_library(zlib OBJECT ${ZLIB_SRC} ${ZLIB_INC})
if(MSVC)
target_compile_options(zlib PUBLIC /wd"4996" /wd"4267")
target_compile_options(zlib PUBLIC /wd4996 /wd4267)
else()
target_compile_options(zlib PUBLIC -Wno-implicit-function-declaration)
endif()

View File

@ -20,7 +20,7 @@ add_library(marian STATIC
common/config_validator.cpp
common/options.cpp
common/binary.cpp
common/build_info.cpp
${CMAKE_CURRENT_BINARY_DIR}/common/build_info.cpp
common/io.cpp
common/filesystem.cpp
common/file_stream.cpp
@ -70,7 +70,7 @@ add_library(marian STATIC
layers/generic.cpp
layers/loss.cpp
layers/weight.cpp
layers/lsh.cpp
layers/lsh.cpp
rnn/cells.cpp
rnn/attention.cpp
@ -117,21 +117,22 @@ target_compile_options(marian PUBLIC ${ALL_WARNINGS})
# Git updates .git/logs/HEAD file whenever you pull or commit something.
# If Marian is checked out as a submodule in another repository,
# there's no .git directory in ${CMAKE_SOURCE_DIR}. Instead .git is a
# file that specifies the relative path from ${CMAKE_SOURCE_DIR} to
# ./git/modules/<MARIAN_ROOT_DIR> in the root of the repository that
# contains Marian as a submodule. We set MARIAN_GIT_DIR to the appropriate
# path, depending on whether ${CMAKE_SOURCE_DIR}/.git is a directory or file.
if(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) # not a submodule
set(MARIAN_GIT_DIR ${CMAKE_SOURCE_DIR}/.git)
else(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
file(READ ${CMAKE_SOURCE_DIR}/.git MARIAN_GIT_DIR)
# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is not a directory but a file
# that specifies the relative path from ${CMAKE_CURRENT_SOURCE_DIR}/..
# to ./git/modules/<MARIAN_ROOT_DIR> in the root of the check_out of
# the project that contains Marian as a submodule.
#
# We set MARIAN_GIT_DIR to the appropriate path, depending on whether
# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is a directory or file.
set(MARIAN_GIT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../.git)
if(NOT IS_DIRECTORY ${MARIAN_GIT_DIR}) # i.e., it's a submodule
file(READ ${MARIAN_GIT_DIR} MARIAN_GIT_DIR)
string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR})
get_filename_component(MARIAN_GIT_DIR "${CMAKE_SOURCE_DIR}/${MARIAN_GIT_DIR}" ABSOLUTE)
endif(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
get_filename_component(MARIAN_GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${MARIAN_GIT_DIR}" ABSOLUTE)
endif(NOT IS_DIRECTORY ${MARIAN_GIT_DIR})
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD
VERBATIM
@ -220,7 +221,7 @@ if(COMPILE_SERVER)
add_executable(marian_server command/marian_server.cpp)
set_target_properties(marian_server PROPERTIES OUTPUT_NAME marian-server)
if(MSVC)
# Disable warnings from the SimpleWebSocketServer library
# Disable warnings from the SimpleWebSocketServer library needed for compilation of marian-server
target_compile_options(marian_server PUBLIC ${ALL_WARNINGS} /wd4267 /wd4244 /wd4456 /wd4458)
else(MSVC)
# -Wno-suggest-override disables warnings from Boost 1.69+

View File

@ -14,6 +14,7 @@ int main(int argc, char **argv) {
// Initialize translation task
auto options = parseOptions(argc, argv, cli::mode::server, true);
auto task = New<TranslateService<BeamSearch>>(options);
auto quiet = options->get<bool>("quiet-translation");
// Initialize web server
WSServer server;
@ -21,8 +22,8 @@ int main(int argc, char **argv) {
auto &translate = server.endpoint["^/translate/?$"];
translate.on_message = [&task](Ptr<WSServer::Connection> connection,
Ptr<WSServer::InMessage> message) {
translate.on_message = [&task, quiet](Ptr<WSServer::Connection> connection,
Ptr<WSServer::InMessage> message) {
// Get input text
auto inputText = message->string();
auto sendStream = std::make_shared<WSServer::OutMessage>();
@ -30,9 +31,9 @@ int main(int argc, char **argv) {
// Translate
timer::Timer timer;
auto outputText = task->run(inputText);
LOG(info, "Best translation: {}", outputText);
*sendStream << outputText << std::endl;
LOG(info, "Translation took: {:.5f}s", timer.elapsed());
if(!quiet)
LOG(info, "Translation took: {:.5f}s", timer.elapsed());
// Send translation back
connection->send(sendStream, [](const SimpleWeb::error_code &ec) {

View File

@ -14,7 +14,9 @@ namespace filesystem {
// Pretend that Windows knows no named pipes. It does, by the way, but
// they seem to be different from pipes on Unix / Linux. See
// https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
bool is_fifo(char const*) { return false; }
bool is_fifo(char const* /*path*/) {
return false;
}
#else
bool is_fifo(char const* path) {
struct stat buf;

View File

@ -2,9 +2,9 @@
namespace marian {
Options::Options()
Options::Options()
#if FASTOPT
: fastOptions_(options_)
: fastOptions_(options_)
#endif
{}

View File

@ -3,6 +3,7 @@
#include <iostream>
#include <sstream>
#include <chrono>
#include <ctime>
namespace marian {
namespace timer {

View File

@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const {
TextInput::TextInput(std::vector<std::string> inputs,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options)
: DatasetBase(inputs, options), vocabs_(vocabs) {
// note: inputs are automatically stored in the inherited variable named paths_, but these are
: DatasetBase(inputs, options),
vocabs_(vocabs),
maxLength_(options_->get<size_t>("max-length")),
maxLengthCrop_(options_->get<bool>("max-length-crop")) {
// Note: inputs are automatically stored in the inherited variable named paths_, but these are
// texts not paths!
for(const auto& text : paths_)
files_.emplace_back(new std::istringstream(text));
@ -42,6 +45,10 @@ SentenceTuple TextInput::next() {
std::string line;
if(io::getline(*files_[i], line)) {
Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
if(this->maxLengthCrop_ && words.size() > this->maxLength_) {
words.resize(maxLength_);
words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels
}
if(words.empty())
words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
tup.push_back(words);

View File

@ -33,6 +33,9 @@ private:
size_t pos_{0};
size_t maxLength_{0};
bool maxLengthCrop_{false};
public:
typedef SentenceTuple Sample;

View File

@ -23,7 +23,7 @@ namespace marian {
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
auto firstLogits = logits_.front()->loss();
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
"Labels not matching logits shape ({} != {}, {})??",
labels.size() * firstLogits->shape()[-1],
firstLogits->shape().elements(),
@ -267,8 +267,8 @@ namespace marian {
Logits Output::applyAsLogits(Expr input) /*override final*/ {
lazyConstruct(input->shape()[-1]);
auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
#if BLAS_FOUND
auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
if(lsh_) {
ABORT_IF( transA, "Transposed query not supported for LSH");
ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
@ -276,10 +276,12 @@ namespace marian {
} else {
return affine(x, W, b, transA, transB);
}
#else
return affine(x, W, b, transA, transB);
#endif
};
#else
auto affineOrLSH = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
return affine(x, W, b, transA, transB);
};
#endif
if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
@ -362,7 +364,7 @@ namespace marian {
factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
else
factorLogits = affine(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
// optionally add lemma-dependent bias
if (Plemma) { // [B... x U0]
int lemmaVocabDim = Plemma->shape()[-1];
@ -448,7 +450,7 @@ namespace marian {
// Embedding layer initialization should depend only on embedding size, hence fanIn=false
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
if (options_->has("embFile")) {
std::string file = opt<std::string>("embFile");
if (!file.empty()) {

View File

@ -32,8 +32,6 @@ Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
return New<MeanMultiRationalLoss>();
else
ABORT("Unknown multi-loss-type {}", multiLossType);
return nullptr;
}
} // namespace marian

View File

@ -74,7 +74,6 @@ std::string ScoreCollector::getAlignment(const data::SoftAlignment& align) {
} else {
ABORT("Unrecognized word alignment type");
}
return "";
}
ScoreCollectorNBest::ScoreCollectorNBest(const Ptr<Options>& options)

View File

@ -19,7 +19,6 @@ struct QuantizeNodeOp : public UnaryNodeOp {
NodeOps backwardOps() override {
ABORT("Only used for inference");
return {NodeOp(0)};
}
const std::string type() override { return "quantizeInt16"; }
@ -54,7 +53,6 @@ public:
NodeOps backwardOps() override {
ABORT("Only used for inference");
return {NodeOp(0)};
}
const std::string type() override { return "dotInt16"; }
@ -92,7 +90,6 @@ public:
NodeOps backwardOps() override {
ABORT("Only used for inference");
return {NodeOp(0)};
}
const std::string type() override { return "affineInt16"; }

View File

@ -1,25 +1,26 @@
# Unit tests
add_subdirectory(units)
if(NOT MSVC)
# Testing apps
set(APP_TESTS
logger
dropout
sqlite
prod
cli
pooling
)
# Testing apps
set(APP_TESTS
logger
dropout
sqlite
prod
cli
pooling
)
foreach(test ${APP_TESTS})
add_executable("test_${test}" "${test}.cpp")
foreach(test ${APP_TESTS})
add_executable("test_${test}" "${test}.cpp")
if(CUDA_FOUND)
target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
else(CUDA_FOUND)
target_link_libraries("test_${test}" marian ${EXT_LIBS})
endif(CUDA_FOUND)
if(CUDA_FOUND)
target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
else(CUDA_FOUND)
target_link_libraries("test_${test}" marian ${EXT_LIBS})
endif(CUDA_FOUND)
set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(test)
set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(test)
endif(NOT MSVC)

View File

@ -17,5 +17,10 @@ foreach(test ${UNIT_TESTS})
target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
endif(CUDA_FOUND)
if(MSVC)
# Disable C4305: truncation from 'double' to '_Ty'
target_compile_options("run_${test}" PUBLIC /wd4305)
endif(MSVC)
add_test(NAME ${test} COMMAND "run_${test}")
endforeach(test)

View File

@ -5,8 +5,6 @@
using namespace marian;
TEST_CASE("FastOpt can be constructed from a YAML node", "[fastopt]") {
YAML::Node node;
SECTION("from a simple node") {
YAML::Node node = YAML::Load("{foo: bar}");
const FastOpt o(node);

View File

@ -40,12 +40,12 @@ void tests(DeviceType device, Type floatType = Type::float32) {
std::vector<T> vA({1, -2, 3, -4});
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
auto compare = [&](Expr res, std::function<float(float)> f, bool exactMatch) -> bool {
auto compare = [&](Expr res, std::function<float(float)> f) -> bool {
if (res->shape() != Shape({ 2, 2, 1 }))
return false;
res->val()->get(values);
std::vector<float> ref{f(vA[0]), f(vA[1]), f(vA[2]), f(vA[3])};
return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
return std::equal(values.begin(), values.end(), ref.begin(), floatEqual);
};
// @TODO: add all operators and scalar variants here for completeness
@ -55,15 +55,15 @@ void tests(DeviceType device, Type floatType = Type::float32) {
auto rmax2 = maximum(1, a);
auto rmin1 = minimum(a, 1);
auto rmin2 = minimum(1, a);
graph->forward();
CHECK(compare(rsmult, [](float a) {return 2.f * a;}, true));
CHECK(compare(rabs, [](float a) {return std::abs(a);}, true));
CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);}, true));
CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);}, true));
CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);}, true));
CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);}, true));
CHECK(compare(rsmult, [](float a) {return 2.f * a;}));
CHECK(compare(rabs, [](float a) {return std::abs(a);}));
CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);}));
CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);}));
CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);}));
CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);}));
}
SECTION("elementwise binary operators with broadcasting") {
@ -76,12 +76,23 @@ void tests(DeviceType device, Type floatType = Type::float32) {
auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
auto b = graph->constant({2, 1}, inits::fromVector(vB));
auto compare = [&](Expr res, std::function<float(float,float)> f, bool exactMatch) -> bool {
// Two lambdas below differ in the use of floatEqual or floatApprox and
// are not merged because MSVC compiler returns C2446: no conversion from
// lambda_x to lambda_y
auto compare = [&](Expr res, std::function<float(float,float)> f) -> bool {
if (res->shape() != Shape({ 2, 2, 1 }))
return false;
res->val()->get(values);
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
return std::equal(values.begin(), values.end(), ref.begin(), floatEqual);
};
auto compareApprox = [&](Expr res, std::function<float(float, float)> f) -> bool {
if(res->shape() != Shape({2, 2, 1}))
return false;
res->val()->get(values);
std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
return std::equal(values.begin(), values.end(), ref.begin(), floatApprox);
};
auto rplus = a + b;
@ -100,19 +111,19 @@ void tests(DeviceType device, Type floatType = Type::float32) {
graph->forward();
CHECK(compare(rplus, [](float a, float b) {return a + b;}, true));
CHECK(compare(rminus, [](float a, float b) {return a - b;}, true));
CHECK(compare(rmult, [](float a, float b) {return a * b;}, true));
CHECK(compare(rdiv, [](float a, float b) {return a / b;}, false));
CHECK(compare(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}, false));
CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}, true));
CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}, true));
CHECK(compare(rlt, [](float a, float b) {return a < b;}, true));
CHECK(compare(req, [](float a, float b) {return a == b;}, true));
CHECK(compare(rgt, [](float a, float b) {return a > b;}, true));
CHECK(compare(rge, [](float a, float b) {return a >= b;}, true));
CHECK(compare(rne, [](float a, float b) {return a != b;}, true));
CHECK(compare(rle, [](float a, float b) {return a <= b;}, true));
CHECK(compare(rplus, [](float a, float b) {return a + b;}));
CHECK(compare(rminus, [](float a, float b) {return a - b;}));
CHECK(compare(rmult, [](float a, float b) {return a * b;}));
CHECK(compareApprox(rdiv, [](float a, float b) {return a / b;}));
CHECK(compareApprox(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}));
CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}));
CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}));
CHECK(compare(rlt, [](float a, float b) {return a < b;}));
CHECK(compare(req, [](float a, float b) {return a == b;}));
CHECK(compare(rgt, [](float a, float b) {return a > b;}));
CHECK(compare(rge, [](float a, float b) {return a >= b;}));
CHECK(compare(rne, [](float a, float b) {return a != b;}));
CHECK(compare(rle, [](float a, float b) {return a <= b;}));
}
SECTION("transposing and reshaping") {
@ -399,8 +410,8 @@ void tests(DeviceType device, Type floatType = Type::float32) {
std::vector<float> SV; // create CSR version of S
std::vector<IndexType> SI, SO;
SO.push_back((IndexType)SI.size());
for (IndexType i = 0; i < S->shape()[0]; i++) {
for (IndexType j = 0; j < S->shape()[1]; j++) {
for (IndexType i = 0; i < (IndexType)S->shape()[0]; i++) {
for (IndexType j = 0; j < (IndexType)S->shape()[1]; j++) {
auto k = 4 * i + j;
if (vS[k] != 0) {
SV.push_back(vS[k]);
@ -477,7 +488,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
aff1->val()->get(values);
CHECK(values == vAff);
std::vector<T> values2;
values2.clear();
CHECK(aff2->shape() == aff1->shape());
aff2->val()->get(values2);
CHECK(values2 == values);
@ -653,7 +664,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
SECTION("relation of rows and columns selection using transpose") {
graph->clear();
values.clear();
std::vector<T> values2;
values2.clear();
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
std::vector<IndexType> idx({0, 1});
@ -763,7 +774,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
SECTION("rows/cols as gather operations") {
graph->clear();
values.clear();
std::vector<T> values2;
values2.clear();
std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
@ -791,14 +802,14 @@ void tests(DeviceType device, Type floatType = Type::float32) {
SECTION("topk operations") {
graph->clear();
values.clear();
std::vector<T> vA({ 0, .3333, -.2,
-.3, 0, 4.5,
5.2, -10, 101.45,
std::vector<T> vA({ 0, .3333, -.2,
-.3, 0, 4.5,
5.2, -10, 101.45,
-100.05, 0, 1.05e-5});
auto a = graph->constant({2, 2, 3}, inits::fromVector(vA));
// get top-k indices and values as a tuple
auto rtopk1 = topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/true);
auto rval1 = get<0>(rtopk1); // values from top-k
@ -807,13 +818,13 @@ void tests(DeviceType device, Type floatType = Type::float32) {
auto ridx2 = get<1>(topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/false));
auto gval2 = gather(a, -1, ridx2); // get the same values via gather and indices
auto ridx3 = get<1>(argmin(a, -1));
auto ridx3_ = slice(ridx2, -1, 0); // slice and cast now support uint32_t/IndexType
// @TODO: add integer types to more operators
auto eq3 = eq(cast(ridx3, floatType), cast(ridx3_, floatType));
auto rtopk4 = argmax(a, /*axis=*/-2); // axes other than -1 are currently implemented via inefficient transpose
auto rval4 = get<0>(rtopk4);
auto ridx4 = get<1>(rtopk4);
@ -824,12 +835,12 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(rval1 != gval1);
CHECK(rval1->shape() == gval1->shape());
CHECK(ridx1->shape() == gval1->shape());
std::vector<T> vval1 = { 0.3333, 0,
4.5, 0,
101.45, 5.2,
std::vector<T> vval1 = { 0.3333, 0,
4.5, 0,
101.45, 5.2,
1.05e-5, 0 };
std::vector<T> rvalues;
std::vector<T> gvalues;
rval1->val()->get(rvalues);
@ -837,9 +848,9 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK( rvalues == gvalues );
CHECK( rvalues == vval1 );
std::vector<T> vval2 = { -0.2, 0,
-0.3, 0,
-10.0, 5.2,
std::vector<T> vval2 = { -0.2, 0,
-0.3, 0,
-10.0, 5.2,
-100.05, 0 };
gval2->val()->get(values);
CHECK( values == vval2 );
@ -850,10 +861,10 @@ void tests(DeviceType device, Type floatType = Type::float32) {
std::vector<IndexType> vidx4;
ridx4->val()->get(vidx4);
CHECK( ridx4->shape() == Shape({2, 1, 3}) );
CHECK( vidx4 == std::vector<IndexType>({0, 0, 1,
CHECK( vidx4 == std::vector<IndexType>({0, 0, 1,
0, 1, 0}) );
std::vector<T> vval4 = { 0, 0.3333, 4.5,
std::vector<T> vval4 = { 0, 0.3333, 4.5,
5.2, 0, 101.45 };
rval4->val()->get(values);
CHECK( values == vval4 );
@ -886,7 +897,7 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]")
TEST_CASE("Compare aggregate operator", "[graph]") {
auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); };
Config::seed = 1234;
std::vector<float> initc;
@ -908,7 +919,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") {
SECTION("initializing with zero (cpu)") {
std::vector<float> values1;
std::vector<float> values2;
auto graph1 = New<ExpressionGraph>();
graph1->setDevice({0, DeviceType::cpu});
graph1->reserveWorkspaceMB(40);
@ -916,7 +927,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") {
auto graph2 = New<ExpressionGraph>();
graph2->setDevice({0, DeviceType::gpu});
graph2->reserveWorkspaceMB(40);
auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
auto prod1 = scalar_product(chl1, adj1, -1);
@ -935,4 +946,4 @@ TEST_CASE("Compare aggregate operator", "[graph]") {
}
#endif
#endif
#endif

View File

@ -181,7 +181,7 @@ void tests(DeviceType type, Type floatType = Type::float32) {
auto context = concatenate({rnnFw.construct(graph)->transduce(input, mask),
rnnBw.construct(graph)->transduce(input, mask)},
/*axis =*/ input->shape().size() - 1);
/*axis =*/ (int)input->shape().size() - 1);
if(second > 0) {
// add more layers (unidirectional) by transducing the output of the

View File

@ -79,13 +79,14 @@ void OutputCollector::Write(long sourceId,
}
}
StringCollector::StringCollector() : maxId_(-1) {}
StringCollector::StringCollector(bool quiet /*=false*/) : maxId_(-1), quiet_(quiet) {}
void StringCollector::add(long sourceId,
const std::string& best1,
const std::string& bestn) {
std::lock_guard<std::mutex> lock(mutex_);
LOG(info, "Best translation {} : {}", sourceId, best1);
if(!quiet_)
LOG(info, "Best translation {} : {}", sourceId, best1);
outputs_[sourceId] = std::make_pair(best1, bestn);
if(maxId_ <= sourceId)
maxId_ = sourceId;

View File

@ -74,14 +74,15 @@ protected:
class StringCollector {
public:
StringCollector();
StringCollector(bool quiet = false);
StringCollector(const StringCollector&) = delete;
void add(long sourceId, const std::string& best1, const std::string& bestn);
std::vector<std::string> collect(bool nbest);
protected:
long maxId_;
long maxId_; // the largest index of the translated source sentences
bool quiet_; // if true do not log best translations
std::mutex mutex_;
typedef std::map<long, std::pair<std::string, std::string>> Outputs;

View File

@ -1,5 +1,7 @@
#pragma once
#include <string>
#include "data/batch_generator.h"
#include "data/corpus.h"
#include "data/shortlist.h"
@ -245,10 +247,14 @@ public:
}
std::string run(const std::string& input) override {
auto corpus_ = New<data::TextInput>(std::vector<std::string>({input}), srcVocabs_, options_);
// split tab-separated input into fields if necessary
auto inputs = options_->get<bool>("tsv", false)
? convertTsvToLists(input, options_->get<size_t>("tsv-fields", 1))
: std::vector<std::string>({input});
auto corpus_ = New<data::TextInput>(inputs, srcVocabs_, options_);
data::BatchGenerator<data::TextInput> batchGenerator(corpus_, options_);
auto collector = New<StringCollector>();
auto collector = New<StringCollector>(options_->get<bool>("quiet-translation", false));
auto printer = New<OutputPrinter>(options_, trgVocab_);
size_t batchId = 0;
@ -258,7 +264,6 @@ public:
ThreadPool threadPool_(numDevices_, numDevices_);
for(auto batch : batchGenerator) {
auto task = [=](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local std::vector<Ptr<Scorer>> scorers;
@ -287,5 +292,30 @@ public:
auto translations = collector->collect(options_->get<bool>("n-best"));
return utils::join(translations, "\n");
}
private:
// Converts a multi-line input with tab-separated source(s) and target sentences into separate lists
// of sentences from source(s) and target sides, e.g.
// "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"]
std::vector<std::string> convertTsvToLists(const std::string& inputText, size_t numFields) {
std::vector<std::string> outputFields(numFields);
std::string line;
std::vector<std::string> lineFields(numFields);
std::istringstream inputStream(inputText);
bool first = true;
while(std::getline(inputStream, line)) {
utils::splitTsv(line, lineFields, numFields);
for(size_t i = 0; i < numFields; ++i) {
if(!first)
outputFields[i] += "\n"; // join sentences with a new line sign
outputFields[i] += lineFields[i];
}
if(first)
first = false;
}
return outputFields;
}
};
} // namespace marian