ZennyKenny commited on
Commit
91cbc46
Β·
verified Β·
1 Parent(s): 349af26

try new dataset handling

Browse files
Files changed (1) hide show
  1. app.py +115 -100
app.py CHANGED
@@ -1,79 +1,102 @@
 
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import matplotlib
4
  import matplotlib.pyplot as plt
5
- import pandas as pd
6
 
7
  from datasets import load_dataset
8
  from sklearn.ensemble import GradientBoostingClassifier
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.metrics import accuracy_score, confusion_matrix
11
 
12
- matplotlib.use('Agg') # Avoid issues in some remote environments
 
13
 
14
- # Pre-populate a short list of "recommended" Hugging Face datasets
15
- # (Replace "datasorg/iris" etc. with real dataset IDs you want to showcase)
 
 
 
 
16
  SUGGESTED_DATASETS = [
17
- "datasorg/iris", # hypothetical ID
18
- "uciml/wine_quality-red", # example from the HF Hub
19
- "SKIP/ENTER_CUSTOM" # We'll treat this as a "separator" or "prompt" for custom
20
  ]
21
 
22
- def load_and_prepare_dataset(dataset_id, label_column, feature_columns):
 
23
  """
24
- Loads a dataset from the Hugging Face Hub,
25
- converts it to a pandas DataFrame,
26
- returns X, y as NumPy arrays for modeling.
27
  """
28
- # Load only the "train" split for simplicity
29
- # Many datasets have "train", "test", "validation" splits
30
- ds = load_dataset(dataset_id, split="train")
31
-
32
- # Convert to a DataFrame for easy manipulation
33
- df = pd.DataFrame(ds)
34
-
35
- # Subset to selected columns
36
- if label_column not in df.columns:
37
- raise ValueError(f"Label column '{label_column}' not in dataset columns: {df.columns.to_list()}")
38
-
39
- for col in feature_columns:
40
- if col not in df.columns:
41
- raise ValueError(f"Feature column '{col}' not in dataset columns: {df.columns.to_list()}")
42
-
43
- # Split into X and y
44
- X = df[feature_columns].values
45
- y = df[label_column].values
46
-
47
- return X, y, df.columns.tolist()
48
 
49
- def train_model(dataset_id, custom_dataset_id, label_column, feature_columns,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  learning_rate, n_estimators, max_depth, test_size):
51
  """
52
- 1. Determine final dataset ID (either from dropdown or custom text).
53
- 2. Load dataset -> DataFrame -> X, y.
54
- 3. Train a GradientBoostingClassifier.
55
- 4. Generate plots & metrics (accuracy and confusion matrix).
 
 
56
  """
57
-
58
- # Decide which dataset ID to use
59
  if dataset_id != "SKIP/ENTER_CUSTOM":
60
  final_id = dataset_id
61
  else:
62
- # Use the user-supplied "custom_dataset_id"
63
  final_id = custom_dataset_id.strip()
64
 
65
- # Prepare data
66
- X, y, columns_available = load_and_prepare_dataset(
67
- final_id,
68
- label_column,
69
- feature_columns
70
- )
71
-
72
- # Train/test split
 
 
 
 
 
 
 
 
73
  X_train, X_test, y_train, y_test = train_test_split(
74
  X, y, test_size=test_size, random_state=42
75
  )
76
-
77
  # Train model
