|
import threading |
|
import time |
|
|
|
import huggingface_hub |
|
from gradio_client import Client, handle_file |
|
|
|
from trackio.media import TrackioImage |
|
from trackio.sqlite_storage import SQLiteStorage |
|
from trackio.typehints import LogEntry, UploadEntry |
|
from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name |
|
|
|
BATCH_SEND_INTERVAL = 0.5 |
|
|
|
|
|
class Run: |
|
def __init__( |
|
self, |
|
url: str, |
|
project: str, |
|
client: Client | None, |
|
name: str | None = None, |
|
config: dict | None = None, |
|
space_id: str | None = None, |
|
): |
|
self.url = url |
|
self.project = project |
|
self._client_lock = threading.Lock() |
|
self._client_thread = None |
|
self._client = client |
|
self._space_id = space_id |
|
self.name = name or generate_readable_name( |
|
SQLiteStorage.get_runs(project), space_id |
|
) |
|
self.config = config or {} |
|
self._queued_logs: list[LogEntry] = [] |
|
self._queued_uploads: list[UploadEntry] = [] |
|
self._stop_flag = threading.Event() |
|
|
|
self._client_thread = threading.Thread(target=self._init_client_background) |
|
self._client_thread.daemon = True |
|
self._client_thread.start() |
|
|
|
def _batch_sender(self): |
|
"""Send batched logs every BATCH_SEND_INTERVAL.""" |
|
while not self._stop_flag.is_set() or len(self._queued_logs) > 0: |
|
|
|
|
|
if not self._stop_flag.is_set(): |
|
time.sleep(BATCH_SEND_INTERVAL) |
|
|
|
with self._client_lock: |
|
if self._queued_logs and self._client is not None: |
|
logs_to_send = self._queued_logs.copy() |
|
self._queued_logs.clear() |
|
self._client.predict( |
|
api_name="/bulk_log", |
|
logs=logs_to_send, |
|
hf_token=huggingface_hub.utils.get_token(), |
|
) |
|
if self._queued_uploads and self._client is not None: |
|
uploads_to_send = self._queued_uploads.copy() |
|
self._queued_uploads.clear() |
|
self._client.predict( |
|
api_name="/bulk_upload_media", |
|
uploads=uploads_to_send, |
|
hf_token=huggingface_hub.utils.get_token(), |
|
) |
|
|
|
def _init_client_background(self): |
|
if self._client is None: |
|
fib = fibo() |
|
for sleep_coefficient in fib: |
|
try: |
|
client = Client(self.url, verbose=False) |
|
|
|
with self._client_lock: |
|
self._client = client |
|
break |
|
except Exception: |
|
pass |
|
if sleep_coefficient is not None: |
|
time.sleep(0.1 * sleep_coefficient) |
|
|
|
self._batch_sender() |
|
|
|
def _process_media(self, metrics, step: int | None) -> dict: |
|
""" |
|
Serialize media in metrics and upload to space if needed. |
|
""" |
|
serializable_metrics = {} |
|
if not step: |
|
step = 0 |
|
for key, value in metrics.items(): |
|
if isinstance(value, TrackioImage): |
|
value._save(self.project, self.name, step) |
|
serializable_metrics[key] = value._to_dict() |
|
if self._space_id: |
|
|
|
upload_entry: UploadEntry = { |
|
"project": self.project, |
|
"run": self.name, |
|
"step": step, |
|
"uploaded_file": handle_file(value._get_absolute_file_path()), |
|
} |
|
with self._client_lock: |
|
self._queued_uploads.append(upload_entry) |
|
else: |
|
serializable_metrics[key] = value |
|
return serializable_metrics |
|
|
|
def log(self, metrics: dict, step: int | None = None): |
|
for k in metrics.keys(): |
|
if k in RESERVED_KEYS or k.startswith("__"): |
|
raise ValueError( |
|
f"Please do not use this reserved key as a metric: {k}" |
|
) |
|
|
|
metrics = self._process_media(metrics, step) |
|
log_entry: LogEntry = { |
|
"project": self.project, |
|
"run": self.name, |
|
"metrics": metrics, |
|
"step": step, |
|
} |
|
|
|
with self._client_lock: |
|
self._queued_logs.append(log_entry) |
|
|
|
def finish(self): |
|
"""Cleanup when run is finished.""" |
|
self._stop_flag.set() |
|
|
|
|
|
time.sleep(2 * BATCH_SEND_INTERVAL) |
|
|
|
if self._client_thread is not None: |
|
print( |
|
f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)" |
|
) |
|
self._client_thread.join() |
|
|