chattts / modules /speaker.py
zhzluke96
update
bed01bd
import os
import uuid
from typing import Union
import torch
from box import Box
from modules import models
from modules.utils.SeedContext import SeedContext
def create_speaker_from_seed(seed):
chat_tts = models.load_chat_tts()
with SeedContext(seed, True):
emb = chat_tts.sample_random_speaker()
return emb
class Speaker:
@staticmethod
def from_file(file_like):
speaker = torch.load(file_like, map_location=torch.device("cpu"))
speaker.fix()
return speaker
@staticmethod
def from_tensor(tensor):
speaker = Speaker(seed_or_tensor=-2)
speaker.emb = tensor
return speaker
@staticmethod
def from_seed(seed: int):
speaker = Speaker(seed_or_tensor=seed)
speaker.emb = create_speaker_from_seed(seed)
return speaker
def __init__(
self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
):
self.id = uuid.uuid4()
self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor
self.name = name
self.gender = gender
self.describe = describe
self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor
# TODO replace emb => tokens
self.tokens = []
def to_json(self, with_emb=False):
return Box(
**{
"id": str(self.id),
"seed": self.seed,
"name": self.name,
"gender": self.gender,
"describe": self.describe,
"emb": self.emb.tolist() if with_emb else None,
}
)
def fix(self):
is_update = False
if "id" not in self.__dict__:
setattr(self, "id", uuid.uuid4())
is_update = True
if "seed" not in self.__dict__:
setattr(self, "seed", -2)
is_update = True
if "name" not in self.__dict__:
setattr(self, "name", "")
is_update = True
if "gender" not in self.__dict__:
setattr(self, "gender", "*")
is_update = True
if "describe" not in self.__dict__:
setattr(self, "describe", "")
is_update = True
return is_update
def __hash__(self):
return hash(str(self.id))
def __eq__(self, other):
if not isinstance(other, Speaker):
return False
return str(self.id) == str(other.id)
# 每个speaker就是一个 emb 文件 .pt
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
# 可以 用 seed 创建一个 speaker
# 可以 刷新列表 读取所有 speaker
# 可以列出所有 speaker
class SpeakerManager:
def __init__(self):
self.speakers = {}
self.speaker_dir = "./data/speakers/"
self.refresh_speakers()
def refresh_speakers(self):
self.speakers = {}
for speaker_file in os.listdir(self.speaker_dir):
if speaker_file.endswith(".pt"):
self.speakers[speaker_file] = Speaker.from_file(
self.speaker_dir + speaker_file
)
# 检查是否有被删除的,同步到 speakers
for fname, spk in self.speakers.items():
if not os.path.exists(self.speaker_dir + fname):
del self.speakers[fname]
def list_speakers(self) -> list[Speaker]:
return list(self.speakers.values())
def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
if name == "":
name = seed
filename = name + ".pt"
speaker = Speaker(seed, name=name, gender=gender, describe=describe)
speaker.emb = create_speaker_from_seed(seed)
torch.save(speaker, self.speaker_dir + filename)
self.refresh_speakers()
return speaker
def create_speaker_from_tensor(
self, tensor, filename="", name="", gender="", describe=""
):
if filename == "":
filename = name
speaker = Speaker(
seed_or_tensor=-2, name=name, gender=gender, describe=describe
)
if isinstance(tensor, torch.Tensor):
speaker.emb = tensor
if isinstance(tensor, list):
speaker.emb = torch.tensor(tensor)
torch.save(speaker, self.speaker_dir + filename + ".pt")
self.refresh_speakers()
return speaker
def get_speaker(self, name) -> Union[Speaker, None]:
for speaker in self.speakers.values():
if speaker.name == name:
return speaker
return None
def get_speaker_by_id(self, id) -> Union[Speaker, None]:
for speaker in self.speakers.values():
if str(speaker.id) == str(id):
return speaker
return None
def get_speaker_filename(self, id: str):
filename = None
for fname, spk in self.speakers.items():
if str(spk.id) == str(id):
filename = fname
break
return filename
def update_speaker(self, speaker: Speaker):
filename = None
for fname, spk in self.speakers.items():
if str(spk.id) == str(speaker.id):
filename = fname
break
if filename:
torch.save(speaker, self.speaker_dir + filename)
self.refresh_speakers()
return speaker
else:
raise ValueError("Speaker not found for update")
def save_all(self):
for speaker in self.speakers.values():
filename = self.get_speaker_filename(speaker.id)
torch.save(speaker, self.speaker_dir + filename)
# self.refresh_speakers()
def __len__(self):
return len(self.speakers)
speaker_mgr = SpeakerManager()