78
  clf = GradientBoostingClassifier(
79
  learning_rate=learning_rate,
@@ -82,13 +105,15 @@ def train_model(dataset_id, custom_dataset_id, label_column, feature_columns,
82
  random_state=42
83
  )
84
  clf.fit(X_train, y_train)
85
-
86
- # Evaluate
87
  y_pred = clf.predict(X_test)
88
  accuracy = accuracy_score(y_test, y_pred)
89
  cm = confusion_matrix(y_test, y_pred)
90
 
91
- # Plot figure
 
 
92
  fig, axs = plt.subplots(1, 2, figsize=(10, 4))
93
 
94
  # Subplot 1: Feature Importances
@@ -103,90 +128,80 @@ def train_model(dataset_id, custom_dataset_id, label_column, feature_columns,
103
  im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
104
  axs[1].set_title("Confusion Matrix")
105
  plt.colorbar(im, ax=axs[1])
106
- # Labeling
107
  axs[1].set_xlabel("Predicted")
108
  axs[1].set_ylabel("True")
109
 
110
- # If you want to annotate each cell:
111
  thresh = cm.max() / 2.0
112
  for i in range(cm.shape[0]):
113
  for j in range(cm.shape[1]):
114
  color = "white" if cm[i, j] > thresh else "black"
115
- axs[1].text(j, i, format(cm[i, j], "d"), ha="center", va="center", color=color)
116
 
117
  plt.tight_layout()
118
 
119
- output_text = f"**Dataset used:** {final_id}\n\n"
120
- output_text += f"**Accuracy:** {accuracy:.3f}\n\n"
121
- output_text += "**Confusion Matrix** (raw counts above)."
 
 
 
 
122
 
123
- return output_text, fig, columns_available
124
 
125
- def update_columns(dataset_id, dataset_config, custom_dataset_id):
126
- """
127
- Load the dataset from HF hub, using either the suggested one or the custom user-specified,
128
- plus an optional config.
129
- """
130
- if dataset_id != "SKIP/ENTER_CUSTOM":
131
- final_id = dataset_id
132
- final_config = dataset_config.strip() if dataset_config else None
133
- else:
134
- # Use the user-supplied text
135
- final_id = custom_dataset_id.strip()
136
- final_config = None # or parse from text if you like
137
-
138
- try:
139
- if final_config:
140
- ds = load_dataset(final_id, final_config, split="train")
141
- else:
142
- ds = load_dataset(final_id, split="train")
143
- df = pd.DataFrame(ds)
144
- cols = df.columns.tolist()
145
- return gr.update(choices=cols), gr.update(choices=cols), f"Columns found: {cols}"
146
- except Exception as e:
147
- return gr.update(choices=[]), gr.update(choices=[]), f"Error loading {final_id}: {e}"
148
 
 
149
  with gr.Blocks() as demo:
150
- gr.Markdown("## Train GradientBoostingClassifier on a Hugging Face dataset of your choice")
151
-
 
 
 
 
 
 
 
 
 
152
  with gr.Row():
153
  dataset_dropdown = gr.Dropdown(
 
154
  choices=SUGGESTED_DATASETS,
155
- value=SUGGESTED_DATASETS[0],
156
- label="Choose a dataset"
 
 
 
157
  )
158
- custom_dataset_id = gr.Textbox(label="Or enter HF dataset (user/dataset)", value="",
159
- placeholder="e.g. 'username/my_custom_dataset'")
160
-
161
- # Button to load columns from the chosen dataset
162
- load_cols_btn = gr.Button("Load columns")
163
  load_cols_info = gr.Markdown()
164
-
 
165
  with gr.Row():
166
  label_col = gr.Dropdown(choices=[], label="Label column (choose 1)")
167
  feature_cols = gr.CheckboxGroup(choices=[], label="Feature columns (choose 1 or more)")
168
 
169
- # Once columns are chosen, we can set hyperparams
170
  learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
171
  n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
172
  max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
173
- test_size_slider = gr.Slider(0.1, 0.9, value=0.3, step=0.1, label="test_size (fraction)")
174
 
175
  train_button = gr.Button("Train & Evaluate")
176
-
177
  output_text = gr.Markdown()
178
  output_plot = gr.Plot()
179
- # We might also want to show the columns for reference post-training
180
- columns_return = gr.Markdown()
181
 
182
- # When "Load columns" is clicked, we call update_columns to fetch the dataset columns
183
  load_cols_btn.click(
184
  fn=update_columns,
185
  inputs=[dataset_dropdown, custom_dataset_id],
186
- outputs=[label_col, feature_cols, load_cols_info]
187
  )
188
 
189
- # When "Train & Evaluate" is clicked, we train the model
190
  train_button.click(
191
  fn=train_model,
192
  inputs=[
@@ -199,7 +214,7 @@ with gr.Blocks() as demo:
199
  max_depth_slider,
200
  test_size_slider
201
  ],
202
- outputs=[output_text, output_plot, columns_return]
203
  )
204
 
205
  demo.launch()
 
1
+ # app.py
2
+
3
  import gradio as gr
4
  import numpy as np
5
+ import pandas as pd
6
  import matplotlib
7
  import matplotlib.pyplot as plt
 
8
 
9
  from datasets import load_dataset
10
  from sklearn.ensemble import GradientBoostingClassifier
11
  from sklearn.model_selection import train_test_split
12
  from sklearn.metrics import accuracy_score, confusion_matrix
13
 
14
+ # In some remote environments, Matplotlib needs to be set to 'Agg' backend
15
+ matplotlib.use('Agg')
16
 
17
+ ################################################################################
18
+ # SUGGESTED_DATASETS: Must actually exist on huggingface.co/datasets.
19
+ #
20
+ # "scikit-learn/iris" -> a tabular Iris dataset with a "train" split of 150 rows.
21
+ # "uci/wine" -> a tabular Wine dataset with a "train" split of 178 rows.
22
+ ################################################################################
23
  SUGGESTED_DATASETS = [
24
+ "scikit-learn/iris",
25
+ "uci/wine",
26
+ "SKIP/ENTER_CUSTOM" # a placeholder meaning "use custom_dataset_id"
27
  ]
28
 
29
+
30
+ def update_columns(dataset_id, custom_dataset_id):
31
  """
32
+ Loads the chosen dataset (train split) and returns its column names,
33
+ to populate the Label Column & Feature Columns selectors.
 
34
  """
35
+ # If user picked a suggested dataset (not SKIP), use that
36
+ if dataset_id != "SKIP/ENTER_CUSTOM":
37
+ final_id = dataset_id
38
+ else:
39
+ # Use the user-supplied dataset ID
40
+ final_id = custom_dataset_id.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ try:
43
+ # Load just the "train" split; many HF datasets have train/test/validation
44
+ ds = load_dataset(final_id, split="train")
45
+ df = pd.DataFrame(ds)
46
+ cols = df.columns.tolist()
47
+
48
+ message = f"**Loaded dataset**: {final_id}\n\n**Columns found**: {cols}"
49
+ # Return list of columns for both label & features
50
+ return (
51
+ gr.update(choices=cols, value=None), # label_col dropdown
52
+ gr.update(choices=cols, value=[]), # feature_cols checkbox group
53
+ message
54
+ )
55
+ except Exception as e:
56
+ # If load fails or dataset doesn't exist
57
+ err_msg = f"**Error loading** `{final_id}`: {e}"
58
+ return (
59
+ gr.update(choices=[], value=None),
60
+ gr.update(choices=[], value=[]),
61
+ err_msg
62
+ )
63
+
64
+
65
+ def train_model(dataset_id, custom_dataset_id, label_column, feature_columns,
66
  learning_rate, n_estimators, max_depth, test_size):
67
  """
68
+ 1. Determine the final dataset ID (from dropdown or custom text).
69
+ 2. Load the dataset -> create dataframe -> X, y.
70
+ 3. Train GradientBoostingClassifier.
71
+ 4. Return metrics (accuracy) and a Matplotlib figure with:
72
+ - Feature importance bar chart
73
+ - Confusion matrix heatmap
74
  """
 
 
75
  if dataset_id != "SKIP/ENTER_CUSTOM":
76
  final_id = dataset_id
77
  else:
 
78
  final_id = custom_dataset_id.strip()
79
 
80
+ # Load dataset
81
+ ds = load_dataset(final_id, split="train")
82
+ df = pd.DataFrame(ds)
83
+
84
+ # Basic validation
85
+ if label_column not in df.columns:
86
+ raise ValueError(f"Label column '{label_column}' not found in dataset columns.")
87
+ for fc in feature_columns:
88
+ if fc not in df.columns:
89
+ raise ValueError(f"Feature column '{fc}' not found in dataset columns.")
90
+
91
+ # Build X, y arrays
92
+ X = df[feature_columns].values
93
+ y = df[label_column].values
94
+
95
+ # Split
96
  X_train, X_test, y_train, y_test = train_test_split(
97
  X, y, test_size=test_size, random_state=42
98
  )
99
+
100
  # Train model
101
  clf = GradientBoostingClassifier(
102
  learning_rate=learning_rate,
 
105
  random_state=42
106
  )
107
  clf.fit(X_train, y_train)
108
+
109
+ # Predictions & metrics
110
  y_pred = clf.predict(X_test)
111
  accuracy = accuracy_score(y_test, y_pred)
112
  cm = confusion_matrix(y_test, y_pred)
113
 
114
+ # Build a single figure with 2 subplots:
115
+ # 1) Feature importances
116
+ # 2) Confusion matrix heatmap
117
  fig, axs = plt.subplots(1, 2, figsize=(10, 4))
118
 
119
  # Subplot 1: Feature Importances
 
128
  im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
129
  axs[1].set_title("Confusion Matrix")
130
  plt.colorbar(im, ax=axs[1])
 
131
  axs[1].set_xlabel("Predicted")
132
  axs[1].set_ylabel("True")
133
 
134
+ # Optionally annotate each cell with the count
135
  thresh = cm.max() / 2.0
136
  for i in range(cm.shape[0]):
137
  for j in range(cm.shape[1]):
138
  color = "white" if cm[i, j] > thresh else "black"
139
+ axs[1].text(j, i, str(cm[i, j]), ha="center", va="center", color=color)
140
 
141
  plt.tight_layout()
142
 
143
+ # Build textual summary
144
+ text_summary = (
145
+ f"**Dataset used**: `{final_id}`\n\n"
146
+ f"**Label column**: `{label_column}`\n\n"
147
+ f"**Feature columns**: `{feature_columns}`\n\n"
148
+ f"**Accuracy**: {accuracy:.3f}\n\n"
149
+ )
150
 
151
+ return text_summary, fig
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ # Build the Gradio Blocks UI
155
  with gr.Blocks() as demo:
156
+ gr.Markdown("# Train a GradientBoostingClassifier on any HF Dataset\n")
157
+ gr.Markdown(
158
+ "1. Choose a suggested dataset from the dropdown **or** enter a custom dataset ID in the format `user/dataset`.\n"
159
+ "2. Click **Load Columns** to inspect the columns.\n"
160
+ "3. Pick a **Label column** and **Feature columns**.\n"
161
+ "4. Adjust hyperparameters and click **Train & Evaluate**.\n"
162
+ "5. Observe accuracy, feature importances, and a confusion matrix heatmap.\n\n"
163
+ "*(Note: the dataset must have a `train` split!)*"
164
+ )
165
+
166
+ # Row 1: Dataset selection
167
  with gr.Row():
168
  dataset_dropdown = gr.Dropdown(
169
+ label="Choose suggested dataset",
170
  choices=SUGGESTED_DATASETS,
171
+ value=SUGGESTED_DATASETS[0] # default
172
+ )
173
+ custom_dataset_id = gr.Textbox(
174
+ label="Or enter a custom dataset ID",
175
+ placeholder="e.g. username/my_custom_dataset"
176
  )
177
+
178
+ load_cols_btn = gr.Button("Load Columns")
 
 
 
179
  load_cols_info = gr.Markdown()
180
+
181
+ # Row 2: label & feature columns
182
  with gr.Row():
183
  label_col = gr.Dropdown(choices=[], label="Label column (choose 1)")
184
  feature_cols = gr.CheckboxGroup(choices=[], label="Feature columns (choose 1 or more)")
185
 
186
+ # Hyperparameters
187
  learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate")
188
  n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators")
189
  max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth")
190
+ test_size_slider = gr.Slider(0.1, 0.9, value=0.3, step=0.1, label="test_size fraction (0.1-0.9)")
191
 
192
  train_button = gr.Button("Train & Evaluate")
193
+
194
  output_text = gr.Markdown()
195
  output_plot = gr.Plot()
 
 
196
 
197
+ # Link the "Load Columns" button -> update_columns function
198
  load_cols_btn.click(
199
  fn=update_columns,
200
  inputs=[dataset_dropdown, custom_dataset_id],
201
+ outputs=[label_col, feature_cols, load_cols_info],
202
  )
203
 
204
+ # Link "Train & Evaluate" -> train_model function
205
  train_button.click(
206
  fn=train_model,
207
  inputs=[
 
214
  max_depth_slider,
215
  test_size_slider
216
  ],
217
+ outputs=[output_text, output_plot],
218
  )
219
 
220
  demo.launch()