File size: 4,950 Bytes
01e655b 02e90e4 01e655b 02e90e4 01e655b 02e90e4 49bce5c 01e655b 02e90e4 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
from typing import Union
import torch
from modules import models
from modules.utils.SeedContext import SeedContext
import uuid
def create_speaker_from_seed(seed):
chat_tts = models.load_chat_tts()
with SeedContext(seed):
emb = chat_tts.sample_random_speaker()
return emb
class Speaker:
def __init__(self, seed, name="", gender="", describe=""):
self.id = uuid.uuid4()
self.seed = seed
self.name = name
self.gender = gender
self.describe = describe
self.emb = None
def to_json(self, with_emb=False):
return {
"id": str(self.id),
"seed": self.seed,
"name": self.name,
"gender": self.gender,
"describe": self.describe,
"emb": self.emb.tolist() if with_emb else None,
}
def fix(self):
is_update = False
if "id" not in self.__dict__:
setattr(self, "id", uuid.uuid4())
is_update = True
if "seed" not in self.__dict__:
setattr(self, "seed", -2)
is_update = True
if "name" not in self.__dict__:
setattr(self, "name", "")
is_update = True
if "gender" not in self.__dict__:
setattr(self, "gender", "*")
is_update = True
if "describe" not in self.__dict__:
setattr(self, "describe", "")
is_update = True
return is_update
def __hash__(self):
return hash(str(self.id))
def __eq__(self, other):
if not isinstance(other, Speaker):
return False
return str(self.id) == str(other.id)
# 每个speaker就是一个 emb 文件 .pt
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
# 可以 用 seed 创建一个 speaker
# 可以 刷新列表 读取所有 speaker
# 可以列出所有 speaker
class SpeakerManager:
def __init__(self):
self.speakers = {}
self.speaker_dir = "./data/speakers/"
self.refresh_speakers()
def refresh_speakers(self):
self.speakers = {}
for speaker_file in os.listdir(self.speaker_dir):
if speaker_file.endswith(".pt"):
speaker = torch.load(
self.speaker_dir + speaker_file, map_location=torch.device("cpu")
)
self.speakers[speaker_file] = speaker
is_update = speaker.fix()
if is_update:
torch.save(speaker, self.speaker_dir + speaker_file)
def list_speakers(self):
return list(self.speakers.values())
def create_speaker_from_seed(self, seed, name="", gender="", describe=""):
if name == "":
name = seed
filename = name + ".pt"
speaker = Speaker(seed, name=name, gender=gender, describe=describe)
speaker.emb = create_speaker_from_seed(seed)
torch.save(speaker, self.speaker_dir + filename)
self.refresh_speakers()
return speaker
def create_speaker_from_tensor(
self, tensor, filename="", name="", gender="", describe=""
):
if name == "":
name = filename
speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
if isinstance(tensor, torch.Tensor):
speaker.emb = tensor
if isinstance(tensor, list):
speaker.emb = torch.tensor(tensor)
torch.save(speaker, self.speaker_dir + filename + ".pt")
self.refresh_speakers()
return speaker
def get_speaker(self, name) -> Union[Speaker, None]:
for speaker in self.speakers.values():
if speaker.name == name:
return speaker
return None
def get_speaker_by_id(self, id) -> Union[Speaker, None]:
for speaker in self.speakers.values():
if str(speaker.id) == str(id):
return speaker
return None
def get_speaker_filename(self, id: str):
filename = None
for fname, spk in self.speakers.items():
if str(spk.id) == str(id):
filename = fname
break
return filename
def update_speaker(self, speaker: Speaker):
filename = None
for fname, spk in self.speakers.items():
if str(spk.id) == str(speaker.id):
filename = fname
break
if filename:
torch.save(speaker, self.speaker_dir + filename)
self.refresh_speakers()
return speaker
else:
raise ValueError("Speaker not found for update")
def save_all(self):
for speaker in self.speakers.values():
filename = self.get_speaker_filename(speaker.id)
torch.save(speaker, self.speaker_dir + filename)
# self.refresh_speakers()
def __len__(self):
return len(self.speakers)
speaker_mgr = SpeakerManager()
|