broadfield-dev commited on
Commit
98fe021
·
verified ·
1 Parent(s): 808b711

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from datasets import load_dataset, get_dataset_split_names
4
  from huggingface_hub import HfApi
5
  import os
6
  import pathlib
@@ -45,28 +45,33 @@ def get_dataset_splits(dataset_id: str):
45
  if not dataset_id:
46
  return gr.update(choices=[], value=None, interactive=False)
47
  try:
48
- splits = get_dataset_split_names(dataset_id)
49
- # Set the first split as the default value to trigger the next event
50
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
51
  except Exception as e:
52
  gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
53
  return gr.update(choices=[], value=None, interactive=False)
54
 
55
- def get_split_columns(dataset_id: str, split: str):
56
- """Gets all columns for a selected split by loading its metadata."""
57
- if not dataset_id or not split:
58
  return gr.update(choices=[], value=None, interactive=False)
59
  try:
60
- # Get the .features property from the dataset info.
61
- features = load_dataset(dataset_id, split=split, streaming=True).features
 
 
 
 
 
62
  columns = list(features.keys())
63
 
64
- # Heuristically find the best text column
65
  preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
66
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
 
67
  return gr.update(choices=columns, value=best_col, interactive=True)
68
  except Exception as e:
69
- gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
70
  return gr.update(choices=[], value=None, interactive=False)
71
 
72
  # --- Main Atlas Generation Logic ---
@@ -88,7 +93,7 @@ def generate_atlas(
88
 
89
  progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
90
  try:
91
- dataset = load_dataset(dataset_name, split=split)
92
  df = dataset.to_pandas()
93
  except Exception as e:
94
  raise gr.Error(f"Failed to load data. Error: {e}")
@@ -163,14 +168,14 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
163
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
164
 
165
  # --- Chained Event Listeners for Dynamic UI ---
166
- # When the user submits a name, get their datasets
167
  hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
168
 
169
- # --- THIS IS THE FIX ---
170
- # Use .change() so that when a dataset is selected (by user OR another function), it triggers the next step.
171
  dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
172
- split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
173
-
 
 
174
  # --- Button Click Event ---
175
  generate_button.click(
176
  fn=generate_atlas,
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from datasets import load_dataset, get_dataset_split_names, get_dataset_config_info
4
  from huggingface_hub import HfApi
5
  import os
6
  import pathlib
 
45
  if not dataset_id:
46
  return gr.update(choices=[], value=None, interactive=False)
47
  try:
48
+ splits = get_dataset_split_names(dataset_id, trust_remote_code=True)
 
49
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
50
  except Exception as e:
51
  gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
52
  return gr.update(choices=[], value=None, interactive=False)
53
 
54
+ def get_split_columns(dataset_id: str):
55
+ """Gets all columns for a selected dataset by loading its metadata info."""
56
+ if not dataset_id:
57
  return gr.update(choices=[], value=None, interactive=False)
58
  try:
59
+ # --- THIS IS THE ROBUST FIX ---
60
+ # Use get_dataset_config_info to get schema without loading data.
61
+ # This is the official and most reliable way.
62
+ info = get_dataset_config_info(dataset_id, trust_remote_code=True)
63
+ features = info.features
64
+
65
+ # The user is right, we should show ALL columns.
66
  columns = list(features.keys())
67
 
68
+ # We can still be helpful by guessing the best default.
69
  preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
70
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
71
+
72
  return gr.update(choices=columns, value=best_col, interactive=True)
73
  except Exception as e:
74
+ gr.Warning(f"Could not get columns for '{dataset_id}'. It might be a gated dataset or have an unusual structure. Error: {e}")
75
  return gr.update(choices=[], value=None, interactive=False)
76
 
77
  # --- Main Atlas Generation Logic ---
 
93
 
94
  progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
95
  try:
96
+ dataset = load_dataset(dataset_name, split=split, trust_remote_code=True)
97
  df = dataset.to_pandas()
98
  except Exception as e:
99
  raise gr.Error(f"Failed to load data. Error: {e}")
 
168
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
169
 
170
  # --- Chained Event Listeners for Dynamic UI ---
 
171
  hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
172
 
173
+ # When a dataset is selected, get its splits.
 
174
  dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
175
+
176
+ # When a dataset is selected, ALSO get its columns. The split doesn't matter for column schema.
177
+ dataset_input.change(fn=get_split_columns, inputs=dataset_input, outputs=text_column_input)
178
+
179
  # --- Button Click Event ---
180
  generate_button.click(
181
  fn=generate_atlas,