broadfield-dev commited on
Commit
701c41c
·
verified ·
1 Parent(s): f8c307f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -121
app.py CHANGED
@@ -1,21 +1,18 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from datasets import load_dataset
 
4
  import os
5
  import pathlib
6
  import uuid
7
- import hashlib
8
 
9
  # --- Embedding Atlas Imports ---
10
- # We will import the necessary components directly from the library
11
  from embedding_atlas.data_source import DataSource
12
  from embedding_atlas.server import make_server
13
  from embedding_atlas.projection import compute_text_projection
14
- # Hasher is correctly located in the utils module
15
  from embedding_atlas.utils import Hasher
16
 
17
  # --- Helper function from embedding_atlas/cli.py ---
18
- # To make the script self-contained, we copy this small helper function here.
19
  def find_column_name(existing_names, candidate):
20
  """Finds a unique column name, appending '_1', '_2', etc. if the candidate name already exists."""
21
  if candidate not in existing_names:
@@ -28,20 +25,54 @@ def find_column_name(existing_names, candidate):
28
  return s
29
  index += 1
30
 
31
- # --- Global State ---
32
- # We need to keep track of the mounted app to avoid errors on re-runs.
33
- # A dictionary to store unique app instances for each run.
34
- mounted_apps = {}
35
 
