voxmenthe commited on
Commit
3357f2e
·
1 Parent(s): be92e89

try again on gradio progress update streaming

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. evaluation.py +12 -3
app.py CHANGED
@@ -108,7 +108,7 @@ def run_full_evaluation_gradio():
108
  for update in evaluate(model, test_dataloader_full, device):
109
  if isinstance(update, dict):
110
  # This is the final results dictionary
111
- results_str = "--- Full Evaluation Results ---\n"
112
  for key, value in update.items():
113
  if isinstance(value, float):
114
  results_str += f"{key.capitalize()}: {value:.4f}\n"
@@ -120,7 +120,7 @@ def run_full_evaluation_gradio():
120
  break # Stop after getting the results dict
121
  else:
122
  # This is a progress string
123
- yield update
124
 
125
  # Ensure the final formatted results string is yielded if not already (e.g., if loop broke early)
126
  # However, the logic above should yield it before breaking.
 
108
  for update in evaluate(model, test_dataloader_full, device):
109
  if isinstance(update, dict):
110
  # This is the final results dictionary
111
+ results_str = "\n--- Full Evaluation Results ---\n" # Start with a newline
112
  for key, value in update.items():
113
  if isinstance(value, float):
114
  results_str += f"{key.capitalize()}: {value:.4f}\n"
 
120
  break # Stop after getting the results dict
121
  else:
122
  # This is a progress string
123
+ yield str(update) + "\n" # Append newline to each progress string
124
 
125
  # Ensure the final formatted results string is yielded if not already (e.g., if loop broke early)
126
  # However, the logic above should yield it before breaking.
evaluation.py CHANGED
@@ -13,8 +13,6 @@ def evaluate(model, dataloader, device):
13
  num_batches = len(dataloader)
14
  processed_batches = 0
15
 
16
- yield "Starting evaluation..."
17
-
18
  with torch.no_grad():
19
  for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
20
  processed_batches += 1
@@ -54,8 +52,19 @@ def evaluate(model, dataloader, device):
54
  all_preds.extend(preds.cpu().numpy())
55
 
56
  all_labels.extend(labels.cpu().numpy())
 
 
 
 
 
 
 
 
 
 
57
  # Yield progress update
58
- if processed_batches % (num_batches // 20) == 0 or processed_batches == num_batches: # Update roughly 20 times + final
 
59
  yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
60
 
61
  avg_loss = total_loss / num_batches
 
13
  num_batches = len(dataloader)
14
  processed_batches = 0
15
 
 
 
16
  with torch.no_grad():
17
  for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
18
  processed_batches += 1
 
52
  all_preds.extend(preds.cpu().numpy())
53
 
54
  all_labels.extend(labels.cpu().numpy())
55
+
56
+ # Populate probabilities for AUC calculation
57
+ if logits.shape[1] > 1:
58
+ # Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC
59
+ probs_for_auc = torch.softmax(logits, dim=1)[:, 1]
60
+ else:
61
+ # Binary classification with a single logit output
62
+ probs_for_auc = torch.sigmoid(logits).squeeze()
63
+ all_probs_for_auc.extend(probs_for_auc.cpu().numpy())
64
+
65
  # Yield progress update
66
+ progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero
67
+ if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final
68
  yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
69
 
70
  avg_loss = total_loss / num_batches