Flux9665 commited on
Commit
904dd18
β€’
1 Parent(s): 7231468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -43
app.py CHANGED
@@ -1,52 +1,83 @@
 
1
  import os
2
  import pickle
3
  import re
4
- import json
5
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import networkx as nx
9
  from tqdm import tqdm
10
 
 
11
  def load_json_from_path(path):
12
  with open(path, "r", encoding="utf8") as f:
13
  obj = json.loads(f.read())
14
 
15
  return obj
16
 
 
17
  class Visualizer:
18
 
19
  def __init__(self, cache_root="."):
 
 
 
 
20
  tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
21
- self.tree_dist = load_json_from_path(tree_lookup_path)
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
24
- self.map_dist = load_json_from_path(map_lookup_path)
25
- largest_value_map_dist = 0.0
26
- for _, values in self.map_dist.items():
27
- for _, value in values.items():
28
- largest_value_map_dist = max(largest_value_map_dist, value)
29
- for key1 in self.map_dist:
30
- for key2 in self.map_dist[key1]:
31
- self.map_dist[key1][key2] = self.map_dist[key1][key2] / largest_value_map_dist
 
 
 
 
 
32
 
33
  asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
34
  with open(asp_dict_path, 'rb') as dictfile:
35
  asp_sim = pickle.load(dictfile)
36
  lang_list = list(asp_sim.keys())
37
- self.asp_dist = dict()
38
  seen_langs = set()
39
  for lang_1 in lang_list:
40
  if lang_1 not in seen_langs:
41
  seen_langs.add(lang_1)
42
- self.asp_dist[lang_1] = dict()
43
  for index, lang_2 in enumerate(lang_list):
44
  if lang_2 not in seen_langs: # it's symmetric
45
- self.asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index]
46
-
47
- self.iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json"))
48
- for code in self.iso_codes_to_names:
49
- self.iso_codes_to_names[code] = re.sub("\(.*?\)", "", self.iso_codes_to_names[code])
 
 
 
 
 
 
 
 
50
 
51
  def visualize(self, distance_type, neighbor, num_neighbors):
52
  plt.clf()
@@ -56,46 +87,30 @@ class Visualizer:
56
  "Distance to the Lowest Common Ancestor in the Language Family Tree",
57
  "Angular Distance between the Frequencies of Phonemes"]
58
  if distance_type == "Distance to the Lowest Common Ancestor in the Language Family Tree":
59
- distance_measure = self.tree_dist
60
  elif distance_type == "Angular Distance between the Frequencies of Phonemes":
61
- distance_measure = self.asp_dist
62
  elif distance_type == "Physical Distance between Language Centroids on the Globe":
63
- distance_measure = self.map_dist
64
-
65
- distances = list()
66
-
67
- for lang_1 in distance_measure:
68
- if lang_1 not in self.iso_codes_to_names:
69
- continue
70
- for lang_2 in distance_measure[lang_1]:
71
- if lang_2 not in self.iso_codes_to_names:
72
- continue
73
- distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], distance_measure[lang_1][lang_2]))
74
 
75
  G = nx.Graph()
76
- min_dist = min(d for _, _, d in distances)
77
- max_dist = max(d for _, _, d in distances)
78
- normalized_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
79
-
80
  d_dist = list()
81
  for entity1, entity2, d in tqdm(normalized_distances):
82
  if neighbor == entity2 or neighbor == entity1:
83
- if entity1 != entity2:
84
- d_dist.append(d)
85
  thresh = sorted(d_dist)[num_neighbors]
86
  neighbors = set()
87
  for entity1, entity2, d in tqdm(normalized_distances):
88
- if d < thresh and (neighbor == entity2 or neighbor == entity1) and (entity1 != entity2):
89
  neighbors.add(entity1)
90
  neighbors.add(entity2)
91
- spring_tension = (thresh - d) * 10 # for vis purposes
92
  G.add_edge(entity1, entity2, weight=spring_tension)
93
  neighbors.remove(neighbor)
94
  for entity1, entity2, d in tqdm(normalized_distances):
95
  if entity2 in neighbors and entity1 in neighbors:
96
- if entity1 != entity2:
97
- spring_tension = thresh - d
98
- G.add_edge(entity1, entity2, weight=spring_tension)
99
 
100
  pos = nx.spring_layout(G, weight="weight") # Positions for all nodes
101
  edges = G.edges(data=True)
@@ -106,7 +121,7 @@ class Visualizer:
106
  # nx.draw_networkx_edges(G, pos, edgelist=edges_not_connected_to_specific_node, edge_color='gray', alpha=0.1, width=1)
107
  for u, v, d in edges:
108
  if u == neighbor or v == neighbor:
109
- nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 10)) * 10, 2)}, font_color="red", alpha=0.4) # reverse modifications
110
  nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green')
111
  nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red')
112
  plt.title(f'Graph of {distance_type}')
@@ -116,7 +131,7 @@ class Visualizer:
116
 
117
 
118
  if __name__ == '__main__':
119
- vis = Visualizer(cache_root=".")
120
  text_selection = [f"{vis.iso_codes_to_names[iso_code]}" for iso_code in vis.iso_codes_to_names]
