import os
import os.path
import re
import shutil
import json
import stat
import tqdm
from collections import OrderedDict
from multiprocessing.pool import ThreadPool as Pool

from modules import shared, sd_models, hashes
from scripts import safetensors_hack, model_util, util
import modules.scripts as scripts


# MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count or 5
MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count if hasattr(shared.cmd_opts, "addnet_max_model_count") else 5
LORA_MODEL_EXTS = [".pt", ".ckpt", ".safetensors"]
re_legacy_hash = re.compile("\(([0-9a-f]{8})\)$")  # matches 8-character hashes, new hash has 12 characters
lora_models = {}  # "My_Lora(abcdef123456)" -> "C:/path/to/model.safetensors"
lora_model_names = {}  # "my_lora" -> "My_Lora(My_Lora(abcdef123456)"
legacy_model_names = {}
lora_models_dir = os.path.join(scripts.basedir(), "models/lora")
os.makedirs(lora_models_dir, exist_ok=True)


def is_safetensors(filename):
    return os.path.splitext(filename)[1] == ".safetensors"


def read_model_metadata(model_path, module):
    if model_path.startswith('"') and model_path.endswith('"'):  # trim '"' at start/end
        model_path = model_path[1:-1]
    if not os.path.exists(model_path):
        return None

    metadata = None
    if module == "LoRA":
        if os.path.splitext(model_path)[1] == ".safetensors":
            metadata = safetensors_hack.read_metadata(model_path)

    return metadata


def write_model_metadata(model_path, module, updates):
    if model_path.startswith('"') and model_path.endswith('"'):  # trim '"' at start/end
        model_path = model_path[1:-1]
    if not os.path.exists(model_path):
        return None

    from safetensors.torch import save_file

    back_up = shared.opts.data.get("additional_networks_back_up_model_when_saving", True)
    if back_up:
        backup_path = model_path + ".backup"
        if not os.path.exists(backup_path):
            print(f"[MetadataEditor] Backing up current model to {backup_path}")
            shutil.copyfile(model_path, backup_path)

    metadata = None
    tensors = {}
    if module == "LoRA":
        if os.path.splitext(model_path)[1] == ".safetensors":
            tensors, metadata = safetensors_hack.load_file(model_path, "cpu")

            for k, v in updates.items():
                metadata[k] = str(v)

            save_file(tensors, model_path, metadata)
            print(f"[MetadataEditor] Model saved: {model_path}")


def get_model_list(module, model, model_dir, sort_by):
    if model_dir == "":
        # Get list of models with same folder as this one
        model_path = lora_models.get(model, None)
        if model_path is None:
            return []
        model_dir = os.path.dirname(model_path)

    if not os.path.isdir(model_dir):
        return []

    found, _ = get_all_models([model_dir], sort_by, "")
    return list(found.keys())  # convert dict_keys to list


def traverse_all_files(curr_path, model_list):
    f_list = [(os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path)]
    for f_info in f_list:
        fname, fstat = f_info
        if os.path.splitext(fname)[1] in LORA_MODEL_EXTS:
            model_list.append(f_info)
        elif stat.S_ISDIR(fstat.st_mode):
            model_list = traverse_all_files(fname, model_list)
    return model_list


def get_model_hash(metadata, filename):
    if metadata is None:
        return hashes.calculate_sha256(filename)

    if "sshs_model_hash" in metadata:
        return metadata["sshs_model_hash"]

    return safetensors_hack.hash_file(filename)


def get_legacy_hash(metadata, filename):
    if metadata is None:
        return sd_models.model_hash(filename)

    if "sshs_legacy_hash" in metadata:
        return metadata["sshs_legacy_hash"]

    return safetensors_hack.legacy_hash_file(filename)


import filelock

cache_filename = os.path.join(scripts.basedir(), "hashes.json")
cache_data = None


def cache(subsection):
    global cache_data

    if cache_data is None:
        with filelock.FileLock(cache_filename + ".lock"):
            if not os.path.isfile(cache_filename):
                cache_data = {}
            else:
                with open(cache_filename, "r", encoding="utf8") as file:
                    cache_data = json.load(file)

    s = cache_data.get(subsection, {})
    cache_data[subsection] = s

    return s


def dump_cache():
    with filelock.FileLock(cache_filename + ".lock"):
        with open(cache_filename, "w", encoding="utf8") as file:
            json.dump(cache_data, file, indent=4)


def get_model_rating(filename):
    if not model_util.is_safetensors(filename):
        return 0

    metadata = safetensors_hack.read_metadata(filename)
    return int(metadata.get("ssmd_rating", "0"))


def has_user_metadata(filename):
    if not model_util.is_safetensors(filename):
        return False

    metadata = safetensors_hack.read_metadata(filename)
    return any(k.startswith("ssmd_") for k in metadata.keys())


def hash_model_file(finfo):
    filename = finfo[0]
    stat = finfo[1]
    name = os.path.splitext(os.path.basename(filename))[0]

    # Prevent a hypothetical "None.pt" from being listed.
    if name != "None":
        metadata = None

        cached = cache("hashes").get(filename, None)
        if cached is None or stat.st_mtime != cached["mtime"]:
            if metadata is None and model_util.is_safetensors(filename):
                try:
                    metadata = safetensors_hack.read_metadata(filename)
                except Exception as ex:
                    return {"error": ex, "filename": filename}
            model_hash = get_model_hash(metadata, filename)
            legacy_hash = get_legacy_hash(metadata, filename)
        else:
            model_hash = cached["model"]
            legacy_hash = cached["legacy"]

    return {"model": model_hash, "legacy": legacy_hash, "fileinfo": finfo}


