Cleaner handling of numpy-based extensions in setup.py

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/853

Differential Revision: D17147879

Pulled By: myleott

fbshipit-source-id: b1f5e838533de62ade52fa82112ea5308734c70f
This commit is contained in:
Myle Ott 2019-08-31 16:52:03 -07:00 committed by Facebook Github Bot
parent 746e59a262
commit 8d4588b1ba

View File

@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
import sys
@ -23,6 +22,23 @@ else:
extra_compile_args = ['-std=c++11', '-O3']
class NumpyExtension(Extension):
"""Source: https://stackoverflow.com/a/54128391"""
def __init__(self, *args, **kwargs):
self.__include_dirs = []
super().__init__(*args, **kwargs)
@property
def include_dirs(self):
import numpy
return self.__include_dirs + [numpy.get_include()]
@include_dirs.setter
def include_dirs(self, dirs):
self.__include_dirs = dirs
extensions = [
Extension(
'fairseq.libbleu',
@ -32,13 +48,13 @@ extensions = [
],
extra_compile_args=extra_compile_args,
),
Extension(
NumpyExtension(
'fairseq.data.data_utils_fast',
sources=['fairseq/data/data_utils_fast.pyx'],
language='c++',
extra_compile_args=extra_compile_args,
),
Extension(
NumpyExtension(
'fairseq.data.token_block_utils_fast',
sources=['fairseq/data/token_block_utils_fast.pyx'],
language='c++',
@ -47,15 +63,6 @@ extensions = [
]
class CustomBuildExtCommand(build_ext):
"""Source: https://stackoverflow.com/a/42163080"""
def run(self):
# Import numpy here, only when headers are needed
import numpy
self.include_dirs.append(numpy.get_include())
super().run()
setup(
name='fairseq',
version='0.8.0',
@ -71,7 +78,6 @@ setup(
long_description=readme,
long_description_content_type='text/markdown',
setup_requires=[
'numpy',
'cython',
'numpy',
'setuptools>=18.0',
@ -99,6 +105,5 @@ setup(
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
cmdclass={'build_ext': CustomBuildExtCommand},
zip_safe=False,
)