From d2410c4207b3a32cd1147236982abec2273a3d69 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Tue, 27 Aug 2019 10:06:26 -0700 Subject: [PATCH] Minor cleanup for setup.py Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1078 Differential Revision: D17072514 Pulled By: myleott fbshipit-source-id: 69a8c8c9cc7caa7e04c414329a5d79e6e1a6621c --- fairseq/data/data_utils.py | 6 +++-- hubconf.py | 1 + setup.py | 51 +++++++++++++++++++------------------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 22c8c60db..9d72d93e6 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -10,11 +10,11 @@ except ImportError: import contextlib import itertools import os - -import numpy as np import sys import types +import numpy as np + def infer_language_pair(path): """Infer language pair from filename: .-.(...).idx""" @@ -204,12 +204,14 @@ def batch_by_size( raise ImportError( 'Please build Cython components with: `pip install --editable .`' ) + max_tokens = max_tokens if max_tokens is not None else sys.maxsize max_sentences = max_sentences if max_sentences is not None else sys.maxsize bsz_mult = required_batch_size_multiple if isinstance(indices, types.GeneratorType): indices = np.fromiter(indices, dtype=np.int64, count=-1) + return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) diff --git a/hubconf.py b/hubconf.py index 34179c9db..c13977085 100644 --- a/hubconf.py +++ b/hubconf.py @@ -11,6 +11,7 @@ from fairseq.models import MODEL_REGISTRY dependencies = [ + 'numpy', 'regex', 'requests', 'torch', diff --git a/setup.py b/setup.py index d900b9465..9ec2d7360 100644 --- a/setup.py +++ b/setup.py @@ -11,47 +11,45 @@ import sys if sys.version_info < (3,): sys.exit('Sorry, Python3 is required for fairseq.') + with open('README.md') as f: readme = f.read() + if sys.platform == 'darwin': extra_compile_args = ['-stdlib=libc++', '-O3'] - extra_link_args = ['-stdlib=libc++'] else: extra_compile_args = ['-std=c++11', '-O3'] - extra_link_args = ['-std=c++11'] - -bleu = Extension( - 'fairseq.libbleu', - sources=[ - 'fairseq/clib/libbleu/libbleu.cpp', - 'fairseq/clib/libbleu/module.cpp', - ], - extra_compile_args=extra_compile_args, -) -def get_cython_modules(): - token_block_utils = Extension( - "fairseq.data.token_block_utils_fast", - ["fairseq/data/token_block_utils_fast.pyx"], +extensions = [ + Extension( + 'fairseq.libbleu', + sources=[ + 'fairseq/clib/libbleu/libbleu.cpp', + 'fairseq/clib/libbleu/module.cpp', + ], extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - data_utils_fast = Extension( - "fairseq.data.data_utils_fast", - ["fairseq/data/data_utils_fast.pyx"], - language="c++", + ), + Extension( + 'fairseq.data.data_utils_fast', + sources=['fairseq/data/data_utils_fast.pyx'], + language='c++', extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - return [token_block_utils, data_utils_fast] + ), + Extension( + 'fairseq.data.token_block_utils_fast', + sources=['fairseq/data/token_block_utils_fast.pyx'], + language='c++', + extra_compile_args=extra_compile_args, + ), +] def my_build_ext(pars): """ Delay loading of numpy headers. - More details: https://stackoverflow.com/questions/54117786/add-numpy-get-include-argument-to-setuptools-without-preinstalled-numpy + More details: https://stackoverflow.com/a/54138355 """ from setuptools.command.build_ext import build_ext as _build_ext @@ -81,6 +79,7 @@ setup( setup_requires=[ 'numpy', 'cython', + 'numpy', 'setuptools>=18.0', ], install_requires=[ @@ -93,7 +92,7 @@ setup( 'tqdm', ], packages=find_packages(exclude=['scripts', 'tests']), - ext_modules=get_cython_modules() + [bleu], + ext_modules=extensions, test_suite='tests', entry_points={ 'console_scripts': [