import copy
import glob
import json
import os
#os.environ['CURL_CA_BUNDLE'] = ''
#os.environ['REQUESTS_CA_BUNDLE'] = ''
os.environ["CURL_CA_BUNDLE"] = "/etc/ssl/certs/ca-certificates.crt"
print(f"{os.environ.get('CURL_CA_BUNDLE') = }")
print(f"{os.environ.get('REQUESTS_CA_BUNDLE') = }")
import hashlib
import time
import requests
from collections import namedtuple
from xml.sax.saxutils import escape as xmlEscape, quoteattr as xmlQuoteAttr

import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, snapshot_download

from compare_significance import SUPPORTED_METRICS

VISIBLE_METRICS = SUPPORTED_METRICS + ["macro_f1"]

api = HfApi()

ORG = "xdolez52"
REPO = f"{ORG}/LLM_benchmark_data"
HF_TOKEN = os.environ.get("HF_TOKEN")
TASKS_METADATA_PATH = "./tasks_metadata.json"

MARKDOWN_SPECIAL_CHARACTERS = {
    "#": "#",  # for usage in xml.sax.saxutils.escape as entities must be first
    "\\": "\",
    "`": "`",
    "*": "*",
    "_": "_",
    "{": "{",
    "}": "}",
    "[": "[",
    "]": "]",
    "(": "(",
    ")": ")",
    "+": "+",
    "-": "-",
    ".": ".",
    "!": "!",
    "=": "=",
    "|": "|"
}

def check_significance_send_task(model_a_path, model_b_path):
    url = 'https://czechllm.fit.vutbr.cz:4443/benczechmark-leaderboard/compare_significance/'

    # prepare and send request
    with (
        open(model_a_path, 'rb') as model_a_fp,
        open(model_b_path, 'rb') as model_b_fp,
    ):
        files = {
            'model_a': model_a_fp,
            'model_b': model_b_fp,
        }
        response = requests.post(url, files=files)

    # check response
    if response.status_code == 202:
        result_url = response.url
        #task_id = response.json()['task_id']
    elif response.status_code == 429:
        raise RuntimeError('Server is too busy. Please try again later.')  # TODO: try-except do raise gr.error
    else:
        raise RuntimeError(f'Failed to submit task. Status code: {response.status_code}')  # TODO: try-except do raise gr.error

    return result_url

def check_significance_wait_for_result(result_url):
    while True:
        response = requests.get(result_url)
        if response.status_code == 200:
            result = response.json()
            break
        elif response.status_code == 202:
            time.sleep(5)
        else:
            raise RuntimeError(f'Failed to get result. Status code: {response.status_code}')  # TODO: try-except do raise gr.error
    
    print(result)
    return result['result']

def check_significance(model_a_path, model_b_path):
    result_url = check_significance_send_task(model_a_path, model_b_path)
    result = check_significance_wait_for_result(result_url)
    return result

