broadfield-dev commited on
Commit
fd3f0f2
·
verified ·
1 Parent(s): a9fb05a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -64
app.py CHANGED
@@ -6,6 +6,10 @@ import os
6
  import pathlib
7
  import uuid
8
  import logging
 
 
 
 
9
 
10
  # --- Setup Logging ---
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -16,52 +20,60 @@ from embedding_atlas.server import make_server
16
  from embedding_atlas.projection import compute_text_projection
17
  from embedding_atlas.utils import Hasher
18
 
19
- # --- Helper function from embedding_atlas/cli.py ---
20
  def find_column_name(existing_names, candidate):
21
  if candidate not in existing_names:
22
  return candidate
23
- else:
24
- index = 1
25
- while True:
26
- s = f"{candidate}_{index}"
27
- if s not in existing_names:
28
- return s
29
- index += 1
30
-
31
- # --- Hugging Face API Helpers for Dynamic UI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  hf_api = HfApi()
33
 
34
  def get_user_datasets(username: str):
35
  logging.info(f"Fetching datasets for user: {username}")
36
- if not username:
37
- return gr.update(choices=[], value=None, interactive=False)
38
  try:
39
  datasets = hf_api.list_datasets(author=username, full=True)
40
  dataset_ids = [d.id for d in datasets if not d.private]
41
- logging.info(f"Found {len(dataset_ids)} datasets for {username}.")
42
  return gr.update(choices=sorted(dataset_ids), value=None, interactive=True)
43
  except Exception as e:
44
- logging.error(f"Failed to fetch datasets for {username}: {e}")
45
- gr.Warning(f"Could not fetch datasets for user '{username}'.")
46
  return gr.update(choices=[], value=None, interactive=False)
47
 
48
  def get_dataset_splits(dataset_id: str):
49
- logging.info(f"Fetching splits for dataset: {dataset_id}")
50
- if not dataset_id:
51
- return gr.update(choices=[], value=None, interactive=False)
52
  try:
53
  splits = get_dataset_split_names(dataset_id)
54
- logging.info(f"Found splits for {dataset_id}: {splits}")
55
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
56
  except Exception as e:
57
- logging.error(f"Failed to fetch splits for {dataset_id}: {e}")
58
- gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'.")
59
  return gr.update(choices=[], value=None, interactive=False)
60
 
61
  def get_split_columns(dataset_id: str, split: str):
62
- logging.info(f"Fetching columns for: {dataset_id}, split: {split}")
63
- if not dataset_id or not split:
64
- return gr.update(choices=[], value=None, interactive=False)
65
  try:
66
  dataset_sample = load_dataset(dataset_id, split=split, streaming=True)
67
  first_row = next(iter(dataset_sample))
@@ -69,11 +81,9 @@ def get_split_columns(dataset_id: str, split: str):
69
  logging.info(f"Found columns: {columns}")
70
  preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
71
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
72
- logging.info(f"Best default column chosen: {best_col}")
73
  return gr.update(choices=columns, value=best_col, interactive=True)
74
  except Exception as e:
75
- logging.error(f"Failed to get columns for {dataset_id}/{split}: {e}", exc_info=True)
76
- gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
77
  return gr.update(choices=[], value=None, interactive=False)
78
 
79
  # --- Main Atlas Generation Logic ---
@@ -85,72 +95,233 @@ def generate_atlas(
85
  model_name: str,
86
  umap_neighbors: int,
87
  umap_min_dist: float,
88
- request: gr.Request,
89
  progress=gr.Progress(track_tqdm=True)
90
  ):
91
- """
92
- Loads data, computes embeddings, and serves the Embedding Atlas UI.
93
- """
94
  if not all([dataset_name, split, text_column]):
95
  raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.")
96
 
97
- progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
98
- try:
99
- dataset = load_dataset(dataset_name, split=split)
100
- df = dataset.to_pandas()
101
- except Exception as e:
102
- raise gr.Error(f"Failed to load data. Error: {e}")
103
-
104
  if sample_size > 0 and sample_size < len(df):
