"""
╔══════════════════════════════════════════════════════════════╗
║ Author: p3nGu1nZz                                            ║
║ Email: rawsonkara@gmail.com                                  ║
║ Copyright (c) 2025 Huggingface                               ║
║ License: MIT                                                 ║
║ This source file is subject to the terms and conditions      ║
║ defined in the file 'LICENSE', which is part of this sorcery ║
║ code package.                                                ║
╚══════════════════════════════════════════════════════════════╝
"""

DONUT = """
    ⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⣀⣀⣀⣀⣀⠀⠀⠀⠀⠀⠀
    ⠀⠀⠀⢀⡤⠖⠛⠛⠉⢉⠉⠉⠙⠛⠛⠻⠷⣦⡀⠀⠀
    ⠀⠀⣰⠉⠀⠀⠖⡄⠀⠘⠀⢠⡰⠀⠀⠀⣨⡝⢦⠀
    ⠀⢀⣯⠯⠃⠀⠀⠀⢐⣤⣤⣤⣄⡀⠈⠀⠀⠀⠈⡇
    ⠀⡜⠁⠀⠀⡀⠀⡴⠛⠉⠙⠺⡻⣦⠀⠀⠲⠆⠀⢹
    ⢀⡷⡦⠀⠘⠃⢸⠁⠀⠀⠀⠀⣻⣿⠀⠀⠀⠀⣿
    ⠘⠿⠁⠀⠀⠀⠈⢆⠀⠀⠀⢀⣴⣿⠏⢶⠀⠖⢠⡏
    ⠀⢗⢒⣤⣄⠀⠀⠀⠉⠙⠛⠛⢩⡄⠀⠀⠀⠀⡾⠁
    ⠀⠈⢧⣸⠱⠆⠀⠀⠁⠀⠀⠀⠀⢀⡠⢤⠀⣠⠞⠁⠀
    ⠀⠀⠀⠉⠻⢬⡩⡉⢭⡑⡂⢐⣪⣵⣲⡾⠞⠋⠀⠀⠀
    ⠀⠀⠀⠀⠀⠀⠈⠉⠉⠁⠉⠉⠉⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀
"""

CONFIG = {
    "url": "https://download.pytorch.org/whl/cu124",
    "cmds": [
        "python -m pip install --upgrade pip",
        "python -m pip install --upgrade setuptools",
        "python -m pip install --upgrade wheel"
    ],
    "pkg_common": ["tqdm==4.62.3", "pillow==8.4.0"],
    "pkg_cuda"  : ["torch==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124"],
    "pkg_cpu"   : ["torch==2.5.1"],
    "pkg_indiv" : ["matplotlib", "numpy==2.1.3"]
}

import subprocess as sp
import sys
import importlib.metadata as imd

def ensure_rich():
    """
    Ensure the rich library is installed.
    """
    try:
        imd.version("rich")
    except imd.PackageNotFoundError:
        sp.check_call([sys.executable, "-m", "pip", "install", "rich"])

# Ensure rich is installed before importing
ensure_rich()

from rich.console import Console
from rich.text import Text

console = Console()

def run_cmd(cmd):
    """
    Execute a shell command.
    """
    try:
        sp.check_call(cmd, shell=True)
    except sp.CalledProcessError as err:
        console.print(f"[red]Error running command: {cmd}\n{err}[/red]")

def install_pkg(pkg):
    """
    Install a package using pip.
    """
    try:
        sp.check_call([sys.executable, "-m", "pip", "install", pkg])
    except sp.CalledProcessError as err:
        console.print(f"[red]Error installing package: {pkg}\n{err}[/red]")

def is_installed(pkg):
    """
    Check if a package is installed.
    """
    try:
        imd.version(pkg)
        return True
    except imd.PackageNotFoundError:
        return False

def display_donut():
    """
    Display the ASCII donut art.
    """
    donut_text = Text(DONUT, style="bold magenta")
    console.print(donut_text)

def upgrade_tools():
    """
    Upgrade pip, setuptools, and wheel.
    """
    for cmd in CONFIG["cmds"]:
        run_cmd(cmd)

def install_packages(packages):
    """
    Install a list of packages if not already installed.
    """
    for pkg in packages:
        if not is_installed(pkg.split("==")[0]):
            install_pkg(pkg)

def ensure_numpy():
    """
    Ensure numpy is compatible.
    """
    try:
        np_ver = imd.version("numpy")
        if not np_ver.startswith("2.1.3"):
            install_pkg("numpy==2.1.3")
    except imd.PackageNotFoundError:
        install_pkg("numpy==2.1.3")

def main():
    """
    Main function to orchestrate the setup.
    """
    display_donut()
    upgrade_tools()
    ensure_numpy()
    install_pkg("matplotlib")
    install_packages(CONFIG["pkg_common"])
    try:
        import torch
        if torch.cuda.is_available():
            install_packages(CONFIG["pkg_cuda"])
        else:
            install_packages(CONFIG["pkg_cpu"])
    except ImportError:
        install_packages(CONFIG["pkg_cpu"])

    console.print("[bold green]All Your Donut Belong To Us![/bold green]\n")

if __name__ == "__main__":
    main()