Spaces:
Runtime error
Runtime error
rusticluftig
commited on
Commit
·
8e90a67
1
Parent(s):
877e6fb
Fix how "stale" is computed when there are multiple competitions
Browse files
utils.py
CHANGED
@@ -6,7 +6,6 @@ import math
|
|
6 |
import os
|
7 |
import time
|
8 |
import traceback
|
9 |
-
from collections import defaultdict
|
10 |
from dataclasses import dataclass
|
11 |
from typing import Any, Dict, List, Optional, Tuple
|
12 |
|
@@ -18,8 +17,6 @@ from bittensor.extrinsics.serving import get_metadata
|
|
18 |
from dotenv import load_dotenv
|
19 |
from wandb.apis.public.history import HistoryScan
|
20 |
|
21 |
-
import competitions
|
22 |
-
|
23 |
NETUID = 37
|
24 |
DELAY_SECS = 3
|
25 |
RETRIES = 3
|
@@ -181,8 +178,15 @@ def get_scores(
|
|
181 |
uids (List[int]): List of UIDs to get scores for.
|
182 |
wandb_runs (List): List of validator runs from Wandb. Requires the runs are provided in descending order.
|
183 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
result = {}
|
185 |
previous_timestamp = None
|
|
|
186 |
# Iterate through the runs until we've processed all the uids.
|
187 |
for i, run in enumerate(wandb_runs):
|
188 |
if not "original_format_json" in run.summary:
|
@@ -196,21 +200,23 @@ def get_scores(
|
|
196 |
), f"Timestamps are not in descending order: {timestamp} >= {previous_timestamp}"
|
197 |
previous_timestamp = timestamp
|
198 |
|
|
|
199 |
for uid in uids:
|
200 |
if uid in result:
|
201 |
continue
|
202 |
if str(uid) in all_uid_data:
|
203 |
uid_data = all_uid_data[str(uid)]
|
204 |
-
# Only the most recent run is fresh.
|
205 |
-
is_fresh =
|
206 |
result[uid] = {
|
207 |
-
"avg_loss": uid_data.get("average_loss", None),
|
208 |
"win_rate": uid_data.get("win_rate", None),
|
209 |
"win_total": uid_data.get("win_total", None),
|
210 |
"weight": uid_data.get("weight", None),
|
211 |
"competition_id": uid_data.get("competition_id", None),
|
212 |
"fresh": is_fresh,
|
213 |
}
|
|
|
214 |
if len(result) == len(uids):
|
215 |
break
|
216 |
return result
|
@@ -266,7 +272,7 @@ def get_losses_over_time(wandb_runs: List, competition_id: int) -> pd.DataFrame:
|
|
266 |
continue
|
267 |
|
268 |
if loss < best_loss:
|
269 |
-
best_loss =
|
270 |
should_add_datapoint = True
|
271 |
# Now that we've processed the run's most recent steps, check if we should add a datapoint.
|
272 |
if should_add_datapoint:
|
|
|
6 |
import os
|
7 |
import time
|
8 |
import traceback
|
|
|
9 |
from dataclasses import dataclass
|
10 |
from typing import Any, Dict, List, Optional, Tuple
|
11 |
|
|
|
17 |
from dotenv import load_dotenv
|
18 |
from wandb.apis.public.history import HistoryScan
|
19 |
|
|
|
|
|
20 |
NETUID = 37
|
21 |
DELAY_SECS = 3
|
22 |
RETRIES = 3
|
|
|
178 |
uids (List[int]): List of UIDs to get scores for.
|
179 |
wandb_runs (List): List of validator runs from Wandb. Requires the runs are provided in descending order.
|
180 |
"""
|
181 |
+
def _maybe_convert_loss(loss: float, comp_id: int) -> float:
|
182 |
+
"""Converts loss to score for competitions that require it."""
|
183 |
+
if comp_id == 2:
|
184 |
+
return 1 - loss if loss else None
|
185 |
+
return loss
|
186 |
+
|
187 |
result = {}
|
188 |
previous_timestamp = None
|
189 |
+
seen_competitions = set()
|
190 |
# Iterate through the runs until we've processed all the uids.
|
191 |
for i, run in enumerate(wandb_runs):
|
192 |
if not "original_format_json" in run.summary:
|
|
|
200 |
), f"Timestamps are not in descending order: {timestamp} >= {previous_timestamp}"
|
201 |
previous_timestamp = timestamp
|
202 |
|
203 |
+
comp_id = data.get("competition_id", None)
|
204 |
for uid in uids:
|
205 |
if uid in result:
|
206 |
continue
|
207 |
if str(uid) in all_uid_data:
|
208 |
uid_data = all_uid_data[str(uid)]
|
209 |
+
# Only the most recent run per competition is fresh.
|
210 |
+
is_fresh = comp_id not in seen_competitions
|
211 |
result[uid] = {
|
212 |
+
"avg_loss": _maybe_convert_loss(uid_data.get("average_loss", None), comp_id),
|
213 |
"win_rate": uid_data.get("win_rate", None),
|
214 |
"win_total": uid_data.get("win_total", None),
|
215 |
"weight": uid_data.get("weight", None),
|
216 |
"competition_id": uid_data.get("competition_id", None),
|
217 |
"fresh": is_fresh,
|
218 |
}
|
219 |
+
seen_competitions.add(comp_id)
|
220 |
if len(result) == len(uids):
|
221 |
break
|
222 |
return result
|
|
|
272 |
continue
|
273 |
|
274 |
if loss < best_loss:
|
275 |
+
best_loss = loss
|
276 |
should_add_datapoint = True
|
277 |
# Now that we've processed the run's most recent steps, check if we should add a datapoint.
|
278 |
if should_add_datapoint:
|