class LeaderboardServer:
    def __init__(self):
        self.server_address = REPO
        self.repo_type = "dataset"
        self.local_leaderboard = snapshot_download(
            self.server_address,
            repo_type=self.repo_type,
            token=HF_TOKEN,
            local_dir="./",
        )
        self.submission_id_to_file = {}  # Map submission ids to file paths
        self.tasks_metadata = json.load(open(TASKS_METADATA_PATH))
        self.tasks_categories = {self.tasks_metadata[task]["category"] for task in self.tasks_metadata}
        self.tasks_category_overall = "Overall"
        self.submission_ids = set()
        self.fetch_existing_models()
        self.tournament_results = self.load_tournament_results()
        self.pre_submit = None

    def update_leaderboard(self):
        self.local_leaderboard = snapshot_download(
            self.server_address,
            repo_type=self.repo_type,
            token=HF_TOKEN,
            local_dir="./",
        )
        self.fetch_existing_models()
        self.tournament_results = self.load_tournament_results()

    def load_tournament_results(self):
        metadata_rank_paths = os.path.join(self.local_leaderboard, "tournament.json")
        if not os.path.exists(metadata_rank_paths):
            return {}
        with open(metadata_rank_paths) as ranks_file:
            results = json.load(ranks_file)
        return results

    def fetch_existing_models(self):
        # Models data
        for submission_file in glob.glob(os.path.join(self.local_leaderboard, "data") + "/*.json"):
            data = json.load(open(submission_file))
            metadata = data.get('metadata')
            if metadata is None:
                continue
            submission_id = metadata["submission_id"]
            self.submission_ids.add(submission_id)

            self.submission_id_to_file[submission_id] = submission_file

    def get_leaderboard(self, tournament_results=None, category=None):
        tournament_results = tournament_results if tournament_results else self.tournament_results
        category = category if category else self.tasks_category_overall

        if len(tournament_results) == 0:
            return pd.DataFrame(columns=['No submissions yet'])
        else:
            processed_results = []
            for submission_id in tournament_results.keys():
                path = self.submission_id_to_file.get(submission_id)
                if path is None:
                    if self.pre_submit and submission_id == self.pre_submit.submission_id:
                        data = json.load(open(self.pre_submit.file))
                    else:
                        raise gr.Error(f"Internal error: Submission [{submission_id}] not found")
                elif path:
                    data = json.load(open(path))
                else:
                    raise gr.Error(f"Submission [{submission_id}] not found")
                
                if submission_id != data["metadata"]["submission_id"]:
                    raise gr.Error(f"Proper submission [{submission_id}] not found")

                local_results = {}
                win_score = {}
                visible_metrics_map_word_to_header = {}
                for task in self.tasks_metadata.keys():
                    
                    task_category = self.tasks_metadata[task]["category"]
                    if category not in (self.tasks_category_overall, task_category):
                        continue
                    else:
                        # tournament_results
                        num_of_competitors = 0
                        num_of_wins = 0
                        for competitor_id in tournament_results[submission_id].keys() - {submission_id}: # without self
                            num_of_competitors += 1
                            if tournament_results[submission_id][competitor_id][task]:
                                num_of_wins += 1
                        task_score = num_of_wins / num_of_competitors * 100 # TODO: if num_of_competitors > 0 else ???
                        win_score.setdefault(task_category, []).append(task_score)
                        
                        if category == task_category:
                            local_results[task] = task_score
                            for metric in VISIBLE_METRICS:
                                visible_metrics_map_word_to_header[task + "_" + metric] = self.tasks_metadata[task]["abbreviation"] + " " + metric
                                metric_value = data['results'][task].get(metric)
                                if metric_value is not None:
                                    local_results[task + "_" + metric] = metric_value * 100
                                    break  # Only the first metric of every task
                
                
                for c in win_score:
                    win_score[c] = sum(win_score[c]) / len(win_score[c])
                
                if category == self.tasks_category_overall:
                    for c in win_score:
                        local_results[c] = win_score[c]
                    local_results["average_score"] = sum(win_score.values()) / len(win_score)
                else:
                    local_results["average_score"] = win_score[category]
                
                model_link = data["metadata"]["link_to_model"]
                model_title = data["metadata"]["team_name"] + "/" + data["metadata"]["model_name"]
                model_title_abbr = self.abbreviate(data["metadata"]["team_name"], 14) + "/" + self.abbreviate(data["metadata"]["model_name"], 14)
                local_results["model"] = f'<a href={xmlQuoteAttr(model_link)} title={xmlQuoteAttr(model_title)}>{xmlEscape(model_title_abbr, MARKDOWN_SPECIAL_CHARACTERS)}</a>'
                release = data["metadata"].get("submission_timestamp")
                release = time.strftime("%Y-%m-%d", time.gmtime(release)) if release else "N/A"
                local_results["release"] = release
                local_results["model_type"] = data["metadata"]["model_type"]
                local_results["parameters"] = data["metadata"]["parameters"]
                
                if self.pre_submit and submission_id == self.pre_submit.submission_id:
                    processed_results.insert(0, local_results)
                else:
                    processed_results.append(local_results)
            dataframe = pd.DataFrame.from_records(processed_results)
            
            extra_attributes_map_word_to_header = {
                "model": "Model",
                "release": "Release",
                "average_score": "Average ⬆️",
                "team_name": "Team name",
                "model_name": "Model name",
                "model_type": "Type",
                "parameters": "# θ (B)",
                "input_length": "Input length (# tokens)",
                "precision": "Precision",
                "description": "Description",
                "link_to_model": "Link to model"
            }
            first_attributes = [
                "model",
                "release",
                "model_type",
                "parameters",
                "average_score",
            ]
            df_order = [
                key
                for key in dict.fromkeys(
                    first_attributes
                    + list(self.tasks_metadata.keys())
                    + list(dataframe.columns)
                ).keys()
                if key in dataframe.columns
            ]
            dataframe = dataframe[df_order]
            attributes_map_word_to_header = {key: value["abbreviation"] for key, value in self.tasks_metadata.items()}
            attributes_map_word_to_header.update(extra_attributes_map_word_to_header)
            attributes_map_word_to_header.update(visible_metrics_map_word_to_header)
            dataframe = dataframe.rename(
                columns=attributes_map_word_to_header
            )
            return dataframe

    def start_tournament(self, new_submission_id, new_model_file):
        new_tournament = copy.deepcopy(self.tournament_results)
        new_tournament[new_submission_id] = {}
        new_tournament[new_submission_id][new_submission_id] = {
            task: False for task in self.tasks_metadata.keys()
        }
        
        for competitor_id in self.submission_ids:
            res = check_significance_send_task(new_model_file, self.submission_id_to_file[competitor_id])
            res_inverse = check_significance_send_task(self.submission_id_to_file[competitor_id], new_model_file)
            
            res = check_significance_wait_for_result(res)
            res_inverse = check_significance_wait_for_result(res_inverse)
            
            new_tournament[new_submission_id][competitor_id] = {
                task: data["significant"] for task, data in res.items()
            }
            new_tournament[competitor_id][new_submission_id] = {
                task: data["significant"] for task, data in res_inverse.items()
            }
        return new_tournament

    @staticmethod
    def abbreviate(s, max_length, dots_place="center"):
        if len(s) <= max_length:
            return s
        else:
            if max_length <= 1:
                return "…"
            elif dots_place == "begin":
                return "…" + s[-max_length + 1:].lstrip()
            elif dots_place == "center" and max_length >= 3:
                max_length_begin = max_length // 2
                max_length_end = max_length - max_length_begin - 1
                return s[:max_length_begin].rstrip() + "…" + s[-max_length_end:].lstrip()
            else:  # dots_place == "end"
                return s[:max_length - 1].rstrip() + "…"

    @staticmethod
    def create_submission_id(metadata):
        # Délka ID musí být omezena, protože se používá v názvu souboru
        submission_id = "_".join([metadata[key][:7] for key in (
            "team_name",
            "model_name",
            "model_predictions_sha256",
            "model_results_sha256",
        )])
        submission_id = submission_id.replace("/", "_").replace("\n", "_").strip()
        return submission_id

    @staticmethod
    def get_sha256_hexdigest(obj):
        data = json.dumps(
            obj,
            separators=(',', ':'),
            sort_keys=True,
            ensure_ascii=True,
        ).encode()
        result = hashlib.sha256(data).hexdigest()
        return result
    
    PreSubmit = namedtuple('PreSubmit', 'tournament_results, submission_id, file')
    
    def prepare_model_for_submission(self, file, metadata) -> None:
        with open(file, "r") as f:
            data = json.load(f)
        
        data["metadata"] = metadata
        
        metadata["model_predictions_sha256"] = self.get_sha256_hexdigest(data["predictions"])
        metadata["model_results_sha256"] = self.get_sha256_hexdigest(data["results"])
        
        submission_id = self.create_submission_id(metadata)
        metadata["submission_id"] = submission_id
        
        metadata["submission_timestamp"] = time.time()  # timestamp
        
        with open(file, "w") as f:
            json.dump(data, f, separators=(',', ':'))  # compact JSON
        
        tournament_results = self.start_tournament(submission_id, file)
        self.pre_submit = self.PreSubmit(tournament_results, submission_id, file)

    def save_pre_submit(self):
        if self.pre_submit:
            tournament_results, submission_id, file = self.pre_submit
            api.upload_file(
                path_or_fileobj=file,
                path_in_repo=f"data/{submission_id}.json",
                repo_id=self.server_address,
                repo_type=self.repo_type,
                token=HF_TOKEN,
            )

            # Temporary save tournament results
            tournament_results_path = os.path.join(self.local_leaderboard, "tournament.json")
            with open(tournament_results_path, "w") as f:
                json.dump(tournament_results, f, sort_keys=True, indent=2)  # readable JSON

            api.upload_file(
                path_or_fileobj=tournament_results_path,
                path_in_repo="tournament.json",
                repo_id=self.server_address,
                repo_type=self.repo_type,
                token=HF_TOKEN,
            )

    def get_model_detail(self, submission_id):
        path = self.submission_id_to_file.get(submission_id)
        if path is None:
            raise gr.Error(f"Submission [{submission_id}] not found")
        data = json.load(open(path))
        return data["metadata"]