guidel's picture
Duplicate from OFA-Sys/OFA-Generic_Interface
8c90e7d
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import subprocess
import sys
from setuptools import Extension, find_packages, setup
if sys.version_info < (3, 6):
sys.exit("Sorry, Python >= 3.6 is required for fairseq.")
def write_version_py():
with open(os.path.join("fairseq", "version.txt")) as f:
version = f.read().strip()
# append latest commit hash to version string
try:
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("ascii")
.strip()
)
version += "+" + sha[:7]
except Exception:
pass
# write version info to fairseq/version.py
with open(os.path.join("fairseq", "version.py"), "w") as f:
f.write('__version__ = "{}"\n'.format(version))
return version
version = write_version_py()
with open("README.md") as f:
readme = f.read()
if sys.platform == "darwin":
extra_compile_args = ["-stdlib=libc++", "-O3"]
else:
extra_compile_args = ["-std=c++11", "-O3"]
class NumpyExtension(Extension):
"""Source: https://stackoverflow.com/a/54128391"""
def __init__(self, *args, **kwargs):
self.__include_dirs = []
super().__init__(*args, **kwargs)
@property
def include_dirs(self):
import numpy
return self.__include_dirs + [numpy.get_include()]
@include_dirs.setter
def include_dirs(self, dirs):
self.__include_dirs = dirs
extensions = [
Extension(
"fairseq.libbleu",
sources=[
"fairseq/clib/libbleu/libbleu.cpp",
"fairseq/clib/libbleu/module.cpp",
],
extra_compile_args=extra_compile_args,
),
NumpyExtension(
"fairseq.data.data_utils_fast",
sources=["fairseq/data/data_utils_fast.pyx"],
language="c++",
extra_compile_args=extra_compile_args,
),
NumpyExtension(
"fairseq.data.token_block_utils_fast",
sources=["fairseq/data/token_block_utils_fast.pyx"],
language="c++",
extra_compile_args=extra_compile_args,
),
]
cmdclass = {}
try:
# torch is not available when generating docs
from torch.utils import cpp_extension
extensions.extend(
[
cpp_extension.CppExtension(
"fairseq.libbase",
sources=[
"fairseq/clib/libbase/balanced_assignment.cpp",
],
)
]
)
extensions.extend(
[
cpp_extension.CppExtension(
"fairseq.libnat",
sources=[
"fairseq/clib/libnat/edit_dist.cpp",
],
),
cpp_extension.CppExtension(
"alignment_train_cpu_binding",
sources=[
"examples/operators/alignment_train_cpu.cpp",
],
),
]
)
if "CUDA_HOME" in os.environ:
extensions.extend(
[
cpp_extension.CppExtension(
"fairseq.libnat_cuda",
sources=[
"fairseq/clib/libnat_cuda/edit_dist.cu",
"fairseq/clib/libnat_cuda/binding.cpp",
],
),
cpp_extension.CppExtension(
"fairseq.ngram_repeat_block_cuda",
sources=[
"fairseq/clib/cuda/ngram_repeat_block_cuda.cpp",
"fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu",
],
),
cpp_extension.CppExtension(
"alignment_train_cuda_binding",
sources=[
"examples/operators/alignment_train_kernel.cu",
"examples/operators/alignment_train_cuda.cpp",
],
),
]
)
cmdclass["build_ext"] = cpp_extension.BuildExtension
except ImportError:
pass
if "READTHEDOCS" in os.environ:
# don't build extensions when generating docs
extensions = []
if "build_ext" in cmdclass:
del cmdclass["build_ext"]
# use CPU build of PyTorch
dependency_links = [
"https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl"
]
else:
dependency_links = []
if "clean" in sys.argv[1:]:
# Source: https://bit.ly/2NLVsgE
print("deleting Cython files...")
import subprocess
subprocess.run(
["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"],
shell=True,
)
extra_packages = []
if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")):
extra_packages.append("fairseq.model_parallel.megatron.mpu")
def do_setup(package_data):
setup(
name="fairseq",
version=version,
description="Facebook AI Research Sequence-to-Sequence Toolkit",
url="https://github.com/pytorch/fairseq",
classifiers=[
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
long_description=readme,
long_description_content_type="text/markdown",
setup_requires=[
"cython",
'numpy<1.20.0; python_version<"3.7"',
'numpy; python_version>="3.7"',
"setuptools>=18.0",
],
install_requires=[
"cffi",
"cython",
'dataclasses; python_version<"3.7"',
"hydra-core>=1.0.7,<1.1",
"omegaconf<2.1",
'numpy<1.20.0; python_version<"3.7"',
'numpy; python_version>="3.7"',
"regex",
"sacrebleu>=1.4.12",
# "torch",
"tqdm",
"bitarray",
# "torchaudio>=0.8.0",
],
dependency_links=dependency_links,
packages=find_packages(
exclude=[
"examples",
"examples.*",
"scripts",
"scripts.*",
"tests",
"tests.*",
]
)
+ extra_packages,
package_data=package_data,
ext_modules=extensions,
test_suite="tests",
entry_points={
"console_scripts": [
"fairseq-eval-lm = fairseq_cli.eval_lm:cli_main",
"fairseq-generate = fairseq_cli.generate:cli_main",
"fairseq-hydra-train = fairseq_cli.hydra_train:cli_main",
"fairseq-interactive = fairseq_cli.interactive:cli_main",
"fairseq-preprocess = fairseq_cli.preprocess:cli_main",
"fairseq-score = fairseq_cli.score:cli_main",
"fairseq-train = fairseq_cli.train:cli_main",
"fairseq-validate = fairseq_cli.validate:cli_main",
],
},
cmdclass=cmdclass,
zip_safe=False,
)
def get_files(path, relative_to="fairseq"):
all_files = []
for root, _dirs, files in os.walk(path, followlinks=True):
root = os.path.relpath(root, relative_to)
for file in files:
if file.endswith(".pyc"):
continue
all_files.append(os.path.join(root, file))
return all_files
if __name__ == "__main__":
try:
# symlink examples into fairseq package so package_data accepts them
fairseq_examples = os.path.join("fairseq", "examples")
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples):
os.symlink(os.path.join("..", "examples"), fairseq_examples)
package_data = {
"fairseq": (
get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config"))
)
}
do_setup(package_data)
finally:
if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples):
os.unlink(fairseq_examples)