Skip to content

Commit a043045

Browse files
JacekDabrowski1JacekDabrowski1
authored andcommitted
Add a Visuals tab to display UMAP plots for embedding comparisons
Adds a new 'Visuals' route and template to display UMAP plots of graph embeddings generated by various algorithms across different datasets, including a `generate_umap_plots.py` script and associated static image assets. Replit-Commit-Author: Agent Replit-Commit-Session-Id: ec794acd-c4a5-47f6-b906-d70ac3c316ee Replit-Commit-Checkpoint-Type: full_checkpoint Replit-Commit-Event-Id: 170f5b25-f98a-4f3d-8865-2a7633f52e9d Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/28ec11df-9ccf-40bc-9ff4-d0523e5b6a98/ec794acd-c4a5-47f6-b906-d70ac3c316ee/wSaHZEy Replit-Helium-Checkpoint-Created: true
1 parent befe4d6 commit a043045

27 files changed

Lines changed: 543 additions & 0 deletions

generate_umap_plots.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import numpy as np
2+
import matplotlib
3+
matplotlib.use('Agg')
4+
import matplotlib.pyplot as plt
5+
import os
6+
import sys
7+
import time
8+
9+
from pycleora import SparseMatrix, embed
10+
from pycleora.algorithms import (
11+
embed_prone, embed_randne, embed_hope, embed_netmf,
12+
embed_grarep, embed_deepwalk, embed_node2vec,
13+
)
14+
from pycleora.datasets import load_dataset
15+
from pycleora.community import detect_communities_louvain
16+
17+
DIM = 64
18+
OUTPUT_DIR = "website/static/umap"
19+
os.makedirs(OUTPUT_DIR, exist_ok=True)
20+
21+
ALGO_COLORS = {
22+
"Cleora": "#a78bfa",
23+
"ProNE": "#f59e0b",
24+
"RandNE": "#ef4444",
25+
"NetMF": "#3b82f6",
26+
"DeepWalk": "#f472b6",
27+
"HOPE": "#34d399",
28+
"GraRep": "#fb923c",
29+
"Node2Vec": "#22d3ee",
30+
}
31+
32+
CLASS_PALETTES = {
33+
3: ["#e74c3c", "#3498db", "#2ecc71"],
34+
6: ["#e74c3c", "#3498db", "#2ecc71", "#f39c12", "#9b59b6", "#1abc9c"],
35+
7: ["#e74c3c", "#3498db", "#2ecc71", "#f39c12", "#9b59b6", "#1abc9c", "#e91e63"],
36+
}
37+
38+
39+
def get_class_colors(num_classes):
40+
if num_classes in CLASS_PALETTES:
41+
return CLASS_PALETTES[num_classes]
42+
cmap = plt.cm.get_cmap("tab20", num_classes)
43+
return [matplotlib.colors.to_hex(cmap(i)) for i in range(num_classes)]
44+
45+
46+
def make_algo_fn(algo_name, graph):
47+
if algo_name == "Cleora":
48+
return embed(graph, DIM, num_iterations=40, propagation="left", normalization="l2", whiten=True, seed=42)
49+
elif algo_name == "ProNE":
50+
return embed_prone(graph, DIM)
51+
elif algo_name == "RandNE":
52+
return embed_randne(graph, DIM)
53+
elif algo_name == "HOPE":
54+
return embed_hope(graph, DIM)
55+
elif algo_name == "NetMF":
56+
return embed_netmf(graph, DIM)
57+
elif algo_name == "GraRep":
58+
return embed_grarep(graph, DIM)
59+
elif algo_name == "DeepWalk":
60+
return embed_deepwalk(graph, DIM, num_walks=10, walk_length=20)
61+
elif algo_name == "Node2Vec":
62+
return embed_node2vec(graph, DIM, num_walks=10, walk_length=20, p=1.0, q=0.5)
63+
64+
65+
def save_umap_plot(emb_2d, labels_arr, class_colors, algo_name, dataset_name, algo_color, num_classes):
66+
fig, ax = plt.subplots(figsize=(4, 4), dpi=120)
67+
fig.patch.set_facecolor('#0a0a0f')
68+
ax.set_facecolor('#0a0a0f')
69+
70+
unique_labels = np.unique(labels_arr)
71+
for label in unique_labels:
72+
mask = labels_arr == label
73+
color = class_colors[int(label) % len(class_colors)]
74+
ax.scatter(
75+
emb_2d[mask, 0], emb_2d[mask, 1],
76+
c=color, s=3, alpha=0.6, edgecolors='none', rasterized=True
77+
)
78+
79+
ax.set_xticks([])
80+
ax.set_yticks([])
81+
for spine in ax.spines.values():
82+
spine.set_visible(False)
83+
84+
ax.set_title(algo_name, color=algo_color, fontsize=14, fontweight='bold', pad=8)
85+
86+
fname = f"{dataset_name.lower()}_{algo_name.lower()}.png"
87+
fpath = os.path.join(OUTPUT_DIR, fname)
88+
fig.savefig(fpath, bbox_inches='tight', facecolor='#0a0a0f', edgecolor='none', pad_inches=0.1)
89+
plt.close(fig)
90+
print(f" Saved {fpath}")
91+
return fname
92+
93+
94+
def run_dataset(ds_key, ds_display, algo_names):
95+
import umap
96+
97+
print(f"\n{'='*60}")
98+
print(f"Dataset: {ds_display}")
99+
print(f"{'='*60}")
100+
101+
ds = load_dataset(ds_key)
102+
graph = SparseMatrix.from_iterator(iter(ds["edges"]), ds["columns"])
103+
labels = ds["labels"]
104+
num_classes = ds["num_classes"]
105+
106+
if not labels or len(labels) < 4:
107+
print(f" No labels, using Louvain communities...")
108+
labels = detect_communities_louvain(graph)
109+
num_classes = len(set(labels.values()))
110+
print(f" Found {num_classes} communities")
111+
112+
entity_ids = graph.entity_ids
113+
labels_arr = np.array([labels.get(eid, 0) for eid in entity_ids])
114+
115+
unique_labels = np.unique(labels_arr)
116+
label_remap = {old: new for new, old in enumerate(unique_labels)}
117+
labels_arr = np.array([label_remap[l] for l in labels_arr])
118+
actual_classes = len(unique_labels)
119+
class_colors = get_class_colors(actual_classes)
120+
121+
embeddings = {}
122+
for algo_name in algo_names:
123+
out_path = os.path.join(OUTPUT_DIR, f"{ds_display.lower()}_{algo_name.lower()}.png")
124+
if os.path.exists(out_path):
125+
print(f" {algo_name}: already exists, skipping")
126+
continue
127+
print(f" Running {algo_name}...", end=" ", flush=True)
128+
t0 = time.time()
129+
try:
130+
emb = make_algo_fn(algo_name, graph)
131+
elapsed = time.time() - t0
132+
print(f"done ({elapsed:.2f}s)")
133+
embeddings[algo_name] = emb
134+
except Exception as e:
135+
elapsed = time.time() - t0
136+
print(f"FAILED ({elapsed:.2f}s): {e}")
137+
138+
if embeddings:
139+
print(f"\n Running UMAP for {len(embeddings)} embeddings...")
140+
for algo_name, emb in embeddings.items():
141+
print(f" UMAP {algo_name}...", end=" ", flush=True)
142+
t0 = time.time()
143+
try:
144+
reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
145+
emb_2d = reducer.fit_transform(emb)
146+
elapsed = time.time() - t0
147+
print(f"done ({elapsed:.2f}s)")
148+
save_umap_plot(emb_2d, labels_arr, class_colors, algo_name, ds_display, ALGO_COLORS.get(algo_name, "#ffffff"), actual_classes)
149+
except Exception as e:
150+
elapsed = time.time() - t0
151+
print(f"FAILED ({elapsed:.2f}s): {e}")
152+
153+
154+
if __name__ == "__main__":
155+
dataset = sys.argv[1] if len(sys.argv) > 1 else "all"
156+
157+
configs = {
158+
"cora": ("cora", "Cora", ["Cleora", "NetMF", "ProNE", "RandNE", "HOPE", "DeepWalk"]),
159+
"citeseer": ("citeseer", "CiteSeer", ["Cleora", "NetMF", "ProNE", "RandNE", "HOPE"]),
160+
"facebook": ("facebook", "Facebook", ["Cleora", "NetMF", "ProNE", "RandNE"]),
161+
"pubmed": ("pubmed", "PubMed", ["Cleora", "RandNE", "ProNE"]),
162+
"ppi": ("ppi", "PPI", ["Cleora", "RandNE", "ProNE"]),
163+
}
164+
165+
if dataset == "all":
166+
for key in configs:
167+
run_dataset(*configs[key])
168+
elif dataset in configs:
169+
run_dataset(*configs[dataset])
170+
else:
171+
print(f"Unknown dataset: {dataset}")
172+
print(f"Available: {', '.join(configs.keys())}")

replit.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ website/ Marketing website (Flask)
4343
app.py Flask server (port 5000) — includes sync API endpoints
4444
sync_worker.py Background benchmark sync worker (threading-based)
4545
static/style.css Dark theme CSS
46+
static/umap/ Pre-generated UMAP scatter plots (23 PNGs, 5 datasets x algos)
4647
templates/
4748
base.html Shared layout (nav, footer)
4849
index.html Landing page (features, comparison, code examples)
50+
visuals.html UMAP embedding projection gallery (dataset filter tabs)
4951
docs.html Documentation (installation, all APIs, guides)
5052
api.html API Reference (all functions, params, returns)
5153
benchmarks.html Benchmark results with interactive Chart.js visualizations

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
flask>=2.0
2+
numpy
3+
umap-learn

website/app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
def index():
1414
return render_template('index.html')
1515

16+
@app.route('/visuals')
17+
def visuals():
18+
return render_template('visuals.html')
19+
1620
@app.route('/docs')
1721
def docs():
1822
return render_template('docs.html')
95.1 KB
Loading
31.1 KB
Loading
55.6 KB
Loading
118 KB
Loading
120 KB
Loading
90.6 KB
Loading

0 commit comments

Comments
 (0)