broadfield-dev commited on
Commit
c84ca1f
·
verified ·
1 Parent(s): cd69d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -1,10 +1,14 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from datasets import load_dataset, get_dataset_split_names, get_dataset_config_info
4
  from huggingface_hub import HfApi
5
  import os
6
  import pathlib
7
  import uuid
 
 
 
 
8
 
9
  # --- Embedding Atlas Imports ---
10
  from embedding_atlas.data_source import DataSource
@@ -30,45 +34,54 @@ hf_api = HfApi()
30
 
31
  def get_user_datasets(username: str):
32
  """Fetches all public datasets for a given username or organization."""
 
33
  if not username:
34
  return gr.update(choices=[], value=None, interactive=False)
35
  try:
36
  datasets = hf_api.list_datasets(author=username, full=True)
37
  dataset_ids = [d.id for d in datasets if not d.private]
 
38
  return gr.update(choices=sorted(dataset_ids), value=None, interactive=True)
39
  except Exception as e:
40
- gr.Warning(f"Could not fetch datasets for user '{username}'. Error: {e}")
 
41
  return gr.update(choices=[], value=None, interactive=False)
42
 
43
  def get_dataset_splits(dataset_id: str):
44
  """Gets all available splits for a selected dataset."""
 
45
  if not dataset_id:
46
  return gr.update(choices=[], value=None, interactive=False)
47
  try:
48
- # --- FIX: Removed trust_remote_code=True ---
49
  splits = get_dataset_split_names(dataset_id)
 
50
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
51
  except Exception as e:
52
- gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'. Error: {e}")
 
53
  return gr.update(choices=[], value=None, interactive=False)
54
 
55
- def get_split_columns(dataset_id: str):
56
- """Gets all columns for a selected dataset by loading its metadata info."""
57
- if not dataset_id:
 
58
  return gr.update(choices=[], value=None, interactive=False)
59
  try:
60
- # --- FIX: Removed trust_remote_code=True ---
61
- info = get_dataset_config_info(dataset_id)
62
- features = info.features
63
-
64
- columns = list(features.keys())
65
 
 
66
  preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
67
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
 
68
 
69
  return gr.update(choices=columns, value=best_col, interactive=True)
70
  except Exception as e:
71
- gr.Warning(f"Could not get columns for '{dataset_id}'. It might be a gated dataset or have an unusual structure. Error: {e}")
 
72
  return gr.update(choices=[], value=None, interactive=False)
73
 
74
  # --- Main Atlas Generation Logic ---