105
- progress(0.1, desc=f"Sampling {sample_size} rows...")
106
  df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
107
 
108
- if text_column not in df.columns:
109
- raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
110
-
111
  progress(0.2, desc="Computing embeddings and UMAP...")
112
  x_col = find_column_name(df.columns, "projection_x")
113
  y_col = find_column_name(df.columns, "projection_y")
114
  neighbors_col = find_column_name(df.columns, "__neighbors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  try:
117
- compute_text_projection(
118
- df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name,
119
- umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
120
- )
 
 
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
- raise gr.Error(f"Failed to compute embeddings. Check model name or sample size. Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  progress(0.8, desc="Preparing Atlas data source...")
125
  id_col = find_column_name(df.columns, "_row_index")
126
  df[id_col] = range(df.shape[0])
127
-
128
  metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
129
  hasher = Hasher()
130
- hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
131
  identifier = hasher.hexdigest()
132
  atlas_dataset = DataSource(identifier, df, metadata)
133
 
134
- progress(0.9, desc="Mounting visualization UI...")
135
  static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve())
136
- mount_path = f"/{uuid.uuid4().hex}"
137
  atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
138
 
139
- logging.info(f"Mounting FastAPI app at path: {mount_path}")
140
- request.app.mount(mount_path, atlas_app)
141
-
142
- progress(1.0, desc="Done!")
143
-
144
- # --- THE FINAL FIX: Add a trailing slash to the iframe src ---
145
- iframe_html = f"<iframe src='{mount_path}/' width='100%' height='800px' frameborder='0'></iframe>"
146
 
 
 
 
 
147
  return gr.HTML(iframe_html)
148
 
149
  # --- Gradio UI Definition ---
150
  with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
 
151
  gr.Markdown("# Embedding Atlas Explorer")
152
- gr.Markdown("Interactively select and visualize any text-based dataset from the Hugging Face Hub.")
153
-
154
  with gr.Row():
155
  with gr.Column(scale=1):
156
  gr.Markdown("### 1. Select Data")
@@ -173,18 +344,15 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
173
  gr.Markdown("### 3. Explore Atlas")
174
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
175
 
176
- # --- Chained Event Listeners for Dynamic UI ---
177
  hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
178
  dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
179
  split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
180
-
181
- # --- Button Click Event ---
182
  generate_button.click(
183
  fn=generate_atlas,
184
  inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
185
  outputs=[output_html],
186
  )
187
-
188
  app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
189
 
190
  if __name__ == "__main__":
 
6
  import pathlib
7
  import uuid
8
  import logging
9
+ import threading
10
+ import time
11
+ import socket
12
+ import uvicorn
13
 
14
  # --- Setup Logging ---
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
20
  from embedding_atlas.projection import compute_text_projection
21
  from embedding_atlas.utils import Hasher
22
 
23
+ # --- Helper functions ---
24
  def find_column_name(existing_names, candidate):
25
  if candidate not in existing_names:
26
  return candidate
27
+ index = 1
28
+ while True:
29
+ s = f"{candidate}_{index}"
30
+ if s not in existing_names:
31
+ return s
32
+ index += 1
33
+
34
+ def find_available_port(start_port: int, max_attempts: int = 100):
35
+ """Finds an available TCP port on the host."""
36
+ for port in range(start_port, start_port + max_attempts):
37
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
38
+ if s.connect_ex(('127.0.0.1', port)) != 0:
39
+ logging.info(f"Found available port: {port}")
40
+ return port
41
+ raise RuntimeError("Could not find an available port.")
42
+
43
+ def run_atlas_server(app, port):
44
+ """Target function for the background thread to run the Uvicorn server."""
45
+ logging.info(f"Starting Atlas server on http://127.0.0.1:{port}")
46
+ uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning")
47
+
48
+ # --- Hugging Face API Helpers ---
49
  hf_api = HfApi()
50
 
51
  def get_user_datasets(username: str):
52
  logging.info(f"Fetching datasets for user: {username}")
53
+ if not username: return gr.update(choices=[], value=None, interactive=False)
 
54
  try:
55
  datasets = hf_api.list_datasets(author=username, full=True)
56
  dataset_ids = [d.id for d in datasets if not d.private]
57
+ logging.info(f"Found {len(dataset_ids)} datasets.")
58
  return gr.update(choices=sorted(dataset_ids), value=None, interactive=True)
59
  except Exception as e:
60
+ logging.error(f"Failed to fetch datasets: {e}")
 
61
  return gr.update(choices=[], value=None, interactive=False)
62
 
63
  def get_dataset_splits(dataset_id: str):
64
+ logging.info(f"Fetching splits for: {dataset_id}")
65
+ if not dataset_id: return gr.update(choices=[], value=None, interactive=False)
 
66
  try:
67
  splits = get_dataset_split_names(dataset_id)
68
+ logging.info(f"Found splits: {splits}")
69
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
70
  except Exception as e:
71
+ logging.error(f"Failed to fetch splits: {e}")
 
72
  return gr.update(choices=[], value=None, interactive=False)
73
 
74
  def get_split_columns(dataset_id: str, split: str):
75
+ logging.info(f"Fetching columns for: {dataset_id}/{split}")
76
+ if not dataset_id or not split: return gr.update(choices=[], value=None, interactive=False)
 
77
  try:
78
  dataset_sample = load_dataset(dataset_id, split=split, streaming=True)
79
  first_row = next(iter(dataset_sample))
 
81
  logging.info(f"Found columns: {columns}")
82
  preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
83
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
 
84
  return gr.update(choices=columns, value=best_col, interactive=True)
85
  except Exception as e:
86
+ logging.error(f"Failed to get columns: {e}", exc_info=True)
 
87
  return gr.update(choices=[], value=None, interactive=False)
88
 
89
  # --- Main Atlas Generation Logic ---
 
95
  model_name: str,
96
  umap_neighbors: int,
97
  umap_min_dist: float,
 
98
  progress=gr.Progress(track_tqdm=True)
99
  ):
 
 
 
