Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -46,6 +46,7 @@ def get_dataset_splits(dataset_id: str):
|
|
46 |
return gr.update(choices=[], value=None, interactive=False)
|
47 |
try:
|
48 |
splits = get_dataset_split_names(dataset_id)
|
|
|
49 |
return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
|
50 |
except Exception as e:
|
51 |
gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
|
@@ -56,9 +57,7 @@ def get_split_columns(dataset_id: str, split: str):
|
|
56 |
if not dataset_id or not split:
|
57 |
return gr.update(choices=[], value=None, interactive=False)
|
58 |
try:
|
59 |
-
#
|
60 |
-
# Instead of iterating, we get the .features property from the dataset info.
|
61 |
-
# This is much faster and more reliable as it only fetches metadata.
|
62 |
features = load_dataset(dataset_id, split=split, streaming=True).features
|
63 |
columns = list(features.keys())
|
64 |
|
@@ -67,9 +66,7 @@ def get_split_columns(dataset_id: str, split: str):
|
|
67 |
best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
|
68 |
return gr.update(choices=columns, value=best_col, interactive=True)
|
69 |
except Exception as e:
|
70 |
-
|
71 |
-
print(f"Error fetching columns for {dataset_id}/{split}: {e}")
|
72 |
-
gr.Warning(f"Could not fetch columns for split '{split}'. Check if the dataset requires special access. Error: {e}")
|
73 |
return gr.update(choices=[], value=None, interactive=False)
|
74 |
|
75 |
# --- Main Atlas Generation Logic ---
|
@@ -103,7 +100,7 @@ def generate_atlas(
|
|
103 |
if text_column not in df.columns:
|
104 |
raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
|
105 |
|
106 |
-
progress(0.2, desc="Computing embeddings and UMAP
|
107 |
|
108 |
x_col = find_column_name(df.columns, "projection_x")
|
109 |
y_col = find_column_name(df.columns, "projection_y")
|
@@ -115,15 +112,13 @@ def generate_atlas(
|
|
115 |
umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
|
116 |
)
|
117 |
except Exception as e:
|
118 |
-
raise gr.Error(f"Failed to compute embeddings. Check model name or
|
119 |
|
120 |
progress(0.8, desc="Preparing Atlas data source...")
|
121 |
id_col = find_column_name(df.columns, "_row_index")
|
122 |
df[id_col] = range(df.shape[0])
|
123 |
|
124 |
-
metadata = {
|
125 |
-
"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col},
|
126 |
-
}
|
127 |
hasher = Hasher()
|
128 |
hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
|
129 |
identifier = hasher.hexdigest()
|
@@ -143,10 +138,7 @@ def generate_atlas(
|
|
143 |
# --- Gradio UI Definition ---
|
144 |
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
|
145 |
gr.Markdown("# Embedding Atlas Explorer")
|
146 |
-
gr.Markdown(
|
147 |
-
"Interactively select and visualize any text-based dataset from the Hugging Face Hub. "
|
148 |
-
"The app computes embeddings and projects them into a 2D map for exploration."
|
149 |
-
)
|
150 |
|
151 |
with gr.Row():
|
152 |
with gr.Column(scale=1):
|
@@ -171,34 +163,23 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
|
|
171 |
output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
|
172 |
|
173 |
# --- Chained Event Listeners for Dynamic UI ---
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
)
|
179 |
-
dataset_input.
|
180 |
-
|
181 |
-
inputs=[dataset_input],
|
182 |
-
outputs=[split_input]
|
183 |
-
)
|
184 |
-
split_input.select(
|
185 |
-
fn=get_split_columns,
|
186 |
-
inputs=[dataset_input, split_input],
|
187 |
-
outputs=[text_column_input]
|
188 |
-
)
|
189 |
|
190 |
# --- Button Click Event ---
|
191 |
generate_button.click(
|
192 |
fn=generate_atlas,
|
193 |
-
inputs=[
|
194 |
-
dataset_input, split_input, text_column_input,
|
195 |
-
sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input
|
196 |
-
],
|
197 |
outputs=[output_html],
|
198 |
)
|
199 |
|
200 |
# Load initial example data on app load
|
201 |
-
app.load(fn=get_user_datasets, inputs=
|
202 |
|
203 |
if __name__ == "__main__":
|
204 |
app.launch(debug=True)
|
|
|
46 |
return gr.update(choices=[], value=None, interactive=False)
|
47 |
try:
|
48 |
splits = get_dataset_split_names(dataset_id)
|
49 |
+
# Set the first split as the default value to trigger the next event
|
50 |
return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
|
51 |
except Exception as e:
|
52 |
gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
|
|
|
57 |
if not dataset_id or not split:
|
58 |
return gr.update(choices=[], value=None, interactive=False)
|
59 |
try:
|
60 |
+
# Get the .features property from the dataset info.
|
|
|
|
|
61 |
features = load_dataset(dataset_id, split=split, streaming=True).features
|
62 |
columns = list(features.keys())
|
63 |
|
|
|
66 |
best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
|
67 |
return gr.update(choices=columns, value=best_col, interactive=True)
|
68 |
except Exception as e:
|
69 |
+
gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
|
|
|
|
|
70 |
return gr.update(choices=[], value=None, interactive=False)
|
71 |
|
72 |
# --- Main Atlas Generation Logic ---
|
|
|
100 |
if text_column not in df.columns:
|
101 |
raise gr.Error(f"Column '{text_column}' not found. Please select a valid column.")
|
102 |
|
103 |
+
progress(0.2, desc="Computing embeddings and UMAP...")
|
104 |
|
105 |
x_col = find_column_name(df.columns, "projection_x")
|
106 |
y_col = find_column_name(df.columns, "projection_y")
|
|
|
112 |
umap_args={"n_neighbors": umap_neighbors, "min_dist": umap_min_dist, "metric": "cosine", "random_state": 42},
|
113 |
)
|
114 |
except Exception as e:
|
115 |
+
raise gr.Error(f"Failed to compute embeddings. Check model name or sample size. Error: {e}")
|
116 |
|
117 |
progress(0.8, desc="Preparing Atlas data source...")
|
118 |
id_col = find_column_name(df.columns, "_row_index")
|
119 |
df[id_col] = range(df.shape[0])
|
120 |
|
121 |
+
metadata = {"columns": {"id": id_col, "text": text_column, "embedding": {"x": x_col, "y": y_col}, "neighbors": neighbors_col}}
|
|
|
|
|
122 |
hasher = Hasher()
|
123 |
hasher.update(f"{dataset_name}-{split}-{text_column}-{sample_size}-{model_name}")
|
124 |
identifier = hasher.hexdigest()
|
|
|
138 |
# --- Gradio UI Definition ---
|
139 |
with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
|
140 |
gr.Markdown("# Embedding Atlas Explorer")
|
141 |
+
gr.Markdown("Interactively select and visualize any text-based dataset from the Hugging Face Hub.")
|
|
|
|
|
|
|
142 |
|
143 |
with gr.Row():
|
144 |
with gr.Column(scale=1):
|
|
|
163 |
output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
|
164 |
|
165 |
# --- Chained Event Listeners for Dynamic UI ---
|
166 |
+
# When the user submits a name, get their datasets
|
167 |
+
hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
|
168 |
+
|
169 |
+
# --- THIS IS THE FIX ---
|
170 |
+
# Use .change() so that when a dataset is selected (by user OR another function), it triggers the next step.
|
171 |
+
dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
|
172 |
+
split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
# --- Button Click Event ---
|
175 |
generate_button.click(
|
176 |
fn=generate_atlas,
|
177 |
+
inputs=[dataset_input, split_input, text_column_input, sample_size_input, model_input, umap_neighbors_input, umap_min_dist_input],
|
|
|
|
|
|
|
178 |
outputs=[output_html],
|
179 |
)
|
180 |
|
181 |
# Load initial example data on app load
|
182 |
+
app.load(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
|
183 |
|
184 |
if __name__ == "__main__":
|
185 |
app.launch(debug=True)
|