WalisonCruz commited on
Commit
8dfb8c1
·
verified ·
1 Parent(s): 32b2401

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. CHANGELOG.md +98 -0
  3. __init__.py +527 -0
  4. __pycache__/__init__.cpython-312.pyc +0 -0
  5. __pycache__/cli.cpython-312.pyc +0 -0
  6. __pycache__/commit_scheduler.cpython-312.pyc +0 -0
  7. __pycache__/context_vars.cpython-312.pyc +0 -0
  8. __pycache__/deploy.cpython-312.pyc +0 -0
  9. __pycache__/dummy_commit_scheduler.cpython-312.pyc +0 -0
  10. __pycache__/histogram.cpython-312.pyc +0 -0
  11. __pycache__/imports.cpython-312.pyc +0 -0
  12. __pycache__/run.cpython-312.pyc +0 -0
  13. __pycache__/sqlite_storage.cpython-312.pyc +0 -0
  14. __pycache__/table.cpython-312.pyc +0 -0
  15. __pycache__/typehints.cpython-312.pyc +0 -0
  16. __pycache__/utils.cpython-312.pyc +0 -0
  17. assets/badge.png +0 -0
  18. assets/trackio_logo_dark.png +0 -0
  19. assets/trackio_logo_light.png +0 -0
  20. assets/trackio_logo_old.png +3 -0
  21. assets/trackio_logo_type_dark.png +0 -0
  22. assets/trackio_logo_type_dark_transparent.png +0 -0
  23. assets/trackio_logo_type_light.png +0 -0
  24. assets/trackio_logo_type_light_transparent.png +0 -0
  25. cli.py +93 -0
  26. commit_scheduler.py +391 -0
  27. context_vars.py +21 -0
  28. deploy.py +346 -0
  29. dummy_commit_scheduler.py +12 -0
  30. histogram.py +71 -0
  31. imports.py +304 -0
  32. media/__init__.py +34 -0
  33. media/__pycache__/__init__.cpython-312.pyc +0 -0
  34. media/__pycache__/audio.cpython-312.pyc +0 -0
  35. media/__pycache__/image.cpython-312.pyc +0 -0
  36. media/__pycache__/media.cpython-312.pyc +0 -0
  37. media/__pycache__/utils.cpython-312.pyc +0 -0
  38. media/__pycache__/video.cpython-312.pyc +0 -0
  39. media/audio.py +171 -0
  40. media/image.py +88 -0
  41. media/media.py +83 -0
  42. media/utils.py +63 -0
  43. media/video.py +251 -0
  44. package.json +6 -0
  45. py.typed +0 -0
  46. run.py +241 -0
  47. sqlite_storage.py +677 -0
  48. table.py +175 -0
  49. typehints.py +19 -0
  50. ui/__init__.py +10 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
