mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
0bad0ce56a
This reverts commit c0c326cbf8
.
274 lines
7.9 KiB
Python
274 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
|
|
from setuptools import Extension, find_packages, setup
|
|
|
|
if sys.version_info < (3, 6):
|
|
sys.exit("Sorry, Python >= 3.6 is required for fairseq.")
|
|
|
|
|
|
def write_version_py():
|
|
with open(os.path.join("fairseq", "version.txt")) as f:
|
|
version = f.read().strip()
|
|
|
|
# write version info to fairseq/version.py
|
|
with open(os.path.join("fairseq", "version.py"), "w") as f:
|
|
f.write('__version__ = "{}"\n'.format(version))
|
|
return version
|
|
|
|
|
|
version = write_version_py()
|
|
|
|
|
|
with open("README.md") as f:
|
|
readme = f.read()
|
|
|
|
|
|
if sys.platform == "darwin":
|
|
extra_compile_args = ["-stdlib=libc++", "-O3"]
|
|
else:
|
|
extra_compile_args = ["-std=c++11", "-O3"]
|
|
|
|
|
|
class NumpyExtension(Extension):
|
|
"""Source: https://stackoverflow.com/a/54128391"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.__include_dirs = []
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def include_dirs(self):
|
|
import numpy
|
|
|
|
return self.__include_dirs + [numpy.get_include()]
|
|
|
|
@include_dirs.setter
|
|
def include_dirs(self, dirs):
|
|
self.__include_dirs = dirs
|
|
|
|
|
|
extensions = [
|
|
Extension(
|
|
"fairseq.libbleu",
|
|
sources=[
|
|
"fairseq/clib/libbleu/libbleu.cpp",
|
|
"fairseq/clib/libbleu/module.cpp",
|
|
],
|
|
extra_compile_args=extra_compile_args,
|
|
),
|
|
NumpyExtension(
|
|
"fairseq.data.data_utils_fast",
|
|
sources=["fairseq/data/data_utils_fast.pyx"],
|
|
language="c++",
|
|
extra_compile_args=extra_compile_args,
|
|
),
|
|
NumpyExtension(
|
|
"fairseq.data.token_block_utils_fast",
|
|
sources=["fairseq/data/token_block_utils_fast.pyx"],
|
|
language="c++",
|
|
extra_compile_args=extra_compile_args,
|
|
),
|
|
]
|
|
|
|
|
|
cmdclass = {}
|
|
|
|
|
|
try:
|
|
# torch is not available when generating docs
|
|
from torch.utils import cpp_extension
|
|
|
|
extensions.extend(
|
|
[
|
|
cpp_extension.CppExtension(
|
|
"fairseq.libbase",
|
|
sources=[
|
|
"fairseq/clib/libbase/balanced_assignment.cpp",
|
|
],
|
|
)
|
|
]
|
|
)
|
|
|
|
extensions.extend(
|
|
[
|
|
cpp_extension.CppExtension(
|
|
"fairseq.libnat",
|
|
sources=[
|
|
"fairseq/clib/libnat/edit_dist.cpp",
|
|
],
|
|
),
|
|
cpp_extension.CppExtension(
|
|
"alignment_train_cpu_binding",
|
|
sources=[
|
|
"examples/operators/alignment_train_cpu.cpp",
|
|
],
|
|
),
|
|
]
|
|
)
|
|
if "CUDA_HOME" in os.environ:
|
|
extensions.extend(
|
|
[
|
|
cpp_extension.CppExtension(
|
|
"fairseq.libnat_cuda",
|
|
sources=[
|
|
"fairseq/clib/libnat_cuda/edit_dist.cu",
|
|
"fairseq/clib/libnat_cuda/binding.cpp",
|
|
],
|
|
),
|
|
cpp_extension.CppExtension(
|
|
"fairseq.ngram_repeat_block_cuda",
|
|
sources=[
|
|
"fairseq/clib/cuda/ngram_repeat_block_cuda.cpp",
|
|
"fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu",
|
|
],
|
|
),
|
|
cpp_extension.CppExtension(
|
|
"alignment_train_cuda_binding",
|
|
sources=[
|
|
"examples/operators/alignment_train_kernel.cu",
|
|
"examples/operators/alignment_train_cuda.cpp",
|
|
],
|
|
),
|
|
]
|
|
)
|
|
cmdclass["build_ext"] = cpp_extension.BuildExtension
|
|
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
if "READTHEDOCS" in os.environ:
|
|
# don't build extensions when generating docs
|
|
extensions = []
|
|
if "build_ext" in cmdclass:
|
|
del cmdclass["build_ext"]
|
|
|
|
# use CPU build of PyTorch
|
|
dependency_links = [
|
|
"https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl"
|
|
]
|
|
else:
|
|
dependency_links = []
|
|
|
|
|
|
if "clean" in sys.argv[1:]:
|
|
# Source: https://bit.ly/2NLVsgE
|
|
print("deleting Cython files...")
|
|
|
|
subprocess.run(
|
|
["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"],
|
|
shell=True,
|
|
)
|
|
|
|
|
|
extra_packages = []
|
|
if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")):
|
|
extra_packages.append("fairseq.model_parallel.megatron.mpu")
|
|
|
|
|
|
def do_setup(package_data):
|
|
setup(
|
|
name="fairseq",
|
|
version=version,
|
|
description="Facebook AI Research Sequence-to-Sequence Toolkit",
|
|
url="https://github.com/pytorch/fairseq",
|
|
classifiers=[
|
|
"Intended Audience :: Science/Research",
|
|
"License :: OSI Approved :: MIT License",
|
|
"Programming Language :: Python :: 3.6",
|
|
"Programming Language :: Python :: 3.7",
|
|
"Programming Language :: Python :: 3.8",
|
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
],
|
|
long_description=readme,
|
|
long_description_content_type="text/markdown",
|
|
setup_requires=[
|
|
"cython",
|
|
'numpy<1.20.0; python_version<"3.7"',
|
|
'numpy; python_version>="3.7"',
|
|
"setuptools>=18.0",
|
|
],
|
|
install_requires=[
|
|
"cffi",
|
|
"cython",
|
|
'dataclasses; python_version<"3.7"',
|
|
"hydra-core>=1.0.7,<1.1",
|
|
"omegaconf<2.1",
|
|
'numpy<1.20.0; python_version<"3.7"',
|
|
'numpy; python_version>="3.7"',
|
|
"regex",
|
|
"sacrebleu>=1.4.12",
|
|
"torch",
|
|
"tqdm",
|
|
"bitarray",
|
|
"torchaudio>=0.8.0",
|
|
],
|
|
dependency_links=dependency_links,
|
|
packages=find_packages(
|
|
exclude=[
|
|
"examples",
|
|
"examples.*",
|
|
"scripts",
|
|
"scripts.*",
|
|
"tests",
|
|
"tests.*",
|
|
]
|
|
)
|
|
+ extra_packages,
|
|
package_data=package_data,
|
|
ext_modules=extensions,
|
|
test_suite="tests",
|
|
entry_points={
|
|
"console_scripts": [
|
|
"fairseq-eval-lm = fairseq_cli.eval_lm:cli_main",
|
|
"fairseq-generate = fairseq_cli.generate:cli_main",
|
|
"fairseq-hydra-train = fairseq_cli.hydra_train:cli_main",
|
|
"fairseq-interactive = fairseq_cli.interactive:cli_main",
|
|
"fairseq-preprocess = fairseq_cli.preprocess:cli_main",
|
|
"fairseq-score = fairseq_cli.score:cli_main",
|
|
"fairseq-train = fairseq_cli.train:cli_main",
|
|
"fairseq-validate = fairseq_cli.validate:cli_main",
|
|
],
|
|
},
|
|
cmdclass=cmdclass,
|
|
zip_safe=False,
|
|
)
|
|
|
|
|
|
def get_files(path, relative_to="fairseq"):
|
|
all_files = []
|
|
for root, _dirs, files in os.walk(path, followlinks=True):
|
|
root = os.path.relpath(root, relative_to)
|
|
for file in files:
|
|
if file.endswith(".pyc"):
|
|
continue
|
|
all_files.append(os.path.join(root, file))
|
|
return all_files
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
# symlink examples into fairseq package so package_data accepts them
|
|
fairseq_examples = os.path.join("fairseq", "examples")
|
|
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples):
|
|
os.symlink(os.path.join("..", "examples"), fairseq_examples)
|
|
|
|
package_data = {
|
|
"fairseq": (
|
|
get_files(fairseq_examples)
|
|
+ get_files(os.path.join("fairseq", "config"))
|
|
)
|
|
}
|
|
do_setup(package_data)
|
|
finally:
|
|
if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples):
|
|
os.unlink(fairseq_examples)
|