import os
from typing import Union
import torch

from modules import models
from modules.utils.SeedContext import SeedContext

import uuid


def create_speaker_from_seed(seed):
    chat_tts = models.load_chat_tts()
    with SeedContext(seed):
        emb = chat_tts.sample_random_speaker()
    return emb


class Speaker:
    def __init__(self, seed, name="", gender="", describe=""):
        self.id = uuid.uuid4()
        self.seed = seed
        self.name = name
        self.gender = gender
        self.describe = describe
        self.emb = None

    def to_json(self, with_emb=False):
        return {
            "id": str(self.id),
            "seed": self.seed,
            "name": self.name,
            "gender": self.gender,
            "describe": self.describe,
            "emb": self.emb.tolist() if with_emb else None,
        }

    def fix(self):
        is_update = False
        if "id" not in self.__dict__:
            setattr(self, "id", uuid.uuid4())
            is_update = True
        if "seed" not in self.__dict__:
            setattr(self, "seed", -2)
            is_update = True
        if "name" not in self.__dict__:
            setattr(self, "name", "")
            is_update = True
        if "gender" not in self.__dict__:
            setattr(self, "gender", "*")
            is_update = True
        if "describe" not in self.__dict__:
            setattr(self, "describe", "")
            is_update = True

        return is_update

    def __hash__(self):
        return hash(str(self.id))

    def __eq__(self, other):
        if not isinstance(other, Speaker):
            return False
        return str(self.id) == str(other.id)


# 每个speaker就是一个 emb 文件 .pt
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
# 可以 用 seed 创建一个 speaker
# 可以 刷新列表 读取所有 speaker
# 可以列出所有 speaker
class SpeakerManager:
    def __init__(self):
        self.speakers = {}
        self.speaker_dir = "./data/speakers/"
        self.refresh_speakers()

    def refresh_speakers(self):
        self.speakers = {}
        for speaker_file in os.listdir(self.speaker_dir):
            if speaker_file.endswith(".pt"):
                speaker = torch.load(
                    self.speaker_dir + speaker_file, map_location=torch.device("cpu")
                )
                self.speakers[speaker_file] = speaker

                is_update = speaker.fix()
                if is_update:
                    torch.save(speaker, self.speaker_dir + speaker_file)

    def list_speakers(self):
        return list(self.speakers.values())

    def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
        if name == "":
            name = seed
        filename = name + ".pt"
        speaker = Speaker(seed, name=name, gender=gender, describe=describe)
        speaker.emb = create_speaker_from_seed(seed)
        torch.save(speaker, self.speaker_dir + filename)
        self.refresh_speakers()
        return speaker

    def create_speaker_from_tensor(
        self, tensor, filename="", name="", gender="", describe=""
    ):
        if name == "":
            name = filename
        speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
        if isinstance(tensor, torch.Tensor):
            speaker.emb = tensor
        if isinstance(tensor, list):
            speaker.emb = torch.tensor(tensor)
        torch.save(speaker, self.speaker_dir + filename + ".pt")
        self.refresh_speakers()
        return speaker

    def get_speaker(self, name) -> Union[Speaker, None]:
        for speaker in self.speakers.values():
            if speaker.name == name:
                return speaker
        return None

    def get_speaker_by_id(self, id) -> Union[Speaker, None]:
        for speaker in self.speakers.values():
            if str(speaker.id) == str(id):
                return speaker
        return None

    def get_speaker_filename(self, id: str):
        filename = None
        for fname, spk in self.speakers.items():
            if str(spk.id) == str(id):
                filename = fname
                break
        return filename

    def update_speaker(self, speaker: Speaker):
        filename = None
        for fname, spk in self.speakers.items():
            if str(spk.id) == str(speaker.id):
                filename = fname
                break

        if filename:
            torch.save(speaker, self.speaker_dir + filename)
            self.refresh_speakers()
            return speaker
        else:
            raise ValueError("Speaker not found for update")

    def save_all(self):
        for speaker in self.speakers.values():
            filename = self.get_speaker_filename(speaker.id)
            torch.save(speaker, self.speaker_dir + filename)
        # self.refresh_speakers()

    def __len__(self):
        return len(self.speakers)


speaker_mgr = SpeakerManager()