C10X commited on
Commit
9949adf
Β·
verified Β·
1 Parent(s): 531701e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -170
app.py CHANGED
@@ -14,7 +14,7 @@ import unicodedata
14
  from multiprocessing import cpu_count
15
  from transformers import LlamaTokenizerFast
16
  import fasttext
17
- from typing import Tuple, Dict, List
18
  import json
19
  import matplotlib.pyplot as plt
20
  import seaborn as sns
@@ -50,13 +50,6 @@ css = """
50
  background-color: #ff8534 !important;
51
  border-color: #ff8534 !important;
52
  }
53
- .gr-button-secondary {
54
- background-color: #475467 !important;
55
- }
56
- #login-button {
57
- background-color: #FFD21E !important;
58
- color: #000000 !important;
59
- }
60
  """
61
 
62
  # HTML templates
@@ -67,44 +60,32 @@ TITLE = """
67
  </div>
68
  """
69
 
70
- # FIXED: Added `color: #444;` to ensure text is visible on the light background.
71
- DESCRIPTION = """
72
- <div style="padding: 20px; background-color: #f0f0f0; border-radius: 10px; margin-bottom: 20px; color: #444;">
73
- <h3>πŸ“‹ How it works:</h3>
74
- <ol>
75
- <li>Choose a dataset from Hugging Face Hub.</li>
76
- <li>The Ultra-FineWeb classifier will score each text sample.</li>
77
- <li>View quality distribution and download the scored dataset.</li>
78
- <li>Optionally, upload the results to a new repository on your Hugging Face account.</li>
79
- </ol>
80
- <p><strong>Note:</strong> The first run will download the model (~347MB), which may take a moment.</p>
81
- </div>
82
  """
83
 
84
  # --- Helper Functions ---
85
  def escape(s: str) -> str:
86
  """Escape HTML for safe display"""
87
- s = str(s).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace('"', "&quot;").replace("\n", "<br/>")
88
- return s
89
 
90
  def fasttext_preprocess(content: str, tokenizer) -> str:
91
- """Preprocess text for FastText model"""
92
- if not isinstance(content, str):
93
- return ""
94
  content = re.sub(r'\n{3,}', '\n\n', content).lower()
95
- content = ''.join(c for c in unicodedata.normalize('NFKD', content)
96
- if unicodedata.category(c) != 'Mn')
97
  token_ids = tokenizer.encode(content, add_special_tokens=False)
98
- single_text_list = [tokenizer.decode([token_id]) for token_id in token_ids]
99
- content = ' '.join(single_text_list)
100
- content = re.sub(r'\n', ' n ', content)
101
- content = re.sub(r'\r', '', content)
102
- content = re.sub(r'\t', ' ', content)
103
- content = re.sub(r' +', ' ', content).strip()
104
- return content
105
 
106
  def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
107
- """Run FastText inference"""
108
  pred_label, pred_prob = model.predict(norm_content)
109
  pred_label = pred_label[0]
110
  _score = min(pred_prob.tolist()[0], 1.0)
@@ -113,62 +94,38 @@ def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
113
  return pred_label, _score
114
 
115
  def load_models():
116
- """Load models with caching"""
117
  global MODEL_LOADED, fasttext_model, tokenizer
118
- if MODEL_LOADED:
119
- return True
120
-
121
  try:
122
  model_dir = MODEL_CACHE_DIR / "Ultra-FineWeb-classifier"
123
  if not model_dir.exists():
124
- print("Downloading Ultra-FineWeb-classifier...")
125
  snapshot_download(repo_id="openbmb/Ultra-FineWeb-classifier", local_dir=str(model_dir), local_dir_use_symlinks=False)
126
-
127
  fasttext_path = model_dir / "classifiers" / "ultra_fineweb_en.bin"
128
  tokenizer_path = model_dir / "local_tokenizer"
