import datetime
import os
import shutil
import subprocess
import tempfile
import uuid

import logging
from typing import List

ARG_ORDER = ["samples_per_complex", "keep_local_structures", "save_visualisation",
             "pocket_center_x", "pocket_center_y", "pocket_center_z", "flexible_sidechains"]


def set_env_variables():
    if "DiffDock-Pocket-Dir" not in os.environ:
        work_dir = os.path.abspath(os.path.join("../DiffDock-Pocket"))
        if os.path.exists(work_dir):
            os.environ["DiffDock-Pocket-Dir"] = work_dir
        else:
            raise ValueError(f"DiffDock-Pocket-Dir {work_dir} not found")

    if "LOG_LEVEL" not in os.environ:
        os.environ["LOG_LEVEL"] = "INFO"


def configure_logging(level=None):
    if level is None:
        level = getattr(logging, os.environ.get("LOG_LEVEL", "INFO"))

    # Note that this sets the universal logger,
    # which includes other libraries.
    logging.basicConfig(
        level=level,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[
            logging.StreamHandler(),  # Outputs logs to stderr by default
            # If you also want to log to a file, uncomment the following line:
            # logging.FileHandler('my_app.log', mode='a', encoding='utf-8')
        ],
    )


def kwargs_to_cli_args(**kwargs) -> List[str]:
    """
    Converts keyword arguments to a CLI argument string.
    Boolean kwargs are added as flags if True, and omitted if False.
    """
    cli_args = []
    for key, value in kwargs.items():
        if isinstance(value, bool):
            if value:
                cli_args.append(f"--{key}")
        else:
            if value is not None and str(value) != "":
                cli_args.append(f"--{key}={value}")

    return cli_args


def read_file_lines(fi_path: str):
    with open(fi_path, "r") as fp:
        lines = fp.readlines()
    mol = "".join(lines)
    return mol


def run_cli_command(
    protein_path: str,
    ligand: str,
    other_args_dict: dict,
    *args,
    work_dir=None,
):
    if work_dir is None:
        work_dir = os.environ.get(
            "DiffDock-Pocket-Dir", os.path.join(os.environ["HOME"], "DiffDock-Pocket")
        )

    assert len(args) == len(ARG_ORDER), f'Expected {len(ARG_ORDER)} arguments, got {len(args)}'

    all_arg_dict = other_args_dict
    all_arg_dict["protein_path"] = protein_path
    all_arg_dict["ligand"] = ligand

    for arg_name, arg_val in zip(ARG_ORDER, args):
        all_arg_dict[arg_name] = arg_val

    command = [
        "python3",
        "inference.py"]

    command += kwargs_to_cli_args(**all_arg_dict)

    with tempfile.TemporaryDirectory() as temp_dir:
        temp_dir_path = temp_dir
        command.append(f"--out_dir={temp_dir_path}")

        # Convert command list to string for printing
        command_str = " ".join(command)
        logging.info(f"Executing command: {command_str}")

        # Running the command
        try:
            skip_running = False
            if not skip_running:
                result = subprocess.run(
                    command,
                    cwd=work_dir,
                    check=False,
                    text=True,
                    capture_output=True,
                    env=os.environ,
                )
                logging.debug(f"Command output:\n{result.stdout}")
                if result.stderr:
                    # Skip progress bar lines
                    stderr_lines = result.stderr.split("\n")
                    stderr_lines = filter(lambda x: "%|" not in x, stderr_lines)
                    stderr_text = "\n".join(stderr_lines)
                    logging.error(f"Command error:\n{stderr_text}")
        except subprocess.CalledProcessError as e:
            logging.error(f"An error occurred while executing the command: {e}")

        # If there's a file for viewing, load it and view it.
        sub_dirs = [os.path.join(temp_dir_path, x) for x in os.listdir(temp_dir_path)]
        sub_dirs = list(filter(lambda x: os.path.isdir(x), sub_dirs))
        pdb_path = pdb_text = sdf_path = sdf_text = None
        if len(sub_dirs) == 1:
            sub_dir = sub_dirs[0]
            pdb_path = os.path.join(sub_dir, "rank1_reverseprocess_protein.pdb")
            sdf_path = os.path.join(sub_dir, "rank1.sdf")

        if skip_running:
            # Test/debugging only
            example_dir = os.path.join(os.environ["HOME"], "Projects", "DiffDock-Pocket", "example_data", "example_outputs")
            pdb_path = os.path.join(example_dir, "rank1_reverseprocess_protein.pdb")
            sdf_path = os.path.join(example_dir, "rank1.sdf")

        logging.debug(f"PDB path: {pdb_path}")
        logging.debug(f"SDF path: {sdf_path}")
        if pdb_path and os.path.exists(pdb_path):
            pdb_text = read_file_lines(pdb_path)
        if sdf_path and os.path.exists(sdf_path):
            sdf_text = read_file_lines(sdf_path)

        # Zip the output directory
        # Generate a unique filename using a timestamp and a UUID
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        uuid_tag = str(uuid.uuid4())[0:8]
        unique_filename = f"diffdock_pocket_output_{timestamp}_{uuid_tag}"
        zip_base_name = os.path.join("tmp", unique_filename)

        logging.debug(f"About to zip directory '{temp_dir}' to {unique_filename}")

        full_zip_path = shutil.make_archive(zip_base_name, "zip", temp_dir)

        logging.debug(f"Directory '{temp_dir}' zipped to {unique_filename}'")

    return full_zip_path, pdb_text, sdf_text


def main_test():
    # Testing code
    set_env_variables()
    configure_logging()

    work_dir = os.path.abspath("../DiffDock-Pocket")
    os.environ["DiffDock-Pocket-Dir"] = work_dir
    protein_path = os.path.join(work_dir, "example_data", "3dpf_protein.pdb")
    ligand = os.path.join(work_dir, "example_data", "3dpf_ligand.sdf")
    other_arg_file = os.path.join(work_dir, "example_data", "example_args.yml")

    run_cli_command(
        protein_path,
        ligand,
        samples_per_complex=1,
        keep_local_structures=True,
        save_visualisation=True,
        other_arg_file=other_arg_file
    )


if __name__ == "__main__":
    main_test()