121
  iface = gr.Interface(fn=vis.visualize,
122
  inputs=[gr.Dropdown(["Physical Distance between Language Centroids on the Globe",
 
1
+ import json
2
  import os
3
  import pickle
4
  import re
 
5
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import networkx as nx
9
  from tqdm import tqdm
10
 
11
+
12
  def load_json_from_path(path):
13
  with open(path, "r", encoding="utf8") as f:
14
  obj = json.loads(f.read())
15
 
16
  return obj
17
 
18
+
19
  class Visualizer:
20
 
21
  def __init__(self, cache_root="."):
22
+ self.iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json"))
23
+ for code in self.iso_codes_to_names:
24
+ self.iso_codes_to_names[code] = re.sub("\(.*?\)", "", self.iso_codes_to_names[code])
25
+
26
  tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
27
+ tree_dist = load_json_from_path(tree_lookup_path)
28
+ distances = list()
29
+ for lang_1 in tree_dist:
30
+ if lang_1 not in self.iso_codes_to_names:
31
+ continue
32
+ for lang_2 in tree_dist[lang_1]:
33
+ if lang_2 not in self.iso_codes_to_names:
34
+ continue
35
+ if lang_1 != lang_2:
36
+ distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], tree_dist[lang_1][lang_2]))
37
+ min_dist = min(d for _, _, d in distances)
38
+ max_dist = max(d for _, _, d in distances)
39
+ self.tree_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
40
 
41
  map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
42
+ map_dist = load_json_from_path(map_lookup_path)
43
+ distances = list()
44
+ for lang_1 in map_dist:
45
+ if lang_1 not in self.iso_codes_to_names:
46
+ continue
47
+ for lang_2 in map_dist[lang_1]:
48
+ if lang_2 not in self.iso_codes_to_names:
49
+ continue
50
+ if lang_1 != lang_2:
51
+ distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], map_dist[lang_1][lang_2]))
52
+ min_dist = min(d for _, _, d in distances)
53
+ max_dist = max(d for _, _, d in distances)
54
+ self.map_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
55
 
56
  asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
57
  with open(asp_dict_path, 'rb') as dictfile:
58
  asp_sim = pickle.load(dictfile)
59
  lang_list = list(asp_sim.keys())
60
+ asp_dist = dict()
61
  seen_langs = set()
62
  for lang_1 in lang_list:
63
  if lang_1 not in seen_langs:
64
  seen_langs.add(lang_1)
65
+ asp_dist[lang_1] = dict()
66
  for index, lang_2 in enumerate(lang_list):
67
  if lang_2 not in seen_langs: # it's symmetric
68
+ asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index]
69
+ distances = list()
70
+ for lang_1 in asp_dist:
71
+ if lang_1 not in self.iso_codes_to_names:
72
+ continue
73
+ for lang_2 in asp_dist[lang_1]:
74
+ if lang_2 not in self.iso_codes_to_names:
75
+ continue
76
+ if lang_1 != lang_2:
77
+ distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], asp_dist[lang_1][lang_2]))
78
+ min_dist = min(d for _, _, d in distances)
79
+ max_dist = max(d for _, _, d in distances)
80
+ self.asp_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances]
81
 
82
  def visualize(self, distance_type, neighbor, num_neighbors):
83
  plt.clf()
 
87
  "Distance to the Lowest Common Ancestor in the Language Family Tree",
88
  "Angular Distance between the Frequencies of Phonemes"]
89
  if distance_type == "Distance to the Lowest Common Ancestor in the Language Family Tree":
90
+ normalized_distances = self.tree_distances
91
  elif distance_type == "Angular Distance between the Frequencies of Phonemes":
92
+ normalized_distances = self.asp_distances
93
  elif distance_type == "Physical Distance between Language Centroids on the Globe":
94
+ normalized_distances = self.map_distances
 
 
 
 
 
 
 
 
 
 
95
 
96
  G = nx.Graph()
 
 
 
 
97
  d_dist = list()
98
  for entity1, entity2, d in tqdm(normalized_distances):
99
  if neighbor == entity2 or neighbor == entity1:
100
+ d_dist.append(d)
 
101
  thresh = sorted(d_dist)[num_neighbors]
102
  neighbors = set()
103
  for entity1, entity2, d in tqdm(normalized_distances):
104
+ if d <= thresh and (neighbor == entity2 or neighbor == entity1) and len(neighbors) < num_neighbors + 1:
105
  neighbors.add(entity1)
106
  neighbors.add(entity2)
107
+ spring_tension = (thresh - d) * 100 # for vis purposes
108
  G.add_edge(entity1, entity2, weight=spring_tension)
109
  neighbors.remove(neighbor)
110
  for entity1, entity2, d in tqdm(normalized_distances):
111
  if entity2 in neighbors and entity1 in neighbors:
112
+ spring_tension = thresh - d
113
+ G.add_edge(entity1, entity2, weight=spring_tension)
 
114
 
115
  pos = nx.spring_layout(G, weight="weight") # Positions for all nodes
116
  edges = G.edges(data=True)
 
121
  # nx.draw_networkx_edges(G, pos, edgelist=edges_not_connected_to_specific_node, edge_color='gray', alpha=0.1, width=1)
122
  for u, v, d in edges:
123
  if u == neighbor or v == neighbor:
124
+ nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 100)) * 10, 2)}, font_color="red", alpha=0.4) # reverse modifications
125
  nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green')
126
  nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red')
127
  plt.title(f'Graph of {distance_type}')
 
131
 
132
 
133
  if __name__ == '__main__':
134
+ vis = Visualizer(cache_root="Preprocessing/multilinguality")
135
  text_selection = [f"{vis.iso_codes_to_names[iso_code]}" for iso_code in vis.iso_codes_to_names]
136
  iface = gr.Interface(fn=vis.visualize,
137
  inputs=[gr.Dropdown(["Physical Distance between Language Centroids on the Globe",