def get_all_models(paths, sort_by, filter_by):
    fileinfos = []
    for path in paths:
        if os.path.isdir(path):
            fileinfos += traverse_all_files(path, [])

    show_only_safetensors = shared.opts.data.get("additional_networks_show_only_safetensors", False)
    show_only_missing_meta = shared.opts.data.get("additional_networks_show_only_models_with_metadata", "disabled")

    if show_only_safetensors:
        fileinfos = [x for x in fileinfos if is_safetensors(x[0])]

    if show_only_missing_meta == "has metadata":
        fileinfos = [x for x in fileinfos if has_user_metadata(x[0])]
    elif show_only_missing_meta == "missing metadata":
        fileinfos = [x for x in fileinfos if not has_user_metadata(x[0])]

    print("[AddNet] Updating model hashes...")
    data = []
    thread_count = max(1, int(shared.opts.data.get("additional_networks_hash_thread_count", 1)))
    p = Pool(processes=thread_count)
    with tqdm.tqdm(total=len(fileinfos)) as pbar:
        for res in p.imap_unordered(hash_model_file, fileinfos):
            pbar.update()
            if "error" in res:
                print(f"Failed to read model file {res['filename']}: {res['error']}")
            else:
                data.append(res)
    p.close()

    cache_hashes = cache("hashes")

    res = OrderedDict()
    res_legacy = OrderedDict()
    filter_by = filter_by.strip(" ")
    if len(filter_by) != 0:
        data = [x for x in data if filter_by.lower() in os.path.basename(x["fileinfo"][0]).lower()]
    if sort_by == "name":
        data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0]))
    elif sort_by == "date":
        data = sorted(data, key=lambda x: -x["fileinfo"][1].st_mtime)
    elif sort_by == "path name":
        data = sorted(data, key=lambda x: x["fileinfo"][0])
    elif sort_by == "rating":
        data = sorted(data, key=lambda x: get_model_rating(x["fileinfo"][0]), reverse=True)
    elif sort_by == "has user metadata":
        data = sorted(
            data, key=lambda x: os.path.basename(x["fileinfo"][0]) if has_user_metadata(x["fileinfo"][0]) else "", reverse=True
        )

    reverse = shared.opts.data.get("additional_networks_reverse_sort_order", False)
    if reverse:
        data = reversed(data)

    for result in data:
        finfo = result["fileinfo"]
        filename = finfo[0]
        stat = finfo[1]
        model_hash = result["model"]
        legacy_hash = result["legacy"]

        name = os.path.splitext(os.path.basename(filename))[0]

        # Commas in the model name will mess up infotext restoration since the
        # infotext is delimited by commas
        name = name.replace(",", "_")

        # Prevent a hypothetical "None.pt" from being listed.
        if name != "None":
            full_name = name + f"({model_hash[0:12]})"
            res[full_name] = filename
            res_legacy[legacy_hash] = full_name
            cache_hashes[filename] = {"model": model_hash, "legacy": legacy_hash, "mtime": stat.st_mtime}

    return res, res_legacy


def find_closest_lora_model_name(search: str):
    if not search or search == "None":
        return None

    # Match name and hash, case-sensitive
    # "MyModel-epoch00002(abcdef123456)"
    if search in lora_models:
        return search

    # Match model path, case-sensitive (from metadata editor)
    # "C:/path/to/mymodel-epoch00002.safetensors"
    if os.path.isfile(search):
        import json

        find = os.path.normpath(search)
        value = next((k for k in lora_models.keys() if lora_models[k] == find), None)
        if value:
            return value

    search = search.lower()

    # Match full name, case-insensitive
    # "mymodel-epoch00002"
    if search in lora_model_names:
        return lora_model_names.get(search)

    # Match legacy hash (8 characters)
    # "MyModel(abcd1234)"
    result = re_legacy_hash.search(search)
    if result is not None:
        model_hash = result.group(1)
        if model_hash in legacy_model_names:
            new_model_name = legacy_model_names[model_hash]
            return new_model_name

    # Use any model with the search term as the prefix, case-insensitive, sorted
    # by name length
    # "mymodel"
    applicable = [name for name in lora_model_names.keys() if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return lora_model_names[applicable[0]]


def update_models():
    global lora_models, lora_model_names, legacy_model_names
    paths = [lora_models_dir]
    extra_lora_paths = util.split_path_list(shared.opts.data.get("additional_networks_extra_lora_path", ""))
    for path in extra_lora_paths:
        path = path.lstrip()
        if os.path.isdir(path):
            paths.append(path)

    sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name")
    filter_by = shared.opts.data.get("additional_networks_model_name_filter", "")
    res, res_legacy = get_all_models(paths, sort_by, filter_by)

    lora_models.clear()
    lora_models["None"] = None
    lora_models.update(res)

    for name_and_hash, filename in lora_models.items():
        if filename == None:
            continue
        name = os.path.splitext(os.path.basename(filename))[0].lower()
        lora_model_names[name] = name_and_hash

    legacy_model_names = res_legacy
    dump_cache()


update_models()