EnglishToucan / Preprocessing /multilinguality /create_distance_lookups.py
Flux9665's picture
initial commit
6faeba1
raw
history blame
6.76 kB
import argparse
import json
import os.path
import torch
from geopy.distance import geodesic
from tqdm import tqdm
from Preprocessing.multilinguality.MetricMetaLearner import create_learned_cache
from Utility.storage_config import MODELS_DIR
from Utility.utils import load_json_from_path
class CacheCreator:
def __init__(self, cache_root="."):
self.iso_codes = list(load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json")).keys())
self.iso_lookup = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))
self.cache_root = cache_root
self.pairs = list() # ignore order, collect all language pairs
for index_1 in tqdm(range(len(self.iso_codes)), desc="Collecting language pairs"):
for index_2 in range(index_1, len(self.iso_codes)):
self.pairs.append((self.iso_codes[index_1], self.iso_codes[index_2]))
def create_tree_cache(self, cache_root="."):
iso_to_family_memberships = load_json_from_path(os.path.join(cache_root, "iso_to_memberships.json"))
self.pair_to_tree_similarity = dict()
self.pair_to_depth = dict()
for pair in tqdm(self.pairs, desc="Generating tree pairs"):
self.pair_to_tree_similarity[pair] = len(set(iso_to_family_memberships[pair[0]]).intersection(set(iso_to_family_memberships[pair[1]])))
self.pair_to_depth[pair] = len(iso_to_family_memberships[pair[0]]) + len(iso_to_family_memberships[pair[1]])
lang_1_to_lang_2_to_tree_dist = dict()
for pair in self.pair_to_tree_similarity:
lang_1 = pair[0]
lang_2 = pair[1]
if self.pair_to_tree_similarity[pair] == 2:
dist = 1.0
else:
dist = 1 - ((self.pair_to_tree_similarity[pair] * 2) / self.pair_to_depth[pair])
if lang_1 not in lang_1_to_lang_2_to_tree_dist.keys():
lang_1_to_lang_2_to_tree_dist[lang_1] = dict()
lang_1_to_lang_2_to_tree_dist[lang_1][lang_2] = dist
with open(os.path.join(cache_root, 'lang_1_to_lang_2_to_tree_dist.json'), 'w', encoding='utf-8') as f:
json.dump(lang_1_to_lang_2_to_tree_dist, f, ensure_ascii=False, indent=4)
def create_map_cache(self, cache_root="."):
self.pair_to_map_dist = dict()
iso_to_long_lat = load_json_from_path(os.path.join(cache_root, "iso_to_long_lat.json"))
for pair in tqdm(self.pairs, desc="Generating map pairs"):
try:
long_1, lat_1 = iso_to_long_lat[pair[0]]
long_2, lat_2 = iso_to_long_lat[pair[1]]
geodesic((lat_1, long_1), (lat_2, long_2))
self.pair_to_map_dist[pair] = geodesic((lat_1, long_1), (lat_2, long_2)).miles
except KeyError:
pass
lang_1_to_lang_2_to_map_dist = dict()
for pair in self.pair_to_map_dist:
lang_1 = pair[0]
lang_2 = pair[1]
dist = self.pair_to_map_dist[pair]
if lang_1 not in lang_1_to_lang_2_to_map_dist.keys():
lang_1_to_lang_2_to_map_dist[lang_1] = dict()
lang_1_to_lang_2_to_map_dist[lang_1][lang_2] = dist
with open(os.path.join(cache_root, 'lang_1_to_lang_2_to_map_dist.json'), 'w', encoding='utf-8') as f:
json.dump(lang_1_to_lang_2_to_map_dist, f, ensure_ascii=False, indent=4)
def create_oracle_cache(self, model_path, cache_root="."):
"""Oracle language-embedding distance of supervised languages is only used for evaluation, not usable for zero-shot.
Note: The generated oracle cache is only valid for the given `model_path`!"""
loss_fn = torch.nn.MSELoss(reduction="mean")
self.pair_to_oracle_dist = dict()
lang_embs = torch.load(model_path)["model"]["encoder.language_embedding.weight"]
lang_embs.requires_grad_(False)
for pair in tqdm(self.pairs, desc="Generating oracle pairs"):
try:
dist = loss_fn(lang_embs[self.iso_lookup[-1][pair[0]]], lang_embs[self.iso_lookup[-1][pair[1]]]).item()
self.pair_to_oracle_dist[pair] = dist
except KeyError:
pass
lang_1_to_lang_2_oracle_dist = dict()
for pair in self.pair_to_oracle_dist:
lang_1 = pair[0]
lang_2 = pair[1]
dist = self.pair_to_oracle_dist[pair]
if lang_1 not in lang_1_to_lang_2_oracle_dist.keys():
lang_1_to_lang_2_oracle_dist[lang_1] = dict()
lang_1_to_lang_2_oracle_dist[lang_1][lang_2] = dist
with open(os.path.join(cache_root, "lang_1_to_lang_2_to_oracle_dist.json"), "w", encoding="utf-8") as f:
json.dump(lang_1_to_lang_2_oracle_dist, f, ensure_ascii=False, indent=4)
def create_learned_cache(self, model_path, cache_root="."):
"""Note: The generated learned distance cache is only valid for the given `model_path`!"""
create_learned_cache(model_path, cache_root=cache_root)
def create_required_files(self, model_path, create_oracle=False):
if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_tree_dist.json")):
self.create_tree_cache(cache_root="Preprocessing/multilinguality")
if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_map_dist.json")):
self.create_map_cache(cache_root="Preprocessing/multilinguality")
if not os.path.exists(os.path.join(self.cache_root, "asp_dict.pkl")):
raise FileNotFoundError("asp_dict.pkl must be downloaded separately.")
if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_learned_dist.json")):
self.create_learned_cache(model_path=model_path, cache_root="Preprocessing/multilinguality")
if create_oracle:
if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_oracle_dist.json")):
if not model_path:
raise ValueError("model_path is required for creating oracle cache.")
self.create_oracle_cache(model_path=args.model_path, cache_root="Preprocessing/multilinguality")
print("All required cache files exist.")
if __name__ == '__main__':
default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", "-m", type=str, default=default_model_path, help="model path that should be used for creating oracle lang emb distance cache")
args = parser.parse_args()
cc = CacheCreator()
cc.create_required_files(args.model_path, create_oracle=True)