Spaces:
Sleeping
Sleeping
import json | |
import os | |
import pickle | |
import re | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
from tqdm import tqdm | |
def load_json_from_path(path): | |
with open(path, "r", encoding="utf8") as f: | |
obj = json.loads(f.read()) | |
return obj | |
class Visualizer: | |
def __init__(self, cache_root="."): | |
self.iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json")) | |
for code in self.iso_codes_to_names: | |
self.iso_codes_to_names[code] = re.sub("\(.*?\)", "", self.iso_codes_to_names[code]) | |
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json") | |
tree_dist = load_json_from_path(tree_lookup_path) | |
distances = list() | |
for lang_1 in tree_dist: | |
if lang_1 not in self.iso_codes_to_names: | |
continue | |
for lang_2 in tree_dist[lang_1]: | |
if lang_2 not in self.iso_codes_to_names: | |
continue | |
if lang_1 != lang_2: | |
distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], tree_dist[lang_1][lang_2])) | |
min_dist = min(d for _, _, d in distances) | |
max_dist = max(d for _, _, d in distances) | |
self.tree_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] | |
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") | |
map_dist = load_json_from_path(map_lookup_path) | |
distances = list() | |
for lang_1 in map_dist: | |
if lang_1 not in self.iso_codes_to_names: | |
continue | |
for lang_2 in map_dist[lang_1]: | |
if lang_2 not in self.iso_codes_to_names: | |
continue | |
if lang_1 != lang_2: | |
distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], map_dist[lang_1][lang_2])) | |
min_dist = min(d for _, _, d in distances) | |
max_dist = max(d for _, _, d in distances) | |
self.map_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] | |
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl") | |
with open(asp_dict_path, 'rb') as dictfile: | |
asp_sim = pickle.load(dictfile) | |
lang_list = list(asp_sim.keys()) | |
asp_dist = dict() | |
seen_langs = set() | |
for lang_1 in lang_list: | |
if lang_1 not in seen_langs: | |
seen_langs.add(lang_1) | |
asp_dist[lang_1] = dict() | |
for index, lang_2 in enumerate(lang_list): | |
if lang_2 not in seen_langs: # it's symmetric | |
asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index] | |
distances = list() | |
for lang_1 in asp_dist: | |
if lang_1 not in self.iso_codes_to_names: | |
continue | |
for lang_2 in asp_dist[lang_1]: | |
if lang_2 not in self.iso_codes_to_names: | |
continue | |
if lang_1 != lang_2: | |
distances.append((self.iso_codes_to_names[lang_1], self.iso_codes_to_names[lang_2], asp_dist[lang_1][lang_2])) | |
min_dist = min(d for _, _, d in distances) | |
max_dist = max(d for _, _, d in distances) | |
self.asp_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] | |
def visualize(self, distance_type, neighbor, num_neighbors): | |
plt.clf() | |
plt.figure(figsize=(12, 12)) | |
assert distance_type in ["Physical Distance between Language Centroids on the Globe", | |
"Distance to the Lowest Common Ancestor in the Language Family Tree", | |
"Angular Distance between the Frequencies of Phonemes"] | |
if distance_type == "Distance to the Lowest Common Ancestor in the Language Family Tree": | |
normalized_distances = self.tree_distances | |
elif distance_type == "Angular Distance between the Frequencies of Phonemes": | |
normalized_distances = self.asp_distances | |
elif distance_type == "Physical Distance between Language Centroids on the Globe": | |
normalized_distances = self.map_distances | |
G = nx.Graph() | |
d_dist = list() | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if neighbor == entity2 or neighbor == entity1: | |
d_dist.append(d) | |
thresh = sorted(d_dist)[num_neighbors] | |
neighbors = set() | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if d <= thresh and (neighbor == entity2 or neighbor == entity1) and len(neighbors) < num_neighbors + 1: | |
neighbors.add(entity1) | |
neighbors.add(entity2) | |
spring_tension = ((thresh + 0.1) - d) * 100 # for vis purposes | |
G.add_edge(entity1, entity2, weight=spring_tension) | |
neighbors.remove(neighbor) | |
thresh_for_neighbors = max([x for _, _, x in normalized_distances]) | |
for entity1, entity2, d in tqdm(normalized_distances): | |
if entity2 in neighbors and entity1 in neighbors: | |
spring_tension = (thresh_for_neighbors + 0.1) - d | |
G.add_edge(entity1, entity2, weight=spring_tension) | |
pos = nx.spring_layout(G, weight="weight", iterations=200, threshold=1e-6) # Positions for all nodes | |
edges = G.edges(data=True) | |
nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) | |
edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor] | |
nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='orange', alpha=0.4, width=3) | |
if num_neighbors < 6: | |
edges_not_connected_to_specific_node = [(u, v) for u, v in G.edges() if u != neighbor and v != neighbor] | |
nx.draw_networkx_edges(G, pos, edgelist=edges_not_connected_to_specific_node, edge_color='gray', alpha=0.05, width=1) | |
for u, v, d in edges: | |
if u == neighbor or v == neighbor: | |
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round(((thresh + 0.1) - (d['weight'] / 100)) * 100, 2)}, font_color="red", alpha=0.4) # reverse modifications | |
nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green') | |
nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red') | |
plt.title(f'Graph of {distance_type}') | |
plt.subplots_adjust(left=0, right=1, top=0.9, bottom=0) | |
plt.tight_layout() | |
return plt.gcf() | |
if __name__ == '__main__': | |
vis = Visualizer(cache_root=".") | |
text_selection = [f"{vis.iso_codes_to_names[iso_code]}" for iso_code in vis.iso_codes_to_names] | |
iface = gr.Interface(fn=vis.visualize, | |
inputs=[gr.Dropdown(["Physical Distance between Language Centroids on the Globe", | |
"Distance to the Lowest Common Ancestor in the Language Family Tree", | |
"Angular Distance between the Frequencies of Phonemes"], | |
type="value", | |
value='Physical Distance between Language Centroids on the Globe', | |
label="Select the Type of Distance"), | |
gr.Dropdown(text_selection, | |
type="value", | |
value="German", | |
label="Select the second Language (type on your keyboard to find it quickly)"), | |
gr.Slider(minimum=0, maximum=100, step=1, | |
value=12, | |
label="How many Nearest Neighbors should be displayed?") | |
], | |
outputs=[gr.Plot(label="", show_label=False, format="png", container=True)], | |
description="<br><br> This demo allows you to find the nearest neighbors of a language from the ISO 639-3 list according to several distance measurement functions. " | |
"For more information, check out our paper: https://arxiv.org/abs/2406.06403 and our text-to-speech tool, in which we make use of " | |
"this technique: https://github.com/DigitalPhonetics/IMS-Toucan <br><br>", | |
allow_flagging="never") | |
iface.launch() | |