Flux9665's picture
Update app.py
caaa066 verified
raw
history blame
8.75 kB
import json
import os
import pickle
import re
import gradio as gr
import matplotlib.pyplot as plt
import networkx as nx
from tqdm import tqdm
def load_json_from_path(path):
with open(path, "r", encoding="utf8") as f:
obj = json.loads(f.read())
return obj
class Visualizer:
def __init__(self, cache_root="."):
self.iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json"))
for code in self.iso_codes_to_names:
self.iso_codes_to_names[code] = re.sub("\(.*?\)", "", self.iso_codes_to_names[code])
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
tree_dist = load_json_from_path(tree_lookup_path)
distances = list()
for lang_1 in tree_dist:
if lang_1 not in self.iso_codes_to_names:
continue
for lang_2 in tree_dist[lang_1]:
if lang_2 not in self.iso_codes_to_names:
continue
if lang_1 != lang_2:
distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], tree_dist[lang_1][lang_2]))
min_dist = min(d for _, _, d in distances)
max_dist = max(d for _, _, d in distances)
self.tree_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
map_dist = load_json_from_path(map_lookup_path)
distances = list()
for lang_1 in map_dist:
if lang_1 not in self.iso_codes_to_names:
continue
for lang_2 in map_dist[lang_1]:
if lang_2 not in self.iso_codes_to_names:
continue
if lang_1 != lang_2:
distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], map_dist[lang_1][lang_2]))
min_dist = min(d for _, _, d in distances)
max_dist = max(d for _, _, d in distances)
self.map_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
with open(asp_dict_path, 'rb') as dictfile:
asp_sim = pickle.load(dictfile)
lang_list = list(asp_sim.keys())
asp_dist = dict()
seen_langs = set()
for lang_1 in lang_list:
if lang_1 not in seen_langs:
seen_langs.add(lang_1)
asp_dist[lang_1] = dict()
for index, lang_2 in enumerate(lang_list):
if lang_2 not in seen_langs: # it's symmetric
asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index]
distances = list()
for lang_1 in asp_dist:
if lang_1 not in self.iso_codes_to_names:
continue
for lang_2 in asp_dist[lang_1]:
if lang_2 not in self.iso_codes_to_names:
continue
if lang_1 != lang_2:
distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], asp_dist[lang_1][lang_2]))
min_dist = min(d for _, _, d in distances)
max_dist = max(d for _, _, d in distances)
self.asp_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
def visualize(self, distance_type, neighbor, num_neighbors):
plt.clf()
plt.figure(figsize=(12, 12))
assert distance_type in ["Physical Distance between Language Centroids on the Globe",
"Distance to the Lowest Common Ancestor in the Language Family Tree",
"Angular Distance between the Frequencies of Phonemes"]
if distance_type == "Distance to the Lowest Common Ancestor in the Language Family Tree":
normalized_distances = self.tree_distances
elif distance_type == "Angular Distance between the Frequencies of Phonemes":
normalized_distances = self.asp_distances
elif distance_type == "Physical Distance between Language Centroids on the Globe":
normalized_distances = self.map_distances
G = nx.Graph()
d_dist = list()
for entity1, entity2, d in tqdm(normalized_distances):
if neighbor == entity2 or neighbor == entity1:
d_dist.append(d)
thresh = sorted(d_dist)[num_neighbors]
neighbors = set()
for entity1, entity2, d in tqdm(normalized_distances):
if d <= thresh and (neighbor == entity2 or neighbor == entity1) and len(neighbors) < num_neighbors + 1:
neighbors.add(entity1)
neighbors.add(entity2)
spring_tension = ((thresh + 0.1) - d) * 100 # for vis purposes
G.add_edge(entity1, entity2, weight=spring_tension)
neighbors.remove(neighbor)
thresh_for_neighbors = max([x for _, _, x in normalized_distances])
for entity1, entity2, d in tqdm(normalized_distances):
if entity2 in neighbors and entity1 in neighbors:
spring_tension = (thresh_for_neighbors + 0.1) - d
G.add_edge(entity1, entity2, weight=spring_tension)
pos = nx.spring_layout(G, weight="weight", iterations=200, threshold=1e-6) # Positions for all nodes
edges = G.edges(data=True)
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01)
edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor]
nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='orange', alpha=0.4, width=3)
if num_neighbors < 6:
edges_not_connected_to_specific_node = [(u, v) for u, v in G.edges() if u != neighbor and v != neighbor]
nx.draw_networkx_edges(G, pos, edgelist=edges_not_connected_to_specific_node, edge_color='gray', alpha=0.05, width=1)
for u, v, d in edges:
if u == neighbor or v == neighbor:
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round(((thresh + 0.1) - (d['weight'] / 100)) * 100, 2)}, font_color="red", alpha=0.4) # reverse modifications
nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green')
nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red')
plt.title(f'Graph of {distance_type}')
plt.subplots_adjust(left=0, right=1, top=0.9, bottom=0)
plt.tight_layout()
return plt.gcf()
if __name__ == '__main__':
vis = Visualizer(cache_root=".")
text_selection = [f"{vis.iso_codes_to_names[iso_code]}" for iso_code in vis.iso_codes_to_names]
iface = gr.Interface(fn=vis.visualize,
inputs=[gr.Dropdown(["Physical Distance between Language Centroids on the Globe",
"Distance to the Lowest Common Ancestor in the Language Family Tree",
"Angular Distance between the Frequencies of Phonemes"],
type="value",
value='Physical Distance between Language Centroids on the Globe',
label="Select the Type of Distance"),
gr.Dropdown(text_selection,
type="value",
value="German",
label="Select the second Language (type on your keyboard to find it quickly)"),
gr.Slider(minimum=0, maximum=100, step=1,
value=12,
label="How many Nearest Neighbors should be displayed?")
],
outputs=[gr.Plot(label="", show_label=False, format="png", container=True)],
description="<br><br> This demo allows you to find the nearest neighbors of a language from the ISO 639-3 list according to several distance measurement functions. "
"For more information, check out our paper: https://arxiv.org/abs/2406.06403 and our text-to-speech tool, in which we make use of "
"this technique: https://github.com/DigitalPhonetics/IMS-Toucan <br><br>",
allow_flagging="never")
iface.launch()