100
  if not all([dataset_name, split, text_column]):
101
  raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.")
102
 
103
+ progress(0, desc="Loading dataset...")
104
+ df = load_dataset(dataset_name, split=split).to_pandas()
 
 
 
 
 
105
  if sample_size > 0 and sample_size < len(df):
 
106
  df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
107
 
 
 
 
108
  progress(0.2, desc="Computing embeddings and UMAP...")
109
  x_col = find_column_name(df.columns, "projection_x")
110
  y_col = find_column_name(df.columns, "projection_y")
111
  neighbors_col = find_column_name(df.columns, "__neighbors")
112
+ compute_text_projection(
113
+ df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name,
114
+ umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
115
+ )
116
+
117
+ progress(0.8, desc="Preparing Atlas data source...")
118
+ id_col = find_column_name(df.columns, "_row_index")
119
+ df[id_col] = range(df.shape[0])
120
+ metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
121
+ hasher = Hasher()
122
+ hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}-{uuid.uuid4()}")
123
+ identifier = hasher.hexdigest()
124
+ atlas_dataset = DataSource(identifier, df, metadata)
125
+
126
+ progress(0.9, desc="Starting Atlas server...")
127
+ static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve())
128
+ atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
129
+
130
+ # Find an open port and run the server in a background thread
131
+ port = find_available_port(start_port=8001)
132
+ thread = threading.Thread(target=run_atlas_server, args=(atlas_app, port), daemon=True)
133
+ thread.start()
134
 
