from __future__ import annotations
from pathlib import Path
import time

import gradio as gr

from gradio_molecule3d import Molecule3D
from gradio_molecule2d import molecule2d
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
from biotite.structure import centroid, from_template
from biotite.structure.io import load_structure
from biotite.structure.io.mol import MOLFile, SDFile
from biotite.structure.io.pdb import PDBFile

from plinder.eval.docking.write_scores import evaluate


EVAL_METRICS = ["system", "LDDT-PLI", "LDDT-LP", "BISY-RMSD"]

EVAL_METRICS_PINDER = ["system","L_rms", "I_rms", "F_nat", "DOCKQ", "CAPRI_class"]


import os

from huggingface_hub import HfApi

# Info to change for your repository
# ----------------------------------
TOKEN = os.environ.get("HF_TOKEN") # A read/write token for your org

OWNER = "MLSB" # Change to your org - don't forget to create a results and request dataset, with the correct format!
# ----------------------------------

REPO_ID = f"{OWNER}/leaderboard2024"
QUEUE_REPO = f"{OWNER}/requests"
RESULTS_REPO = f"{OWNER}/results"

# If you setup a cache later, just change HF_HOME
CACHE_PATH=os.getenv("HF_HOME", ".")

# Local caches
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")

API = HfApi(token=TOKEN)



def get_metrics(
    system_id: str,
    receptor_file: Path,
    ligand_file: Path,
    flexible: bool = True,
    posebusters: bool = True,
    methodname: str = "",
    store:bool =True
) -> tuple[pd.DataFrame, float]:
    start_time = time.time()
    metrics = pd.DataFrame(
        [
            evaluate(
                model_system_id=system_id,
                reference_system_id=system_id,
                receptor_file=receptor_file,
                ligand_file_list=[Path(ligand_file)],
                flexible=flexible,
                posebusters=posebusters,
                posebusters_full=False,
            ).get("LIG_0", {})
        ]
    )
    if posebusters:
        metrics["posebusters"] = metrics[
            [col for col in metrics.columns if col.startswith("posebusters_")]
        ].sum(axis=1)
        metrics["posebusters_valid"] = metrics[
            [col for col in metrics.columns if col.startswith("posebusters_")]
        ].sum(axis=1) == 20
    columns = ["reference", "lddt_pli_ave", "lddt_lp_ave", "bisy_rmsd_ave"]
    if flexible:
        columns.extend(["lddt", "bb_lddt"])
    if posebusters:
        columns.extend([col for col in metrics.columns if col.startswith("posebusters")])

    metrics = metrics[columns].copy()
    mapping = {
            "lddt_pli_ave": "LDDT-PLI",
            "lddt_lp_ave": "LDDT-LP",
            "bisy_rmsd_ave": "BISY-RMSD",
            "reference": "system",
        }
    if flexible:
        mapping["lddt"] = "LDDT"
        mapping["bb_lddt"] = "Backbone LDDT"
    if posebusters:
        mapping["posebusters"] = "PoseBusters #checks"
        mapping["posebusters_valid"] = "PoseBusters valid"
    metrics.rename(
        columns=mapping,
        inplace=True,
    )
    if store:
        with tempfile.NamedTemporaryFile as temp:
            metrics.to_csv(temp.name)
            API.upload_file(
                    path_or_fileobj=temp.name,
                    path_in_repo=f"{dataset}/{methodname}/{system_id}/",
                    repo_id=QUEUE_REPO,
                    repo_type="dataset",
                    commit_message=f"Add {model_name} to eval queue",
            )
        API.upload_file(
                path_or_fileobj=receptor_file.name,
                path_in_repo=f"{dataset}/{methodname}/{system_id}/",
                repo_id=QUEUE_REPO,
                repo_type="dataset",
                commit_message=f"Add {model_name} to eval queue",
        )
        API.upload_file(
                path_or_fileobj=ligand_file.name,
                path_in_repo=f"{dataset}/{methodname}/{system_id}/",
                repo_id=QUEUE_REPO,
                repo_type="dataset",
                commit_message=f"Add {model_name} to eval queue",
        )
    end_time = time.time()
    run_time = end_time - start_time
    return gr.DataFrame(metrics, visible=True), run_time



