Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
try again on gradio progress update streaming
Browse files- app.py +2 -2
- evaluation.py +12 -3
app.py
CHANGED
@@ -108,7 +108,7 @@ def run_full_evaluation_gradio():
|
|
108 |
for update in evaluate(model, test_dataloader_full, device):
|
109 |
if isinstance(update, dict):
|
110 |
# This is the final results dictionary
|
111 |
-
results_str = "--- Full Evaluation Results ---\n"
|
112 |
for key, value in update.items():
|
113 |
if isinstance(value, float):
|
114 |
results_str += f"{key.capitalize()}: {value:.4f}\n"
|
@@ -120,7 +120,7 @@ def run_full_evaluation_gradio():
|
|
120 |
break # Stop after getting the results dict
|
121 |
else:
|
122 |
# This is a progress string
|
123 |
-
yield update
|
124 |
|
125 |
# Ensure the final formatted results string is yielded if not already (e.g., if loop broke early)
|
126 |
# However, the logic above should yield it before breaking.
|
|
|
108 |
for update in evaluate(model, test_dataloader_full, device):
|
109 |
if isinstance(update, dict):
|
110 |
# This is the final results dictionary
|
111 |
+
results_str = "\n--- Full Evaluation Results ---\n" # Start with a newline
|
112 |
for key, value in update.items():
|
113 |
if isinstance(value, float):
|
114 |
results_str += f"{key.capitalize()}: {value:.4f}\n"
|
|
|
120 |
break # Stop after getting the results dict
|
121 |
else:
|
122 |
# This is a progress string
|
123 |
+
yield str(update) + "\n" # Append newline to each progress string
|
124 |
|
125 |
# Ensure the final formatted results string is yielded if not already (e.g., if loop broke early)
|
126 |
# However, the logic above should yield it before breaking.
|
evaluation.py
CHANGED
@@ -13,8 +13,6 @@ def evaluate(model, dataloader, device):
|
|
13 |
num_batches = len(dataloader)
|
14 |
processed_batches = 0
|
15 |
|
16 |
-
yield "Starting evaluation..."
|
17 |
-
|
18 |
with torch.no_grad():
|
19 |
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
|
20 |
processed_batches += 1
|
@@ -54,8 +52,19 @@ def evaluate(model, dataloader, device):
|
|
54 |
all_preds.extend(preds.cpu().numpy())
|
55 |
|
56 |
all_labels.extend(labels.cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# Yield progress update
|
58 |
-
|
|
|
59 |
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
|
60 |
|
61 |
avg_loss = total_loss / num_batches
|
|
|
13 |
num_batches = len(dataloader)
|
14 |
processed_batches = 0
|
15 |
|
|
|
|
|
16 |
with torch.no_grad():
|
17 |
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
|
18 |
processed_batches += 1
|
|
|
52 |
all_preds.extend(preds.cpu().numpy())
|
53 |
|
54 |
all_labels.extend(labels.cpu().numpy())
|
55 |
+
|
56 |
+
# Populate probabilities for AUC calculation
|
57 |
+
if logits.shape[1] > 1:
|
58 |
+
# Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC
|
59 |
+
probs_for_auc = torch.softmax(logits, dim=1)[:, 1]
|
60 |
+
else:
|
61 |
+
# Binary classification with a single logit output
|
62 |
+
probs_for_auc = torch.sigmoid(logits).squeeze()
|
63 |
+
all_probs_for_auc.extend(probs_for_auc.cpu().numpy())
|
64 |
+
|
65 |
# Yield progress update
|
66 |
+
progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero
|
67 |
+
if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final
|
68 |
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
|
69 |
|
70 |
avg_loss = total_loss / num_batches
|