broadfield-dev commited on
Commit
808b711
·
verified ·
1 Parent(s): 23f8201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -35
app.py CHANGED
@@ -46,6 +46,7 @@ def get_dataset_splits(dataset_id: str):
46
  return gr.update(choices=[], value=None, interactive=False)
47
  try:
48
  splits = get_dataset_split_names(dataset_id)
 
49
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
50
  except Exception as e:
51
  gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
@@ -56,9 +57,7 @@ def get_split_columns(dataset_id: str, split: str):
56
  if not dataset_id or not split:
57
  return gr.update(choices=[], value=None, interactive=False)
58
  try:
59
- # --- THIS IS THE FIX ---
60
- # Instead of iterating, we get the .features property from the dataset info.
61
- # This is much faster and more reliable as it only fetches metadata.
62
  features = load_dataset(dataset_id, split=split, streaming=True).features
63
  columns = list(features.keys())
64
 
@@ -67,9 +66,7 @@ def get_split_columns(dataset_id: str, split: str):
67
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
68
  return gr.update(choices=columns, value=best_col, interactive=True)
69
  except Exception as e:
70
- # Adding a print statement here can help debug in the terminal
71
- print(f"Error fetching columns for {dataset_id}/{split}: {e}")
72
- gr.Warning(f"Could not fetch columns for split '{split}'. Check if the dataset requires special access. Error: {e}")
73
  return gr.update(choices=[], value=None, interactive=False)
74
 
75
  # --- Main Atlas Generation Logic ---
@@ -103,7 +100,7 @@ def generate_atlas(
103
  if text_column not in df.columns:
104
  raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
105
 
106
- progress(0.2, desc="Computing embeddings and UMAP. This may take a while...")
107
 
108
  x_col = find_column_name(df.columns, "projection_x")
109
  y_col = find_column_name(df.columns, "projection_y")
@@ -115,15 +112,13 @@ def generate_atlas(
115
  umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
116
  )
117
  except Exception as e:
118
- raise gr.Error(f"Failed to compute embeddings. Check model name or try a smaller sample. Error: {e}")
119
 
120
  progress(0.8, desc="Preparing Atlas data source...")
121
  id_col = find_column_name(df.columns, "_row_index")
122
  df[id_col] = range(df.shape[0])
123
 
124
- metadata = {
125
- "columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col},
126
- }
127
  hasher = Hasher()
128
  hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
129
  identifier = hasher.hexdigest()
@@ -143,10 +138,7 @@ def generate_atlas(
143
  # --- Gradio UI Definition ---
144
  with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
145
  gr.Markdown("# Embedding Atlas Explorer")
146
- gr.Markdown(
147
- "Interactively select and visualize any text-based dataset from the Hugging Face Hub. "
148
- "The app computes embeddings and projects them into a 2D map for exploration."
149
- )
150
 
151
  with gr.Row():
152
  with gr.Column(scale=1):
@@ -171,34 +163,23 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
171
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
172
 
173
  # --- Chained Event Listeners for Dynamic UI ---
174
- hf_user_input.submit(
175
- fn=get_user_datasets,
176
- inputs=[hf_user_input],
177
- outputs=[dataset_input]
178
- )
179
- dataset_input.select(
180
- fn=get_dataset_splits,
181
- inputs=[dataset_input],
182
- outputs=[split_input]
183
- )
184
- split_input.select(
185
- fn=get_split_columns,
186
- inputs=[dataset_input, split_input],
187
- outputs=[text_column_input]
188
- )
189
 
190
  # --- Button Click Event ---
191
  generate_button.click(
192
  fn=generate_atlas,
193
- inputs=[
194
- dataset_input, split_input, text_column_input,
195
- sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input
196
- ],
197
  outputs=[output_html],
198
  )
199
 
200
  # Load initial example data on app load
201
- app.load(fn=get_user_datasets, inputs=[hf_user_input], outputs=[dataset_input])
202
 
203
  if __name__ == "__main__":
204
  app.launch(debug=True)
 
46
  return gr.update(choices=[], value=None, interactive=False)
47
  try:
48
  splits = get_dataset_split_names(dataset_id)
49
+ # Set the first split as the default value to trigger the next event
50
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
51
  except Exception as e:
52
  gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
 
57
  if not dataset_id or not split:
58
  return gr.update(choices=[], value=None, interactive=False)
59
  try:
60
+ # Get the .features property from the dataset info.
 
 
61
  features = load_dataset(dataset_id, split=split, streaming=True).features
62
  columns = list(features.keys())
63
 
 
66
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
67
  return gr.update(choices=columns, value=best_col, interactive=True)
68
  except Exception as e:
69
+ gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
 
 
70
  return gr.update(choices=[], value=None, interactive=False)
71
 
72
  # --- Main Atlas Generation Logic ---
 
100
  if text_column not in df.columns:
101
  raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
102
 
103
+ progress(0.2, desc="Computing embeddings and UMAP...")
104
 
105
  x_col = find_column_name(df.columns, "projection_x")
106
  y_col = find_column_name(df.columns, "projection_y")
 
112
  umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
113
  )
114
  except Exception as e:
115
+ raise gr.Error(f"Failed to compute embeddings. Check model name or sample size. Error: {e}")
116
 
117
  progress(0.8, desc="Preparing Atlas data source...")
118
  id_col = find_column_name(df.columns, "_row_index")
119
  df[id_col] = range(df.shape[0])
120
 
121
+ metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
 
 
122
  hasher = Hasher()
123
  hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
124
  identifier = hasher.hexdigest()
 
138
  # --- Gradio UI Definition ---
139
  with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
140
  gr.Markdown("# Embedding Atlas Explorer")
141
+ gr.Markdown("Interactively select and visualize any text-based dataset from the Hugging Face Hub.")
 
 
 
142
 
143
  with gr.Row():
144
  with gr.Column(scale=1):
 
163
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
164
 
165
  # --- Chained Event Listeners for Dynamic UI ---
166
+ # When the user submits a name, get their datasets
167
+ hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
168
+
169
+ # --- THIS IS THE FIX ---
170
+ # Use .change() so that when a dataset is selected (by user OR another function), it triggers the next step.
171
+ dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
172
+ split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
 
 
 
 
 
 
 
 
173
 
174
  # --- Button Click Event ---
175
  generate_button.click(
176
  fn=generate_atlas,
177
+ inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
 
 
 
178
  outputs=[output_html],
179
  )
180
 
181
  # Load initial example data on app load
182
+ app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
183
 
184
  if __name__ == "__main__":
185
  app.launch(debug=True)