CHANGELOG.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # trackio
2
+
3
+ ## 0.13.0
4
+
5
+ ### Features
6
+
7
+ - [#358](https://github.com/gradio-app/trackio/pull/358) [`073715d`](https://github.com/gradio-app/trackio/commit/073715d1caf8282f68890117f09c3ac301205312) - Improvements to `trackio.sync()`. Thanks @abidlabs!
8
+
9
+ ## 0.12.0
10
+
11
+ ### Features
12
+
13
+ - [#357](https://github.com/gradio-app/trackio/pull/357) [`02ba815`](https://github.com/gradio-app/trackio/commit/02ba815358060f1966052de051a5bdb09702920e) - Redesign media and tables to show up on separate page. Thanks @abidlabs!
14
+ - [#359](https://github.com/gradio-app/trackio/pull/359) [`08fe9c9`](https://github.com/gradio-app/trackio/commit/08fe9c9ddd7fe99ee811555fdfb62df9ab88e939) - docs: Improve docstrings. Thanks @qgallouedec!
15
+
16
+ ## 0.11.0
17
+
18
+ ### Features
19
+
20
+ - [#355](https://github.com/gradio-app/trackio/pull/355) [`ea51f49`](https://github.com/gradio-app/trackio/commit/ea51f4954922f21be76ef828700420fe9a912c4b) - Color code run checkboxes and match with plot lines. Thanks @abidlabs!
21
+ - [#353](https://github.com/gradio-app/trackio/pull/353) [`8abe691`](https://github.com/gradio-app/trackio/commit/8abe6919aeefe21fc7a23af814883efbb037c21f) - Remove show_api from demo.launch. Thanks @sergiopaniego!
22
+ - [#351](https://github.com/gradio-app/trackio/pull/351) [`8a8957e`](https://github.com/gradio-app/trackio/commit/8a8957e530dd7908d1fef7f2df030303f808101f) - Add `trackio.save()`. Thanks @abidlabs!
23
+
24
+ ## 0.10.0
25
+
26
+ ### Features
27
+
28
+ - [#305](https://github.com/gradio-app/trackio/pull/305) [`e64883a`](https://github.com/gradio-app/trackio/commit/e64883a51f7b8b93f7d48b8afe55acdb62238b71) - bump to gradio 6.0, make `trackio` compatible, and fix related issues. Thanks @abidlabs!
29
+
30
+ ## 0.9.1
31
+
32
+ ### Features
33
+
34
+ - [#344](https://github.com/gradio-app/trackio/pull/344) [`7e01024`](https://github.com/gradio-app/trackio/commit/7e010241d9a34794e0ce0dc19c1a6f0cf94ba856) - Avoid redundant calls to /whoami-v2. Thanks @Wauplin!
35
+
36
+ ## 0.9.0
37
+
38
+ ### Features
39
+
40
+ - [#343](https://github.com/gradio-app/trackio/pull/343) [`51bea30`](https://github.com/gradio-app/trackio/commit/51bea30f2877adff8e6497466d3a799400a0a049) - Sync offline projects to Hugging Face spaces. Thanks @candemircan!
41
+ - [#341](https://github.com/gradio-app/trackio/pull/341) [`4fd841f`](https://github.com/gradio-app/trackio/commit/4fd841fa190e15071b02f6fba7683ef4f393a654) - Adds a basic UI test to `trackio`. Thanks @abidlabs!
42
+ - [#339](https://github.com/gradio-app/trackio/pull/339) [`011d91b`](https://github.com/gradio-app/trackio/commit/011d91bb6ae266516fd250a349285670a8049d05) - Allow customzing the trackio color palette. Thanks @abidlabs!
43
+
44
+ ## 0.8.1
45
+
46
+ ### Features
47
+
48
+ - [#336](https://github.com/gradio-app/trackio/pull/336) [`5f9f51d`](https://github.com/gradio-app/trackio/commit/5f9f51dac8677f240d7c42c3e3b2660a22aee138) - Support a list of `Trackio.Image` in a `trackio.Table` cell. Thanks @abidlabs!
49
+
50
+ ## 0.8.0
51
+
52
+ ### Features
53
+
54
+ - [#331](https://github.com/gradio-app/trackio/pull/331) [`2c02d0f`](https://github.com/gradio-app/trackio/commit/2c02d0fd0a5824160528782402bb0dd4083396d5) - Truncate table string values that are greater than 250 characters (configuirable via env variable). Thanks @abidlabs!
55
+ - [#324](https://github.com/gradio-app/trackio/pull/324) [`50b2122`](https://github.com/gradio-app/trackio/commit/50b2122e7965ac82a72e6cb3b7d048bc10a2a6b1) - Add log y-axis functionality to UI. Thanks @abidlabs!
56
+ - [#326](https://github.com/gradio-app/trackio/pull/326) [`61dc1f4`](https://github.com/gradio-app/trackio/commit/61dc1f40af2f545f8e70395ddf0dbb8aee6b60d5) - Fix: improve table rendering for metrics in Trackio Dashboard. Thanks @vigneshwaran!
57
+ - [#328](https://github.com/gradio-app/trackio/pull/328) [`6857cbb`](https://github.com/gradio-app/trackio/commit/6857cbbe557a59a4642f210ec42566d108294e63) - Support trackio.Table with trackio.Image columns. Thanks @abidlabs!
58
+
59
+ ## 0.7.0
60
+
61
+ ### Features
62
+
63
+ - [#277](https://github.com/gradio-app/trackio/pull/277) [`db35601`](https://github.com/gradio-app/trackio/commit/db35601b9c023423c4654c9909b8ab73e58737de) - fix: make grouped runs view reflect live updates. Thanks @Saba9!
64
+ - [#320](https://github.com/gradio-app/trackio/pull/320) [`24ae739`](https://github.com/gradio-app/trackio/commit/24ae73969b09fb3126acd2f91647cdfbf8cf72a1) - Add additional query parms for xmin, xmax, and smoothing. Thanks @abidlabs!
65
+ - [#270](https://github.com/gradio-app/trackio/pull/270) [`cd1dfc3`](https://github.com/gradio-app/trackio/commit/cd1dfc3dc641b4499ac6d4a1b066fa8e2b52c57b) - feature: add support for logging audio. Thanks @Saba9!
66
+
67
+ ## 0.6.0
68
+
69
+ ### Features
70
+
71
+ - [#309](https://github.com/gradio-app/trackio/pull/309) [`1df2353`](https://github.com/gradio-app/trackio/commit/1df23534d6c01938c8db9c0f584ffa23e8d6021d) - Add histogram support with wandb-compatible API. Thanks @abidlabs!
72
+ - [#315](https://github.com/gradio-app/trackio/pull/315) [`76ba060`](https://github.com/gradio-app/trackio/commit/76ba06055dc43ca8f03b79f3e72d761949bd19a8) - Add guards to avoid silent fails. Thanks @Xmaster6y!
73
+ - [#313](https://github.com/gradio-app/trackio/pull/313) [`a606b3e`](https://github.com/gradio-app/trackio/commit/a606b3e1c5edf3d4cf9f31bd50605226a5a1c5d0) - No longer prevent certain keys from being used. Instead, dunderify them to prevent collisions with internal usage. Thanks @abidlabs!
74
+ - [#317](https://github.com/gradio-app/trackio/pull/317) [`27370a5`](https://github.com/gradio-app/trackio/commit/27370a595d0dbdf7eebbe7159d2ba778f039da44) - quick fixes for trackio.histogram. Thanks @abidlabs!
75
+ - [#312](https://github.com/gradio-app/trackio/pull/312) [`aa0f3bf`](https://github.com/gradio-app/trackio/commit/aa0f3bf372e7a0dd592a38af699c998363830eeb) - Fix video logging by adding TRACKIO_DIR to allowed_paths. Thanks @abidlabs!
76
+
77
+ ## 0.5.3
78
+
79
+ ### Features
80
+
81
+ - [#300](https://github.com/gradio-app/trackio/pull/300) [`5e4cacf`](https://github.com/gradio-app/trackio/commit/5e4cacf2e7ce527b4ce60de3a5bc05d2c02c77fb) - Adds more environment variables to allow customization of Trackio dashboard. Thanks @abidlabs!
82
+
83
+ ## 0.5.2
84
+
85
+ ### Features
86
+
87
+ - [#293](https://github.com/gradio-app/trackio/pull/293) [`64afc28`](https://github.com/gradio-app/trackio/commit/64afc28d3ea1dfd821472dc6bf0b8ed35a9b74be) - Ensures that the TRACKIO_DIR environment variable is respected. Thanks @abidlabs!
88
+ - [#287](https://github.com/gradio-app/trackio/pull/287) [`cd3e929`](https://github.com/gradio-app/trackio/commit/cd3e9294320949e6b8b829239069a43d5d7ff4c1) - fix(sqlite): unify .sqlite extension, allow export when DBs exist, clean WAL sidecars on import. Thanks @vaibhav-research!
89
+
90
+ ### Fixes
91
+
92
+ - [#291](https://github.com/gradio-app/trackio/pull/291) [`3b5adc3`](https://github.com/gradio-app/trackio/commit/3b5adc3d1f452dbab7a714d235f4974782f93730) - Fix the wheel build. Thanks @pngwn!
93
+
94
+ ## 0.5.1
95
+
96
+ ### Fixes
97
+
98
+ - [#278](https://github.com/gradio-app/trackio/pull/278) [`314c054`](https://github.com/gradio-app/trackio/commit/314c05438007ddfea3383e06fd19143e27468e2d) - Fix row orientation of metrics plots. Thanks @abidlabs!
__init__.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import logging
4
+ import os
5
+ import warnings
6
+ import webbrowser
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import huggingface_hub
11
+ from gradio.themes import ThemeClass
12
+ from gradio.utils import TupleNoPrint
13
+ from gradio_client import Client, handle_file
14
+ from huggingface_hub import SpaceStorage
15
+ from huggingface_hub.errors import LocalTokenNotFoundError
16
+
17
+ from trackio import context_vars, deploy, utils
18
+ from trackio.deploy import sync
19
+ from trackio.histogram import Histogram
20
+ from trackio.imports import import_csv, import_tf_events
21
+ from trackio.media import TrackioAudio, TrackioImage, TrackioVideo
22
+ from trackio.run import Run
23
+ from trackio.sqlite_storage import SQLiteStorage
24
+ from trackio.table import Table
25
+ from trackio.typehints import UploadEntry
26
+ from trackio.ui.main import CSS, HEAD, demo
27
+ from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
28
+
29
+ logging.getLogger("httpx").setLevel(logging.WARNING)
30
+
31
+ warnings.filterwarnings(
32
+ "ignore",
33
+ message="Empty session being created. Install gradio\\[oauth\\]",
34
+ category=UserWarning,
35
+ module="gradio.helpers",
36
+ )
37
+
38
+ __version__ = json.loads(Path(__file__).parent.joinpath("package.json").read_text())[
39
+ "version"
40
+ ]
41
+
42
+ __all__ = [
43
+ "init",
44
+ "log",
45
+ "finish",
46
+ "show",
47
+ "sync",
48
+ "delete_project",
49
+ "import_csv",
50
+ "import_tf_events",
51
+ "save",
52
+ "Image",
53
+ "Video",
54
+ "Audio",
55
+ "Table",
56
+ "Histogram",
57
+ ]
58
+
59
+ Image = TrackioImage
60
+ Video = TrackioVideo
61
+ Audio = TrackioAudio
62
+
63
+
64
+ config = {}
65
+
66
+ DEFAULT_THEME = "default"
67
+
68
+
69
+ def init(
70
+ project: str,
71
+ name: str | None = None,
72
+ group: str | None = None,
73
+ space_id: str | None = None,
74
+ space_storage: SpaceStorage | None = None,
75
+ dataset_id: str | None = None,
76
+ config: dict | None = None,
77
+ resume: str = "never",
78
+ settings: Any = None,
79
+ private: bool | None = None,
80
+ embed: bool = True,
81
+ ) -> Run:
82
+ """
83
+ Creates a new Trackio project and returns a [`Run`] object.
84
+
85
+ Args:
86
+ project (`str`):
87
+ The name of the project (can be an existing project to continue tracking or
88
+ a new project to start tracking from scratch).
89
+ name (`str`, *optional*):
90
+ The name of the run (if not provided, a default name will be generated).
91
+ group (`str`, *optional*):
92
+ The name of the group which this run belongs to in order to help organize
93
+ related runs together. You can toggle the entire group's visibilitiy in the
94
+ dashboard.
95
+ space_id (`str`, *optional*):
96
+ If provided, the project will be logged to a Hugging Face Space instead of
97
+ a local directory. Should be a complete Space name like
98
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
99
+ case the Space will be created in the currently-logged-in Hugging Face
100
+ user's namespace. If the Space does not exist, it will be created. If the
101
+ Space already exists, the project will be logged to it.
102
+ space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
103
+ Choice of persistent storage tier.
104
+ dataset_id (`str`, *optional*):
105
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
106
+ created and the metrics will be synced to it every 5 minutes. Specify a
107
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
108
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
109
+ or `None` (uses the same name as the Space but with the `"_dataset"`
110
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
111
+ already exists, the project will be appended to it.
112
+ config (`dict`, *optional*):
113
+ A dictionary of configuration options. Provided for compatibility with
114
+ `wandb.init()`.
115
+ resume (`str`, *optional*, defaults to `"never"`):
116
+ Controls how to handle resuming a run. Can be one of:
117
+
118
+ - `"must"`: Must resume the run with the given name, raises error if run
119
+ doesn't exist
120
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
121
+ - `"never"`: Never resume a run, always create a new one
122
+ private (`bool`, *optional*):
123
+ Whether to make the Space private. If None (default), the repo will be
124
+ public unless the organization's default is private. This value is ignored
125
+ if the repo already exists.
126
+ settings (`Any`, *optional*):
127
+ Not used. Provided for compatibility with `wandb.init()`.
128
+ embed (`bool`, *optional*, defaults to `True`):
129
+ If running inside a jupyter/Colab notebook, whether the dashboard should
130
+ automatically be embedded in the cell when trackio.init() is called.
131
+
132
+ Returns:
133
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
134
+ """
135
+ if settings is not None:
136
+ warnings.warn(
137
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
138
+ )
139
+
140
+ if space_id is None and dataset_id is not None:
141
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
142
+ try:
143
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(
144
+ space_id, dataset_id
145
+ )
146
+ except LocalTokenNotFoundError as e:
147
+ raise LocalTokenNotFoundError(
148
+ f"You must be logged in to Hugging Face locally when `space_id` is provided to deploy to a Space. {e}"
149
+ ) from e
150
+ url = context_vars.current_server.get()
151
+ share_url = context_vars.current_share_server.get()
152
+
153
+ if url is None:
154
+ if space_id is None:
155
+ _, url, share_url = demo.launch(
156
+ css=CSS,
157
+ head=HEAD,
158
+ footer_links=["gradio", "settings"],
159
+ inline=False,
160
+ quiet=True,
161
+ prevent_thread_lock=True,
162
+ show_error=True,
163
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
164
+ allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
165
+ )
166
+ context_vars.current_space_id.set(None)
167
+ else:
168
+ url = space_id
169
+ share_url = None
170
+ context_vars.current_space_id.set(space_id)
171
+
172
+ context_vars.current_server.set(url)
173
+ context_vars.current_share_server.set(share_url)
174
+ if (
175
+ context_vars.current_project.get() is None
176
+ or context_vars.current_project.get() != project
177
+ ):
178
+ print(f"* Trackio project initialized: {project}")
179
+
180
+ if dataset_id is not None:
181
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
182
+ print(
183
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
184
+ )
185
+ if space_id is None:
186
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
187
+ if utils.is_in_notebook() and embed:
188
+ base_url = share_url + "/" if share_url else url
189
+ full_url = utils.get_full_url(
190
+ base_url, project=project, write_token=demo.write_token, footer=True
191
+ )
192
+ utils.embed_url_in_notebook(full_url)
193
+ else:
194
+ utils.print_dashboard_instructions(project)
195
+ else:
196
+ deploy.create_space_if_not_exists(
197
+ space_id, space_storage, dataset_id, private
198
+ )
199
+ user_name, space_name = space_id.split("/")
200
+ space_url = deploy.SPACE_HOST_URL.format(
201
+ user_name=user_name, space_name=space_name
202
+ )
203
+ print(f"* View dashboard by going to: {space_url}")
204
+ if utils.is_in_notebook() and embed:
205
+ utils.embed_url_in_notebook(space_url)
206
+ context_vars.current_project.set(project)
207
+
208
+ client = None
209
+ if not space_id:
210
+ client = Client(url, verbose=False)
211
+
212
+ if resume == "must":
213
+ if name is None:
214
+ raise ValueError("Must provide a run name when resume='must'")
215
+ if name not in SQLiteStorage.get_runs(project):
216
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
217
+ resumed = True
218
+ elif resume == "allow":
219
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
220
+ elif resume == "never":
221
+ if name is not None and name in SQLiteStorage.get_runs(project):
222
+ warnings.warn(
223
+ f"* Warning: resume='never' but a run '{name}' already exists in "
224
+ f"project '{project}'. Generating a new name and instead. If you want "
225
+ "to resume this run, call init() with resume='must' or resume='allow'."
226
+ )
227
+ name = None
228
+ resumed = False
229
+ else:
230
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
231
+
232
+ run = Run(
233
+ url=url,
234
+ project=project,
235
+ client=client,
236
+ name=name,
237
+ group=group,
238
+ config=config,
239
+ space_id=space_id,
240
+ )
241
+
242
+ if resumed:
243
+ print(f"* Resumed existing run: {run.name}")
244
+ else:
245
+ print(f"* Created new run: {run.name}")
246
+
247
+ context_vars.current_run.set(run)
248
+ globals()["config"] = run.config
249
+ return run
250
+
251
+
252
+ def log(metrics: dict, step: int | None = None) -> None:
253
+ """
254
+ Logs metrics to the current run.
255
+
256
+ Args:
257
+ metrics (`dict`):
258
+ A dictionary of metrics to log.
259
+ step (`int`, *optional*):
260
+ The step number. If not provided, the step will be incremented
261
+ automatically.
262
+ """
263
+ run = context_vars.current_run.get()
264
+ if run is None:
265
+ raise RuntimeError("Call trackio.init() before trackio.log().")
266
+ run.log(
267
+ metrics=metrics,
268
+ step=step,
269
+ )
270
+
271
+
272
+ def finish():
273
+ """
274
+ Finishes the current run.
275
+ """
276
+ run = context_vars.current_run.get()
277
+ if run is None:
278
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
279
+ run.finish()
280
+
281
+
282
+ def delete_project(project: str, force: bool = False) -> bool:
283
+ """
284
+ Deletes a project by removing its local SQLite database.
285
+
286
+ Args:
287
+ project (`str`):
288
+ The name of the project to delete.
289
+ force (`bool`, *optional*, defaults to `False`):
290
+ If `True`, deletes the project without prompting for confirmation.
291
+ If `False`, prompts the user to confirm before deleting.
292
+
293
+ Returns:
294
+ `bool`: `True` if the project was deleted, `False` otherwise.
295
+ """
296
+ db_path = SQLiteStorage.get_project_db_path(project)
297
+
298
+ if not db_path.exists():
299
+ print(f"* Project '{project}' does not exist.")
300
+ return False
301
+
302
+ if not force:
303
+ response = input(
304
+ f"Are you sure you want to delete project '{project}'? "
305
+ f"This will permanently delete all runs and metrics. (y/N): "
306
+ )
307
+ if response.lower() not in ["y", "yes"]:
308
+ print("* Deletion cancelled.")
309
+ return False
310
+
311
+ try:
312
+ db_path.unlink()
313
+
314
+ for suffix in ("-wal", "-shm"):
315
+ sidecar = Path(str(db_path) + suffix)
316
+ if sidecar.exists():
317
+ sidecar.unlink()
318
+
319
+ print(f"* Project '{project}' has been deleted.")
320
+ return True
321
+ except Exception as e:
322
+ print(f"* Error deleting project '{project}': {e}")
323
+ return False
324
+
325
+
326
+ def save(
327
+ glob_str: str | Path,
328
+ project: str | None = None,
329
+ ) -> str:
330
+ """
331
+ Saves files to a project (not linked to a specific run). If Trackio is running
332
+ locally, the file(s) will be moved to the project's files directory. If Trackio is
333
+ running in a Space, the file(s) will be uploaded to the Space's files directory.
334
+
335
+ Args:
336
+ glob_str (`str` or `Path`):
337
+ The file path or glob pattern to save. Can be a single file or a pattern
338
+ matching multiple files (e.g., `"*.py"`, `"models/**/*.pth"`).
339
+ project (`str`, *optional*):
340
+ The name of the project to save files to. If not provided, uses the current
341
+ project from `trackio.init()`. If no project is initialized, raises an
342
+ error.
343
+
344
+ Returns:
345
+ `str`: The path where the file(s) were saved (project's files directory).
346
+
347
+ Example:
348
+ ```python
349
+ import trackio
350
+
351
+ trackio.init(project="my-project")
352
+ trackio.save("config.yaml")
353
+ trackio.save("models/*.pth")
354
+ ```
355
+ """
356
+ if project is None:
357
+ project = context_vars.current_project.get()
358
+ if project is None:
359
+ raise RuntimeError(
360
+ "No project specified. Either call trackio.init() first or provide a "
361
+ "project parameter to trackio.save()."
362
+ )
363
+
364
+ glob_str = Path(glob_str)
365
+ base_path = Path.cwd().resolve()
366
+
367
+ matched_files = []
368
+ if glob_str.is_file():
369
+ matched_files = [glob_str.resolve()]
370
+ else:
371
+ pattern = str(glob_str)
372
+ if not glob_str.is_absolute():
373
+ pattern = str((Path.cwd() / glob_str).resolve())
374
+ matched_files = [
375
+ Path(f).resolve()
376
+ for f in glob.glob(pattern, recursive=True)
377
+ if Path(f).is_file()
378
+ ]
379
+
380
+ if not matched_files:
381
+ raise ValueError(f"No files found matching pattern: {glob_str}")
382
+
383
+ url = context_vars.current_server.get()
384
+ current_run = context_vars.current_run.get()
385
+
386
+ upload_entries = []
387
+
388
+ for file_path in matched_files:
389
+ try:
390
+ relative_to_base = file_path.relative_to(base_path)
391
+ except ValueError:
392
+ relative_to_base = Path(file_path.name)
393
+
394
+ if current_run is not None:
395
+ # If a run is active, use its queue to upload the file to the project's files directory
396
+ # as it's more efficent than uploading files one by one. But we should not use the run name
397
+ # as the files should be stored in the project's files directory, not the run's, hence
398
+ # the use_run_name flag is set to False.
399
+ current_run._queue_upload(
400
+ file_path,
401
+ step=None,
402
+ relative_path=str(relative_to_base.parent),
403
+ use_run_name=False,
404
+ )
405
+ else:
406
+ upload_entry: UploadEntry = {
407
+ "project": project,
408
+ "run": None,
409
+ "step": None,
410
+ "relative_path": str(relative_to_base),
411
+ "uploaded_file": handle_file(file_path),
412
+ }
413
+ upload_entries.append(upload_entry)
414
+
415
+ if upload_entries:
416
+ if url is None:
417
+ raise RuntimeError(
418
+ "No server available. Call trackio.init() before trackio.save() to start the server."
419
+ )
420
+
421
+ try:
422
+ client = Client(url, verbose=False, httpx_kwargs={"timeout": 90})
423
+ client.predict(
424
+ api_name="/bulk_upload_media",
425
+ uploads=upload_entries,
426
+ hf_token=huggingface_hub.utils.get_token(),
427
+ )
428
+ except Exception as e:
429
+ warnings.warn(
430
+ f"Failed to upload files: {e}. "
431
+ "Files may not be available in the dashboard."
432
+ )
433
+
434
+ return str(utils.MEDIA_DIR / project / "files")
435
+
436
+
437
+ def show(
438
+ project: str | None = None,
439
+ *,
440
+ theme: str | ThemeClass | None = None,
441
+ mcp_server: bool | None = None,
442
+ footer: bool = True,
443
+ color_palette: list[str] | None = None,
444
+ open_browser: bool = True,
445
+ block_thread: bool | None = None,
446
+ ):
447
+ """
448
+ Launches the Trackio dashboard.
449
+
450
+ Args:
451
+ project (`str`, *optional*):
452
+ The name of the project whose runs to show. If not provided, all projects
453
+ will be shown and the user can select one.
454
+ theme (`str` or `ThemeClass`, *optional*):
455
+ A Gradio Theme to use for the dashboard instead of the default Gradio theme,
456
+ can be a built-in theme (e.g. `'soft'`, `'citrus'`), a theme from the Hub
457
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class. If not provided, the
458
+ `TRACKIO_THEME` environment variable will be used, or if that is not set,
459
+ the default Gradio theme will be used.
460
+ mcp_server (`bool`, *optional*):
461
+ If `True`, the Trackio dashboard will be set up as an MCP server and certain
462
+ functions will be added as MCP tools. If `None` (default behavior), then the
463
+ `GRADIO_MCP_SERVER` environment variable will be used to determine if the
464
+ MCP server should be enabled (which is `"True"` on Hugging Face Spaces).
465
+ footer (`bool`, *optional*, defaults to `True`):
466
+ Whether to show the Gradio footer. When `False`, the footer will be hidden.
467
+ This can also be controlled via the `footer` query parameter in the URL.
468
+ color_palette (`list[str]`, *optional*):
469
+ A list of hex color codes to use for plot lines. If not provided, the
470
+ `TRACKIO_COLOR_PALETTE` environment variable will be used (comma-separated
471
+ hex codes), or if that is not set, the default color palette will be used.
472
+ Example: `['#FF0000', '#00FF00', '#0000FF']`
473
+ open_browser (`bool`, *optional*, defaults to `True`):
474
+ If `True` and not in a notebook, a new browser tab will be opened with the
475
+ dashboard. If `False`, the browser will not be opened.
476
+ block_thread (`bool`, *optional*):
477
+ If `True`, the main thread will be blocked until the dashboard is closed.
478
+ If `None` (default behavior), then the main thread will not be blocked if the
479
+ dashboard is launched in a notebook, otherwise the main thread will be blocked.
480
+
481
+ Returns:
482
+ `app`: The Gradio app object corresponding to the dashboard launched by Trackio.
483
+ `url`: The local URL of the dashboard.
484
+ `share_url`: The public share URL of the dashboard.
485
+ `full_url`: The full URL of the dashboard including the write token (will use the public share URL if launched publicly, otherwise the local URL).
486
+ """
487
+ if color_palette is not None:
488
+ os.environ["TRACKIO_COLOR_PALETTE"] = ",".join(color_palette)
489
+
490
+ theme = theme or os.environ.get("TRACKIO_THEME", DEFAULT_THEME)
491
+
492
+ _mcp_server = (
493
+ mcp_server
494
+ if mcp_server is not None
495
+ else os.environ.get("GRADIO_MCP_SERVER", "False") == "True"
496
+ )
497
+
498
+ app, url, share_url = demo.launch(
499
+ css=CSS,
500
+ head=HEAD,
501
+ footer_links=["gradio", "settings"] + (["api"] if _mcp_server else []),
502
+ quiet=True,
503
+ inline=False,
504
+ prevent_thread_lock=True,
505
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
506
+ allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
507
+ mcp_server=_mcp_server,
508
+ theme=theme,
509
+ )
510
+
511
+ base_url = share_url + "/" if share_url else url
512
+ full_url = utils.get_full_url(
513
+ base_url, project=project, write_token=demo.write_token, footer=footer
514
+ )
515
+
516
+ if not utils.is_in_notebook():
517
+ print(f"* Trackio UI launched at: {full_url}")
518
+ if open_browser:
519
+ webbrowser.open(full_url)
520
+ block_thread = block_thread if block_thread is not None else True
521
+ else:
522
+ utils.embed_url_in_notebook(full_url)
523
+ block_thread = block_thread if block_thread is not None else False
524
+
525
+ if block_thread:
526
+ utils.block_main_thread_until_keyboard_interrupt()
527
+ return TupleNoPrint((demo, url, share_url, full_url))
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (22.1 kB). View file
 
__pycache__/cli.cpython-312.pyc ADDED
Binary file (3.82 kB). View file
 
__pycache__/commit_scheduler.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
__pycache__/context_vars.cpython-312.pyc ADDED
Binary file (1.09 kB). View file
 
__pycache__/deploy.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
__pycache__/dummy_commit_scheduler.cpython-312.pyc ADDED
Binary file (1.03 kB). View file
 
__pycache__/histogram.cpython-312.pyc ADDED
Binary file (3.24 kB). View file
 
__pycache__/imports.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
__pycache__/run.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
__pycache__/sqlite_storage.cpython-312.pyc ADDED
Binary file (31.4 kB). View file
 
__pycache__/table.cpython-312.pyc ADDED
Binary file (8.73 kB). View file
 
__pycache__/typehints.cpython-312.pyc ADDED
Binary file (971 Bytes). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (29.8 kB). View file
 
assets/badge.png ADDED
assets/trackio_logo_dark.png ADDED
assets/trackio_logo_light.png ADDED
assets/trackio_logo_old.png ADDED

Git LFS Details

  • SHA256: 3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
assets/trackio_logo_type_dark.png ADDED
assets/trackio_logo_type_dark_transparent.png ADDED
assets/trackio_logo_type_light.png ADDED
assets/trackio_logo_type_light_transparent.png ADDED
cli.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from trackio import show, sync
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Trackio CLI")
8
+ subparsers = parser.add_subparsers(dest="command")
9
+
10
+ ui_parser = subparsers.add_parser(
11
+ "show", help="Show the Trackio dashboard UI for a project"
12
+ )
13
+ ui_parser.add_argument(
14
+ "--project", required=False, help="Project name to show in the dashboard"
15
+ )
16
+ ui_parser.add_argument(
17
+ "--theme",
18
+ required=False,
19
+ default="default",
20
+ help="A Gradio Theme to use for the dashboard instead of the default, can be a built-in theme (e.g. 'soft', 'citrus'), or a theme from the Hub (e.g. 'gstaff/xkcd').",
21
+ )
22
+ ui_parser.add_argument(
23
+ "--mcp-server",
24
+ action="store_true",
25
+ help="Enable MCP server functionality. The Trackio dashboard will be set up as an MCP server and certain functions will be exposed as MCP tools.",
26
+ )
27
+ ui_parser.add_argument(
28
+ "--footer",
29
+ action="store_true",
30
+ default=True,
31
+ help="Show the Gradio footer. Use --no-footer to hide it.",
32
+ )
33
+ ui_parser.add_argument(
34
+ "--no-footer",
35
+ dest="footer",
36
+ action="store_false",
37
+ help="Hide the Gradio footer.",
38
+ )
39
+ ui_parser.add_argument(
40
+ "--color-palette",
41
+ required=False,
42
+ help="Comma-separated list of hex color codes for plot lines (e.g. '#FF0000,#00FF00,#0000FF'). If not provided, the TRACKIO_COLOR_PALETTE environment variable will be used, or the default palette if not set.",
43
+ )
44
+
45
+ sync_parser = subparsers.add_parser(
46
+ "sync",
47
+ help="Sync a local project's database to a Hugging Face Space. If the Space does not exist, it will be created.",
48
+ )
49
+ sync_parser.add_argument(
50
+ "--project", required=True, help="The name of the local project."
51
+ )
52
+ sync_parser.add_argument(
53
+ "--space-id",
54
+ required=True,
55
+ help="The Hugging Face Space ID where the project will be synced (e.g. username/space_id).",
56
+ )
57
+ sync_parser.add_argument(
58
+ "--private",
59
+ action="store_true",
60
+ help="Make the Hugging Face Space private if creating a new Space. By default, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.",
61
+ )
62
+ sync_parser.add_argument(
63
+ "--force",
64
+ action="store_true",
65
+ help="Overwrite the existing database without prompting for confirmation.",
66
+ )
67
+
68
+ args = parser.parse_args()
69
+
70
+ if args.command == "show":
71
+ color_palette = None
72
+ if args.color_palette:
73
+ color_palette = [color.strip() for color in args.color_palette.split(",")]
74
+ show(
75
+ project=args.project,
76
+ theme=args.theme,
77
+ mcp_server=args.mcp_server,
78
+ footer=args.footer,
79
+ color_palette=color_palette,
80
+ )
81
+ elif args.command == "sync":
82
+ sync(
83
+ project=args.project,
84
+ space_id=args.space_id,
85
+ private=args.private,
86
+ force=args.force,
87
+ )
88
+ else:
89
+ parser.print_help()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
commit_scheduler.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
2
+
3
+ import atexit
4
+ import logging
5
+ import os
6
+ import time
7
+ from concurrent.futures import Future
8
+ from dataclasses import dataclass
9
+ from io import SEEK_END, SEEK_SET, BytesIO
10
+ from pathlib import Path
11
+ from threading import Lock, Thread
12
+ from typing import Callable, Dict, List, Union
13
+
14
+ from huggingface_hub.hf_api import (
15
+ DEFAULT_IGNORE_PATTERNS,
16
+ CommitInfo,
17
+ CommitOperationAdd,
18
+ HfApi,
19
+ )
20
+ from huggingface_hub.utils import filter_repo_objects
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class _FileToUpload:
27
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
28
+
29
+ local_path: Path
30
+ path_in_repo: str
31
+ size_limit: int
32
+ last_modified: float
33
+
34
+
35
+ class CommitScheduler:
36
+ """
37
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
38
+
39
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
40
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
41
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
42
+ to learn more about how to use it.
43
+
44
+ Args:
45
+ repo_id (`str`):
46
+ The id of the repo to commit to.
47
+ folder_path (`str` or `Path`):
48
+ Path to the local folder to upload regularly.
49
+ every (`int` or `float`, *optional*):
50
+ The number of minutes between each commit. Defaults to 5 minutes.
51
+ path_in_repo (`str`, *optional*):
52
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
53
+ of the repository.
54
+ repo_type (`str`, *optional*):
55
+ The type of the repo to commit to. Defaults to `model`.
56
+ revision (`str`, *optional*):
57
+ The revision of the repo to commit to. Defaults to `main`.
58
+ private (`bool`, *optional*):
59
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
60
+ token (`str`, *optional*):
61
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
62
+ allow_patterns (`List[str]` or `str`, *optional*):
63
+ If provided, only files matching at least one pattern are uploaded.
64
+ ignore_patterns (`List[str]` or `str`, *optional*):
65
+ If provided, files matching any of the patterns are not uploaded.
66
+ squash_history (`bool`, *optional*):
67
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
68
+ useful to avoid degraded performances on the repo when it grows too large.
69
+ hf_api (`HfApi`, *optional*):
70
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
71
+ on_before_commit (`Callable[[], None]`, *optional*):
72
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
73
+
74
+ Example:
75
+ ```py
76
+ >>> from pathlib import Path
77
+ >>> from huggingface_hub import CommitScheduler
78
+
79
+ # Scheduler uploads every 10 minutes
80
+ >>> csv_path = Path("watched_folder/data.csv")
81
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
82
+
83
+ >>> with csv_path.open("a") as f:
84
+ ... f.write("first line")
85
+
86
+ # Some time later (...)
87
+ >>> with csv_path.open("a") as f:
88
+ ... f.write("second line")
89
+ ```
90
+
91
+ Example using a context manager:
92
+ ```py
93
+ >>> from pathlib import Path
94
+ >>> from huggingface_hub import CommitScheduler
95
+
96
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
97
+ ... csv_path = Path("watched_folder/data.csv")
98
+ ... with csv_path.open("a") as f:
99
+ ... f.write("first line")
100
+ ... (...)
101
+ ... with csv_path.open("a") as f:
102
+ ... f.write("second line")
103
+
104
+ # Scheduler is now stopped and last commit have been triggered
105
+ ```
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ repo_id: str,
112
+ folder_path: Union[str, Path],
113
+ every: Union[int, float] = 5,
114
+ path_in_repo: str | None = None,
115
+ repo_type: str | None = None,
116
+ revision: str | None = None,
117
+ private: bool | None = None,
118
+ token: str | None = None,
119
+ allow_patterns: list[str] | str | None = None,
120
+ ignore_patterns: list[str] | str | None = None,
121
+ squash_history: bool = False,
122
+ hf_api: HfApi | None = None,
123
+ on_before_commit: Callable[[], None] | None = None,
124
+ ) -> None:
125
+ self.api = hf_api or HfApi(token=token)
126
+ self.on_before_commit = on_before_commit
127
+
128
+ # Folder
129
+ self.folder_path = Path(folder_path).expanduser().resolve()
130
+ self.path_in_repo = path_in_repo or ""
131
+ self.allow_patterns = allow_patterns
132
+
133
+ if ignore_patterns is None:
134
+ ignore_patterns = []
135
+ elif isinstance(ignore_patterns, str):
136
+ ignore_patterns = [ignore_patterns]
137
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
138
+
139
+ if self.folder_path.is_file():
140
+ raise ValueError(
141
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
142
+ )
143
+ self.folder_path.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Repository
146
+ repo_url = self.api.create_repo(
147
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
148
+ )
149
+ self.repo_id = repo_url.repo_id
150
+ self.repo_type = repo_type
151
+ self.revision = revision
152
+ self.token = token
153
+
154
+ self.last_uploaded: Dict[Path, float] = {}
155
+ self.last_push_time: float | None = None
156
+
157
+ if not every > 0:
158
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
159
+ self.lock = Lock()
160
+ self.every = every
161
+ self.squash_history = squash_history
162
+
163
+ logger.info(
164
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
165
+ )
166
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
167
+ self._scheduler_thread.start()
168
+ atexit.register(self._push_to_hub)
169
+
170
+ self.__stopped = False
171
+
172
+ def stop(self) -> None:
173
+ """Stop the scheduler.
174
+
175
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
176
+ """
177
+ self.__stopped = True
178
+
179
+ def __enter__(self) -> "CommitScheduler":
180
+ return self
181
+
182
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
183
+ # Upload last changes before exiting
184
+ self.trigger().result()
185
+ self.stop()
186
+ return
187
+
188
+ def _run_scheduler(self) -> None:
189
+ """Dumb thread waiting between each scheduled push to Hub."""
190
+ while True:
191
+ self.last_future = self.trigger()
192
+ time.sleep(self.every * 60)
193
+ if self.__stopped:
194
+ break
195
+
196
+ def trigger(self) -> Future:
197
+ """Trigger a `push_to_hub` and return a future.
198
+
199
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
200
+ immediately, without waiting for the next scheduled commit.
201
+ """
202
+ return self.api.run_as_future(self._push_to_hub)
203
+
204
+ def _push_to_hub(self) -> CommitInfo | None:
205
+ if self.__stopped: # If stopped, already scheduled commits are ignored
206
+ return None
207
+
208
+ logger.info("(Background) scheduled commit triggered.")
209
+ try:
210
+ value = self.push_to_hub()
211
+ if self.squash_history:
212
+ logger.info("(Background) squashing repo history.")
213
+ self.api.super_squash_history(
214
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
215
+ )
216
+ return value
217
+ except Exception as e:
218
+ logger.error(
219
+ f"Error while pushing to Hub: {e}"
220
+ ) # Depending on the setup, error might be silenced
221
+ raise
222
+
223
+ def push_to_hub(self) -> CommitInfo | None:
224
+ """
225
+ Push folder to the Hub and return the commit info.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
230
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
231
+ issues.
232
+
233
+ </Tip>
234
+
235
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
236
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
237
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
238
+ for example to compress data together in a single file before committing. For more details and examples, check
239
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
240
+ """
241
+ # Check files to upload (with lock)
242
+ with self.lock:
243
+ if self.on_before_commit is not None:
244
+ self.on_before_commit()
245
+
246
+ logger.debug("Listing files to upload for scheduled commit.")
247
+
248
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
249
+ relpath_to_abspath = {
250
+ path.relative_to(self.folder_path).as_posix(): path
251
+ for path in sorted(
252
+ self.folder_path.glob("**/*")
253
+ ) # sorted to be deterministic
254
+ if path.is_file()
255
+ }
256
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
257
+
258
+ # Filter with pattern + filter out unchanged files + retrieve current file size
259
+ files_to_upload: List[_FileToUpload] = []
260
+ for relpath in filter_repo_objects(
261
+ relpath_to_abspath.keys(),
262
+ allow_patterns=self.allow_patterns,
263
+ ignore_patterns=self.ignore_patterns,
264
+ ):
265
+ local_path = relpath_to_abspath[relpath]
266
+ stat = local_path.stat()
267
+ if (
268
+ self.last_uploaded.get(local_path) is None
269
+ or self.last_uploaded[local_path] != stat.st_mtime
270
+ ):
271
+ files_to_upload.append(
272
+ _FileToUpload(
273
+ local_path=local_path,
274
+ path_in_repo=prefix + relpath,
275
+ size_limit=stat.st_size,
276
+ last_modified=stat.st_mtime,
277
+ )
278
+ )
279
+
280
+ # Return if nothing to upload
281
+ if len(files_to_upload) == 0:
282
+ logger.debug("Dropping schedule commit: no changed file to upload.")
283
+ return None
284
+
285
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
286
+ logger.debug("Removing unchanged files since previous scheduled commit.")
287
+ add_operations = [
288
+ CommitOperationAdd(
289
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
290
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
291
+ path_or_fileobj=file_to_upload.local_path,
292
+ path_in_repo=file_to_upload.path_in_repo,
293
+ )
294
+ for file_to_upload in files_to_upload
295
+ ]
296
+
297
+ # Upload files (append mode expected - no need for lock)
298
+ logger.debug("Uploading files for scheduled commit.")
299
+ commit_info = self.api.create_commit(
300
+ repo_id=self.repo_id,
301
+ repo_type=self.repo_type,
302
+ operations=add_operations,
303
+ commit_message="Scheduled Commit",
304
+ revision=self.revision,
305
+ )
306
+
307
+ for file in files_to_upload:
308
+ self.last_uploaded[file.local_path] = file.last_modified
309
+
310
+ self.last_push_time = time.time()
311
+
312
+ return commit_info
313
+
314
+
315
+ class PartialFileIO(BytesIO):
316
+ """A file-like object that reads only the first part of a file.
317
+
318
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
319
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
320
+
321
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
322
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
323
+
324
+ Only supports `read`, `tell` and `seek` methods.
325
+
326
+ Args:
327
+ file_path (`str` or `Path`):
328
+ Path to the file to read.
329
+ size_limit (`int`):
330
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
331
+ will be read (and uploaded).
332
+ """
333
+
334
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
335
+ self._file_path = Path(file_path)
336
+ self._file = self._file_path.open("rb")
337
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
338
+
339
+ def __del__(self) -> None:
340
+ self._file.close()
341
+ return super().__del__()
342
+
343
+ def __repr__(self) -> str:
344
+ return (
345
+ f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
346
+ )
347
+
348
+ def __len__(self) -> int:
349
+ return self._size_limit
350
+
351
+ def __getattribute__(self, name: str):
352
+ if name.startswith("_") or name in (
353
+ "read",
354
+ "tell",
355
+ "seek",
356
+ ): # only 3 public methods supported
357
+ return super().__getattribute__(name)
358
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
359
+
360
+ def tell(self) -> int:
361
+ """Return the current file position."""
362
+ return self._file.tell()
363
+
364
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
365
+ """Change the stream position to the given offset.
366
+
367
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
368
+ """
369
+ if __whence == SEEK_END:
370
+ # SEEK_END => set from the truncated end
371
+ __offset = len(self) + __offset
372
+ __whence = SEEK_SET
373
+
374
+ pos = self._file.seek(__offset, __whence)
375
+ if pos > self._size_limit:
376
+ return self._file.seek(self._size_limit)
377
+ return pos
378
+
379
+ def read(self, __size: int | None = -1) -> bytes:
380
+ """Read at most `__size` bytes from the file.
381
+
382
+ Behavior is the same as a regular file, except that it is capped to the size limit.
383
+ """
384
+ current = self._file.tell()
385
+ if __size is None or __size < 0:
386
+ # Read until file limit
387
+ truncated_size = self._size_limit - current
388
+ else:
389
+ # Read until file limit or __size
390
+ truncated_size = min(__size, self._size_limit - current)
391
+ return self._file.read(truncated_size)
context_vars.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from trackio.run import Run
6
+
7
+ current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
8
+ "current_run", default=None
9
+ )
10
+ current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
11
+ "current_project", default=None
12
+ )
13
+ current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
14
+ "current_server", default=None
15
+ )
16
+ current_space_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
17
+ "current_space_id", default=None
18
+ )
19
+ current_share_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
20
+ "current_share_server", default=None
21
+ )
deploy.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import io
3
+ import os
4
+ import threading
5
+ import time
6
+ from importlib.resources import files
7
+ from pathlib import Path
8
+
9
+ import gradio
10
+ import huggingface_hub
11
+ from gradio_client import Client, handle_file
12
+ from httpx import ReadTimeout
13
+ from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError
14
+
15
+ import trackio
16
+ from trackio.sqlite_storage import SQLiteStorage
17
+ from trackio.utils import get_or_create_project_hash, preprocess_space_and_dataset_ids
18
+
19
+ SPACE_HOST_URL = "https://{user_name}-{space_name}.hf.space/"
20
+ SPACE_URL = "https://huggingface.co/spaces/{space_id}"
21
+
22
+
23
+ def _is_trackio_installed_from_source() -> bool:
24
+ """Check if trackio is installed from source/editable install vs PyPI."""
25
+ try:
26
+ trackio_file = trackio.__file__
27
+ if "site-packages" not in trackio_file:
28
+ return True
29
+
30
+ dist = importlib.metadata.distribution("trackio")
31
+ if dist.files:
32
+ files = list(dist.files)
33
+ has_pth = any(".pth" in str(f) for f in files)
34
+ if has_pth:
35
+ return True
36
+
37
+ return False
38
+ except (
39
+ AttributeError,
40
+ importlib.metadata.PackageNotFoundError,
41
+ importlib.metadata.MetadataError,
42
+ ValueError,
43
+ TypeError,
44
+ ):
45
+ return True
46
+
47
+
48
+ def deploy_as_space(
49
+ space_id: str,
50
+ space_storage: huggingface_hub.SpaceStorage | None = None,
51
+ dataset_id: str | None = None,
52
+ private: bool | None = None,
53
+ ):
54
+ if (
55
+ os.getenv("SYSTEM") == "spaces"
56
+ ): # in case a repo with this function is uploaded to spaces
57
+ return
58
+
59
+ trackio_path = files("trackio")
60
+
61
+ hf_api = huggingface_hub.HfApi()
62
+
63
+ try:
64
+ huggingface_hub.create_repo(
65
+ space_id,
66
+ private=private,
67
+ space_sdk="gradio",
68
+ space_storage=space_storage,
69
+ repo_type="space",
70
+ exist_ok=True,
71
+ )
72
+ except HfHubHTTPError as e:
73
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
74
+ print("Need 'write' access token to create a Spaces repo.")
75
+ huggingface_hub.login(add_to_git_credential=False)
76
+ huggingface_hub.create_repo(
77
+ space_id,
78
+ private=private,
79
+ space_sdk="gradio",
80
+ space_storage=space_storage,
81
+ repo_type="space",
82
+ exist_ok=True,
83
+ )
84
+ else:
85
+ raise ValueError(f"Failed to create Space: {e}")
86
+
87
+ with open(Path(trackio_path, "README.md"), "r") as f:
88
+ readme_content = f.read()
89
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
90
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
91
+ hf_api.upload_file(
92
+ path_or_fileobj=readme_buffer,
93
+ path_in_repo="README.md",
94
+ repo_id=space_id,
95
+ repo_type="space",
96
+ )
97
+
98
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
99
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
100
+ is_source_install = _is_trackio_installed_from_source()
101
+
102
+ if is_source_install:
103
+ requirements_content = """pyarrow>=21.0
104
+ plotly>=6.0.0,<7.0.0"""
105
+ else:
106
+ requirements_content = f"""pyarrow>=21.0
107
+ trackio=={trackio.__version__}
108
+ plotly>=6.0.0,<7.0.0"""
109
+
110
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
111
+ hf_api.upload_file(
112
+ path_or_fileobj=requirements_buffer,
113
+ path_in_repo="requirements.txt",
114
+ repo_id=space_id,
115
+ repo_type="space",
116
+ )
117
+
118
+ huggingface_hub.utils.disable_progress_bars()
119
+
120
+ if is_source_install:
121
+ hf_api.upload_folder(
122
+ repo_id=space_id,
123
+ repo_type="space",
124
+ folder_path=trackio_path,
125
+ ignore_patterns=["README.md"],
126
+ )
127
+ else:
128
+ app_file_content = """import trackio
129
+ trackio.show()"""
130
+ app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
131
+ hf_api.upload_file(
132
+ path_or_fileobj=app_file_buffer,
133
+ path_in_repo="ui/main.py",
134
+ repo_id=space_id,
135
+ repo_type="space",
136
+ )
137
+
138
+ if hf_token := huggingface_hub.utils.get_token():
139
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
140
+ if dataset_id is not None:
141
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
142
+
143
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
144
+ huggingface_hub.add_space_variable(
145
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
146
+ )
147
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
148
+ huggingface_hub.add_space_variable(
149
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
150
+ )
151
+
152
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
153
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_PLOT_ORDER", plot_order)
154
+
155
+ if theme := os.environ.get("TRACKIO_THEME"):
156
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
157
+
158
+ huggingface_hub.add_space_variable(space_id, "GRADIO_MCP_SERVER", "True")
159
+
160
+
161
+ def create_space_if_not_exists(
162
+ space_id: str,
163
+ space_storage: huggingface_hub.SpaceStorage | None = None,
164
+ dataset_id: str | None = None,
165
+ private: bool | None = None,
166
+ ) -> None:
167
+ """
168
+ Creates a new Hugging Face Space if it does not exist.
169
+
170
+ Args:
171
+ space_id (`str`):
172
+ The ID of the Space to create.
173
+ space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
174
+ Choice of persistent storage tier for the Space.
175
+ dataset_id (`str`, *optional*):
176
+ The ID of the Dataset to add to the Space as a space variable.
177
+ private (`bool`, *optional*):
178
+ Whether to make the Space private. If `None` (default), the repo will be
179
+ public unless the organization's default is private. This value is ignored
180
+ if the repo already exists.
181
+ """
182
+ if "/" not in space_id:
183
+ raise ValueError(
184
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
185
+ )
186
+ if dataset_id is not None and "/" not in dataset_id:
187
+ raise ValueError(
188
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
189
+ )
190
+ try:
191
+ huggingface_hub.repo_info(space_id, repo_type="space")
192
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
193
+ if dataset_id is not None:
194
+ huggingface_hub.add_space_variable(
195
+ space_id, "TRACKIO_DATASET_ID", dataset_id
196
+ )
197
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
198
+ huggingface_hub.add_space_variable(
199
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
200
+ )
201
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
202
+ huggingface_hub.add_space_variable(
203
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
204
+ )
205
+
206
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
207
+ huggingface_hub.add_space_variable(
208
+ space_id, "TRACKIO_PLOT_ORDER", plot_order
209
+ )
210
+
211
+ if theme := os.environ.get("TRACKIO_THEME"):
212
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
213
+ return
214
+ except RepositoryNotFoundError:
215
+ pass
216
+ except HfHubHTTPError as e:
217
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
218
+ print("Need 'write' access token to create a Spaces repo.")
219
+ huggingface_hub.login(add_to_git_credential=False)
220
+ huggingface_hub.add_space_variable(
221
+ space_id, "TRACKIO_DATASET_ID", dataset_id
222
+ )
223
+ else:
224
+ raise ValueError(f"Failed to create Space: {e}")
225
+
226
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
227
+ deploy_as_space(space_id, space_storage, dataset_id, private)
228
+
229
+
230
+ def wait_until_space_exists(
231
+ space_id: str,
232
+ ) -> None:
233
+ """
234
+ Blocks the current thread until the Space exists.
235
+
236
+ Args:
237
+ space_id (`str`):
238
+ The ID of the Space to wait for.
239
+
240
+ Raises:
241
+ `TimeoutError`: If waiting for the Space takes longer than expected.
242
+ """
243
+ hf_api = huggingface_hub.HfApi()
244
+ delay = 1
245
+ for _ in range(30):
246
+ try:
247
+ hf_api.space_info(space_id)
248
+ return
249
+ except (huggingface_hub.utils.HfHubHTTPError, ReadTimeout):
250
+ time.sleep(delay)
251
+ delay = min(delay * 2, 60)
252
+ raise TimeoutError("Waiting for space to exist took longer than expected")
253
+
254
+
255
+ def upload_db_to_space(project: str, space_id: str, force: bool = False) -> None:
256
+ """
257
+ Uploads the database of a local Trackio project to a Hugging Face Space.
258
+
259
+ This uses the Gradio Client to upload since we do not want to trigger a new build of
260
+ the Space, which would happen if we used `huggingface_hub.upload_file`.
261
+
262
+ Args:
263
+ project (`str`):
264
+ The name of the project to upload.
265
+ space_id (`str`):
266
+ The ID of the Space to upload to.
267
+ force (`bool`, *optional*, defaults to `False`):
268
+ If `True`, overwrites the existing database without prompting. If `False`,
269
+ prompts for confirmation.
270
+ """
271
+ db_path = SQLiteStorage.get_project_db_path(project)
272
+ client = Client(space_id, verbose=False, httpx_kwargs={"timeout": 90})
273
+
274
+ if not force:
275
+ try:
276
+ existing_projects = client.predict(api_name="/get_all_projects")
277
+ if project in existing_projects:
278
+ response = input(
279
+ f"Database for project '{project}' already exists on Space '{space_id}'. "
280
+ f"Overwrite it? (y/N): "
281
+ )
282
+ if response.lower() not in ["y", "yes"]:
283
+ print("* Upload cancelled.")
284
+ return
285
+ except Exception as e:
286
+ print(f"* Warning: Could not check if project exists on Space: {e}")
287
+ print("* Proceeding with upload...")
288
+
289
+ client.predict(
290
+ api_name="/upload_db_to_space",
291
+ project=project,
292
+ uploaded_db=handle_file(db_path),
293
+ hf_token=huggingface_hub.utils.get_token(),
294
+ )
295
+
296
+
297
+ def sync(
298
+ project: str,
299
+ space_id: str | None = None,
300
+ private: bool | None = None,
301
+ force: bool = False,
302
+ run_in_background: bool = False,
303
+ ) -> str:
304
+ """
305
+ Syncs a local Trackio project's database to a Hugging Face Space.
306
+ If the Space does not exist, it will be created.
307
+
308
+ Args:
309
+ project (`str`): The name of the project to upload.
310
+ space_id (`str`, *optional*): The ID of the Space to upload to (e.g., `"username/space_id"`).
311
+ If not provided, a random space_id (e.g. "username/project-2ac3z2aA") will be used.
312
+ private (`bool`, *optional*):
313
+ Whether to make the Space private. If None (default), the repo will be
314
+ public unless the organization's default is private. This value is ignored
315
+ if the repo already exists.
316
+ force (`bool`, *optional*, defaults to `False`):
317
+ If `True`, overwrite the existing database without prompting for confirmation.
318
+ If `False`, prompt the user before overwriting an existing database.
319
+ run_in_background (`bool`, *optional*, defaults to `False`):
320
+ If `True`, the Space creation and database upload will be run in a background thread.
321
+ If `False`, all the steps will be run synchronously.
322
+ Returns:
323
+ `str`: The Space ID of the synced project.
324
+ """
325
+ if space_id is None:
326
+ space_id = f"{project}-{get_or_create_project_hash(project)}"
327
+ space_id, _ = preprocess_space_and_dataset_ids(space_id, None)
328
+
329
+ def space_creation_and_upload(
330
+ space_id: str, private: bool | None = None, force: bool = False
331
+ ):
332
+ print(
333
+ f"* Syncing local Trackio project to: {SPACE_URL.format(space_id=space_id)} (please wait...)"
334
+ )
335
+ create_space_if_not_exists(space_id, private=private)
336
+ wait_until_space_exists(space_id)
337
+ upload_db_to_space(project, space_id, force=force)
338
+ print(f"* Synced successfully to space: {SPACE_URL.format(space_id=space_id)}")
339
+
340
+ if run_in_background:
341
+ threading.Thread(
342
+ target=space_creation_and_upload, args=(space_id, private, force)
343
+ ).start()
344
+ else:
345
+ space_creation_and_upload(space_id, private, force)
346
+ return space_id
dummy_commit_scheduler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A dummy object to fit the interface of huggingface_hub's CommitScheduler
2
+ class DummyCommitSchedulerLock:
3
+ def __enter__(self):
4
+ return None
5
+
6
+ def __exit__(self, exception_type, exception_value, exception_traceback):
7
+ pass
8
+
9
+
10
+ class DummyCommitScheduler:
11
+ def __init__(self):
12
+ self.lock = DummyCommitSchedulerLock()
histogram.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ import numpy as np
4
+
5
+
6
+ class Histogram:
7
+ """
8
+ Histogram data type for Trackio, compatible with wandb.Histogram.
9
+
10
+ Args:
11
+ sequence (`np.ndarray` or `Sequence[float]` or `Sequence[int]`, *optional*):
12
+ Sequence of values to create the histogram from.
13
+ np_histogram (`tuple`, *optional*):
14
+ Pre-computed NumPy histogram as a `(hist, bins)` tuple.
15
+ num_bins (`int`, *optional*, defaults to `64`):
16
+ Number of bins for the histogram (maximum `512`).
17
+
18
+ Example:
19
+ ```python
20
+ import trackio
21
+ import numpy as np
22
+
23
+ # Create histogram from sequence
24
+ data = np.random.randn(1000)
25
+ trackio.log({"distribution": trackio.Histogram(data)})
26
+
27
+ # Create histogram from numpy histogram
28
+ hist, bins = np.histogram(data, bins=30)
29
+ trackio.log({"distribution": trackio.Histogram(np_histogram=(hist, bins))})
30
+
31
+ # Specify custom number of bins
32
+ trackio.log({"distribution": trackio.Histogram(data, num_bins=50)})
33
+ ```
34
+ """
35
+
36
+ TYPE = "trackio.histogram"
37
+
38
+ def __init__(
39
+ self,
40
+ sequence: np.ndarray | Sequence[float] | Sequence[int] | None = None,
41
+ np_histogram: tuple | None = None,
42
+ num_bins: int = 64,
43
+ ):
44
+ if sequence is None and np_histogram is None:
45
+ raise ValueError("Must provide either sequence or np_histogram")
46
+
47
+ if sequence is not None and np_histogram is not None:
48
+ raise ValueError("Cannot provide both sequence and np_histogram")
49
+
50
+ num_bins = min(num_bins, 512)
51
+
52
+ if np_histogram is not None:
53
+ self.histogram, self.bins = np_histogram
54
+ self.histogram = np.asarray(self.histogram)
55
+ self.bins = np.asarray(self.bins)
56
+ else:
57
+ data = np.asarray(sequence).flatten()
58
+ data = data[np.isfinite(data)]
59
+ if len(data) == 0:
60
+ self.histogram = np.array([])
61
+ self.bins = np.array([])
62
+ else:
63
+ self.histogram, self.bins = np.histogram(data, bins=num_bins)
64
+
65
+ def _to_dict(self) -> dict:
66
+ """Convert histogram to dictionary for storage."""
67
+ return {
68
+ "_type": self.TYPE,
69
+ "bins": self.bins.tolist(),
70
+ "values": self.histogram.tolist(),
71
+ }
imports.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from trackio import deploy, utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+
9
+
10
+ def import_csv(
11
+ csv_path: str | Path,
12
+ project: str,
13
+ name: str | None = None,
14
+ space_id: str | None = None,
15
+ dataset_id: str | None = None,
16
+ private: bool | None = None,
17
+ force: bool = False,
18
+ ) -> None:
19
+ """
20
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
21
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
22
+ treated as metrics. It should also include a header row with the column names.
23
+
24
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
25
+
26
+ Args:
27
+ csv_path (`str` or `Path`):
28
+ The str or Path to the CSV file to import.
29
+ project (`str`):
30
+ The name of the project to import the CSV file into. Must not be an existing
31
+ project.
32
+ name (`str`, *optional*):
33
+ The name of the Run to import the CSV file into. If not provided, a default
34
+ name will be generated.
35
+ name (`str`, *optional*):
36
+ The name of the run (if not provided, a default name will be generated).
37
+ space_id (`str`, *optional*):
38
+ If provided, the project will be logged to a Hugging Face Space instead of a
39
+ local directory. Should be a complete Space name like `"username/reponame"`
40
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
41
+ be created in the currently-logged-in Hugging Face user's namespace. If the
42
+ Space does not exist, it will be created. If the Space already exists, the
43
+ project will be logged to it.
44
+ dataset_id (`str`, *optional*):
45
+ If provided, a persistent Hugging Face Dataset will be created and the
46
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
47
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
48
+ `"datasetname"` in which case the Dataset will be created in the
49
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
50
+ exist, it will be created. If the Dataset already exists, the project will
51
+ be appended to it. If not provided, the metrics will be logged to a local
52
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
53
+ will be automatically created with the same name as the Space but with the
54
+ `"_dataset"` suffix.
55
+ private (`bool`, *optional*):
56
+ Whether to make the Space private. If None (default), the repo will be
57
+ public unless the organization's default is private. This value is ignored
58
+ if the repo already exists.
59
+ """
60
+ if SQLiteStorage.get_runs(project):
61
+ raise ValueError(
62
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
63
+ )
64
+
65
+ csv_path = Path(csv_path)
66
+ if not csv_path.exists():
67
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
68
+
69
+ df = pd.read_csv(csv_path)
70
+ if df.empty:
71
+ raise ValueError("CSV file is empty")
72
+
73
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
74
+ df = df.rename(columns=column_mapping)
75
+
76
+ step_column = None
77
+ for col in df.columns:
78
+ if col.lower() == "step":
79
+ step_column = col
80
+ break
81
+
82
+ if step_column is None:
83
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
84
+
85
+ if name is None:
86
+ name = csv_path.stem
87
+
88
+ metrics_list = []
89
+ steps = []
90
+ timestamps = []
91
+
92
+ numeric_columns = []
93
+ for column in df.columns:
94
+ if column == step_column:
95
+ continue
96
+ if column == "timestamp":
97
+ continue
98
+
99
+ try:
100
+ pd.to_numeric(df[column], errors="raise")
101
+ numeric_columns.append(column)
102
+ except (ValueError, TypeError):
103
+ continue
104
+
105
+ for _, row in df.iterrows():
106
+ metrics = {}
107
+ for column in numeric_columns:
108
+ value = row[column]
109
+ if bool(pd.notna(value)):
110
+ metrics[column] = float(value)
111
+
112
+ if metrics:
113
+ metrics_list.append(metrics)
114
+ steps.append(int(row[step_column]))
115
+
116
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
117
+ timestamps.append(str(row["timestamp"]))
118
+ else:
119
+ timestamps.append("")
120
+
121
+ if metrics_list:
122
+ SQLiteStorage.bulk_log(
123
+ project=project,
124
+ run=name,
125
+ metrics_list=metrics_list,
126
+ steps=steps,
127
+ timestamps=timestamps,
128
+ )
129
+
130
+ print(
131
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
132
+ )
133
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
134
+
135
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
136
+ if dataset_id is not None:
137
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
138
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
139
+
140
+ if space_id is None:
141
+ utils.print_dashboard_instructions(project)
142
+ else:
143
+ deploy.create_space_if_not_exists(
144
+ space_id=space_id, dataset_id=dataset_id, private=private
145
+ )
146
+ deploy.wait_until_space_exists(space_id=space_id)
147
+ deploy.upload_db_to_space(project=project, space_id=space_id, force=force)
148
+ print(
149
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
150
+ )
151
+
152
+
153
+ def import_tf_events(
154
+ log_dir: str | Path,
155
+ project: str,
156
+ name: str | None = None,
157
+ space_id: str | None = None,
158
+ dataset_id: str | None = None,
159
+ private: bool | None = None,
160
+ force: bool = False,
161
+ ) -> None:
162
+ """
163
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
164
+ subdirectory in the log directory will be imported as a separate run.
165
+
166
+ Args:
167
+ log_dir (`str` or `Path`):
168
+ The str or Path to the directory containing TensorFlow Events files.
169
+ project (`str`):
170
+ The name of the project to import the TensorFlow Events files into. Must not
171
+ be an existing project.
172
+ name (`str`, *optional*):
173
+ The name prefix for runs (if not provided, will use directory names). Each
174
+ subdirectory will create a separate run.
175
+ space_id (`str`, *optional*):
176
+ If provided, the project will be logged to a Hugging Face Space instead of a
177
+ local directory. Should be a complete Space name like `"username/reponame"`
178
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
179
+ be created in the currently-logged-in Hugging Face user's namespace. If the
180
+ Space does not exist, it will be created. If the Space already exists, the
181
+ project will be logged to it.
182
+ dataset_id (`str`, *optional*):
183
+ If provided, a persistent Hugging Face Dataset will be created and the
184
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
185
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
186
+ `"datasetname"` in which case the Dataset will be created in the
187
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
188
+ exist, it will be created. If the Dataset already exists, the project will
189
+ be appended to it. If not provided, the metrics will be logged to a local
190
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
191
+ will be automatically created with the same name as the Space but with the
192
+ `"_dataset"` suffix.
193
+ private (`bool`, *optional*):
194
+ Whether to make the Space private. If None (default), the repo will be
195
+ public unless the organization's default is private. This value is ignored
196
+ if the repo already exists.
197
+ """
198
+ try:
199
+ from tbparse import SummaryReader
200
+ except ImportError:
201
+ raise ImportError(
202
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
203
+ )
204
+
205
+ if SQLiteStorage.get_runs(project):
206
+ raise ValueError(
207
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
208
+ )
209
+
210
+ path = Path(log_dir)
211
+ if not path.exists():
212
+ raise FileNotFoundError(f"TF events directory not found: {path}")
213
+
214
+ # Use tbparse to read all tfevents files in the directory structure
215
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
216
+ df = reader.scalars
217
+
218
+ if df.empty:
219
+ raise ValueError(f"No TensorFlow events data found in {path}")
220
+
221
+ total_imported = 0
222
+ imported_runs = []
223
+
224
+ # Group by dir_name to create separate runs
225
+ for dir_name, group_df in df.groupby("dir_name"):
226
+ try:
227
+ # Determine run name based on directory name
228
+ if dir_name == "":
229
+ run_name = "main" # For files in the root directory
230
+ else:
231
+ run_name = dir_name # Use directory name
232
+
233
+ if name:
234
+ run_name = f"{name}_{run_name}"
235
+
236
+ if group_df.empty:
237
+ print(f"* Skipping directory {dir_name}: no scalar data found")
238
+ continue
239
+
240
+ metrics_list = []
241
+ steps = []
242
+ timestamps = []
243
+
244
+ for _, row in group_df.iterrows():
245
+ # Convert row values to appropriate types
246
+ tag = str(row["tag"])
247
+ value = float(row["value"])
248
+ step = int(row["step"])
249
+
250
+ metrics = {tag: value}
251
+ metrics_list.append(metrics)
252
+ steps.append(step)
253
+
254
+ # Use wall_time if present, else fallback
255
+ if "wall_time" in group_df.columns and not bool(
256
+ pd.isna(row["wall_time"])
257
+ ):
258
+ timestamps.append(str(row["wall_time"]))
259
+ else:
260
+ timestamps.append("")
261
+
262
+ if metrics_list:
263
+ SQLiteStorage.bulk_log(
264
+ project=project,
265
+ run=str(run_name),
266
+ metrics_list=metrics_list,
267
+ steps=steps,
268
+ timestamps=timestamps,
269
+ )
270
+
271
+ total_imported += len(metrics_list)
272
+ imported_runs.append(run_name)
273
+
274
+ print(
275
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
276
+ )
277
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
278
+
279
+ except Exception as e:
280
+ print(f"* Error processing directory {dir_name}: {e}")
281
+ continue
282
+
283
+ if not imported_runs:
284
+ raise ValueError("No valid TensorFlow events data could be imported")
285
+
286
+ print(f"* Total imported events: {total_imported}")
287
+ print(f"* Created runs: {', '.join(imported_runs)}")
288
+
289
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
290
+ if dataset_id is not None:
291
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
292
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
293
+
294
+ if space_id is None:
295
+ utils.print_dashboard_instructions(project)
296
+ else:
297
+ deploy.create_space_if_not_exists(
298
+ space_id, dataset_id=dataset_id, private=private
299
+ )
300
+ deploy.wait_until_space_exists(space_id)
301
+ deploy.upload_db_to_space(project, space_id, force=force)
302
+ print(
303
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
304
+ )
media/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Media module for Trackio.
3
+
4
+ This module contains all media-related functionality including:
5
+ - TrackioImage, TrackioVideo, TrackioAudio classes
6
+ - Video writing utilities
7
+ - Audio conversion utilities
8
+ """
9
+
10
+ try:
11
+ from trackio.media.audio import TrackioAudio
12
+ from trackio.media.image import TrackioImage
13
+ from trackio.media.media import TrackioMedia
14
+ from trackio.media.utils import get_project_media_path
15
+ from trackio.media.video import TrackioVideo
16
+ except ImportError:
17
+ from media.audio import TrackioAudio
18
+ from media.image import TrackioImage
19
+ from media.media import TrackioMedia
20
+ from media.utils import get_project_media_path
21
+ from media.video import TrackioVideo
22
+
23
+ write_audio = TrackioAudio.write_audio
24
+ write_video = TrackioVideo.write_video
25
+
26
+ __all__ = [
27
+ "TrackioMedia",
28
+ "TrackioImage",
29
+ "TrackioVideo",
30
+ "TrackioAudio",
31
+ "get_project_media_path",
32
+ "write_video",
33
+ "write_audio",
34
+ ]
media/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.05 kB). View file
 
media/__pycache__/audio.cpython-312.pyc ADDED
Binary file (9.15 kB). View file
 
media/__pycache__/image.cpython-312.pyc ADDED
Binary file (4.86 kB). View file
 
media/__pycache__/media.cpython-312.pyc ADDED
Binary file (4.62 kB). View file
 
media/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.67 kB). View file
 
media/__pycache__/video.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
media/audio.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+ from pydub import AudioSegment
9
+
10
+ try:
11
+ from trackio.media.media import TrackioMedia
12
+ from trackio.media.utils import check_ffmpeg_installed, check_path
13
+ except ImportError:
14
+ from media.media import TrackioMedia
15
+ from media.utils import check_ffmpeg_installed, check_path
16
+
17
+ SUPPORTED_FORMATS = ["wav", "mp3"]
18
+ AudioFormatType = Literal["wav", "mp3"]
19
+ TrackioAudioSourceType = str | Path | np.ndarray
20
+
21
+
22
+ class TrackioAudio(TrackioMedia):
23
+ """
24
+ Initializes an Audio object.
25
+
26
+ Example:
27
+ ```python
28
+ import trackio
29
+ import numpy as np
30
+
31
+ # Generate a 1-second 440 Hz sine wave (mono)
32
+ sr = 16000
33
+ t = np.linspace(0, 1, sr, endpoint=False)
34
+ wave = 0.2 * np.sin(2 * np.pi * 440 * t)
35
+ audio = trackio.Audio(wave, caption="A4 sine", sample_rate=sr, format="wav")
36
+ trackio.log({"tone": audio})
37
+
38
+ # Stereo from numpy array (shape: samples, 2)
39
+ stereo = np.stack([wave, wave], axis=1)
40
+ audio = trackio.Audio(stereo, caption="Stereo", sample_rate=sr, format="mp3")
41
+ trackio.log({"stereo": audio})
42
+
43
+ # From an existing file
44
+ audio = trackio.Audio("path/to/audio.wav", caption="From file")
45
+ trackio.log({"file_audio": audio})
46
+ ```
47
+
48
+ Args:
49
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
50
+ A path to an audio file, or a numpy array.
51
+ The array should be shaped `(samples,)` for mono or `(samples, 2)` for stereo.
52
+ Float arrays will be peak-normalized and converted to 16-bit PCM; integer arrays will be converted to 16-bit PCM as needed.
53
+ caption (`str`, *optional*):
54
+ A string caption for the audio.
55
+ sample_rate (`int`, *optional*):
56
+ Sample rate in Hz. Required when `value` is a numpy array.
57
+ format (`Literal["wav", "mp3"]`, *optional*):
58
+ Audio format used when `value` is a numpy array. Default is "wav".
59
+ """
60
+
61
+ TYPE = "trackio.audio"
62
+
63
+ def __init__(
64
+ self,
65
+ value: TrackioAudioSourceType,
66
+ caption: str | None = None,
67
+ sample_rate: int | None = None,
68
+ format: AudioFormatType | None = None,
69
+ ):
70
+ super().__init__(value, caption)
71
+ if isinstance(value, np.ndarray):
72
+ if sample_rate is None:
73
+ raise ValueError("Sample rate is required when value is an ndarray")
74
+ if format is None:
75
+ format = "wav"
76
+ self._format = format
77
+ self._sample_rate = sample_rate
78
+
79
+ def _save_media(self, file_path: Path):
80
+ if isinstance(self._value, np.ndarray):
81
+ TrackioAudio.write_audio(
82
+ data=self._value,
83
+ sample_rate=self._sample_rate,
84
+ filename=file_path,
85
+ format=self._format,
86
+ )
87
+ elif isinstance(self._value, str | Path):
88
+ if os.path.isfile(self._value):
89
+ shutil.copy(self._value, file_path)
90
+ else:
91
+ raise ValueError(f"File not found: {self._value}")
92
+
93
+ @staticmethod
94
+ def ensure_int16_pcm(data: np.ndarray) -> np.ndarray:
95
+ """
96
+ Convert input audio array to contiguous int16 PCM.
97
+ Peak normalization is applied to floating inputs.
98
+ """
99
+ arr = np.asarray(data)
100
+ if arr.ndim not in (1, 2):
101
+ raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])")
102
+
103
+ if arr.dtype != np.int16:
104
+ warnings.warn(
105
+ f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.",
106
+ stacklevel=2,
107
+ )
108
+
109
+ arr = np.nan_to_num(arr, copy=False)
110
+
111
+ # Floating types: normalize to peak 1.0, then scale to int16
112
+ if np.issubdtype(arr.dtype, np.floating):
113
+ max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0
114
+ if max_abs > 0.0:
115
+ arr = arr / max_abs
116
+ out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False)
117
+ return np.ascontiguousarray(out)
118
+
119
+ converters: dict[np.dtype, callable] = {
120
+ np.dtype(np.int16): lambda a: a,
121
+ np.dtype(np.int32): lambda a: (
122
+ (a.astype(np.int32) // 65536).astype(np.int16, copy=False)
123
+ ),
124
+ np.dtype(np.uint16): lambda a: (
125
+ (a.astype(np.int32) - 32768).astype(np.int16, copy=False)
126
+ ),
127
+ np.dtype(np.uint8): lambda a: (
128
+ (a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False)
129
+ ),
130
+ np.dtype(np.int8): lambda a: (
131
+ (a.astype(np.int32) * 256).astype(np.int16, copy=False)
132
+ ),
133
+ }
134
+
135
+ conv = converters.get(arr.dtype)
136
+ if conv is not None:
137
+ out = conv(arr)
138
+ return np.ascontiguousarray(out)
139
+ raise TypeError(f"Unsupported audio dtype: {arr.dtype}")
140
+
141
+ @staticmethod
142
+ def write_audio(
143
+ data: np.ndarray,
144
+ sample_rate: int,
145
+ filename: str | Path,
146
+ format: AudioFormatType = "wav",
147
+ ) -> None:
148
+ if not isinstance(sample_rate, int) or sample_rate <= 0:
149
+ raise ValueError(f"Invalid sample_rate: {sample_rate}")
150
+ if format not in SUPPORTED_FORMATS:
151
+ raise ValueError(
152
+ f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}"
153
+ )
154
+
155
+ check_path(filename)
156
+
157
+ pcm = TrackioAudio.ensure_int16_pcm(data)
158
+
159
+ if format != "wav":
160
+ check_ffmpeg_installed()
161
+
162
+ channels = 1 if pcm.ndim == 1 else pcm.shape[1]
163
+ audio = AudioSegment(
164
+ pcm.tobytes(),
165
+ frame_rate=sample_rate,
166
+ sample_width=2, # int16
167
+ channels=channels,
168
+ )
169
+
170
+ file = audio.export(str(filename), format=format)
171
+ file.close()
media/image.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from PIL import Image as PILImage
7
+
8
+ try:
9
+ from trackio.media.media import TrackioMedia
10
+ except ImportError:
11
+ from media.media import TrackioMedia
12
+
13
+
14
+ TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
15
+
16
+
17
+ class TrackioImage(TrackioMedia):
18
+ """
19
+ Initializes an Image object.
20
+
21
+ Example:
22
+ ```python
23
+ import trackio
24
+ import numpy as np
25
+ from PIL import Image
26
+
27
+ # Create an image from numpy array
28
+ image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
29
+ image = trackio.Image(image_data, caption="Random image")
30
+ trackio.log({"my_image": image})
31
+
32
+ # Create an image from PIL Image
33
+ pil_image = Image.new('RGB', (100, 100), color='red')
34
+ image = trackio.Image(pil_image, caption="Red square")
35
+ trackio.log({"red_image": image})
36
+
37
+ # Create an image from file path
38
+ image = trackio.Image("path/to/image.jpg", caption="Photo from file")
39
+ trackio.log({"file_image": image})
40
+ ```
41
+
42
+ Args:
43
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*):
44
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
45
+ If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
46
+ caption (`str`, *optional*):
47
+ A string caption for the image.
48
+ """
49
+
50
+ TYPE = "trackio.image"
51
+
52
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
53
+ super().__init__(value, caption)
54
+ self._format: str | None = None
55
+
56
+ if not isinstance(self._value, TrackioImageSourceType):
57
+ raise ValueError(
58
+ f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}"
59
+ )
60
+ if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8:
61
+ raise ValueError(
62
+ f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
63
+ )
64
+ if (
65
+ isinstance(self._value, np.ndarray | PILImage.Image)
66
+ and self._format is None
67
+ ):
68
+ self._format = "png"
69
+
70
+ def _as_pil(self) -> PILImage.Image | None:
71
+ try:
72
+ if isinstance(self._value, np.ndarray):
73
+ arr = np.asarray(self._value).astype("uint8")
74
+ return PILImage.fromarray(arr).convert("RGBA")
75
+ if isinstance(self._value, PILImage.Image):
76
+ return self._value.convert("RGBA")
77
+ except Exception as e:
78
+ raise ValueError(f"Failed to process image data: {self._value}") from e
79
+ return None
80
+
81
+ def _save_media(self, file_path: Path):
82
+ if pil := self._as_pil():
83
+ pil.save(file_path, format=self._format)
84
+ elif isinstance(self._value, str | Path):
85
+ if os.path.isfile(self._value):
86
+ shutil.copy(self._value, file_path)
87
+ else:
88
+ raise ValueError(f"File not found: {self._value}")
media/media.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+
6
+ try: # absolute imports when installed
7
+ from trackio.media.utils import get_project_media_path
8
+ from trackio.utils import MEDIA_DIR
9
+ except ImportError: # relative imports for local execution on Spaces
10
+ from media.utils import get_project_media_path
11
+ from utils import MEDIA_DIR
12
+
13
+
14
+ class TrackioMedia(ABC):
15
+ """
16
+ Abstract base class for Trackio media objects
17
+ Provides shared functionality for file handling and serialization.
18
+ """
19
+
20
+ TYPE: str
21
+
22
+ def __init_subclass__(cls, **kwargs):
23
+ """Ensure subclasses define the TYPE attribute."""
24
+ super().__init_subclass__(**kwargs)
25
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
26
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
27
+
28
+ def __init__(self, value, caption: str | None = None):
29
+ """
30
+ Saves the value and caption, and if the value is a file path, checks if the file exists.
31
+ """
32
+ self.caption = caption
33
+ self._value = value
34
+ self._file_path: Path | None = None
35
+
36
+ if isinstance(self._value, str | Path):
37
+ if not os.path.isfile(self._value):
38
+ raise ValueError(f"File not found: {self._value}")
39
+
40
+ def _file_extension(self) -> str:
41
+ if self._file_path:
42
+ return self._file_path.suffix[1:].lower()
43
+ if isinstance(self._value, str | Path):
44
+ path = Path(self._value)
45
+ return path.suffix[1:].lower()
46
+ if hasattr(self, "_format") and self._format:
47
+ return self._format
48
+ return "unknown"
49
+
50
+ def _get_relative_file_path(self) -> Path | None:
51
+ return self._file_path
52
+
53
+ def _get_absolute_file_path(self) -> Path | None:
54
+ if self._file_path:
55
+ return MEDIA_DIR / self._file_path
56
+ return None
57
+
58
+ def _save(self, project: str, run: str, step: int = 0):
59
+ if self._file_path:
60
+ return
61
+
62
+ media_dir = get_project_media_path(project=project, run=run, step=step)
63
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
64
+ file_path = media_dir / filename
65
+
66
+ self._save_media(file_path)
67
+ self._file_path = file_path.relative_to(MEDIA_DIR)
68
+
69
+ @abstractmethod
70
+ def _save_media(self, file_path: Path):
71
+ """
72
+ Performs the actual media saving logic.
73
+ """
74
+ pass
75
+
76
+ def _to_dict(self) -> dict:
77
+ if not self._file_path:
78
+ raise ValueError("Media must be saved to file before serialization")
79
+ return {
80
+ "_type": self.TYPE,
81
+ "file_path": str(self._get_relative_file_path()),
82
+ "caption": self.caption,
83
+ }
media/utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+
4
+ try:
5
+ from trackio.utils import MEDIA_DIR
6
+ except ImportError:
7
+ from utils import MEDIA_DIR
8
+
9
+
10
+ def check_path(file_path: str | Path) -> None:
11
+ """Raise an error if the parent directory does not exist."""
12
+ file_path = Path(file_path)
13
+ if not file_path.parent.exists():
14
+ try:
15
+ file_path.parent.mkdir(parents=True, exist_ok=True)
16
+ except OSError as e:
17
+ raise ValueError(
18
+ f"Failed to create parent directory {file_path.parent}: {e}"
19
+ )
20
+
21
+
22
+ def check_ffmpeg_installed() -> None:
23
+ """Raise an error if ffmpeg is not available on the system PATH."""
24
+ if shutil.which("ffmpeg") is None:
25
+ raise RuntimeError(
26
+ "ffmpeg is required to write video but was not found on your system. "
27
+ "Please install ffmpeg and ensure it is available on your PATH."
28
+ )
29
+
30
+
31
+ def get_project_media_path(
32
+ project: str,
33
+ run: str | None = None,
34
+ step: int | None = None,
35
+ relative_path: str | Path | None = None,
36
+ ) -> Path:
37
+ """
38
+ Get the full path where uploaded files are stored for a Trackio project (and create the directory if it doesn't exist).
39
+ If a run is not provided, the files are stored in a project-level directory with the given relative path.
40
+
41
+ Args:
42
+ project: The project name
43
+ run: The run name
44
+ step: The step number
45
+ relative_path: The relative path within the directory (only used if run is not provided)
46
+
47
+ Returns:
48
+ The full path to the media file
49
+ """
50
+ if step is not None and run is None:
51
+ raise ValueError("Uploading files at a specific step requires a run")
52
+
53
+ path = MEDIA_DIR / project
54
+ if run:
55
+ path /= run
56
+ if step is not None:
57
+ path /= str(step)
58
+ else:
59
+ path /= "files"
60
+ if relative_path:
61
+ path /= relative_path
62
+ path.mkdir(parents=True, exist_ok=True)
63
+ return path
media/video.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import numpy as np
8
+
9
+ try:
10
+ from trackio.media.media import TrackioMedia
11
+ from trackio.media.utils import check_ffmpeg_installed, check_path
12
+ except ImportError:
13
+ from media.media import TrackioMedia
14
+ from media.utils import check_ffmpeg_installed, check_path
15
+
16
+
17
+ TrackioVideoSourceType = str | Path | np.ndarray
18
+ TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
19
+ VideoCodec = Literal["h264", "vp9", "gif"]
20
+
21
+
22
+ class TrackioVideo(TrackioMedia):
23
+ """
24
+ Initializes a Video object.
25
+
26
+ Example:
27
+ ```python
28
+ import trackio
29
+ import numpy as np
30
+
31
+ # Create a simple video from numpy array
32
+ frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8)
33
+ video = trackio.Video(frames, caption="Random video", fps=30)
34
+
35
+ # Create a batch of videos
36
+ batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8)
37
+ batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15)
38
+
39
+ # Create video from file path
40
+ video = trackio.Video("path/to/video.mp4", caption="Video from file")
41
+ ```
42
+
43
+ Args:
44
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
45
+ A path to a video file, or a numpy array.
46
+ If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
47
+ It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width).
48
+ For the latter, the videos will be tiled into a grid.
49
+ caption (`str`, *optional*):
50
+ A string caption for the video.
51
+ fps (`int`, *optional*):
52
+ Frames per second for the video. Only used when value is an ndarray. Default is `24`.
53
+ format (`Literal["gif", "mp4", "webm"]`, *optional*):
54
+ Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif".
55
+ """
56
+
57
+ TYPE = "trackio.video"
58
+
59
+ def __init__(
60
+ self,
61
+ value: TrackioVideoSourceType,
62
+ caption: str | None = None,
63
+ fps: int | None = None,
64
+ format: TrackioVideoFormatType | None = None,
65
+ ):
66
+ super().__init__(value, caption)
67
+
68
+ if not isinstance(self._value, TrackioVideoSourceType):
69
+ raise ValueError(
70
+ f"Invalid value type, expected {TrackioVideoSourceType}, got {type(self._value)}"
71
+ )
72
+ if isinstance(self._value, np.ndarray):
73
+ if self._value.dtype != np.uint8:
74
+ raise ValueError(
75
+ f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
76
+ )
77
+ if format is None:
78
+ format = "gif"
79
+ if fps is None:
80
+ fps = 24
81
+ self._fps = fps
82
+ self._format = format
83
+
84
+ @staticmethod
85
+ def _check_array_format(video: np.ndarray) -> None:
86
+ """Raise an error if the array is not in the expected format."""
87
+ if not (video.ndim == 4 and video.shape[-1] == 3):
88
+ raise ValueError(
89
+ f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. "
90
+ f"Input has {video.ndim} dimensions, expected 4."
91
+ )
92
+ if video.dtype != np.uint8:
93
+ raise TypeError(
94
+ f"Expected dtype=uint8, got {video.dtype}. "
95
+ "Please convert your video data to uint8 format."
96
+ )
97
+
98
+ @staticmethod
99
+ def write_video(
100
+ file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec
101
+ ) -> None:
102
+ """RGB uint8 only, shape (F, H, W, 3)."""
103
+ check_ffmpeg_installed()
104
+ check_path(file_path)
105
+
106
+ if codec not in {"h264", "vp9", "gif"}:
107
+ raise ValueError("Unsupported codec. Use h264, vp9, or gif.")
108
+
109
+ arr = np.asarray(video)
110
+ TrackioVideo._check_array_format(arr)
111
+
112
+ frames = np.ascontiguousarray(arr)
113
+ _, height, width, _ = frames.shape
114
+ out_path = str(file_path)
115
+
116
+ cmd = [
117
+ "ffmpeg",
118
+ "-y",
119
+ "-f",
120
+ "rawvideo",
121
+ "-s",
122
+ f"{width}x{height}",
123
+ "-pix_fmt",
124
+ "rgb24",
125
+ "-r",
126
+ str(fps),
127
+ "-i",
128
+ "-",
129
+ "-an",
130
+ ]
131
+
132
+ if codec == "gif":
133
+ video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse"
134
+ cmd += [
135
+ "-vf",
136
+ video_filter,
137
+ "-loop",
138
+ "0",
139
+ ]
140
+ elif codec == "h264":
141
+ cmd += [
142
+ "-vcodec",
143
+ "libx264",
144
+ "-pix_fmt",
145
+ "yuv420p",
146
+ "-movflags",
147
+ "+faststart",
148
+ ]
149
+ elif codec == "vp9":
150
+ bpp = 0.08
151
+ bps = int(width * height * fps * bpp)
152
+ if bps >= 1_000_000:
153
+ bitrate = f"{round(bps / 1_000_000)}M"
154
+ elif bps >= 1_000:
155
+ bitrate = f"{round(bps / 1_000)}k"
156
+ else:
157
+ bitrate = str(max(bps, 1))
158
+ cmd += [
159
+ "-vcodec",
160
+ "libvpx-vp9",
161
+ "-b:v",
162
+ bitrate,
163
+ "-pix_fmt",
164
+ "yuv420p",
165
+ ]
166
+ cmd += [out_path]
167
+ proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
168
+ try:
169
+ for frame in frames:
170
+ proc.stdin.write(frame.tobytes())
171
+ finally:
172
+ if proc.stdin:
173
+ proc.stdin.close()
174
+ stderr = (
175
+ proc.stderr.read().decode("utf-8", errors="ignore")
176
+ if proc.stderr
177
+ else ""
178
+ )
179
+ ret = proc.wait()
180
+ if ret != 0:
181
+ raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}")
182
+
183
+ @property
184
+ def _codec(self) -> str:
185
+ match self._format:
186
+ case "gif":
187
+ return "gif"
188
+ case "mp4":
189
+ return "h264"
190
+ case "webm":
191
+ return "vp9"
192
+ case _:
193
+ raise ValueError(f"Unsupported format: {self._format}")
194
+
195
+ def _save_media(self, file_path: Path):
196
+ if isinstance(self._value, np.ndarray):
197
+ video = TrackioVideo._process_ndarray(self._value)
198
+ TrackioVideo.write_video(file_path, video, fps=self._fps, codec=self._codec)
199
+ elif isinstance(self._value, str | Path):
200
+ if os.path.isfile(self._value):
201
+ shutil.copy(self._value, file_path)
202
+ else:
203
+ raise ValueError(f"File not found: {self._value}")
204
+
205
+ @staticmethod
206
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
207
+ # Verify value is either 4D (single video) or 5D array (batched videos).
208
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
209
+ if value.ndim < 4:
210
+ raise ValueError(
211
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
212
+ )
213
+ if value.ndim > 5:
214
+ raise ValueError(
215
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
216
+ )
217
+ if value.ndim == 4:
218
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
219
+ value = value[np.newaxis, ...]
220
+
221
+ value = TrackioVideo._tile_batched_videos(value)
222
+ return value
223
+
224
+ @staticmethod
225
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
226
+ """
227
+ Tiles a batch of videos into a grid of videos.
228
+
229
+ Input format: (batch, frames, channels, height, width) - original FCHW format
230
+ Output format: (frames, total_height, total_width, channels)
231
+ """
232
+ batch_size, frames, channels, height, width = video.shape
233
+
234
+ next_pow2 = 1 << (batch_size - 1).bit_length()
235
+ if batch_size != next_pow2:
236
+ pad_len = next_pow2 - batch_size
237
+ pad_shape = (pad_len, frames, channels, height, width)
238
+ padding = np.zeros(pad_shape, dtype=video.dtype)
239
+ video = np.concatenate((video, padding), axis=0)
240
+ batch_size = next_pow2
241
+
242
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
243
+ n_cols = batch_size // n_rows
244
+
245
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
246
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
247
+
248
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
249
+ video = video.transpose(2, 0, 4, 1, 5, 3)
250
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
251
+ return video
package.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "trackio",
3
+ "version": "0.13.0",
4
+ "description": "",
5
+ "python": "true"
6
+ }
py.typed ADDED
File without changes
run.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ import warnings
4
+ from datetime import datetime, timezone
5
+
6
+ import huggingface_hub
7
+ from gradio_client import Client, handle_file
8
+
9
+ from trackio import utils
10
+ from trackio.histogram import Histogram
11
+ from trackio.media import TrackioMedia
12
+ from trackio.sqlite_storage import SQLiteStorage
13
+ from trackio.table import Table
14
+ from trackio.typehints import LogEntry, UploadEntry
15
+ from trackio.utils import _get_default_namespace
16
+
17
+ BATCH_SEND_INTERVAL = 0.5
18
+
19
+
20
+ class Run:
21
+ def __init__(
22
+ self,
23
+ url: str,
24
+ project: str,
25
+ client: Client | None,
26
+ name: str | None = None,
27
+ group: str | None = None,
28
+ config: dict | None = None,
29
+ space_id: str | None = None,
30
+ ):
31
+ self.url = url
32
+ self.project = project
33
+ self._client_lock = threading.Lock()
34
+ self._client_thread = None
35
+ self._client = client
36
+ self._space_id = space_id
37
+ self.name = name or utils.generate_readable_name(
38
+ SQLiteStorage.get_runs(project), space_id
39
+ )
40
+ self.group = group
41
+ self.config = utils.to_json_safe(config or {})
42
+
43
+ if isinstance(self.config, dict):
44
+ for key in self.config:
45
+ if key.startswith("_"):
46
+ raise ValueError(
47
+ f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)"
48
+ )
49
+
50
+ self.config["_Username"] = self._get_username()
51
+ self.config["_Created"] = datetime.now(timezone.utc).isoformat()
52
+ self.config["_Group"] = self.group
53
+
54
+ self._queued_logs: list[LogEntry] = []
55
+ self._queued_uploads: list[UploadEntry] = []
56
+ self._stop_flag = threading.Event()
57
+ self._config_logged = False
58
+
59
+ self._client_thread = threading.Thread(target=self._init_client_background)
60
+ self._client_thread.daemon = True
61
+ self._client_thread.start()
62
+
63
+ def _get_username(self) -> str | None:
64
+ """Get the current HuggingFace username if logged in, otherwise None."""
65
+ try:
66
+ return _get_default_namespace()
67
+ except Exception:
68
+ return None
69
+
70
+ def _batch_sender(self):
71
+ """Send batched logs every BATCH_SEND_INTERVAL."""
72
+ while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
73
+ if not self._stop_flag.is_set():
74
+ time.sleep(BATCH_SEND_INTERVAL)
75
+
76
+ with self._client_lock:
77
+ if self._client is None:
78
+ return
79
+ if self._queued_logs:
80
+ logs_to_send = self._queued_logs.copy()
81
+ self._queued_logs.clear()
82
+ self._client.predict(
83
+ api_name="/bulk_log",
84
+ logs=logs_to_send,
85
+ hf_token=huggingface_hub.utils.get_token(),
86
+ )
87
+ if self._queued_uploads:
88
+ uploads_to_send = self._queued_uploads.copy()
89
+ self._queued_uploads.clear()
90
+ self._client.predict(
91
+ api_name="/bulk_upload_media",
92
+ uploads=uploads_to_send,
93
+ hf_token=huggingface_hub.utils.get_token(),
94
+ )
95
+
96
+ def _init_client_background(self):
97
+ if self._client is None:
98
+ fib = utils.fibo()
99
+ for sleep_coefficient in fib:
100
+ try:
101
+ client = Client(self.url, verbose=False)
102
+
103
+ with self._client_lock:
104
+ self._client = client
105
+ break
106
+ except Exception:
107
+ pass
108
+ if sleep_coefficient is not None:
109
+ time.sleep(0.1 * sleep_coefficient)
110
+
111
+ self._batch_sender()
112
+
113
+ def _queue_upload(
114
+ self,
115
+ file_path,
116
+ step: int | None,
117
+ relative_path: str | None = None,
118
+ use_run_name: bool = True,
119
+ ):
120
+ """
121
+ Queues a media file for upload to a Space.
122
+
123
+ Args:
124
+ file_path:
125
+ The path to the file to upload.
126
+ step (`int` or `None`, *optional*):
127
+ The step number associated with this upload.
128
+ relative_path (`str` or `None`, *optional*):
129
+ The relative path within the project's files directory. Used when
130
+ uploading files via `trackio.save()`.
131
+ use_run_name (`bool`, *optional*):
132
+ Whether to use the run name for the uploaded file. This is set to
133
+ `False` when uploading files via `trackio.save()`.
134
+ """
135
+ upload_entry: UploadEntry = {
136
+ "project": self.project,
137
+ "run": self.name if use_run_name else None,
138
+ "step": step,
139
+ "relative_path": relative_path,
140
+ "uploaded_file": handle_file(file_path),
141
+ }
142
+ with self._client_lock:
143
+ self._queued_uploads.append(upload_entry)
144
+
145
+ def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
146
+ """
147
+ Serialize media in metrics and upload to space if needed.
148
+ """
149
+ value._save(self.project, self.name, step)
150
+ if self._space_id:
151
+ self._queue_upload(value._get_absolute_file_path(), step)
152
+ return value._to_dict()
153
+
154
+ def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None):
155
+ """
156
+ Scan a serialized table for media objects and queue them for upload to space.
157
+ """
158
+ if not self._space_id:
159
+ return
160
+
161
+ table_data = table_dict.get("_value", [])
162
+ for row in table_data:
163
+ for value in row.values():
164
+ if isinstance(value, dict) and value.get("_type") in [
165
+ "trackio.image",
166
+ "trackio.video",
167
+ "trackio.audio",
168
+ ]:
169
+ file_path = value.get("file_path")
170
+ if file_path:
171
+ from trackio.utils import MEDIA_DIR
172
+
173
+ absolute_path = MEDIA_DIR / file_path
174
+ self._queue_upload(absolute_path, step)
175
+ elif isinstance(value, list):
176
+ for item in value:
177
+ if isinstance(item, dict) and item.get("_type") in [
178
+ "trackio.image",
179
+ "trackio.video",
180
+ "trackio.audio",
181
+ ]:
182
+ file_path = item.get("file_path")
183
+ if file_path:
184
+ from trackio.utils import MEDIA_DIR
185
+
186
+ absolute_path = MEDIA_DIR / file_path
187
+ self._queue_upload(absolute_path, step)
188
+
189
+ def log(self, metrics: dict, step: int | None = None):
190
+ renamed_keys = []
191
+ new_metrics = {}
192
+
193
+ for k, v in metrics.items():
194
+ if k in utils.RESERVED_KEYS or k.startswith("__"):
195
+ new_key = f"__{k}"
196
+ renamed_keys.append(k)
197
+ new_metrics[new_key] = v
198
+ else:
199
+ new_metrics[k] = v
200
+
201
+ if renamed_keys:
202
+ warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'")
203
+
204
+ metrics = new_metrics
205
+ for key, value in metrics.items():
206
+ if isinstance(value, Table):
207
+ metrics[key] = value._to_dict(
208
+ project=self.project, run=self.name, step=step
209
+ )
210
+ self._scan_and_queue_media_uploads(metrics[key], step)
211
+ elif isinstance(value, Histogram):
212
+ metrics[key] = value._to_dict()
213
+ elif isinstance(value, TrackioMedia):
214
+ metrics[key] = self._process_media(value, step)
215
+ metrics = utils.serialize_values(metrics)
216
+
217
+ config_to_log = None
218
+ if not self._config_logged and self.config:
219
+ config_to_log = utils.to_json_safe(self.config)
220
+ self._config_logged = True
221
+
222
+ log_entry: LogEntry = {
223
+ "project": self.project,
224
+ "run": self.name,
225
+ "metrics": metrics,
226
+ "step": step,
227
+ "config": config_to_log,
228
+ }
229
+
230
+ with self._client_lock:
231
+ self._queued_logs.append(log_entry)
232
+
233
+ def finish(self):
234
+ """Cleanup when run is finished."""
235
+ self._stop_flag.set()
236
+
237
+ time.sleep(2 * BATCH_SEND_INTERVAL)
238
+
239
+ if self._client_thread is not None:
240
+ print("* Run finished. Uploading logs to Trackio (please wait...)")
241
+ self._client_thread.join()
sqlite_storage.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import sqlite3
4
+ import time
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from threading import Lock
8
+
9
+ try:
10
+ import fcntl
11
+ except ImportError: # fcntl is not available on Windows
12
+ fcntl = None
13
+
14
+ import huggingface_hub as hf
15
+ import orjson
16
+ import pandas as pd
17
+
18
+ try: # absolute imports when installed from PyPI
19
+ from trackio.commit_scheduler import CommitScheduler
20
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
21
+ from trackio.utils import (
22
+ TRACKIO_DIR,
23
+ deserialize_values,
24
+ serialize_values,
25
+ )
26
+ except ImportError: # relative imports when installed from source on Spaces
27
+ from commit_scheduler import CommitScheduler
28
+ from dummy_commit_scheduler import DummyCommitScheduler
29
+ from utils import TRACKIO_DIR, deserialize_values, serialize_values
30
+
31
+ DB_EXT = ".db"
32
+
33
+
34
+ class ProcessLock:
35
+ """A file-based lock that works across processes. Is a no-op on Windows."""
36
+
37
+ def __init__(self, lockfile_path: Path):
38
+ self.lockfile_path = lockfile_path
39
+ self.lockfile = None
40
+ self.is_windows = platform.system() == "Windows"
41
+
42
+ def __enter__(self):
43
+ """Acquire the lock with retry logic."""
44
+ if self.is_windows:
45
+ return self
46
+ self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
47
+ self.lockfile = open(self.lockfile_path, "w")
48
+
49
+ max_retries = 100
50
+ for attempt in range(max_retries):
51
+ try:
52
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
53
+ return self
54
+ except IOError:
55
+ if attempt < max_retries - 1:
56
+ time.sleep(0.1)
57
+ else:
58
+ raise IOError("Could not acquire database lock after 10 seconds")
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ """Release the lock."""
62
+ if self.is_windows:
63
+ return
64
+
65
+ if self.lockfile:
66
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
67
+ self.lockfile.close()
68
+
69
+
70
+ class SQLiteStorage:
71
+ _dataset_import_attempted = False
72
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
73
+ _scheduler_lock = Lock()
74
+
75
+ @staticmethod
76
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
77
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
78
+ # Keep WAL for concurrency + performance on many small writes
79
+ conn.execute("PRAGMA journal_mode = WAL")
80
+ # ---- Minimal perf tweaks for many tiny transactions ----
81
+ # NORMAL = fsync at critical points only (safer than OFF, much faster than FULL)
82
+ conn.execute("PRAGMA synchronous = NORMAL")
83
+ # Keep temp data in memory to avoid disk hits during small writes
84
+ conn.execute("PRAGMA temp_store = MEMORY")
85
+ # Give SQLite a bit more room for cache (negative = KB, engine-managed)
86
+ conn.execute("PRAGMA cache_size = -20000")
87
+ # --------------------------------------------------------
88
+ conn.row_factory = sqlite3.Row
89
+ return conn
90
+
91
+ @staticmethod
92
+ def _get_process_lock(project: str) -> ProcessLock:
93
+ lockfile_path = TRACKIO_DIR / f"{project}.lock"
94
+ return ProcessLock(lockfile_path)
95
+
96
+ @staticmethod
97
+ def get_project_db_filename(project: str) -> str:
98
+ """Get the database filename for a specific project."""
99
+ safe_project_name = "".join(
100
+ c for c in project if c.isalnum() or c in ("-", "_")
101
+ ).rstrip()
102
+ if not safe_project_name:
103
+ safe_project_name = "default"
104
+ return f"{safe_project_name}{DB_EXT}"
105
+
106
+ @staticmethod
107
+ def get_project_db_path(project: str) -> Path:
108
+ """Get the database path for a specific project."""
109
+ filename = SQLiteStorage.get_project_db_filename(project)
110
+ return TRACKIO_DIR / filename
111
+
112
+ @staticmethod
113
+ def init_db(project: str) -> Path:
114
+ """
115
+ Initialize the SQLite database with required tables.
116
+ Returns the database path.
117
+ """
118
+ db_path = SQLiteStorage.get_project_db_path(project)
119
+ db_path.parent.mkdir(parents=True, exist_ok=True)
120
+ with SQLiteStorage._get_process_lock(project):
121
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
122
+ conn.execute("PRAGMA journal_mode = WAL")
123
+ conn.execute("PRAGMA synchronous = NORMAL")
124
+ conn.execute("PRAGMA temp_store = MEMORY")
125
+ conn.execute("PRAGMA cache_size = -20000")
126
+ cursor = conn.cursor()
127
+ cursor.execute(
128
+ """
129
+ CREATE TABLE IF NOT EXISTS metrics (
130
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
131
+ timestamp TEXT NOT NULL,
132
+ run_name TEXT NOT NULL,
133
+ step INTEGER NOT NULL,
134
+ metrics TEXT NOT NULL
135
+ )
136
+ """
137
+ )
138
+ cursor.execute(
139
+ """
140
+ CREATE TABLE IF NOT EXISTS configs (
141
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
142
+ run_name TEXT NOT NULL,
143
+ config TEXT NOT NULL,
144
+ created_at TEXT NOT NULL,
145
+ UNIQUE(run_name)
146
+ )
147
+ """
148
+ )
149
+ cursor.execute(
150
+ """
151
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
152
+ ON metrics(run_name, step)
153
+ """
154
+ )
155
+ cursor.execute(
156
+ """
157
+ CREATE INDEX IF NOT EXISTS idx_configs_run_name
158
+ ON configs(run_name)
159
+ """
160
+ )
161
+ cursor.execute(
162
+ """
163
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_timestamp
164
+ ON metrics(run_name, timestamp)
165
+ """
166
+ )
167
+ conn.commit()
168
+ return db_path
169
+
170
+ @staticmethod
171
+ def export_to_parquet():
172
+ """
173
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
174
+ """
175
+ # don't attempt to export (potentially wrong/blank) data before importing for the first time
176
+ if not SQLiteStorage._dataset_import_attempted:
177
+ return
178
+ if not TRACKIO_DIR.exists():
179
+ return
180
+
181
+ all_paths = os.listdir(TRACKIO_DIR)
182
+ db_names = [f for f in all_paths if f.endswith(DB_EXT)]
183
+ for db_name in db_names:
184
+ db_path = TRACKIO_DIR / db_name
185
+ parquet_path = db_path.with_suffix(".parquet")
186
+ if (not parquet_path.exists()) or (
187
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
188
+ ):
189
+ with sqlite3.connect(str(db_path)) as conn:
190
+ df = pd.read_sql("SELECT * FROM metrics", conn)
191
+ # break out the single JSON metrics column into individual columns
192
+ metrics = df["metrics"].copy()
193
+ metrics = pd.DataFrame(
194
+ metrics.apply(
195
+ lambda x: deserialize_values(orjson.loads(x))
196
+ ).values.tolist(),
197
+ index=df.index,
198
+ )
199
+ del df["metrics"]
200
+ for col in metrics.columns:
201
+ df[col] = metrics[col]
202
+
203
+ df.to_parquet(parquet_path)
204
+
205
+ @staticmethod
206
+ def _cleanup_wal_sidecars(db_path: Path) -> None:
207
+ """Remove leftover -wal/-shm files for a DB basename (prevents disk I/O errors)."""
208
+ for suffix in ("-wal", "-shm"):
209
+ sidecar = Path(str(db_path) + suffix)
210
+ try:
211
+ if sidecar.exists():
212
+ sidecar.unlink()
213
+ except Exception:
214
+ pass
215
+
216
+ @staticmethod
217
+ def import_from_parquet():
218
+ """
219
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
220
+ """
221
+ if not TRACKIO_DIR.exists():
222
+ return
223
+
224
+ all_paths = os.listdir(TRACKIO_DIR)
225
+ parquet_names = [f for f in all_paths if f.endswith(".parquet")]
226
+ for pq_name in parquet_names:
227
+ parquet_path = TRACKIO_DIR / pq_name
228
+ db_path = parquet_path.with_suffix(DB_EXT)
229
+
230
+ SQLiteStorage._cleanup_wal_sidecars(db_path)
231
+
232
+ df = pd.read_parquet(parquet_path)
233
+ # fix up df to have a single JSON metrics column
234
+ if "metrics" not in df.columns:
235
+ # separate other columns from metrics
236
+ metrics = df.copy()
237
+ other_cols = ["id", "timestamp", "run_name", "step"]
238
+ df = df[other_cols]
239
+ for col in other_cols:
240
+ del metrics[col]
241
+ # combine them all into a single metrics col
242
+ metrics = orjson.loads(metrics.to_json(orient="records"))
243
+ df["metrics"] = [orjson.dumps(serialize_values(row)) for row in metrics]
244
+
245
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
246
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
247
+ conn.commit()
248
+
249
+ @staticmethod
250
+ def get_scheduler():
251
+ """
252
+ Get the scheduler for the database based on the environment variables.
253
+ This applies to both local and Spaces.
254
+ """
255
+ with SQLiteStorage._scheduler_lock:
256
+ if SQLiteStorage._current_scheduler is not None:
257
+ return SQLiteStorage._current_scheduler
258
+ hf_token = os.environ.get("HF_TOKEN")
259
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
260
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
261
+ if dataset_id is None or space_repo_name is None:
262
+ scheduler = DummyCommitScheduler()
263
+ else:
264
+ scheduler = CommitScheduler(
265
+ repo_id=dataset_id,
266
+ repo_type="dataset",
267
+ folder_path=TRACKIO_DIR,
268
+ private=True,
269
+ allow_patterns=["*.parquet", "media/**/*"],
270
+ squash_history=True,
271
+ token=hf_token,
272
+ on_before_commit=SQLiteStorage.export_to_parquet,
273
+ )
274
+ SQLiteStorage._current_scheduler = scheduler
275
+ return scheduler
276
+
277
+ @staticmethod
278
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
279
+ """
280
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
281
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
282
+ database locking errors when multiple processes access the same database.
283
+
284
+ This method is not used in the latest versions of Trackio (replaced by bulk_log) but
285
+ is kept for backwards compatibility for users who are connecting to a newer version of
286
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
287
+ """
288
+ db_path = SQLiteStorage.init_db(project)
289
+ with SQLiteStorage._get_process_lock(project):
290
+ with SQLiteStorage._get_connection(db_path) as conn:
291
+ cursor = conn.cursor()
292
+ cursor.execute(
293
+ """
294
+ SELECT MAX(step)
295
+ FROM metrics
296
+ WHERE run_name = ?
297
+ """,
298
+ (run,),
299
+ )
300
+ last_step = cursor.fetchone()[0]
301
+ current_step = (
302
+ 0
303
+ if step is None and last_step is None
304
+ else (step if step is not None else last_step + 1)
305
+ )
306
+ current_timestamp = datetime.now().isoformat()
307
+ cursor.execute(
308
+ """
309
+ INSERT INTO metrics
310
+ (timestamp, run_name, step, metrics)
311
+ VALUES (?, ?, ?, ?)
312
+ """,
313
+ (
314
+ current_timestamp,
315
+ run,
316
+ current_step,
317
+ orjson.dumps(serialize_values(metrics)),
318
+ ),
319
+ )
320
+ conn.commit()
321
+
322
+ @staticmethod
323
+ def bulk_log(
324
+ project: str,
325
+ run: str,
326
+ metrics_list: list[dict],
327
+ steps: list[int] | None = None,
328
+ timestamps: list[str] | None = None,
329
+ config: dict | None = None,
330
+ ):
331
+ """
332
+ Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
333
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
334
+ database locking errors when multiple processes access the same database.
335
+ """
336
+ if not metrics_list:
337
+ return
338
+
339
+ if timestamps is None:
340
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
341
+
342
+ db_path = SQLiteStorage.init_db(project)
343
+ with SQLiteStorage._get_process_lock(project):
344
+ with SQLiteStorage._get_connection(db_path) as conn:
345
+ cursor = conn.cursor()
346
+
347
+ if steps is None:
348
+ steps = list(range(len(metrics_list)))
349
+ elif any(s is None for s in steps):
350
+ cursor.execute(
351
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
352
+ )
353
+ last_step = cursor.fetchone()[0]
354
+ current_step = 0 if last_step is None else last_step + 1
355
+ processed_steps = []
356
+ for step in steps:
357
+ if step is None:
358
+ processed_steps.append(current_step)
359
+ current_step += 1
360
+ else:
361
+ processed_steps.append(step)
362
+ steps = processed_steps
363
+
364
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
365
+ timestamps
366
+ ):
367
+ raise ValueError(
368
+ "metrics_list, steps, and timestamps must have the same length"
369
+ )
370
+
371
+ data = []
372
+ for i, metrics in enumerate(metrics_list):
373
+ data.append(
374
+ (
375
+ timestamps[i],
376
+ run,
377
+ steps[i],
378
+ orjson.dumps(serialize_values(metrics)),
379
+ )
380
+ )
381
+
382
+ cursor.executemany(
383
+ """
384
+ INSERT INTO metrics
385
+ (timestamp, run_name, step, metrics)
386
+ VALUES (?, ?, ?, ?)
387
+ """,
388
+ data,
389
+ )
390
+
391
+ if config:
392
+ current_timestamp = datetime.now().isoformat()
393
+ cursor.execute(
394
+ """
395
+ INSERT OR REPLACE INTO configs
396
+ (run_name, config, created_at)
397
+ VALUES (?, ?, ?)
398
+ """,
399
+ (
400
+ run,
401
+ orjson.dumps(serialize_values(config)),
402
+ current_timestamp,
403
+ ),
404
+ )
405
+
406
+ conn.commit()
407
+
408
+ @staticmethod
409
+ def get_logs(project: str, run: str) -> list[dict]:
410
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
411
+ db_path = SQLiteStorage.get_project_db_path(project)
412
+ if not db_path.exists():
413
+ return []
414
+
415
+ with SQLiteStorage._get_connection(db_path) as conn:
416
+ cursor = conn.cursor()
417
+ cursor.execute(
418
+ """
419
+ SELECT timestamp, step, metrics
420
+ FROM metrics
421
+ WHERE run_name = ?
422
+ ORDER BY timestamp
423
+ """,
424
+ (run,),
425
+ )
426
+
427
+ rows = cursor.fetchall()
428
+ results = []
429
+ for row in rows:
430
+ metrics = orjson.loads(row["metrics"])
431
+ metrics = deserialize_values(metrics)
432
+ metrics["timestamp"] = row["timestamp"]
433
+ metrics["step"] = row["step"]
434
+ results.append(metrics)
435
+ return results
436
+
437
+ @staticmethod
438
+ def load_from_dataset():
439
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
440
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
441
+ if dataset_id is not None and space_repo_name is not None:
442
+ hfapi = hf.HfApi()
443
+ updated = False
444
+ if not TRACKIO_DIR.exists():
445
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
446
+ with SQLiteStorage.get_scheduler().lock:
447
+ try:
448
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
449
+ for file in files:
450
+ # Download parquet and media assets
451
+ if not (file.endswith(".parquet") or file.startswith("media/")):
452
+ continue
453
+ if (TRACKIO_DIR / file).exists():
454
+ continue
455
+ hf.hf_hub_download(
456
+ dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
457
+ )
458
+ updated = True
459
+ except hf.errors.EntryNotFoundError:
460
+ pass
461
+ except hf.errors.RepositoryNotFoundError:
462
+ pass
463
+ if updated:
464
+ SQLiteStorage.import_from_parquet()
465
+ SQLiteStorage._dataset_import_attempted = True
466
+
467
+ @staticmethod
468
+ def get_projects() -> list[str]:
469
+ """
470
+ Get list of all projects by scanning the database files in the trackio directory.
471
+ """
472
+ if not SQLiteStorage._dataset_import_attempted:
473
+ SQLiteStorage.load_from_dataset()
474
+
475
+ projects: set[str] = set()
476
+ if not TRACKIO_DIR.exists():
477
+ return []
478
+
479
+ for db_file in TRACKIO_DIR.glob(f"*{DB_EXT}"):
480
+ project_name = db_file.stem
481
+ projects.add(project_name)
482
+ return sorted(projects)
483
+
484
+ @staticmethod
485
+ def get_runs(project: str) -> list[str]:
486
+ """Get list of all runs for a project."""
487
+ db_path = SQLiteStorage.get_project_db_path(project)
488
+ if not db_path.exists():
489
+ return []
490
+
491
+ with SQLiteStorage._get_connection(db_path) as conn:
492
+ cursor = conn.cursor()
493
+ cursor.execute(
494
+ "SELECT DISTINCT run_name FROM metrics",
495
+ )
496
+ return [row[0] for row in cursor.fetchall()]
497
+
498
+ @staticmethod
499
+ def get_max_steps_for_runs(project: str) -> dict[str, int]:
500
+ """Get the maximum step for each run in a project."""
501
+ db_path = SQLiteStorage.get_project_db_path(project)
502
+ if not db_path.exists():
503
+ return {}
504
+
505
+ with SQLiteStorage._get_connection(db_path) as conn:
506
+ cursor = conn.cursor()
507
+ cursor.execute(
508
+ """
509
+ SELECT run_name, MAX(step) as max_step
510
+ FROM metrics
511
+ GROUP BY run_name
512
+ """
513
+ )
514
+
515
+ results = {}
516
+ for row in cursor.fetchall():
517
+ results[row["run_name"]] = row["max_step"]
518
+
519
+ return results
520
+
521
+ @staticmethod
522
+ def store_config(project: str, run: str, config: dict) -> None:
523
+ """Store configuration for a run."""
524
+ db_path = SQLiteStorage.init_db(project)
525
+
526
+ with SQLiteStorage._get_process_lock(project):
527
+ with SQLiteStorage._get_connection(db_path) as conn:
528
+ cursor = conn.cursor()
529
+ current_timestamp = datetime.now().isoformat()
530
+
531
+ cursor.execute(
532
+ """
533
+ INSERT OR REPLACE INTO configs
534
+ (run_name, config, created_at)
535
+ VALUES (?, ?, ?)
536
+ """,
537
+ (run, orjson.dumps(serialize_values(config)), current_timestamp),
538
+ )
539
+ conn.commit()
540
+
541
+ @staticmethod
542
+ def get_run_config(project: str, run: str) -> dict | None:
543
+ """Get configuration for a specific run."""
544
+ db_path = SQLiteStorage.get_project_db_path(project)
545
+ if not db_path.exists():
546
+ return None
547
+
548
+ with SQLiteStorage._get_connection(db_path) as conn:
549
+ cursor = conn.cursor()
550
+ try:
551
+ cursor.execute(
552
+ """
553
+ SELECT config FROM configs WHERE run_name = ?
554
+ """,
555
+ (run,),
556
+ )
557
+
558
+ row = cursor.fetchone()
559
+ if row:
560
+ config = orjson.loads(row["config"])
561
+ return deserialize_values(config)
562
+ return None
563
+ except sqlite3.OperationalError as e:
564
+ if "no such table: configs" in str(e):
565
+ return None
566
+ raise
567
+
568
+ @staticmethod
569
+ def delete_run(project: str, run: str) -> bool:
570
+ """Delete a run from the database (both metrics and config)."""
571
+ db_path = SQLiteStorage.get_project_db_path(project)
572
+ if not db_path.exists():
573
+ return False
574
+
575
+ with SQLiteStorage._get_process_lock(project):
576
+ with SQLiteStorage._get_connection(db_path) as conn:
577
+ cursor = conn.cursor()
578
+ try:
579
+ cursor.execute("DELETE FROM metrics WHERE run_name = ?", (run,))
580
+ cursor.execute("DELETE FROM configs WHERE run_name = ?", (run,))
581
+ conn.commit()
582
+ return True
583
+ except sqlite3.Error:
584
+ return False
585
+
586
+ @staticmethod
587
+ def get_all_run_configs(project: str) -> dict[str, dict]:
588
+ """Get configurations for all runs in a project."""
589
+ db_path = SQLiteStorage.get_project_db_path(project)
590
+ if not db_path.exists():
591
+ return {}
592
+
593
+ with SQLiteStorage._get_connection(db_path) as conn:
594
+ cursor = conn.cursor()
595
+ try:
596
+ cursor.execute(
597
+ """
598
+ SELECT run_name, config FROM configs
599
+ """
600
+ )
601
+
602
+ results = {}
603
+ for row in cursor.fetchall():
604
+ config = orjson.loads(row["config"])
605
+ results[row["run_name"]] = deserialize_values(config)
606
+ return results
607
+ except sqlite3.OperationalError as e:
608
+ if "no such table: configs" in str(e):
609
+ return {}
610
+ raise
611
+
612
+ @staticmethod
613
+ def get_metric_values(project: str, run: str, metric_name: str) -> list[dict]:
614
+ """Get all values for a specific metric in a project/run."""
615
+ db_path = SQLiteStorage.get_project_db_path(project)
616
+ if not db_path.exists():
617
+ return []
618
+
619
+ with SQLiteStorage._get_connection(db_path) as conn:
620
+ cursor = conn.cursor()
621
+ cursor.execute(
622
+ """
623
+ SELECT timestamp, step, metrics
624
+ FROM metrics
625
+ WHERE run_name = ?
626
+ ORDER BY timestamp
627
+ """,
628
+ (run,),
629
+ )
630
+
631
+ rows = cursor.fetchall()
632
+ results = []
633
+ for row in rows:
634
+ metrics = orjson.loads(row["metrics"])
635
+ metrics = deserialize_values(metrics)
636
+ if metric_name in metrics:
637
+ results.append(
638
+ {
639
+ "timestamp": row["timestamp"],
640
+ "step": row["step"],
641
+ "value": metrics[metric_name],
642
+ }
643
+ )
644
+ return results
645
+
646
+ @staticmethod
647
+ def get_all_metrics_for_run(project: str, run: str) -> list[str]:
648
+ """Get all metric names for a specific project/run."""
649
+ db_path = SQLiteStorage.get_project_db_path(project)
650
+ if not db_path.exists():
651
+ return []
652
+
653
+ with SQLiteStorage._get_connection(db_path) as conn:
654
+ cursor = conn.cursor()
655
+ cursor.execute(
656
+ """
657
+ SELECT metrics
658
+ FROM metrics
659
+ WHERE run_name = ?
660
+ ORDER BY timestamp
661
+ """,
662
+ (run,),
663
+ )
664
+
665
+ rows = cursor.fetchall()
666
+ all_metrics = set()
667
+ for row in rows:
668
+ metrics = orjson.loads(row["metrics"])
669
+ metrics = deserialize_values(metrics)
670
+ for key in metrics.keys():
671
+ if key not in ["timestamp", "step"]:
672
+ all_metrics.add(key)
673
+ return sorted(list(all_metrics))
674
+
675
+ def finish(self):
676
+ """Cleanup when run is finished."""
677
+ pass
table.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Literal
3
+
4
+ from pandas import DataFrame
5
+
6
+ try:
7
+ from trackio.media.media import TrackioMedia
8
+ from trackio.utils import MEDIA_DIR
9
+ except ImportError:
10
+ from media.media import TrackioMedia
11
+ from utils import MEDIA_DIR
12
+
13
+
14
+ class Table:
15
+ """
16
+ Initializes a Table object.
17
+
18
+ Tables can be used to log tabular data including images, numbers, and text.
19
+
20
+ Args:
21
+ columns (`list[str]`, *optional*):
22
+ Names of the columns in the table. Optional if `data` is provided. Not
23
+ expected if `dataframe` is provided. Currently ignored.
24
+ data (`list[list[Any]]`, *optional*):
25
+ 2D row-oriented array of values. Each value can be a number, a string
26
+ (treated as Markdown and truncated if too long), or a `Trackio.Image` or
27
+ list of `Trackio.Image` objects.
28
+ dataframe (`pandas.DataFrame`, *optional*):
29
+ DataFrame used to create the table. When set, `data` and `columns`
30
+ arguments are ignored.
31
+ rows (`list[list[Any]]`, *optional*):
32
+ Currently ignored.
33
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
34
+ Currently ignored.
35
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
36
+ Currently ignored.
37
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
38
+ Currently ignored.
39
+ """
40
+
41
+ TYPE = "trackio.table"
42
+
43
+ def __init__(
44
+ self,
45
+ columns: list[str] | None = None,
46
+ data: list[list[Any]] | None = None,
47
+ dataframe: DataFrame | None = None,
48
+ rows: list[list[Any]] | None = None,
49
+ optional: bool | list[bool] = True,
50
+ allow_mixed_types: bool = False,
51
+ log_mode: Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] | None = "IMMUTABLE",
52
+ ):
53
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
54
+ # for now (like `rows`) they are included for API compat but don't do anything.
55
+ if dataframe is None:
56
+ self.data = DataFrame(data) if data is not None else DataFrame()
57
+ else:
58
+ self.data = dataframe
59
+
60
+ def _has_media_objects(self, dataframe: DataFrame) -> bool:
61
+ """Check if dataframe contains any TrackioMedia objects or lists of TrackioMedia objects."""
62
+ for col in dataframe.columns:
63
+ if dataframe[col].apply(lambda x: isinstance(x, TrackioMedia)).any():
64
+ return True
65
+ if (
66
+ dataframe[col]
67
+ .apply(
68
+ lambda x: isinstance(x, list)
69
+ and len(x) > 0
70
+ and isinstance(x[0], TrackioMedia)
71
+ )
72
+ .any()
73
+ ):
74
+ return True
75
+ return False
76
+
77
+ def _process_data(self, project: str, run: str, step: int = 0):
78
+ """Convert dataframe to dict format, processing any TrackioMedia objects if present."""
79
+ df = self.data
80
+ if not self._has_media_objects(df):
81
+ return df.to_dict(orient="records")
82
+
83
+ processed_df = df.copy()
84
+ for col in processed_df.columns:
85
+ for idx in processed_df.index:
86
+ value = processed_df.at[idx, col]
87
+ if isinstance(value, TrackioMedia):
88
+ value._save(project, run, step)
89
+ processed_df.at[idx, col] = value._to_dict()
90
+ if (
91
+ isinstance(value, list)
92
+ and len(value) > 0
93
+ and isinstance(value[0], TrackioMedia)
94
+ ):
95
+ [v._save(project, run, step) for v in value]
96
+ processed_df.at[idx, col] = [v._to_dict() for v in value]
97
+
98
+ return processed_df.to_dict(orient="records")
99
+
100
+ @staticmethod
101
+ def to_display_format(table_data: list[dict]) -> list[dict]:
102
+ """
103
+ Converts stored table data to display format for UI rendering.
104
+
105
+ Note:
106
+ This does not use the `self.data` attribute, but instead uses the
107
+ `table_data` parameter, which is what the UI receives.
108
+
109
+ Args:
110
+ table_data (`list[dict]`):
111
+ List of dictionaries representing table rows (from stored `_value`).
112
+
113
+ Returns:
114
+ `list[dict]`: Table data with images converted to markdown syntax and long
115
+ text truncated.
116
+ """
117
+ truncate_length = int(os.getenv("TRACKIO_TABLE_TRUNCATE_LENGTH", "250"))
118
+
119
+ def convert_image_to_markdown(image_data: dict) -> str:
120
+ relative_path = image_data.get("file_path", "")
121
+ caption = image_data.get("caption", "")
122
+ absolute_path = MEDIA_DIR / relative_path
123
+ return f'<img src="/gradio_api/file={absolute_path}" alt="{caption}" />'
124
+
125
+ processed_data = []
126
+ for row in table_data:
127
+ processed_row = {}
128
+ for key, value in row.items():
129
+ if isinstance(value, dict) and value.get("_type") == "trackio.image":
130
+ processed_row[key] = convert_image_to_markdown(value)
131
+ elif (
132
+ isinstance(value, list)
133
+ and len(value) > 0
134
+ and isinstance(value[0], dict)
135
+ and value[0].get("_type") == "trackio.image"
136
+ ):
137
+ # This assumes that if the first item is an image, all items are images. Ok for now since we don't support mixed types in a single cell.
138
+ processed_row[key] = (
139
+ '<div style="display: flex; gap: 10px;">'
140
+ + "".join([convert_image_to_markdown(item) for item in value])
141
+ + "</div>"
142
+ )
143
+ elif isinstance(value, str) and len(value) > truncate_length:
144
+ truncated = value[:truncate_length]
145
+ full_text = value.replace("<", "&lt;").replace(">", "&gt;")
146
+ processed_row[key] = (
147
+ f'<details style="display: inline;">'
148
+ f'<summary style="display: inline; cursor: pointer;">{truncated}…<span><em>(truncated, click to expand)</em></span></summary>'
149
+ f'<div style="margin-top: 10px; padding: 10px; background: #f5f5f5; border-radius: 4px; max-height: 400px; overflow: auto;">'
150
+ f'<pre style="white-space: pre-wrap; word-wrap: break-word; margin: 0;">{full_text}</pre>'
151
+ f"</div>"
152
+ f"</details>"
153
+ )
154
+ else:
155
+ processed_row[key] = value
156
+ processed_data.append(processed_row)
157
+ return processed_data
158
+
159
+ def _to_dict(self, project: str, run: str, step: int = 0):
160
+ """
161
+ Converts the table to a dictionary representation.
162
+
163
+ Args:
164
+ project (`str`):
165
+ Project name for saving media files.
166
+ run (`str`):
167
+ Run name for saving media files.
168
+ step (`int`, *optional*, defaults to `0`):
169
+ Step number for saving media files.
170
+ """
171
+ data = self._process_data(project, run, step)
172
+ return {
173
+ "_type": self.TYPE,
174
+ "_value": data,
175
+ }
typehints.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypedDict
2
+
3
+ from gradio import FileData
4
+
5
+
6
+ class LogEntry(TypedDict):
7
+ project: str
8
+ run: str
9
+ metrics: dict[str, Any]
10
+ step: int | None
11
+ config: dict[str, Any] | None
12
+
13
+
14
+ class UploadEntry(TypedDict):
15
+ project: str
16
+ run: str | None
17
+ step: int | None
18
+ relative_path: str | None
19
+ uploaded_file: FileData
ui/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from trackio.ui.main import demo
3
+ from trackio.ui.run_detail import run_detail_page
4
+ from trackio.ui.runs import run_page
5
+ except ImportError:
6
+ from ui.main import demo
7
+ from ui.run_detail import run_detail_page
8
+ from ui.runs import run_page
9
+
10
+ __all__ = ["demo", "run_page", "run_detail_page"]