def get_metrics_pinder(
    system_id: str,
    complex_file: Path,
    methodname: str = "",
    store:bool =True
) -> tuple[pd.DataFrame, float]:
    start_time = time.time()
    
    if not isinstance(prediction, Path):
        prediction = Path(prediction)
    # Infer the ground-truth name from prediction filename or directory where its stored
    # We need to figure out how we plan to consistently map predictions to systems so that eval metrics can be calculated
    # I assume we won't distribute the ground-truth structures (though they are already accessible if we don't blind system IDs)
    native = Path(f"./ground_truth/{system_id}.pdb")
    # alternatively 
    # native = Path(f"./ground_truth/{prediction.parent.parent.stem}.pdb")
    # OR we need the user to provide prediction + system name
    try:        
        # Get eval metrics for the prediction
        bdq = BiotiteDockQ(native, complex_file.name, parallel_io=False)
        metrics = bdq.calculate()
        metrics = metrics[["system", "LRMS", "iRMS", "Fnat", "DockQ", "CAPRI"]].copy()
        metrics.rename(columns={"LRMS": "L_rms", "iRMS": "I_rms", "Fnat": "F_nat", "DockQ": "DOCKQ", "CAPRI": "CAPRI_class"}, inplace=True)
    except Exception as e:
        failed_metrics = {"L_rms": 100.0, "I_rms": 100.0, "F_nat": 0.0, "DOCKQ": 0.0, "CAPRI_class": "Incorrect"} 
        metrics = pd.DataFrame([failed_metrics])
        metrics["system"] = native.stem
        gr.Error(f"Failed to evaluate prediction [{prediction}]:\n{e}")
    if store:
        # Upload to hub
        with tempfile.NamedTemporaryFile as temp:
            metrics.to_csv(temp.name)
            API.upload_file(
                    path_or_fileobj=temp.name,
                    path_in_repo=f"{dataset}/{methodname}/{system_id}/",
                    repo_id=QUEUE_REPO,
                    repo_type="dataset",
                    commit_message=f"Add {model_name} to eval queue",
            )
        API.upload_file(
                path_or_fileobj=complex_file.name,
                path_in_repo=f"{dataset}/{methodname}/{system_id}/",
                repo_id=QUEUE_REPO,
                repo_type="dataset",
                commit_message=f"Add {model_name} to eval queue",
        )
    end_time = time.time()
    run_time = end_time - start_time
    return gr.DataFrame(metrics, visible=True), run_time

with gr.Blocks() as app:
    with gr.Tab("🧬 PINDER evaluation template"):
        with gr.Row():
            with gr.Column():
                input_system_id_pinder = gr.Textbox(label="PINDER system ID")
                input_complex_pinder = gr.File(label="Receptor file")
                methodname_pinder = gr.Textbox(label="Name of your method in the format mlsb/spacename")
                store_pinder = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
        gr.Examples(
            [
                [
                    "4neh__1__1.B__1.H",
                    "input_protein_test.cif",
                    "mlsb/test",
                    False
                ],
            ],
            [input_system_id_pinder, input_complex_pinder, methodname_pinder, store_pinder],
        )
        eval_btn_pinder = gr.Button("Run Evaluation")

       
                
        
    with gr.Tab("⚖️ PLINDER evaluation template"):
        with gr.Row():
            with gr.Column():
                input_system_id = gr.Textbox(label="PLINDER system ID")
                input_receptor_file = gr.File(label="Receptor file (CIF)")
                input_ligand_file = gr.File(label="Ligand file (SDF)")
                flexible = gr.Checkbox(label="Flexible docking", value=True)
                posebusters = gr.Checkbox(label="PoseBusters", value=True)
                methodname = gr.Textbox(label="Name of your method in the format mlsb/spacename")
                store = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
        gr.Examples(
            [
                [
                    "4neh__1__1.B__1.H",
                    "input_protein_test.cif",
                    "input_ligand_test.sdf",
                    True,
                    True,
                    "mlsb/test",
                    False
                ],
            ],
            [input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters,  methodname, store],
        )
        eval_btn = gr.Button("Run Evaluation")
       
        eval_run_time = gr.Textbox(label="Evaluation runtime")
        metric_table = gr.DataFrame(
            pd.DataFrame([], columns=EVAL_METRICS), label="Evaluation metrics", visible=False
        )

        metric_table_pinder = gr.DataFrame(
            pd.DataFrame([], columns=EVAL_METRICS_PINDER), label="Evaluation metrics", visible=False
        )

        eval_btn.click(
            get_metrics,
            inputs=[input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters, methodname, store],
            outputs=[metric_table, eval_run_time],
        )
        eval_btn_pinder.click(
            get_metrics_pinder,
            inputs=[input_system_id_pinder, input_complex_pinder, methodname_pinder, store_pinder],
            outputs=[metric_table_pinder, eval_run_time],
        )

app.launch()