36
- def get_atlas_static_path():
37
- """Finds the path to the static files for the embedding-atlas frontend."""
38
- import embedding_atlas
39
- return str((pathlib.Path(embedding_atlas.__file__).parent / "static").resolve())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
41
  def generate_atlas(
42
  dataset_name: str,
43
- text_column: str,
44
  split: str,
 
45
  sample_size: int,
46
  model_name: str,
47
  umap_neighbors: int,
@@ -51,169 +82,123 @@ def generate_atlas(
51
  """
52
  Loads data, computes embeddings, and serves the Embedding Atlas UI.
53
  """
54
- global mounted_apps
55
-
56
- # --- 1. Load Data ---
57
- progress(0, desc=f"Loading dataset '{dataset_name}'...")
58
  try:
59
- # Load the dataset from Hugging Face
60
  dataset = load_dataset(dataset_name, split=split)
61
  df = dataset.to_pandas()
62
  except Exception as e:
63
- raise gr.Error(f"Failed to load dataset. Please check the name and split. Error: {e}")
64
 
65
- # --- 2. Sample Data (if requested) ---
66
  if sample_size > 0 and sample_size < len(df):
67
  progress(0.1, desc=f"Sampling {sample_size} rows...")
68
  df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
69
 
70
- # Check if the text column exists
71
  if text_column not in df.columns:
72
- raise gr.Error(f"Column '{text_column}' not found in the dataset. Available columns: {', '.join(df.columns)}")
73
 
74
- # --- 3. Compute Embeddings & UMAP Projection ---
75
- progress(0.2, desc="Computing embeddings and UMAP projection. This may take a while...")
76
 
77
- x_column = find_column_name(df.columns, "projection_x")
78
- y_column = find_column_name(df.columns, "projection_y")
79
- neighbors_column = find_column_name(df.columns, "__neighbors")
80
 
81
  try:
82
  compute_text_projection(
83
- df,
84
- text_column,
85
- x=x_column,
86
- y=y_column,
87
- neighbors=neighbors_column,
88
- model=model_name,
89
- umap_args={
90
- "n_neighbors": umap_neighbors,
91
- "min_dist": umap_min_dist,
92
- "metric": "cosine",
93
- "random_state": 42,
94
- },
95
  )
96
  except Exception as e:
97
  raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")
98
 
99
- # --- 4. Prepare Atlas DataSource ---
100
  progress(0.8, desc="Preparing Atlas data source...")
101
- id_column = find_column_name(df.columns, "_row_index")
102
- df[id_column] = range(df.shape[0])
103
 
104
  metadata = {
105
- "columns": {
106
- "id": id_column,
107
- "text": text_column,
108
- "embedding": {"x": x_column, "y": y_column},
109
- "neighbors": neighbors_column,
110
- },
111
  }
112
-
113
- # Create a unique identifier for the dataset to avoid conflicts
114
  hasher = Hasher()
115
- hasher.update(f"{dataset_name}-{text_column}-{sample_size}-{model_name}")
116
  identifier = hasher.hexdigest()
117
-
118
  atlas_dataset = DataSource(identifier, df, metadata)
119
- static_path = get_atlas_static_path()
120
 
121
- # --- 5. Create and Mount the FastAPI App ---
122
  progress(0.9, desc="Mounting visualization UI...")
123
-
124
- # Generate a unique path for this instance to avoid conflicts on remounting
125
  mount_path = f"/{uuid.uuid4().hex}"
126
-
127
- # Create the server instance
128
  atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
129
 
130
- # The `blocks` object is global in the Gradio script context.
131
- # We mount the atlas server onto the main Gradio FastAPI app.
132
- gr.mount_gradio_app(app, atlas_app, path=mount_path)
133
-
134
- mounted_apps[mount_path] = atlas_app # Store it for potential cleanup later
135
 
136
  progress(1.0, desc="Done!")
137
-
138
- # --- 6. Return an IFrame pointing to the mounted path ---
139
  iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
140
-
141
  return gr.HTML(iframe_html)
142
 
143
-
144
  # --- Gradio UI Definition ---
145
  with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
146
  gr.Markdown("# Embedding Atlas Explorer")
147
  gr.Markdown(
148
- "Visualize any text column from a Hugging Face dataset. This app loads the data, "
149
- "computes embeddings using Sentence Transformers, reduces dimensionality with UMAP, "
150
- "and displays the result in an interactive Embedding Atlas."
151
  )
152
 
153
  with gr.Row():
154
  with gr.Column(scale=1):
155
- gr.Markdown("### 1. Configuration")
156
- dataset_input = gr.Textbox(
157
- label="Hugging Face Dataset Name",
158
- value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset"
159
- )
160
- text_column_input = gr.Textbox(
161
- label="Text Column to Visualize",
162
- value="instruction"
163
- )
164
- split_input = gr.Textbox(label="Dataset Split", value="train")
165
- sample_size_input = gr.Slider(
166
- label="Number of Samples (0 for all)",
167
- minimum=0,
168
- maximum=5000,
169
- value=2000,
170
- step=100
171
- )
172
-
173
  with gr.Accordion("Advanced Settings", open=False):
174
- model_input = gr.Dropdown(
175
- label="Embedding Model",
176
- choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"],
177
- value="all-MiniLM-L6-v2",
178
- )
179
- umap_neighbors_input = gr.Slider(
180
- label="UMAP Neighbors",
181
- minimum=2,
182
- maximum=100,
183
- value=15,
184
- step=1,
185
- info="Controls how UMAP balances local vs. global structure."
186
- )
187
- umap_min_dist_input = gr.Slider(
188
- label="UMAP Min Distance",
189
- minimum=0.0,
190
- maximum=0.99,
191
- value=0.1,
192
- step=0.01,
193
- info="Controls how tightly UMAP packs points together."
194
- )
195
 
196
  generate_button = gr.Button("Generate Atlas", variant="primary")
197
 
198
  with gr.Column(scale=3):
199
- gr.Markdown("### 2. Visualization")
200
- output_html = gr.HTML(
201
- "<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>"
202
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
 
204
  generate_button.click(
205
  fn=generate_atlas,
206
  inputs=[
207
- dataset_input,
208
- text_column_input,
209
- split_input,
210
- sample_size_input,
211
- model_input,
212
- umap_neighbors_input,
213
- umap_min_dist_input,
214
  ],
215
  outputs=[output_html],
216
  )
 
 
 
217
 
218
  if __name__ == "__main__":
 
 
219
  app.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from datasets import load_dataset, get_dataset_split_names
4
+ from huggingface_hub import HfApi, HfFolder
5
  import os
6
  import pathlib
7
  import uuid
 
8
 
9
  # --- Embedding Atlas Imports ---
 
10
  from embedding_atlas.data_source import DataSource
11
  from embedding_atlas.server import make_server
12
  from embedding_atlas.projection import compute_text_projection
 
13
  from embedding_atlas.utils import Hasher
14
 
15
  # --- Helper function from embedding_atlas/cli.py ---
 
16
  def find_column_name(existing_names, candidate):
17
  """Finds a unique column name, appending '_1', '_2', etc. if the candidate name already exists."""
18
  if candidate not in existing_names:
 
25
  return s
26
  index += 1
27
 
28
+ # --- Hugging Face API Helpers for Dynamic UI ---
29
+ hf_api = HfApi()
 
 
30
 
31
+ def get_user_datasets(username: str):
32
+ """Fetches all public datasets for a given username or organization."""
33
+ if not username:
34
+ return gr.Dropdown.update(choices=[], value=None, interactive=False)
35
+ try:
36
+ datasets = hf_api.list_datasets(author=username, cardData=True)
37
+ dataset_ids = [d.id for d in datasets if not d.private]
38
+ return gr.Dropdown.update(choices=sorted(dataset_ids), value=None, interactive=True)
39
+ except Exception as e:
40
+ gr.Warning(f"Could not fetch datasets for user '{username}'. Error: {e}")
41
+ return gr.Dropdown.update(choices=[], value=None, interactive=False)
42
+
43
+ def get_dataset_splits(dataset_id: str):
44
+ """Gets all available splits for a selected dataset."""
45
+ if not dataset_id:
46
+ return gr.Dropdown.update(choices=[], value=None, interactive=False)
47
+ try:
48
+ splits = get_dataset_split_names(dataset_id)
49
+ return gr.Dropdown.update(choices=splits, value=splits[0] if splits else None, interactive=True)
50
+ except Exception as e:
51
+ gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
52
+ return gr.Dropdown.update(choices=[], value=None, interactive=False)
53
+
54
+ def get_split_columns(dataset_id: str, split: str):
55
+ """Gets all columns for a selected split by loading one row."""
56
+ if not dataset_id or not split:
57
+ return gr.Dropdown.update(choices=[], value=None, interactive=False)
58
+ try:
59
+ # Stream one row to get column names without downloading the whole dataset
60
+ dataset_sample = load_dataset(dataset_id, split=split, streaming=True)
61
+ first_row = next(iter(dataset_sample))
62
+ columns = list(first_row.keys())
63
+ # Heuristically find the best text column
64
+ preferred_cols = ['text', 'content', 'instruction', 'question', 'document']
65
+ best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
66
+ return gr.Dropdown.update(choices=columns, value=best_col, interactive=True)
67
+ except Exception as e:
68
+ gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
69
+ return gr.Dropdown.update(choices=[], value=None, interactive=False)
70
 
71
+ # --- Main Atlas Generation Logic ---
72
  def generate_atlas(
73
  dataset_name: str,
 
74
  split: str,
75
+ text_column: str,
76
  sample_size: int,
77
  model_name: str,
78
  umap_neighbors: int,
 
82
  """
83
  Loads data, computes embeddings, and serves the Embedding Atlas UI.
84
  """
85
+ if not all([dataset_name, split, text_column]):
86
+ raise gr.Error("Please ensure a Dataset, Split, and Text Column are selected.")
87
+
88
+ progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
89
  try:
 
90
  dataset = load_dataset(dataset_name, split=split)
91
  df = dataset.to_pandas()
92
  except Exception as e:
93
+ raise gr.Error(f"Failed to load data. Error: {e}")
94
 
 
95
  if sample_size > 0 and sample_size < len(df):
96
  progress(0.1, desc=f"Sampling {sample_size} rows...")
97
  df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
98
 
 
99
  if text_column not in df.columns:
100
+ raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
101
 
102
+ progress(0.2, desc="Computing embeddings and UMAP. This may take a while...")
 
103
 
104
+ x_col = find_column_name(df.columns, "projection_x")
105
+ y_col = find_column_name(df.columns, "projection_y")
106
+ neighbors_col = find_column_name(df.columns, "__neighbors")
107
 
108
  try:
109
  compute_text_projection(
110
+ df, text_column, x=x_col, y=y_col, neighbors=neighbors_col, model=model_name,
111
+ umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
 
 
 
 
 
 
 
 
 
 
112
  )
113
  except Exception as e:
114
  raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")
115
 
 
116
  progress(0.8, desc="Preparing Atlas data source...")
117
+ id_col = find_column_name(df.columns, "_row_index")
118
+ df[id_col] = range(df.shape[0])
119
 
120
  metadata = {
121
+ "columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col},
 
 
 
 
 
122
  }
 
 
123
  hasher = Hasher()
124
+ hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
125
  identifier = hasher.hexdigest()
 
126
  atlas_dataset = DataSource(identifier, df, metadata)
 
127
 
 
128
  progress(0.9, desc="Mounting visualization UI...")
129
+ static_path = str((pathlib.Path(__import__('embedding_atlas').__file__).parent / "static").resolve())
 
130
  mount_path = f"/{uuid.uuid4().hex}"
 
 
131
  atlas_app = make_server(atlas_dataset, static_path=static_path, duckdb_uri="wasm")
132
 
133
+ # --- THIS IS THE FIX ---
134
+ # Call mount_gradio_app on the Blocks instance `app`
135
+ app.mount_gradio_app(atlas_app, path=mount_path)
 
 
136
 
137
  progress(1.0, desc="Done!")
 
 
138
  iframe_html = f"<iframe src='{mount_path}' width='100%' height='800px' frameborder='0'></iframe>"
 
139
  return gr.HTML(iframe_html)
140
 
 
141
  # --- Gradio UI Definition ---
142
  with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
143
  gr.Markdown("# Embedding Atlas Explorer")
144
  gr.Markdown(
145
+ "Interactively select and visualize any text-based dataset from the Hugging Face Hub. "
146
+ "The app computes embeddings and projects them into a 2D map for exploration."
 
147
  )
148
 
149
  with gr.Row():
150
  with gr.Column(scale=1):
151
+ gr.Markdown("### 1. Select Data")
152
+ hf_user_input = gr.Textbox(label="Hugging Face User or Org Name", value="Trendyol", placeholder="e.g., 'gradio' or 'google'")
153
+ dataset_input = gr.Dropdown(label="Select a Dataset", interactive=False)
154
+ split_input = gr.Dropdown(label="Select a Split", interactive=False)
155
+ text_column_input = gr.Dropdown(label="Select a Text Column", interactive=False)
156
+
157
+ gr.Markdown("### 2. Configure Visualization")
158
+ sample_size_input = gr.Slider(label="Number of Samples", minimum=0, maximum=10000, value=2000, step=100)
159
+
 
 
 
 
 
 
 
 
 
160
  with gr.Accordion("Advanced Settings", open=False):
161
+ model_input = gr.Dropdown(label="Embedding Model", choices=["all-MiniLM-L6-v2", "all-mpnet-base-v2", "multi-qa-MiniLM-L6-cos-v1"], value="all-MiniLM-L6-v2")
162
+ umap_neighbors_input = gr.Slider(label="UMAP Neighbors", minimum=2, maximum=100, value=15, step=1, info="Controls local vs. global structure.")
163
+ umap_min_dist_input = gr.Slider(label="UMAP Min Distance", minimum=0.0, maximum=0.99, value=0.1, step=0.01, info="Controls how tightly points are packed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  generate_button = gr.Button("Generate Atlas", variant="primary")
166
 
167
  with gr.Column(scale=3):
168
+ gr.Markdown("### 3. Explore Atlas")
169
+ output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
170
+
171
+ # --- Chained Event Listeners for Dynamic UI ---
172
+ hf_user_input.submit(
173
+ fn=get_user_datasets,
174
+ inputs=[hf_user_input],
175
+ outputs=[dataset_input]
176
+ )
177
+ dataset_input.select(
178
+ fn=get_dataset_splits,
179
+ inputs=[dataset_input],
180
+ outputs=[split_input]
181
+ )
182
+ split_input.select(
183
+ fn=get_split_columns,
184
+ inputs=[dataset_input, split_input],
185
+ outputs=[text_column_input]
186
+ )
187
 
188
+ # --- Button Click Event ---
189
  generate_button.click(
190
  fn=generate_atlas,
191
  inputs=[
192
+ dataset_input, split_input, text_column_input,
193
+ sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input
 
 
 
 
 
194
  ],
195
  outputs=[output_html],
196
  )
197
+
198
+ # Load initial example data on app load
199
+ app.load(fn=get_user_datasets, inputs=[hf_user_input], outputs=[dataset_input])
200
 
201
  if __name__ == "__main__":
202
+ # To run locally, you might need to log in to Hugging Face Hub
203
+ # HfFolder.save_token("YOUR_HF_TOKEN")
204
  app.launch()