129
-
130
- if not fasttext_path.exists():
131
- raise FileNotFoundError(f"FastText model not found at {fasttext_path}")
132
-
133
- print("Loading models...")
134
  fasttext_model = fasttext.load_model(str(fasttext_path))
135
  tokenizer = LlamaTokenizerFast.from_pretrained(str(tokenizer_path) if tokenizer_path.exists() else "meta-llama/Llama-2-7b-hf")
136
-
137
  MODEL_LOADED = True
138
- print("Models loaded successfully!")
139
  return True
140
  except Exception as e:
141
- print(f"Error loading models: {e}")
142
  gr.Warning(f"Failed to load models: {e}")
143
  return False
144
 
145
  def create_quality_plot(scores: List[float], dataset_name: str) -> str:
146
- """Create quality distribution plot"""
147
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
148
  output_path = tmpfile.name
149
-
150
  plt.figure(figsize=(10, 6))
151
- sns.histplot(scores, bins=50, kde=True, color='#6B7FD7', edgecolor='black', line_kws={'linewidth': 2, 'color': 'red'})
152
-
153
- mean_score = np.mean(scores)
154
- median_score = np.median(scores)
155
-
156
  plt.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.3f}')
157
  plt.axvline(median_score, color='orange', linestyle=':', linewidth=2, label=f'Median: {median_score:.3f}')
158
-
159
- plt.xlabel('Quality Score', fontsize=12)
160
- plt.ylabel('Density', fontsize=12)
161
- plt.title(f'Quality Score Distribution - {dataset_name}', fontsize=14, fontweight='bold')
162
- plt.legend()
163
- plt.grid(axis='y', alpha=0.3)
164
- plt.xlim(0, 1)
165
-
166
- plt.tight_layout()
167
- plt.savefig(output_path, dpi=150, bbox_inches='tight')
168
  plt.close()
169
-
170
  return output_path
