Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,522 Bytes
6faeba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import argparse
import os
import matplotlib
import numpy as np
import pandas as pd
import torch
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['font.size'] = 7
import matplotlib.pyplot as plt
from Utility.utils import load_json_from_path
from Utility.storage_config import MODELS_DIR
def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embeddings, weighted_avg=False, min_n_langs=5, max_n_langs=30, threshold_percentile=95, loss_fn="MSE"):
df = pd.read_csv(csv_path, sep="|")
if loss_fn == "L1":
loss_fn = torch.nn.L1Loss()
else:
loss_fn = torch.nn.MSELoss()
features_per_closest_lang = 2
# for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column
if "combined_dist_0" in df.columns:
if "map_dist_0" in df.columns:
features_per_closest_lang += 1
if "asp_dist_0" in df.columns:
features_per_closest_lang += 1
if "tree_dist_0" in df.columns:
features_per_closest_lang += 1
n_closest = len(df.columns) // features_per_closest_lang
distance_type = "combined"
# else, df has 2 features per closest lang + 1 target lang column
else:
n_closest = len(df.columns) // features_per_closest_lang
if "map_dist_0" in df.columns:
distance_type = "map"
elif "tree_dist_0" in df.columns:
distance_type = "tree"
elif "asp_dist_0" in df.columns:
distance_type = "asp"
elif "learned_dist_0" in df.columns:
distance_type = "learned"
elif "oracle_dist_0" in df.columns:
distance_type = "oracle"
else:
distance_type = "random"
closest_lang_columns = [f"closest_lang_{i}" for i in range(n_closest)]
closest_dist_columns = [f"{distance_type}_dist_{i}" for i in range(n_closest)]
closest_lang_columns = closest_lang_columns[:max_n_langs]
closest_dist_columns = closest_dist_columns[:max_n_langs]
threshold = np.percentile(df[closest_dist_columns[-1]], threshold_percentile)
print(f"threshold: {threshold}")
all_losses = []
for row in df.itertuples():
try:
y = language_embeddings[iso_lookup[-1][row.target_lang]]
except KeyError:
print(f"KeyError: Unable to retrieve language embedding for {row.target_lang}")
continue
avg_emb = torch.zeros([16])
dists = [getattr(row, d) for i, d in enumerate(closest_dist_columns) if i < min_n_langs or getattr(row, d) < threshold]
langs = [getattr(row, l) for l in closest_lang_columns[:len(dists)]]
if weighted_avg:
for lang, dist in zip(langs, dists):
lang_emb = language_embeddings[iso_lookup[-1][lang]]
avg_emb += lang_emb * dist
normalization_factor = sum(dists)
else:
for lang in langs:
lang_emb = language_embeddings[iso_lookup[-1][lang]]
avg_emb += lang_emb
normalization_factor = len(langs)
avg_emb /= normalization_factor # normalize
current_loss = loss_fn(avg_emb, y).item()
all_losses.append(current_loss)
return all_losses
if __name__ == "__main__":
default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=default_model_path, help="model path that should be used for creating oracle lang emb distance cache")
parser.add_argument("--min_n_langs", type=int, default=5, help="minimum amount of languages used for averaging")
parser.add_argument("--max_n_langs", type=int, default=30, help="maximum amount of languages used for averaging")
parser.add_argument("--threshold_percentile", type=int, default=95, help="percentile of the furthest used languages \
used as cutoff threshold (no langs >= the threshold are used for averagin)")
parser.add_argument("--loss_fn", choices=["MSE", "L1"], type=str, default="MSE", help="loss function used")
args = parser.parse_args()
csv_paths = [
"distance_datasets/dataset_map_top30_furthest.csv",
"distance_datasets/dataset_random_top30.csv",
"distance_datasets/dataset_asp_top30.csv",
"distance_datasets/dataset_tree_top30.csv",
"distance_datasets/dataset_map_top30.csv",
"distance_datasets/dataset_combined_top30_indiv-dists.csv",
"distance_datasets/dataset_learned_top30.csv",
"distance_datasets/dataset_oracle_top30.csv",
]
weighted = [False]
lang_embs = torch.load(args.model_path)["model"]["encoder.language_embedding.weight"]
lang_embs.requires_grad_(False)
iso_lookup = load_json_from_path("iso_lookup.json")
losses_of_multiple_datasets = []
OUT_DIR = "plots"
os.makedirs(OUT_DIR, exist_ok=True)
fig, ax = plt.subplots(figsize=(3.15022, 3.15022*(2/3)), constrained_layout=True)
plt.ylabel(args.loss_fn)
for i, csv_path in enumerate(csv_paths):
print(f"csv_path: {os.path.basename(csv_path)}")
for condition in weighted:
losses = compute_loss_for_approximated_embeddings(csv_path,
iso_lookup,
lang_embs,
condition,
min_n_langs=args.min_n_langs,
max_n_langs=args.max_n_langs,
threshold_percentile=args.threshold_percentile,
loss_fn=args.loss_fn)
print(f"weighted average: {condition} | mean loss: {np.mean(losses)}")
losses_of_multiple_datasets.append(losses)
bp_dict = ax.boxplot(losses_of_multiple_datasets,
labels = [
"map furthest",
"random",
"inv. ASP",
"tree",
"map",
"avg",
"meta-learned",
"oracle",
],
patch_artist=True,
boxprops=dict(facecolor = "lightblue",
),
showfliers=False,
widths=0.45
)
# major ticks every 0.1, minor ticks every 0.05, between 0.0 and 0.6
major_ticks = np.arange(0, 0.6, 0.1)
minor_ticks = np.arange(0, 0.6, 0.05)
ax.set_yticks(major_ticks)
ax.set_yticks(minor_ticks, minor=True)
# horizontal grid lines for minor and major ticks
ax.grid(which='both', linestyle='-', color='lightgray', linewidth=0.3, axis='y')
ax.set_aspect(4.5)
plt.title(f"min. {args.min_n_langs} kNN, max. {args.max_n_langs}\nthreshold: {args.threshold_percentile}th-percentile distance of {args.max_n_langs}th-closest language")
plt.xticks(rotation=45)
plt.savefig(os.path.join(OUT_DIR, "example_boxplot_release.pdf"), bbox_inches='tight')
|