Spaces:
Build error
Build error
import os | |
import re | |
import sys | |
import shutil | |
import argparse | |
import setup_common | |
# Get the absolute path of the current file's directory (Kohua_SS project directory) | |
project_directory = os.path.dirname(os.path.abspath(__file__)) | |
# Check if the "setup" directory is present in the project_directory | |
if "setup" in project_directory: | |
# If the "setup" directory is present, move one level up to the parent directory | |
project_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
# Add the project directory to the beginning of the Python search path | |
sys.path.insert(0, project_directory) | |
from library.custom_logging import setup_logging | |
# Set up logging | |
log = setup_logging() | |
def check_torch(): | |
# Check for nVidia toolkit or AMD toolkit | |
if shutil.which('nvidia-smi') is not None or os.path.exists( | |
os.path.join( | |
os.environ.get('SystemRoot') or r'C:\Windows', | |
'System32', | |
'nvidia-smi.exe', | |
) | |
): | |
log.info('nVidia toolkit detected') | |
elif shutil.which('rocminfo') is not None or os.path.exists( | |
'/opt/rocm/bin/rocminfo' | |
): | |
log.info('AMD toolkit detected') | |
else: | |
log.info('Using CPU-only Torch') | |
try: | |
import torch | |
log.info(f'Torch {torch.__version__}') | |
# Check if CUDA is available | |
if not torch.cuda.is_available(): | |
log.warning('Torch reports CUDA not available') | |
else: | |
if torch.version.cuda: | |
# Log nVidia CUDA and cuDNN versions | |
log.info( | |
f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}' | |
) | |
elif torch.version.hip: | |
# Log AMD ROCm HIP version | |
log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') | |
else: | |
log.warning('Unknown Torch backend') | |
# Log information about detected GPUs | |
for device in [ | |
torch.cuda.device(i) for i in range(torch.cuda.device_count()) | |
]: | |
log.info( | |
f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}' | |
) | |
return int(torch.__version__[0]) | |
except Exception as e: | |
log.error(f'Could not load torch: {e}') | |
sys.exit(1) | |
def main(): | |
setup_common.check_repo_version() | |
# Parse command line arguments | |
parser = argparse.ArgumentParser( | |
description='Validate that requirements are satisfied.' | |
) | |
parser.add_argument( | |
'-r', | |
'--requirements', | |
type=str, | |
help='Path to the requirements file.', | |
) | |
parser.add_argument('--debug', action='store_true', help='Debug on') | |
args = parser.parse_args() | |
torch_ver = check_torch() | |
if args.requirements: | |
setup_common.install_requirements(args.requirements, check_no_verify_flag=True) | |
else: | |
if torch_ver == 1: | |
setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=True) | |
else: | |
setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=True) | |
if __name__ == '__main__': | |
main() | |