@@ -90,9 +103,7 @@ def generate_atlas(
90
 
91
  progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
92
  try:
93
- # Here, trust_remote_code can be useful if the dataset actually needs it.
94
- # It's less likely to crash here than in the metadata functions.
95
- dataset = load_dataset(dataset_name, split=split, trust_remote_code=True)
96
  df = dataset.to_pandas()
97
  except Exception as e:
98
  raise gr.Error(f"Failed to load data. Error: {e}")
@@ -166,11 +177,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Embedding Atlas Explorer") as app:
166
  gr.Markdown("### 3. Explore Atlas")
167
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
168
 
169
- # --- Chained Event Listeners for Dynamic UI ---
170
  hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
171
 
172
  dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
173
- dataset_input.change(fn=get_split_columns, inputs=dataset_input, outputs=text_column_input)
 
 
174
 
175
  # --- Button Click Event ---
176
  generate_button.click(
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from datasets import load_dataset, get_dataset_split_names
4
  from huggingface_hub import HfApi
5
  import os
6
  import pathlib
7
  import uuid
8
+ import logging
9
+
10
+ # --- Setup Logging ---
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
  # --- Embedding Atlas Imports ---
14
  from embedding_atlas.data_source import DataSource
 
34
 
35
  def get_user_datasets(username: str):
36
  """Fetches all public datasets for a given username or organization."""
37
+ logging.info(f"Fetching datasets for user: {username}")
38
  if not username:
39
  return gr.update(choices=[], value=None, interactive=False)
40
  try:
41
  datasets = hf_api.list_datasets(author=username, full=True)
42
  dataset_ids = [d.id for d in datasets if not d.private]
43
+ logging.info(f"Found {len(dataset_ids)} datasets for {username}.")
44
  return gr.update(choices=sorted(dataset_ids), value=None, interactive=True)
45
  except Exception as e:
46
+ logging.error(f"Failed to fetch datasets for {username}: {e}")
47
+ gr.Warning(f"Could not fetch datasets for user '{username}'.")
48
  return gr.update(choices=[], value=None, interactive=False)
49
 
50
  def get_dataset_splits(dataset_id: str):
51
  """Gets all available splits for a selected dataset."""
52
+ logging.info(f"Fetching splits for dataset: {dataset_id}")
53
  if not dataset_id:
54
  return gr.update(choices=[], value=None, interactive=False)
55
  try:
 
56
  splits = get_dataset_split_names(dataset_id)
57
+ logging.info(f"Found splits for {dataset_id}: {splits}")
58
  return gr.update(choices=splits, value=splits[0] if splits else None, interactive=True)
59
  except Exception as e:
60
+ logging.error(f"Failed to fetch splits for {dataset_id}: {e}")
61
+ gr.Warning(f"Could not fetch splits for dataset '{dataset_id}'.")
62
  return gr.update(choices=[], value=None, interactive=False)
63
 
64
+ def get_split_columns(dataset_id: str, split: str):
65
+ """Gets all columns for a selected split by loading one row."""
66
+ logging.info(f"Fetching columns for: {dataset_id}, split: {split}")
67
+ if not dataset_id or not split:
68
  return gr.update(choices=[], value=None, interactive=False)
69
  try:
70
+ # This is the most robust method: stream one row and get its keys.
71
+ dataset_sample = load_dataset(dataset_id, split=split, streaming=True)
72
+ first_row = next(iter(dataset_sample))
73
+ columns = list(first_row.keys())
74
+ logging.info(f"Found columns: {columns}")
75
 
76
+ # Heuristically find the best text column
77
  preferred_cols = ['text', 'content', 'instruction', 'question', 'document', 'prompt']
78
  best_col = next((col for col in preferred_cols if col in columns), columns[0] if columns else None)
79
+ logging.info(f"Best default column chosen: {best_col}")
80
 
81
  return gr.update(choices=columns, value=best_col, interactive=True)
82
  except Exception as e:
83
+ logging.error(f"Failed to get columns for {dataset_id}/{split}: {e}", exc_info=True)
84
+ gr.Warning(f"Could not fetch columns for split '{split}'. Error: {e}")
85
  return gr.update(choices=[], value=None, interactive=False)
86
 
87
  # --- Main Atlas Generation Logic ---
 
103
 
104
  progress(0, desc=f"Loading dataset '{dataset_name}' [{split}]...")
105
  try:
106
+ dataset = load_dataset(dataset_name, split=split)
 
 
107
  df = dataset.to_pandas()
108
  except Exception as e:
109
  raise gr.Error(f"Failed to load data. Error: {e}")
 
177
  gr.Markdown("### 3. Explore Atlas")
178
  output_html = gr.HTML("<div style='display:flex; justify-content:center; align-items:center; height:800px; border: 1px solid #ddd; border-radius: 5px;'><p>Atlas will be displayed here after generation.</p></div>")
179
 
180
+ # --- Chained Event Listeners for Dynamic UI (CORRECTED LOGIC) ---
181
  hf_user_input.submit(fn=get_user_datasets, inputs=hf_user_input, outputs=dataset_input)
182
 
183
  dataset_input.change(fn=get_dataset_splits, inputs=dataset_input, outputs=split_input)
184
+
185
+ # This is the critical fix: The columns are populated only AFTER a split is chosen.
186
+ split_input.change(fn=get_split_columns, inputs=[dataset_input, split_input], outputs=text_column_input)
187
 
188
  # --- Button Click Event ---
189
  generate_button.click(