Gil-Simas commited on
Commit
faa98b7
·
2 Parent(s): 5f71887 9f8acd0

Merge branch 'main' of https://huggingface.co/spaces/SEA-AI/user-friendly-metrics into main

Browse files
Files changed (1) hide show
  1. user-friendly-metrics.py +25 -20
user-friendly-metrics.py CHANGED
@@ -165,6 +165,7 @@ class UserFriendlyMetrics(evaluate.Metric):
165
  wandb_project="user_friendly_metrics",
166
  log_plots: bool = True,
167
  debug: bool = False,
 
168
  ):
169
  """
170
  Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
@@ -190,9 +191,12 @@ class UserFriendlyMetrics(evaluate.Metric):
190
  self.wandb_run(result = result,
191
  wandb_run_name = wandb_run_name,
192
  wandb_project = wandb_project,
193
- debug = debug)
 
 
 
194
 
195
- def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True):
196
 
197
  run = wandb.init(
198
  project = wandb_project,
@@ -258,24 +262,25 @@ class UserFriendlyMetrics(evaluate.Metric):
258
  }
259
  )
260
 
261
- if "per_sequence" in result:
262
- sorted_sequences = sorted(
263
- result["per_sequence"].items(),
264
- key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("f1", 0),
265
- reverse=True, # Set to True for descending order
266
- )
267
-
268
- for sequence_name, sequence_data in sorted_sequences:
269
- for metric, value in sequence_data["all"].items():
270
- log_key = (
271
- f"{wandb_section}/per_sequence/{sequence_name}/{metric}"
272
- if wandb_section
273
- else f"per_sequence/{sequence_name}/{metric}"
274
- )
275
- run.log({log_key: value})
276
- if debug:
277
- print(f" {log_key} = {value}")
278
- print("----------------------------------------------------")
 
279
 
280
  if debug:
281
  print("\nDebug Mode: Logging Summary and History")
 
165
  wandb_project="user_friendly_metrics",
166
  log_plots: bool = True,
167
  debug: bool = False,
168
+ log_per_sequence = False
169
  ):
170
  """
171
  Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
 
191
  self.wandb_run(result = result,
192
  wandb_run_name = wandb_run_name,
193
  wandb_project = wandb_project,
194
+ debug = debug,
195
+ wandb_section = wandb_section,
196
+ log_plots = log_plots,
197
+ log_per_sequence = log_per_sequence)
198
 
199
+ def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True, log_per_sequence = False):
200
 
201
  run = wandb.init(
202
  project = wandb_project,
 
262
  }
263
  )
264
 
265
+ if log_per_sequence:
266
+ if "per_sequence" in result:
267
+ sorted_sequences = sorted(
268
+ result["per_sequence"].items(),
269
+ key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("recall", 0),
270
+ reverse=True, # Set to True for descending order
271
+ )
272
+
273
+ for sequence_name, sequence_data in sorted_sequences:
274
+ for metric, value in sequence_data["all"].items():
275
+ log_key = (
276
+ f"{wandb_section}/per_sequence/{sequence_name}/{metric}"
277
+ if wandb_section
278
+ else f"per_sequence/{sequence_name}/{metric}"
279
+ )
280
+ run.log({log_key: value})
281
+ if debug:
282
+ print(f" {log_key} = {value}")
283
+ print("----------------------------------------------------")
284
 
285
  if debug:
286
  print("\nDebug Mode: Logging Summary and History")