171
 
 
172
  def process_dataset(
173
  model_id: str,
174
  dataset_split: str,
@@ -176,82 +133,80 @@ def process_dataset(
176
  sample_size: int,
177
  batch_size: int,
178
  progress=gr.Progress(track_tqdm=True)
179
- ) -> Tuple[str, str, str, str, gr.update, gr.update]:
180
- """Process dataset and return results, including visibility updates for UI components."""
181
-
 
 
 
 
 
 
 
182
  try:
183
- progress(0, desc="Loading models...")
 
 
184
  if not load_models():
185
- raise gr.Error("Failed to load scoring models. Please check the logs.")
 
186
 
187
- progress(0.1, desc="Loading dataset...")
188
  dataset = load_dataset(model_id, split=dataset_split, streaming=False)
 
189
 
190
  if text_column not in dataset.column_names:
191
- raise gr.Error(f"Column '{text_column}' not found. Available columns: {', '.join(dataset.column_names)}")
192
 
193
- total_samples = len(dataset)
194
- actual_samples = min(sample_size, total_samples)
195
  dataset = dataset.select(range(actual_samples))
196
 
 
197
  scores, scored_data = [], []
198
-
199
  for i in tqdm(range(0, actual_samples, batch_size), desc="Scoring batches"):
200
  batch = dataset[i:min(i + batch_size, actual_samples)]
201
  for text in batch[text_column]:
202
  norm_content = fasttext_preprocess(text, tokenizer)
203
- label, score = (0.0, "__label__neg") if not norm_content else fasttext_infer(norm_content, fasttext_model)
204
  scores.append(score)
205
  scored_data.append({'text': text, 'quality_score': score, 'predicted_label': label})
206
 
207
- progress(0.9, desc="Generating statistics and plot...")
208
- stats_dict = {
209
- 'dataset_id': model_id,
210
- 'dataset_split': dataset_split,
211
- 'processed_samples': actual_samples,
212
- 'statistics': {
213
- 'mean': float(np.mean(scores)), 'median': float(np.median(scores)),
214
- 'std': float(np.std(scores)), 'min': float(np.min(scores)), 'max': float(np.max(scores)),
215
- 'p90': float(np.percentile(scores, 90)),
216
- },
217
- }
218
 
219
  plot_file = create_quality_plot(scores, model_id.split('/')[-1])
220
 
221
- with tempfile.NamedTemporaryFile(mode='w', suffix=".jsonl", delete=False, encoding='utf-8') as f_out:
222
- output_file_path = f_out.name
223
- for item in scored_data:
224
- f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
225
 
226
- with tempfile.NamedTemporaryFile(mode='w', suffix=".json", delete=False, encoding='utf-8') as f_stats:
227
- stats_file_path = f_stats.name
228
- json.dump(stats_dict, f_stats, indent=2)
229
 
230
- summary_html = f"""
231
- <div style="padding: 15px; background-color: #f9f9f9; border-radius: 10px;">
232
- <h4>βœ… Scoring Completed!</h4>
233
- <p><strong>Dataset:</strong> {escape(model_id)}<br>
234
- <strong>Processed Samples:</strong> {actual_samples:,}<br>
235
- <strong>Mean Score:</strong> {stats_dict['statistics']['mean']:.3f}<br>
236
- <strong>Median Score:</strong> {stats_dict['statistics']['median']:.3f}</p>
237
- </div>
238
  """
239
 
240
- return summary_html, output_file_path, stats_file_path, plot_file, gr.update(visible=True), gr.update(visible=True)
241
-
 
 
 
242
  except Exception as e:
243
- error_html = f"""
244
- <div style="padding: 20px; background-color: #fee; border: 1px solid #d00; border-radius: 10px;">
245
- <h4>❌ Error</h4><pre style="white-space: pre-wrap; font-size: 14px;">{escape(e)}</pre>
246
- </div>
247
- """
248
- return error_html, None, None, None, gr.update(visible=False), gr.update(visible=False)
249
 
 
250
  def upload_to_hub(
251
  scored_file: str, stats_file: str, plot_file: str, new_dataset_id: str,
252
  private: bool, hf_token: str, progress=gr.Progress(track_tqdm=True)
253
  ) -> str:
254
- """Upload results to Hugging Face Hub"""
255
  if not hf_token: return '❌ <span style="color: red;">Please provide your Hugging Face token.</span>'
256
  if not all([scored_file, new_dataset_id]): return '❌ <span style="color: red;">Missing scored file or new dataset ID.</span>'
257
 
@@ -264,95 +219,99 @@ def upload_to_hub(
264
  progress(0.2, desc=f"Creating repo: {repo_id}")
265
  repo_url = create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=private, token=hf_token).repo_url
266
 
267
- progress(0.4, desc="Uploading scored dataset...")
268
  upload_file(path_or_fileobj=scored_file, path_in_repo="data/scored_dataset.jsonl", repo_id=repo_id, repo_type="dataset", token=hf_token)
269
-
270
- progress(0.6, desc="Uploading assets...")
271
  if stats_file and os.path.exists(stats_file):
272
  upload_file(path_or_fileobj=stats_file, path_in_repo="statistics.json", repo_id=repo_id, repo_type="dataset", token=hf_token)
273
  if plot_file and os.path.exists(plot_file):
274
  upload_file(path_or_fileobj=plot_file, path_in_repo="quality_distribution.png", repo_id=repo_id, repo_type="dataset", token=hf_token)
275
 
276
- readme_content = dedent(f"""
277
- ---
278
- license: apache-2.0
279
- ---
280
- # Quality-Scored Dataset: {repo_id.split('/')[-1]}
281
- This dataset was scored for quality using the [Dataset Quality Scorer Space](https://huggingface.co/spaces/ggml-org/dataset-quality-scorer).
282
- ![Quality Distribution](quality_distribution.png)
283
- ## Usage
284
- ```python
285
- from datasets import load_dataset
286
- dataset = load_dataset("{repo_id}", split="train")
287
- ```
288
- """).strip()
289
 
290
  upload_file(path_or_fileobj=readme_content.encode(), path_in_repo="README.md", repo_id=repo_id, repo_type="dataset", token=hf_token)
291
  progress(1.0, "Done!")
292
  return f'βœ… <span style="color: green;">Successfully uploaded to <a href="{repo_url}" target="_blank">{repo_id}</a></span>'
293
 
294
  except Exception as e:
295
- return f'❌ <span style="color: red;">Upload failed: {escape(e)}</span>'
 
296
 
297
  def create_demo():
298
  with gr.Blocks(css=css, title="Dataset Quality Scorer") as demo:
299
  gr.HTML(TITLE)
300
- gr.HTML(DESCRIPTION)
301
 
302
- gr.Markdown("### 1. Configure & Score Dataset")
303
  with gr.Row():
304
  with gr.Column(scale=3):
 
305
  dataset_search = HuggingfaceHubSearch(label="Hub Dataset ID", search_type="dataset", value="roneneldan/TinyStories")
306
- text_column = gr.Textbox(label="Text Column Name", value="text", info="The column containing the text to score.")
307
  with gr.Column(scale=2):
 
308
  dataset_split = gr.Dropdown(["train", "validation", "test"], label="Split", value="train")
309
  with gr.Row():
310
- sample_size = gr.Number(label="Sample Size", value=1000, minimum=100, step=100, info="Max samples.")
311
- batch_size = gr.Number(label="Batch Size", value=32, minimum=1, step=1, info="Processing batch.")
312
 
 
 
313
  with gr.Row():
314
  clear_btn = gr.Button("Clear", variant="secondary")
315
  process_btn = gr.Button("πŸš€ Start Scoring", variant="primary", size="lg")
316
 
317
  # --- Results and Upload Sections (Initially Hidden) ---
318
- with gr.Accordion("βœ… Results", open=True, visible=False) as results_accordion:
319
- gr.Markdown("### 2. Review Results")
320
  with gr.Row():
321
- with gr.Column(scale=2):
322
- summary_output = gr.HTML(label="Summary")
 
 
323
  with gr.Column(scale=1):
324
  plot_output = gr.Image(label="Quality Distribution", show_label=True)
325
- with gr.Row():
326
- scored_file_output = gr.File(label="πŸ“„ Download Scored Dataset (.jsonl)", type="filepath")
327
- stats_file_output = gr.File(label="πŸ“Š Download Statistics (.json)", type="filepath")
328
 
329
- with gr.Accordion("☁️ Upload to Hub", open=False, visible=False) as upload_accordion:
330
- gr.Markdown("### 3. (Optional) Upload to Hugging Face Hub")
331
- hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", placeholder="hf_...", value=HF_TOKEN or "", info="Your HF token with 'write' permissions.")
332
- new_dataset_id = gr.Textbox(label="New Dataset Name", placeholder="my-scored-dataset", info="Will be created under your username.")
333
  private_checkbox = gr.Checkbox(label="Make dataset private", value=False)
334
  upload_btn = gr.Button("πŸ“€ Upload to Hub", variant="primary")
335
  upload_status = gr.HTML()
336
 
337
  # --- Event Handlers ---
338
  def clear_form():
339
- return "roneneldan/TinyStories", "train", "text", 1000, 32, None, None, None, None, gr.update(visible=False), gr.update(visible=False), ""
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  clear_btn.click(
342
  fn=clear_form,
343
  outputs=[
344
  dataset_search, dataset_split, text_column, sample_size, batch_size,
345
- summary_output, scored_file_output, stats_file_output, plot_output,
346
- results_accordion, upload_accordion, upload_status
347
  ]
348
  )
349
 
350
- process_btn.click(
351
- fn=process_dataset,
352
- inputs=[dataset_search, dataset_split, text_column, sample_size, batch_size],
353
- outputs=[summary_output, scored_file_output, stats_file_output, plot_output, results_accordion, upload_accordion]
354
- )
355
-
356
  upload_btn.click(
357
  fn=upload_to_hub,
358
  inputs=[scored_file_output, stats_file_output, plot_output, new_dataset_id, private_checkbox, hf_token_input],
@@ -363,20 +322,5 @@ def create_demo():
363
  # --- App Execution ---
364
  demo = create_demo()
365
 
366
- if os.environ.get("SPACE_ID"):
367
- def restart_space():
368
- if HF_TOKEN:
369
- try:
370
- print("Scheduler: Triggering space restart...")
371
- api = HfApi()
372
- api.restart_space(repo_id=os.environ["SPACE_ID"], token=HF_TOKEN)
373
- except Exception as e:
374
- print(f"Scheduler: Failed to restart space: {e}")
375
-
376
- scheduler = BackgroundScheduler()
377
- scheduler.add_job(restart_space, "interval", hours=6)
378
- scheduler.start()
379
- print("Background scheduler for periodic restarts is active.")
380
-
381
  if __name__ == "__main__":
382
  demo.queue().launch(debug=False, show_api=False)
 
14
  from multiprocessing import cpu_count
15
  from transformers import LlamaTokenizerFast
16
  import fasttext
17
+ from typing import Tuple, Dict, List, Generator
18
  import json
19
  import matplotlib.pyplot as plt
20
  import seaborn as sns
 
50
  background-color: #ff8534 !important;
51
  border-color: #ff8534 !important;
52
  }
 
 
 
 
 
 
 
53
  """
54
 
55
  # HTML templates
 
60
  </div>
61
  """
62
 
63
+ # Switched to Markdown for better theme compatibility (dark/light mode)
64
+ DESCRIPTION_MD = """
65
+ ### πŸ“‹ How it works:
66
+ 1. Choose a dataset from Hugging Face Hub.
67
+ 2. The Ultra-FineWeb classifier will score each text sample.
68
+ 3. View quality distribution and download the scored dataset.
69
+ 4. Optionally, upload the results to a new repository on your Hugging Face account.
70
+
71
+ **Note:** The first run will download the model (~347MB), which may take a moment.
 
 
 
72
  """
73
 
74
  # --- Helper Functions ---
75
  def escape(s: str) -> str:
76
  """Escape HTML for safe display"""
77
+ return str(s).replace("&", "&").replace("<", "<").replace(">", ">").replace('"', """).replace("\n", "<br/>")
 
78
 
79
  def fasttext_preprocess(content: str, tokenizer) -> str:
80
+ if not isinstance(content, str): return ""
 
 
81
  content = re.sub(r'\n{3,}', '\n\n', content).lower()
82
+ content = ''.join(c for c in unicodedata.normalize('NFKD', content) if unicodedata.category(c) != 'Mn')
 
83
  token_ids = tokenizer.encode(content, add_special_tokens=False)
84
+ content = ' '.join([tokenizer.decode([token_id]) for token_id in token_ids])
85
+ content = re.sub(r'\n', ' n ', content).replace('\r', '').replace('\t', ' ')
86
+ return re.sub(r' +', ' ', content).strip()
 
 
 
 
87
 
88
  def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
 
89
  pred_label, pred_prob = model.predict(norm_content)
90
  pred_label = pred_label[0]
91
  _score = min(pred_prob.tolist()[0], 1.0)
 
94
  return pred_label, _score
95
 
96
  def load_models():
 
97
  global MODEL_LOADED, fasttext_model, tokenizer
98
+ if MODEL_LOADED: return True
 
 
99
  try:
100
  model_dir = MODEL_CACHE_DIR / "Ultra-FineWeb-classifier"
101
  if not model_dir.exists():
 
102
  snapshot_download(repo_id="openbmb/Ultra-FineWeb-classifier", local_dir=str(model_dir), local_dir_use_symlinks=False)
 
103
  fasttext_path = model_dir / "classifiers" / "ultra_fineweb_en.bin"
104
  tokenizer_path = model_dir / "local_tokenizer"
 
 
 
 
 
105
  fasttext_model = fasttext.load_model(str(fasttext_path))
106
  tokenizer = LlamaTokenizerFast.from_pretrained(str(tokenizer_path) if tokenizer_path.exists() else "meta-llama/Llama-2-7b-hf")
 
107
  MODEL_LOADED = True
 
108
  return True
109
  except Exception as e:
 
110
  gr.Warning(f"Failed to load models: {e}")
111
  return False
112
 
113
  def create_quality_plot(scores: List[float], dataset_name: str) -> str:
 
114
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
115
  output_path = tmpfile.name
 
116
  plt.figure(figsize=(10, 6))
117
+ sns.histplot(scores, bins=50, kde=True, color='#6B7FD7', edgecolor='black')
118
+ mean_score, median_score = np.mean(scores), np.median(scores)
 
 
 
119
  plt.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.3f}')
120
  plt.axvline(median_score, color='orange', linestyle=':', linewidth=2, label=f'Median: {median_score:.3f}')
121
+ plt.xlabel('Quality Score'); plt.ylabel('Density')
122
+ plt.title(f'Quality Score Distribution - {dataset_name}', fontweight='bold')
123
+ plt.legend(); plt.grid(axis='y', alpha=0.3); plt.xlim(0, 1)
124
+ plt.tight_layout(); plt.savefig(output_path, dpi=150)
 
 
 
 
 
 
125
  plt.close()
 
126
  return output_path
127
 
128
+ # UPDATED: This function is now a generator to yield live log updates.
129
  def process_dataset(
130
  model_id: str,
131
  dataset_split: str,
 
133
  sample_size: int,
134
  batch_size: int,
135
  progress=gr.Progress(track_tqdm=True)
136
+ ) -> Generator:
137
+ log_text = ""
138
+ # Helper to update and yield log messages
139
+ def update_log(msg):
140
+ nonlocal log_text
141
+ timestamp = datetime.now().strftime('%H:%M:%S')
142
+ log_text += f"[{timestamp}] {msg}\n"
143
+ # Yield updates for the log, keep other components hidden/empty
144
+ return (log_text, None, None, None, None, gr.update(visible=False), gr.update(visible=False))
145
+
146
  try:
147
+ yield update_log("Starting process...")
148
+
149
+ yield update_log("Loading scoring models...")
150
  if not load_models():
151
+ raise gr.Error("Failed to load scoring models. Please check logs.")
152
+ yield update_log("Models loaded successfully.")
153
 
154
+ yield update_log(f"Loading dataset '{model_id}' split '{dataset_split}'...")
155
  dataset = load_dataset(model_id, split=dataset_split, streaming=False)
156
+ yield update_log("Dataset loaded.")
157
 
158
  if text_column not in dataset.column_names:
159
+ raise gr.Error(f"Column '{text_column}' not found. Available: {', '.join(dataset.column_names)}")
160
 
161
+ actual_samples = min(sample_size, len(dataset))
 
162
  dataset = dataset.select(range(actual_samples))
163
 
164
+ yield update_log(f"Starting to score {actual_samples:,} samples...")
165
  scores, scored_data = [], []
 
166
  for i in tqdm(range(0, actual_samples, batch_size), desc="Scoring batches"):
167
  batch = dataset[i:min(i + batch_size, actual_samples)]
168
  for text in batch[text_column]:
169
  norm_content = fasttext_preprocess(text, tokenizer)
170
+ label, score = fasttext_infer(norm_content, fasttext_model) if norm_content else ("__label__neg", 0.0)
171
  scores.append(score)
172
  scored_data.append({'text': text, 'quality_score': score, 'predicted_label': label})
173
 
174
+ yield update_log("Scoring complete. Generating results and plot...")
175
+ stats_dict = {'dataset_id': model_id, 'processed_samples': actual_samples, 'statistics': {'mean': float(np.mean(scores)), 'median': float(np.median(scores))}}
 
 
 
 
 
 
 
 
 
176
 
177
  plot_file = create_quality_plot(scores, model_id.split('/')[-1])
178
 
179
+ with tempfile.NamedTemporaryFile('w', suffix=".jsonl", delete=False, encoding='utf-8') as f:
180
+ output_file_path = f.name
181
+ for item in scored_data: f.write(json.dumps(item, ensure_ascii=False) + '\n')
 
182
 
183
+ with tempfile.NamedTemporaryFile('w', suffix=".json", delete=False, encoding='utf-8') as f:
184
+ stats_file_path = f.name
185
+ json.dump(stats_dict, f, indent=2)
186
 
187
+ summary_md = f"""
188
+ #### βœ… Scoring Completed!
189
+ - **Dataset:** `{model_id}`
190
+ - **Processed Samples:** `{actual_samples:,}`
191
+ - **Mean Score:** `{stats_dict['statistics']['mean']:.3f}`
192
+ - **Median Score:** `{stats_dict['statistics']['median']:.3f}`
 
 
193
  """
194
 
195
+ yield update_log("Process finished successfully!")
196
+
197
+ # Final return with all components visible and populated
198
+ yield (log_text, summary_md, output_file_path, stats_file_path, plot_file, gr.update(visible=True), gr.update(visible=True))
199
+
200
  except Exception as e:
201
+ error_log = update_log(f"ERROR: {e}")[0]
202
+ error_summary_md = f"### ❌ Error\n```\n{escape(str(e))}\n```"
203
+ yield (error_log, error_summary_md, None, None, None, gr.update(visible=True), gr.update(visible=False))
 
 
 
204
 
205
+ # This function remains the same
206
  def upload_to_hub(
207
  scored_file: str, stats_file: str, plot_file: str, new_dataset_id: str,
208
  private: bool, hf_token: str, progress=gr.Progress(track_tqdm=True)
209
  ) -> str:
 
210
  if not hf_token: return '❌ <span style="color: red;">Please provide your Hugging Face token.</span>'
211
  if not all([scored_file, new_dataset_id]): return '❌ <span style="color: red;">Missing scored file or new dataset ID.</span>'
212
 
 
219
  progress(0.2, desc=f"Creating repo: {repo_id}")
220
  repo_url = create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=private, token=hf_token).repo_url
221
 
222
+ progress(0.4, desc="Uploading files...")
223
  upload_file(path_or_fileobj=scored_file, path_in_repo="data/scored_dataset.jsonl", repo_id=repo_id, repo_type="dataset", token=hf_token)
 
 
224
  if stats_file and os.path.exists(stats_file):
225
  upload_file(path_or_fileobj=stats_file, path_in_repo="statistics.json", repo_id=repo_id, repo_type="dataset", token=hf_token)
226
  if plot_file and os.path.exists(plot_file):
227
  upload_file(path_or_fileobj=plot_file, path_in_repo="quality_distribution.png", repo_id=repo_id, repo_type="dataset", token=hf_token)
228
 
229
+ readme_content = dedent(f"""---
230
+ license: apache-2.0
231
+ ---
232
+ # Quality-Scored Dataset: {repo_id.split('/')[-1]}
233
+ This dataset was scored for quality using the [Dataset Quality Scorer Space](https://huggingface.co/spaces/ggml-org/dataset-quality-scorer).
234
+ ![Quality Distribution](quality_distribution.png)
235
+ ## Usage
236
+ ```python
237
+ from datasets import load_dataset
238
+ dataset = load_dataset("{repo_id}", split="train")
239
+ ```""").strip()
 
 
240
 
241
  upload_file(path_or_fileobj=readme_content.encode(), path_in_repo="README.md", repo_id=repo_id, repo_type="dataset", token=hf_token)
242
  progress(1.0, "Done!")
243
  return f'βœ… <span style="color: green;">Successfully uploaded to <a href="{repo_url}" target="_blank">{repo_id}</a></span>'
244
 
245
  except Exception as e:
246
+ return f'❌ <span style="color: red;">Upload failed: {escape(str(e))}</span>'
247
+
248
 
249
  def create_demo():
250
  with gr.Blocks(css=css, title="Dataset Quality Scorer") as demo:
251
  gr.HTML(TITLE)
252
+ gr.Markdown(DESCRIPTION_MD)
253
 
 
254
  with gr.Row():
255
  with gr.Column(scale=3):
256
+ gr.Markdown("### 1. Configure Dataset")
257
  dataset_search = HuggingfaceHubSearch(label="Hub Dataset ID", search_type="dataset", value="roneneldan/TinyStories")
258
+ text_column = gr.Textbox(label="Text Column Name", value="text")
259
  with gr.Column(scale=2):
260
+ gr.Markdown("### 2. Configure Scoring")
261
  dataset_split = gr.Dropdown(["train", "validation", "test"], label="Split", value="train")
262
  with gr.Row():
263
+ sample_size = gr.Number(label="Sample Size", value=1000, minimum=100, step=100)
264
+ batch_size = gr.Number(label="Batch Size", value=32, minimum=1, step=1)
265
 
266
+ live_log = gr.Textbox(label="Live Log", interactive=False, lines=8, max_lines=20)
267
+
268
  with gr.Row():
269
  clear_btn = gr.Button("Clear", variant="secondary")
270
  process_btn = gr.Button("πŸš€ Start Scoring", variant="primary", size="lg")
271
 
272
  # --- Results and Upload Sections (Initially Hidden) ---
273
+ with gr.Group(visible=False) as results_group:
274
+ gr.Markdown("--- \n ### 3. Review Results")
275
  with gr.Row():
276
+ with gr.Column(scale=1):
277
+ summary_output = gr.Markdown(label="Summary")
278
+ scored_file_output = gr.File(label="πŸ“„ Download Scored Dataset (.jsonl)", type="filepath")
279
+ stats_file_output = gr.File(label="πŸ“Š Download Statistics (.json)", type="filepath")
280
  with gr.Column(scale=1):
281
  plot_output = gr.Image(label="Quality Distribution", show_label=True)
 
 
 
282
 
283
+ with gr.Group(visible=False) as upload_group:
284
+ gr.Markdown("--- \n ### 4. (Optional) Upload to Hugging Face Hub")
285
+ hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", placeholder="hf_...", value=HF_TOKEN or "")
286
+ new_dataset_id = gr.Textbox(label="New Dataset Name", placeholder="my-scored-dataset")
287
  private_checkbox = gr.Checkbox(label="Make dataset private", value=False)
288
  upload_btn = gr.Button("πŸ“€ Upload to Hub", variant="primary")
289
  upload_status = gr.HTML()
290
 
291
  # --- Event Handlers ---
292
  def clear_form():
293
+ return "roneneldan/TinyStories", "train", "text", 1000, 32, "", None, None, None, None, gr.update(visible=False), gr.update(visible=False), ""
294
+
295
+ outputs_list = [
296
+ live_log, summary_output, scored_file_output, stats_file_output, plot_output,
297
+ results_group, upload_group
298
+ ]
299
+
300
+ process_btn.click(
301
+ fn=process_dataset,
302
+ inputs=[dataset_search, dataset_split, text_column, sample_size, batch_size],
303
+ outputs=outputs_list
304
+ )
305
 
306
  clear_btn.click(
307
  fn=clear_form,
308
  outputs=[
309
  dataset_search, dataset_split, text_column, sample_size, batch_size,
310
+ live_log, summary_output, scored_file_output, stats_file_output, plot_output,
311
+ results_group, upload_group, upload_status
312
  ]
313
  )
314
 
 
 
 
 
 
 
315
  upload_btn.click(
316
  fn=upload_to_hub,
317
  inputs=[scored_file_output, stats_file_output, plot_output, new_dataset_id, private_checkbox, hf_token_input],
 
322
  # --- App Execution ---
323
  demo = create_demo()
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  if __name__ == "__main__":
326
  demo.queue().launch(debug=False, show_api=False)