Minor cleanup for setup.py

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/1078

Differential Revision: D17072514

Pulled By: myleott

fbshipit-source-id: 69a8c8c9cc7caa7e04c414329a5d79e6e1a6621c
This commit is contained in:
Myle Ott 2019-08-27 10:06:26 -07:00 committed by Facebook Github Bot
parent 920b85d4bd
commit d2410c4207
3 changed files with 30 additions and 28 deletions

View File

@ -10,11 +10,11 @@ except ImportError:
import contextlib
import itertools
import os
import numpy as np
import sys
import types
import numpy as np
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
@ -204,12 +204,14 @@ def batch_by_size(
raise ImportError(
'Please build Cython components with: `pip install --editable .`'
)
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
bsz_mult = required_batch_size_multiple
if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)

View File

@ -11,6 +11,7 @@ from fairseq.models import MODEL_REGISTRY
dependencies = [
'numpy',
'regex',
'requests',
'torch',

View File

@ -11,47 +11,45 @@ import sys
if sys.version_info < (3,):
sys.exit('Sorry, Python3 is required for fairseq.')
with open('README.md') as f:
readme = f.read()
if sys.platform == 'darwin':
extra_compile_args = ['-stdlib=libc++', '-O3']
extra_link_args = ['-stdlib=libc++']
else:
extra_compile_args = ['-std=c++11', '-O3']
extra_link_args = ['-std=c++11']
bleu = Extension(
'fairseq.libbleu',
sources=[
'fairseq/clib/libbleu/libbleu.cpp',
'fairseq/clib/libbleu/module.cpp',
],
extra_compile_args=extra_compile_args,
)
def get_cython_modules():
token_block_utils = Extension(
"fairseq.data.token_block_utils_fast",
["fairseq/data/token_block_utils_fast.pyx"],
extensions = [
Extension(
'fairseq.libbleu',
sources=[
'fairseq/clib/libbleu/libbleu.cpp',
'fairseq/clib/libbleu/module.cpp',
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
data_utils_fast = Extension(
"fairseq.data.data_utils_fast",
["fairseq/data/data_utils_fast.pyx"],
language="c++",
),
Extension(
'fairseq.data.data_utils_fast',
sources=['fairseq/data/data_utils_fast.pyx'],
language='c++',
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
return [token_block_utils, data_utils_fast]
),
Extension(
'fairseq.data.token_block_utils_fast',
sources=['fairseq/data/token_block_utils_fast.pyx'],
language='c++',
extra_compile_args=extra_compile_args,
),
]
def my_build_ext(pars):
"""
Delay loading of numpy headers.
More details: https://stackoverflow.com/questions/54117786/add-numpy-get-include-argument-to-setuptools-without-preinstalled-numpy
More details: https://stackoverflow.com/a/54138355
"""
from setuptools.command.build_ext import build_ext as _build_ext
@ -81,6 +79,7 @@ setup(
setup_requires=[
'numpy',
'cython',
'numpy',
'setuptools>=18.0',
],
install_requires=[
@ -93,7 +92,7 @@ setup(
'tqdm',
],
packages=find_packages(exclude=['scripts', 'tests']),
ext_modules=get_cython_modules() + [bleu],
ext_modules=extensions,
test_suite='tests',
entry_points={
'console_scripts': [