135
+ # Give the server a moment to start up
136
+ time.sleep(2)
137
+
138
+ iframe_html = f"<iframe src='http://127.0.0.1:{port}' width='100%' height='800px' frameborder='0'></iframe>"
139
+ return gr.HTML(iframe_html)
140
+
141
+ # --- Gradio UI Definition ---
142
+ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
143
+ # UI elements...
144
+ gr.Markdown("# Embedding Atlas Explorer")
145
+ # ... (rest of the UI is the same as before) ...
146
+ with gr.Row():
147
+ with gr.Column(scale=1):
148
+ gr.Markdown("### 1. Select Data")
149
+ hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'")
150
+ dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False)
151
+ split_input = gr.Dropdown(label="Select a Split", interactive=False)
152
+ text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False)
153
+
154
+ gr.Markdown("### 2. Configure Visualization")
155
+ sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100)
156
+
157
+ with gr.Accordion("Advanced Settings", open=False):
158
+ model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2")
159
+ umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.")
160
+ umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.")
161
+
162
+ generate_button = gr.Button("Generate Atlas", variant="primary")
163
+
164
+ with gr.Column(scale=3):
165
+ gr.Markdown("### 3. Explore Atlas")
166
+ output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
167
+
168
+ # --- Event Listeners ---
169
+ hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
170
+ dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
171
+ split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
172
+ generate_button.click(
173
+ fn=generate_atlas,
174
+ inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
175
+ outputs=[output_html],
176
+ )
177
+ app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
178
+
179
+ if __name__ == "__main__":
180
+ app.launch(debug=True)import gradio as gr
181
+ import pandas as pd
182
+ from datasets import load_dataset, get_dataset_split_names
183
+ from huggingface_hub import HfApi
184
+ import os
185
+ import pathlib
186
+ import uuid
187
+ import logging
188
+ import threading
189
+ import time
190
+ import socket
191
+ import uvicorn
192
+
193
+ # --- Setup Logging ---
194
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
195
+
196
+ # --- Embedding Atlas Imports ---
197
+ from embedding_atlas.data_source import DataSource
198
+ from embedding_atlas.server import make_server
199
+ from embedding_atlas.projection import compute_text_projection
200
+ from embedding_atlas.utils import Hasher
201
+
202
+ # --- Helper functions ---
203
+ def find_column_name(existing_names, candidate):
204
+ if candidate not in existing_names:
205
+ return candidate
206
+ index = 1
207
+ while True:
208
+ s = f"{candidate}_{index}"
209
+ if s not in existing_names:
210
+ return s
211
+ index += 1
212
+
213
+ def find_available_port(start_port: int, max_attempts: int = 100):
214
+ """Finds an available TCP port on the host."""
215
+ for port in range(start_port, start_port + max_attempts):
216
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
217
+ if s.connect_ex(('127.0.0.1', port)) != 0:
218
+ logging.info(f"Found available port: {port}")
219
+ return port
220
+ raise RuntimeError("Could not find an available port.")
221
+
222
+ def run_atlas_server(app, port):
223
+ """Target function for the background thread to run the Uvicorn server."""
224
+ logging.info(f"Starting Atlas server on http://127.0.0.1:{port}")
225
+ uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning")
226
+
227
+ # --- Hugging Face API Helpers ---
228
+ hf_api = HfApi()
229
+
230
+ def get_user_datasets(username: str):
231
+ logging.info(f"Fetching datasets for user: {username}")
232
+ if not username: return gr.update(choices=[], value=None, interactive=False)
233
  try:
234
+ datasets = hf_api.list_datasets(author=username, full=True)
235
+ dataset_ids = [d.id for d in datasets if not d.private]
236
+ logging.info(f"Found {len(dataset_ids)} datasets.")
237
+ return gr.update(choices=sorted(dataset_ids), value=None, interactive=True)
238
+ except Exception as e:
239
+ logging.error(f"Failed to fetch datasets: {e}")
240
+ return gr.update(choices=[], value=None, interactive=False)
241
+
242
+ def get_dataset_splits(dataset_id: str):
243
+ logging.info(f"Fetching splits for: {dataset_id}")
244
+ if not dataset_id: return gr.update(choices=[], value=None, interactive=False)
245
+ try:
246
+ splits = get_dataset_split_names(dataset_id)
247
+ logging.info(f"Found splits: {splits}")
248
+ return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
249
  except Exception as e:
250
+ logging.error(f"Failed to fetch splits: {e}")
251
+ return gr.update(choices=[], value=None, interactive=False)
252
+
253
+ def get_split_columns(dataset_id: str, split: str):
254
+ logging.info(f"Fetching columns for: {dataset_id}/{split}")
255
+ if not dataset_id or not split: return gr.update(choices=[], value=None, interactive=False)
256
+ try:
257
+ dataset_sample = load_dataset(dataset_id, split=split, streaming=True)
258
+ first_row = next(iter(dataset_sample))
259
+ columns = list(first_row.keys())
260
+ logging.info(f"Found columns: {columns}")
261
+ preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
262
+ best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
263
+ return gr.update(choices=columns, value=best_col, interactive=True)
264
+ except Exception as e:
265
+ logging.error(f"Failed to get columns: {e}", exc_info=True)
266
+ return gr.update(choices=[], value=None, interactive=False)
267
+
268
+ # --- Main Atlas Generation Logic ---
269
+ def generate_atlas(
270
+ dataset_name: str,
271
+ split: str,
272
+ text_column: str,
273
+ sample_size: int,
274
+ model_name: str,
275
+ umap_neighbors: int,
276
+ umap_min_dist: float,
277
+ progress=gr.Progress(track_tqdm=True)
278
+ ):
279
+ if not all([dataset_name, split, text_column]):
280
+ raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.")
281
+
282
+ progress(0, desc="Loading dataset...")
283
+ df = load_dataset(dataset_name, split=split).to_pandas()
284
+ if sample_size > 0 and sample_size < len(df):
285
+ df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
286
+
287
+ progress(0.2, desc="Computing embeddings and UMAP...")
288
+ x_col = find_column_name(df.columns, "projection_x")
289
+ y_col = find_column_name(df.columns, "projection_y")
290
+ neighbors_col = find_column_name(df.columns, "__neighbors")
291
+ compute_text_projection(
292
+ df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name,
293
+ umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
294
+ )
295
 
296
  progress(0.8, desc="Preparing Atlas data source...")
297
  id_col = find_column_name(df.columns, "_row_index")
298
  df[id_col] = range(df.shape[0])
 
299
  metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
300
  hasher = Hasher()
301
+ hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}-{uuid.uuid4()}")
302
  identifier = hasher.hexdigest()
303
  atlas_dataset = DataSource(identifier, df, metadata)
304
 
305
+ progress(0.9, desc="Starting Atlas server...")
306
  static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve())
 
307
  atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
308
 
309
+ # Find an open port and run the server in a background thread
310
+ port = find_available_port(start_port=8001)
311
+ thread = threading.Thread(target=run_atlas_server, args=(atlas_app, port), daemon=True)
312
+ thread.start()
 
 
 
313
 
314
+ # Give the server a moment to start up
315
+ time.sleep(2)
316
+
317
+ iframe_html = f"<iframe src='http://127.0.0.1:{port}' width='100%' height='800px' frameborder='0'></iframe>"
318
  return gr.HTML(iframe_html)
319
 
320
  # --- Gradio UI Definition ---
321
  with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
322
+ # UI elements...
323
  gr.Markdown("# Embedding Atlas Explorer")
324
+ # ... (rest of the UI is the same as before) ...
 
325
  with gr.Row():
326
  with gr.Column(scale=1):
327
  gr.Markdown("### 1. Select Data")
 
344
  gr.Markdown("### 3. Explore Atlas")
345
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
346
 
347
+ # --- Event Listeners ---
348
  hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
349
  dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
350
  split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
 
 
351
  generate_button.click(
352
  fn=generate_atlas,
353
  inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
354
  outputs=[output_html],
355
  )
 
356
  app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
357
 
358
  if __name__ == "__main__":