Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,313 Bytes
6faeba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import json
import os
import pickle
import random
import kan
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Utility.utils import load_json_from_path
class MetricsCombiner(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.scoring_function = kan.KAN(width=[3, 5, 1], grid=5, k=5, seed=m)
def forward(self, x):
return self.scoring_function(x.squeeze())
class EnsembleModel(torch.nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, x):
distances = list()
for model in self.models:
distances.append(model(x))
return sum(distances) / len(distances)
def create_learned_cache(model_path, cache_root="."):
checkpoint = torch.load(model_path, map_location='cpu')
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding
embedding_provider.requires_grad_(False)
language_list = load_json_from_path(os.path.join(cache_root, "supervised_languages.json"))
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
if not os.path.exists(tree_lookup_path) or not os.path.exists(map_lookup_path):
raise FileNotFoundError("Please ensure the caches exist!")
if not os.path.exists(asp_dict_path):
raise FileNotFoundError(f"{asp_dict_path} must be downloaded separately.")
tree_dist = load_json_from_path(tree_lookup_path)
map_dist = load_json_from_path(map_lookup_path)
with open(asp_dict_path, 'rb') as dictfile:
asp_sim = pickle.load(dictfile)
lang_list = list(asp_sim.keys())
largest_value_map_dist = 0.0
for _, values in map_dist.items():
for _, value in values.items():
largest_value_map_dist = max(largest_value_map_dist, value)
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1]
train_set = language_list
batch_size = 128
model_list = list()
print_intermediate_results = False
# ensemble preparation
n_models = 5
print(f"Training ensemble of {n_models} models for learned distance metric.")
for m in range(n_models):
model_list.append(MetricsCombiner(m))
optim = torch.optim.Adam(model_list[-1].parameters(), lr=0.0005)
running_loss = list()
for epoch in tqdm(range(35), desc=f"MetricsCombiner {m + 1}/{n_models} - Epoch"):
for i in range(1000):
# we have no dataloader, so first we build a batch
embedding_distance_batch = list()
metric_distance_batch = list()
for _ in range(batch_size):
lang_1 = random.sample(train_set, 1)[0]
lang_2 = random.sample(train_set, 1)[0]
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()))
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32))
# ok now we have a batch prepared. Time to feed it to the model.
scores = model_list[-1](torch.stack(metric_distance_batch).squeeze())
if print_intermediate_results:
print("==================================")
print(scores.detach().squeeze()[:9])
print(torch.stack(embedding_distance_batch).squeeze()[:9])
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze(), reduction="none")
loss = loss / (torch.stack(embedding_distance_batch).squeeze() + 0.0001)
loss = loss.mean()
running_loss.append(loss.item())
optim.zero_grad()
loss.backward()
optim.step()
print("\n\n")
print(sum(running_loss) / len(running_loss))
print("\n\n")
running_loss = list()
model_list[-1].scoring_function.plot(folder=f"kan_vis_{m}", beta=5000)
plt.show()
# Time to see if the final ensemble is any good
ensemble = EnsembleModel(model_list)
running_loss = list()
for i in range(100):
# we have no dataloader, so first we build a batch
embedding_distance_batch = list()
metric_distance_batch = list()
for _ in range(batch_size):
lang_1 = random.sample(train_set, 1)[0]
lang_2 = random.sample(train_set, 1)[0]
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()))
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32))
scores = ensemble(torch.stack(metric_distance_batch).squeeze())
print("==================================")
print(scores.detach().squeeze()[:9])
print(torch.stack(embedding_distance_batch).squeeze()[:9])
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze())
running_loss.append(loss.item())
print("\n\n")
print(sum(running_loss) / len(running_loss))
language_to_language_to_learned_distance = dict()
for lang_1 in tqdm(tree_dist):
for lang_2 in tree_dist:
try:
if lang_2 in language_to_language_to_learned_distance:
if lang_1 in language_to_language_to_learned_distance[lang_2]:
continue # it's symmetric
if lang_1 not in language_to_language_to_learned_distance:
language_to_language_to_learned_distance[lang_1] = dict()
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance = torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32)
with torch.inference_mode():
predicted_distance = ensemble(metric_distance)
language_to_language_to_learned_distance[lang_1][lang_2] = predicted_distance.item()
except ValueError:
continue
except KeyError:
continue
with open(os.path.join(cache_root, 'lang_1_to_lang_2_to_learned_dist.json'), 'w', encoding='utf-8') as f:
json.dump(language_to_language_to_learned_distance, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
create_learned_cache("../../Models/ToucanTTS_Meta/best.pt")
|