imone commited on
Commit
bd62227
·
1 Parent(s): caa00bb
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WandB
2
+ /wandb/
3
+ # checkpoints
4
+ /checkpoints/
5
+ # cache
6
+ /cache/
7
+ # data
8
+ /data/
9
+
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/#use-with-ide
119
+ .pdm.toml
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
.vscode/launch.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: Current File",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "console": "integratedTerminal"
13
+ },
14
+ {
15
+ "name": "Debug: Single GPU",
16
+ "type": "debugpy",
17
+ "request": "launch",
18
+ "program": "pretrain.py",
19
+ "args": [],
20
+ "env": {
21
+ "OMP_NUM_THREADS": "1",
22
+ "DISABLE_COMPILE": "true"
23
+ }
24
+ }
25
+ ]
26
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.analysis.typeCheckingMode": "standard"
3
+ }
README.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hierarchical Reasoning Model
2
+
3
+ ![](./assets/hrm.png)
4
+
5
+ Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
6
+ Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
7
+ HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
8
+ Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
9
+ These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.
10
+
11
+ ## Quick Start Guide 🚀
12
+
13
+ ### Prerequisites ⚙️
14
+
15
+ Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
16
+
17
+ ```bash
18
+ # Install CUDA 12.4
19
+ CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run
20
+
21
+ wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
22
+ sudo sh cuda_installer.run --silent --toolkit --override
23
+
24
+ export CUDA_HOME=/usr/local/cuda-12.4
25
+
26
+ # Install PyTorch with CUDA 12.4
27
+ PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu124
28
+
29
+ pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
30
+
31
+ # Additional packages for building extensions
32
+ pip3 install packaging ninja wheel setuptools setuptools-scm
33
+ ```
34
+
35
+ ## Install Python Dependencies 🐍
36
+
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ ## W&B Integration 📈
42
+
43
+ This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:
44
+
45
+ ```bash
46
+ wandb login
47
+ ```
48
+
49
+ ## Run Experiments
50
+
51
+ ### Quick Demo: Sudoku Solver 💻🗲
52
+
53
+ Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩
54
+
55
+ ```bash
56
+ # Download and build Sudoku dataset
57
+ python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
58
+
59
+ # Start training (single GPU, smaller batch size)
60
+ OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
61
+ ```
62
+
63
+ Runtime: ~10 hours on a RTX 4070 laptop GPU
64
+
65
+ ## Full-scale Experiments 🔵
66
+
67
+ Experiments below assume an 8-GPU setup.
68
+
69
+ ### Dataset Preparation
70
+
71
+ ```bash
72
+ # Initialize submodules
73
+ git submodule update --init --recursive
74
+
75
+ # ARC-1
76
+ python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples
77
+ # ARC-2
78
+ python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples
79
+
80
+ # Sudoku-Extreme
81
+ python dataset/build_sudoku_dataset.py # Full version
82
+ python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples
83
+
84
+ # Maze
85
+ python dataset/build_maze_dataset.py # 1000 examples
86
+ ```
87
+
88
+ ### Dataset Visualization
89
+
90
+ Explore the puzzles visually:
91
+
92
+ * Open `puzzle_visualizer.html` in your browser.
93
+ * Upload the generated dataset folder located in `data/...`.
94
+
95
+ ## Launch experiments
96
+
97
+ ### Small-sample (1K)
98
+
99
+ ARC-1:
100
+
101
+ ```bash
102
+ OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
103
+ ```
104
+
105
+ *Runtime:* ~24 hours
106
+
107
+ ARC-2:
108
+
109
+ ```bash
110
+ OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
111
+ ```
112
+
113
+ *Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)
114
+
115
+ Sudoku Extreme (1k):
116
+
117
+ ```bash
118
+ OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
119
+ ```
120
+
121
+ *Runtime:* ~10 minutes
122
+
123
+ Maze 30x30 Hard (1k):
124
+
125
+ ```bash
126
+ OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
127
+ ```
128
+
129
+ *Runtime:* ~1 hour
130
+
131
+ ### Full Sudoku-Hard
132
+
133
+ ```bash
134
+ OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
135
+ ```
136
+
137
+ *Runtime:* ~2 hours
138
+
139
+ ## Evaluation
140
+
141
+ Evaluate your trained models:
142
+
143
+ * Check `eval/exact_accuracy` in W&B.
144
+ * For ARC-AGI, follow these additional steps:
145
+
146
+ ```bash
147
+ OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
148
+ ```
149
+
150
+ * Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.
151
+
152
+ ## Notes
153
+
154
+ - Small-sample learning typically exhibits accuracy variance of around ±2 points.
155
+ - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.
156
+
157
+ ## Citation 📜
158
+
159
+ ```
160
+ @misc{wang2025hierarchicalreasoningmodel,
161
+ title={Hierarchical Reasoning Model},
162
+ author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
163
+ year={2025},
164
+ eprint={2506.21734},
165
+ archivePrefix={arXiv},
166
+ primaryClass={cs.AI},
167
+ url={https://arxiv.org/abs/2506.21734},
168
+ }
169
+ ```
arc_eval.ipynb ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "from glob import glob\n",
12
+ "import hashlib\n",
13
+ "import matplotlib.pyplot as plt\n",
14
+ "import matplotlib.colors as mcolors\n",
15
+ "\n",
16
+ "import torch\n",
17
+ "import torch.nn.functional as F\n",
18
+ "import numpy as np\n",
19
+ "from numba import njit\n",
20
+ "\n",
21
+ "from dataset.common import inverse_dihedral_transform\n",
22
+ "\n",
23
+ "\n",
24
+ "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n",
25
+ "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n",
26
+ "\n",
27
+ "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
28
+ "\n",
29
+ "\n",
30
+ "PAD_PUZZLE_IDENTIFIER = 0\n",
31
+ "\n",
32
+ "# Visualization\n",
33
+ "ARC_COLOR_MAP = mcolors.ListedColormap([\n",
34
+ " \"#000000\", # symbol_0: black\n",
35
+ " \"#0074D9\", # symbol_1: blue\n",
36
+ " \"#FF4136\", # symbol_2: red\n",
37
+ " \"#2ECC40\", # symbol_3: green\n",
38
+ " \"#FFDC00\", # symbol_4: yellow\n",
39
+ " \"#AAAAAA\", # symbol_5: grey\n",
40
+ " \"#F012BE\", # symbol_6: fuschia\n",
41
+ " \"#FF851B\", # symbol_7: orange\n",
42
+ " \"#7FDBFF\", # symbol_8: teal\n",
43
+ " \"#870C25\" # symbol_9: brown\n",
44
+ "])"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
54
+ " # Load puzzle identifiers\n",
55
+ " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
56
+ " identifier_map = json.load(f)\n",
57
+ " \n",
58
+ " # Load preds\n",
59
+ " all_preds = {}\n",
60
+ " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
61
+ " preds = torch.load(filename)\n",
62
+ " for k, v in preds.items():\n",
63
+ " all_preds.setdefault(k, [])\n",
64
+ " all_preds[k].append(v)\n",
65
+ " \n",
66
+ " del preds\n",
67
+ "\n",
68
+ " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
69
+ " \n",
70
+ " # Remove paddings\n",
71
+ " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
72
+ " all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
73
+ "\n",
74
+ " return identifier_map, all_preds\n",
75
+ "\n",
76
+ "\n",
77
+ "def inverse_aug(name: str, grid: np.ndarray):\n",
78
+ " if \"_\" not in name:\n",
79
+ " return grid\n",
80
+ "\n",
81
+ " trans_id, perm = name.split(\"_\")[-2:]\n",
82
+ " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n",
83
+ " inv_perm = np.argsort(list(perm))\n",
84
+ " \n",
85
+ " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
86
+ "\n",
87
+ "\n",
88
+ "def grid_hash(grid: np.ndarray):\n",
89
+ " return hash((grid.tobytes(), grid.shape))\n",
90
+ "\n",
91
+ "\n",
92
+ "@njit\n",
93
+ "def crop(grid: np.ndarray):\n",
94
+ " # Find maximum-sized rectangle without any EOS token inside.\n",
95
+ " grid = grid.reshape(30, 30)\n",
96
+ "\n",
97
+ " max_area = 0\n",
98
+ " max_size = (0, 0)\n",
99
+ " nr, nc = grid.shape\n",
100
+ " \n",
101
+ " num_c = nc\n",
102
+ " for num_r in range(1, nr + 1):\n",
103
+ " # Scan for maximum c\n",
104
+ " for c in range(1, num_c + 1):\n",
105
+ " x = grid[num_r - 1, c - 1]\n",
106
+ " if (x < 2) | (x > 11):\n",
107
+ " num_c = c - 1\n",
108
+ " break\n",
109
+ " \n",
110
+ " area = num_r * num_c\n",
111
+ " if area > max_area:\n",
112
+ " max_area = area\n",
113
+ " max_size = (num_r, num_c)\n",
114
+ "\n",
115
+ " return grid[:max_size[0], :max_size[1]] - 2\n",
116
+ "\n",
117
+ "\n",
118
+ "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
119
+ " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
120
+ " \n",
121
+ " global_hmap = {}\n",
122
+ " \n",
123
+ " # Get puzzles and corresponding answers\n",
124
+ " puzzle_labels = {}\n",
125
+ " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
126
+ " name = identifier_map[identifier]\n",
127
+ " if \"_\" not in name: # Not-augmented\n",
128
+ " puzzle_labels.setdefault(name, {})\n",
129
+ " \n",
130
+ " input = crop(input.numpy())\n",
131
+ " label = crop(label.numpy())\n",
132
+ "\n",
133
+ " input_hash = grid_hash(input)\n",
134
+ " label_hash = grid_hash(label)\n",
135
+ "\n",
136
+ " global_hmap[input_hash] = input\n",
137
+ " global_hmap[label_hash] = label\n",
138
+ "\n",
139
+ " assert input_hash not in puzzle_labels[name]\n",
140
+ " puzzle_labels[name][input_hash] = label_hash\n",
141
+ " \n",
142
+ " print (\"Number of puzzles\", len(puzzle_labels))\n",
143
+ " \n",
144
+ " # Argmax prediction\n",
145
+ " preds = all_preds[\"logits\"].argmax(-1)\n",
146
+ "\n",
147
+ " # Collate\n",
148
+ " pred_answers = {}\n",
149
+ " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
150
+ " name = identifier_map[identifier]\n",
151
+ " orig_name = name.split(\"_\")[0]\n",
152
+ " \n",
153
+ " input = input.numpy()\n",
154
+ " input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
155
+ " assert input_hash in puzzle_labels[orig_name]\n",
156
+ " \n",
157
+ " pred = inverse_aug(name, crop(pred.numpy()))\n",
158
+ " pred_hash = grid_hash(pred)\n",
159
+ " global_hmap[pred_hash] = pred\n",
160
+ " \n",
161
+ " pred_answers.setdefault(orig_name, {})\n",
162
+ " pred_answers[orig_name].setdefault(input_hash, [])\n",
163
+ " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
164
+ "\n",
165
+ " # test-1\n",
166
+ " if visualize:\n",
167
+ " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
168
+ " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
169
+ " \n",
170
+ " fig_id = 0\n",
171
+ " \n",
172
+ " correct = [0 for _ in range(len(Ks))]\n",
173
+ " for name, tests in puzzle_labels.items():\n",
174
+ " num_test_correct = [0 for _ in range(len(Ks))]\n",
175
+ " for input_hash, label_hash in tests.items():\n",
176
+ " p = pred_answers[name][input_hash]\n",
177
+ " p_map = {}\n",
178
+ " \n",
179
+ " for h, q in p:\n",
180
+ " p_map.setdefault(h, [0, 0])\n",
181
+ " p_map[h][0] += 1\n",
182
+ " p_map[h][1] += q\n",
183
+ " \n",
184
+ " for h, stats in p_map.items():\n",
185
+ " stats[1] /= stats[0]\n",
186
+ " \n",
187
+ " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
188
+ "\n",
189
+ " # 2-vote\n",
190
+ " for i, k in enumerate(Ks):\n",
191
+ " ok = False\n",
192
+ " for h, stats in p_map[:k]:\n",
193
+ " ok |= h == label_hash\n",
194
+ " \n",
195
+ " num_test_correct[i] += ok\n",
196
+ "\n",
197
+ " if visualize:\n",
198
+ " # Show input and ground truth\n",
199
+ " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
200
+ " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
201
+ " axes[fig_id, 0].axis('off')\n",
202
+ " \n",
203
+ " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
204
+ " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
205
+ " axes[fig_id, 1].axis('off')\n",
206
+ " \n",
207
+ " trial_id = 2\n",
208
+ " for h, stats in p_map[:2]:\n",
209
+ " ans = global_hmap[h]\n",
210
+ " \n",
211
+ " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
212
+ " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
213
+ " axes[fig_id, trial_id].axis('off')\n",
214
+ " \n",
215
+ " trial_id += 1\n",
216
+ " \n",
217
+ " fig_id += 1\n",
218
+ " \n",
219
+ " # Total correctness\n",
220
+ " for i in range(len(Ks)):\n",
221
+ " correct[i] += num_test_correct[i] == len(tests)\n",
222
+ "\n",
223
+ " for i, k in enumerate(Ks):\n",
224
+ " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
225
+ "\n",
226
+ "\n",
227
+ "test(visualize=False)"
228
+ ]
229
+ }
230
+ ],
231
+ "metadata": {
232
+ "kernelspec": {
233
+ "display_name": "Python 3",
234
+ "language": "python",
235
+ "name": "python3"
236
+ },
237
+ "language_info": {
238
+ "codemirror_mode": {
239
+ "name": "ipython",
240
+ "version": 3
241
+ },
242
+ "file_extension": ".py",
243
+ "mimetype": "text/x-python",
244
+ "name": "python",
245
+ "nbconvert_exporter": "python",
246
+ "pygments_lexer": "ipython3",
247
+ "version": "3.12.10"
248
+ }
249
+ },
250
+ "nbformat": 4,
251
+ "nbformat_minor": 2
252
+ }
assets/hrm.png ADDED
assets/npyjs.js ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class npyjs {
2
+
3
+ constructor(opts) {
4
+ if (opts && !('convertFloat16' in opts)) {
5
+ console.warn([
6
+ "npyjs constructor now accepts {convertFloat16?: boolean}.",
7
+ "For usage, go to https://github.com/jhuapl-boss/npyjs."
8
+ ].join(" "));
9
+ }
10
+
11
+ this.convertFloat16 = opts?.convertFloat16 ?? true;
12
+
13
+ this.dtypes = {
14
+ "<u1": {
15
+ name: "uint8",
16
+ size: 8,
17
+ arrayConstructor: Uint8Array,
18
+ },
19
+ "|u1": {
20
+ name: "uint8",
21
+ size: 8,
22
+ arrayConstructor: Uint8Array,
23
+ },
24
+ "<u2": {
25
+ name: "uint16",
26
+ size: 16,
27
+ arrayConstructor: Uint16Array,
28
+ },
29
+ "|i1": {
30
+ name: "int8",
31
+ size: 8,
32
+ arrayConstructor: Int8Array,
33
+ },
34
+ "<i2": {
35
+ name: "int16",
36
+ size: 16,
37
+ arrayConstructor: Int16Array,
38
+ },
39
+ "<u4": {
40
+ name: "uint32",
41
+ size: 32,
42
+ arrayConstructor: Uint32Array,
43
+ },
44
+ "<i4": {
45
+ name: "int32",
46
+ size: 32,
47
+ arrayConstructor: Int32Array,
48
+ },
49
+ "<u8": {
50
+ name: "uint64",
51
+ size: 64,
52
+ arrayConstructor: BigUint64Array,
53
+ },
54
+ "<i8": {
55
+ name: "int64",
56
+ size: 64,
57
+ arrayConstructor: BigInt64Array,
58
+ },
59
+ "<f4": {
60
+ name: "float32",
61
+ size: 32,
62
+ arrayConstructor: Float32Array
63
+ },
64
+ "<f8": {
65
+ name: "float64",
66
+ size: 64,
67
+ arrayConstructor: Float64Array
68
+ },
69
+ "<f2": {
70
+ name: "float16",
71
+ size: 16,
72
+ arrayConstructor: Uint16Array,
73
+ converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined
74
+ },
75
+ };
76
+ }
77
+
78
+ float16ToFloat32Array(float16Array) {
79
+ const length = float16Array.length;
80
+ const float32Array = new Float32Array(length);
81
+
82
+ for (let i = 0; i < length; i++) {
83
+ float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);
84
+ }
85
+
86
+ return float32Array;
87
+ }
88
+
89
+ static float16ToFloat32(float16) {
90
+ // Extract the parts of the float16
91
+ const sign = (float16 >> 15) & 0x1;
92
+ const exponent = (float16 >> 10) & 0x1f;
93
+ const fraction = float16 & 0x3ff;
94
+
95
+ // Handle special cases
96
+ if (exponent === 0) {
97
+ if (fraction === 0) {
98
+ // Zero
99
+ return sign ? -0 : 0;
100
+ }
101
+ // Denormalized number
102
+ return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);
103
+ } else if (exponent === 0x1f) {
104
+ if (fraction === 0) {
105
+ // Infinity
106
+ return sign ? -Infinity : Infinity;
107
+ }
108
+ // NaN
109
+ return NaN;
110
+ }
111
+
112
+ // Normalized number
113
+ return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);
114
+ }
115
+
116
+ parse(arrayBufferContents) {
117
+ // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
118
+ const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
119
+ const offsetBytes = 10 + headerLength;
120
+
121
+ const hcontents = new TextDecoder("utf-8").decode(
122
+ new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
123
+ );
124
+ const header = JSON.parse(
125
+ hcontents
126
+ .toLowerCase() // True -> true
127
+ .replace(/'/g, '"')
128
+ .replace("(", "[")
129
+ .replace(/,*\),*/g, "]")
130
+ );
131
+ const shape = header.shape;
132
+ const dtype = this.dtypes[header.descr];
133
+
134
+ if (!dtype) {
135
+ console.error(`Unsupported dtype: ${header.descr}`);
136
+ return null;
137
+ }
138
+
139
+ const nums = new dtype.arrayConstructor(
140
+ arrayBufferContents,
141
+ offsetBytes
142
+ );
143
+
144
+ // Convert float16 to float32 if converter exists
145
+ const data = dtype.converter ? dtype.converter.call(this, nums) : nums;
146
+
147
+ return {
148
+ dtype: dtype.name,
149
+ data: data,
150
+ shape,
151
+ fortranOrder: header.fortran_order
152
+ };
153
+ }
154
+
155
+ async load(filename, callback, fetchArgs) {
156
+ /*
157
+ Loads an array from a stream of bytes.
158
+ */
159
+ fetchArgs = fetchArgs || {};
160
+ let arrayBuf;
161
+ // If filename is ArrayBuffer
162
+ if (filename instanceof ArrayBuffer) {
163
+ arrayBuf = filename;
164
+ }
165
+ // If filename is a file path
166
+ else {
167
+ const resp = await fetch(filename, { ...fetchArgs });
168
+ arrayBuf = await resp.arrayBuffer();
169
+ }
170
+ const result = this.parse(arrayBuf);
171
+ if (callback) {
172
+ return callback(result);
173
+ }
174
+ return result;
175
+ }
176
+ }
config/arch/hrm_v1.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
2
+ loss:
3
+ name: losses@ACTLossHead
4
+ loss_type: stablemax_cross_entropy
5
+
6
+ halt_exploration_prob: 0.1
7
+ halt_max_steps: 16
8
+
9
+ H_cycles: 2
10
+ L_cycles: 2
11
+
12
+ H_layers: 4
13
+ L_layers: 4
14
+
15
+ hidden_size: 512
16
+ num_heads: 8 # min(2, hidden_size // 64)
17
+ expansion: 4
18
+
19
+ puzzle_emb_ndim: ${.hidden_size}
20
+
21
+ pos_encodings: rope
config/cfg_pretrain.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ARC training config
2
+
3
+ defaults:
4
+ - arch: hrm_v1
5
+ - _self_
6
+
7
+ hydra:
8
+ output_subdir: null
9
+
10
+ # Data path
11
+ data_path: data/arc-aug-1000
12
+
13
+ # Hyperparams - Training
14
+ global_batch_size: 768
15
+
16
+ epochs: 100000
17
+ eval_interval: 10000
18
+ checkpoint_every_eval: True
19
+
20
+ lr: 1e-4
21
+ lr_min_ratio: 1.0
22
+ lr_warmup_steps: 2000
23
+
24
+ # Standard hyperparameter settings for LM, as used in Llama
25
+ beta1: 0.9
26
+ beta2: 0.95
27
+ weight_decay: 0.1
28
+ puzzle_emb_weight_decay: 0.1
29
+
30
+ # Hyperparams - Puzzle embeddings training
31
+ puzzle_emb_lr: 1e-2
dataset/build_arc_dataset.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Dict
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ import os
5
+ import json
6
+ import hashlib
7
+ import numpy as np
8
+ from glob import glob
9
+
10
+ from argdantic import ArgParser
11
+ from pydantic import BaseModel
12
+
13
+ from common import PuzzleDatasetMetadata, dihedral_transform
14
+
15
+
16
+ cli = ArgParser()
17
+
18
+
19
+ class DataProcessConfig(BaseModel):
20
+ # ARC-1
21
+ dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
22
+ output_dir: str = "data/arc-aug-1000"
23
+
24
+ # ARC-2
25
+ # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
26
+ # output_dir: str = "data/arc-2-aug-1000"
27
+
28
+ seed: int = 42
29
+ num_aug: int = 1000
30
+
31
+
32
+ ARCMaxGridSize = 30
33
+ ARCAugmentRetriesFactor = 5
34
+
35
+
36
+ @dataclass
37
+ class ARCPuzzle:
38
+ id: str
39
+
40
+ examples: List[Tuple[np.ndarray, np.ndarray]]
41
+
42
+
43
+ def arc_grid_to_np(grid: List[List[int]]):
44
+ arr = np.array(grid)
45
+
46
+ # Shape check
47
+ assert arr.ndim == 2
48
+ assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
49
+ # Element check
50
+ assert np.all((arr >= 0) & (arr <= 9))
51
+ return arr.astype(np.uint8)
52
+
53
+
54
+ def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
55
+ # PAD: 0, <eos>: 1, digits: 2 ... 11
56
+ # Compute random top-left pad
57
+ if do_translation:
58
+ pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
59
+ pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
60
+ else:
61
+ pad_r = pad_c = 0
62
+
63
+ # Pad grid
64
+ result = []
65
+ for grid in [inp, out]:
66
+ nrow, ncol = grid.shape
67
+ grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
68
+
69
+ # Add <eos>
70
+ eos_row, eos_col = pad_r + nrow, pad_c + ncol
71
+ if eos_row < ARCMaxGridSize:
72
+ grid[eos_row, pad_c:eos_col] = 1
73
+ if eos_col < ARCMaxGridSize:
74
+ grid[pad_r:eos_row, eos_col] = 1
75
+
76
+ result.append(grid.flatten())
77
+
78
+ return result
79
+
80
+
81
+ def puzzle_hash(puzzle: dict):
82
+ # Hash the puzzle for checking equivalence
83
+ def _grid_hash(grid: np.ndarray):
84
+ buffer = [x.to_bytes(1) for x in grid.shape]
85
+ buffer.append(grid.tobytes())
86
+
87
+ return hashlib.sha256(b"".join(buffer)).hexdigest()
88
+
89
+ hashes = []
90
+ for example_type, example in puzzle.items():
91
+ for input, label in example.examples:
92
+ hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
93
+
94
+ hashes.sort()
95
+ return hashlib.sha256("|".join(hashes).encode()).hexdigest()
96
+
97
+
98
+ def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
99
+ # Remove "name"
100
+ name = puzzle.pop("name", default_name)
101
+
102
+ # Convert
103
+ dests = set(dest_mapping.values())
104
+ converted = {dest: ARCPuzzle(name, []) for dest in dests}
105
+ for example_type, examples in puzzle.items():
106
+ dest = dest_mapping[example_type]
107
+ converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
108
+
109
+ group = [converted]
110
+
111
+ # Augment
112
+ if aug_count > 0:
113
+ hashes = {puzzle_hash(converted)}
114
+
115
+ for _trial in range(ARCAugmentRetriesFactor * aug_count):
116
+ # Augment plan
117
+ trans_id = np.random.randint(0, 8)
118
+ mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
119
+
120
+ aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"
121
+
122
+ def _map_grid(grid: np.ndarray):
123
+ return dihedral_transform(mapping[grid], trans_id)
124
+
125
+ # Check duplicate
126
+ augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
127
+ h = puzzle_hash(augmented)
128
+ if h not in hashes:
129
+ hashes.add(h)
130
+ group.append(augmented)
131
+
132
+ if len(group) >= aug_count + 1:
133
+ break
134
+
135
+ if len(group) < aug_count + 1:
136
+ print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
137
+
138
+ # Append
139
+ for dest in dests:
140
+ # Convert the examples
141
+ dest_split, dest_set = dest
142
+
143
+ results.setdefault(dest_split, {})
144
+ results[dest_split].setdefault(dest_set, [])
145
+ results[dest_split][dest_set].append([converted[dest] for converted in group])
146
+
147
+
148
+ def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
149
+ train_examples_dest = ("train", "all")
150
+ test_examples_map = {
151
+ "evaluation": [(1.0, ("test", "all"))],
152
+ "_default": [(1.0, ("train", "all"))]
153
+ }
154
+
155
+ total_puzzles = 0
156
+ for subdir in os.scandir(dataset_path):
157
+ if subdir.is_dir():
158
+ # Load all puzzles in this directory
159
+ puzzles = []
160
+ for filename in glob(os.path.join(subdir.path, "*.json")):
161
+ with open(filename, "r") as f:
162
+ puzzles.append((Path(filename).stem, json.load(f)))
163
+
164
+ # Shuffle puzzles
165
+ np.random.shuffle(puzzles)
166
+
167
+ # Assign by fraction
168
+ for idx, (default_name, puzzle) in enumerate(puzzles):
169
+ fraction = idx / len(puzzles)
170
+ test_examples_dest = None
171
+ for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
172
+ if fraction < f:
173
+ test_examples_dest = dest
174
+ break
175
+
176
+ assert test_examples_dest is not None
177
+
178
+ convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
179
+ total_puzzles += 1
180
+
181
+ print (f"[{dataset_path}] total puzzles: {total_puzzles}")
182
+
183
+
184
+ def convert_dataset(config: DataProcessConfig):
185
+ np.random.seed(config.seed)
186
+
187
+ # Read dataset
188
+ data = {}
189
+ for dataset_dir in config.dataset_dirs:
190
+ load_puzzles_arcagi(data, dataset_dir, config)
191
+
192
+ # Map global puzzle identifiers
193
+ num_identifiers = 1 # 0 is blank
194
+ identifier_map = {}
195
+ for split_name, split in data.items():
196
+ for subset_name, subset in split.items():
197
+ for group in subset:
198
+ for puzzle in group:
199
+ if puzzle.id not in identifier_map:
200
+ identifier_map[puzzle.id] = num_identifiers
201
+ num_identifiers += 1
202
+
203
+ print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
204
+
205
+ # Save
206
+ for split_name, split in data.items():
207
+ os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
208
+
209
+ # Translational augmentations
210
+ enable_translational_augment = split_name == "train"
211
+
212
+ # Statistics
213
+ total_examples = 0
214
+ total_puzzles = 0
215
+ total_groups = 0
216
+
217
+ for subset_name, subset in split.items():
218
+ # Construct subset
219
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
220
+ results["puzzle_indices"].append(0)
221
+ results["group_indices"].append(0)
222
+
223
+ example_id = 0
224
+ puzzle_id = 0
225
+
226
+ for group in subset:
227
+ for puzzle in group:
228
+ # Push puzzle
229
+ no_aug_id = np.random.randint(0, len(puzzle.examples))
230
+ for _idx_ex, (inp, out) in enumerate(puzzle.examples):
231
+ inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
232
+
233
+ results["inputs"].append(inp)
234
+ results["labels"].append(out)
235
+ example_id += 1
236
+
237
+ total_examples += 1
238
+
239
+ results["puzzle_indices"].append(example_id)
240
+ results["puzzle_identifiers"].append(identifier_map[puzzle.id])
241
+
242
+ puzzle_id += 1
243
+
244
+ total_puzzles += 1
245
+
246
+ # Push group
247
+ results["group_indices"].append(puzzle_id)
248
+ total_groups += 1
249
+
250
+ for k, v in results.items():
251
+ if k in {"inputs", "labels"}:
252
+ v = np.stack(v, 0)
253
+ else:
254
+ v = np.array(v, dtype=np.int32)
255
+
256
+ np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
257
+
258
+ # Metadata
259
+ metadata = PuzzleDatasetMetadata(
260
+ seq_len=ARCMaxGridSize * ARCMaxGridSize,
261
+ vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
262
+
263
+ pad_id=0,
264
+ ignore_label_id=0,
265
+
266
+ blank_identifier_id=0,
267
+ num_puzzle_identifiers=num_identifiers,
268
+
269
+ total_groups=total_groups,
270
+ mean_puzzle_examples=total_examples / total_puzzles,
271
+ sets=list(split.keys())
272
+ )
273
+
274
+ # Save metadata as JSON.
275
+ with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
276
+ json.dump(metadata.model_dump(), f)
277
+
278
+ # Save IDs mapping
279
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
280
+ ids_mapping = {v: k for k, v in identifier_map.items()}
281
+
282
+ json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
283
+
284
+
285
+ @cli.command(singleton=True)
286
+ def main(config: DataProcessConfig):
287
+ convert_dataset(config)
288
+
289
+
290
+ if __name__ == "__main__":
291
+ cli()
dataset/build_maze_dataset.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import math
3
+ import os
4
+ import csv
5
+ import json
6
+ import numpy as np
7
+
8
+ from argdantic import ArgParser
9
+ from pydantic import BaseModel
10
+ from tqdm import tqdm
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from common import PuzzleDatasetMetadata, dihedral_transform
14
+
15
+
16
+ CHARSET = "# SGo"
17
+
18
+
19
+ cli = ArgParser()
20
+
21
+
22
+ class DataProcessConfig(BaseModel):
23
+ source_repo: str = "imone/small-sample-challenge-maze-30x30-hard"
24
+ output_dir: str = "data/maze-30x30-hard-1k"
25
+
26
+ subsample_size: Optional[int] = None
27
+ aug: bool = False
28
+
29
+
30
+ def convert_subset(set_name: str, config: DataProcessConfig):
31
+ # Read CSV
32
+ all_chars = set()
33
+ grid_size = None
34
+ inputs = []
35
+ labels = []
36
+
37
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
38
+ reader = csv.reader(csvfile)
39
+ next(reader) # Skip header
40
+ for source, q, a, rating in reader:
41
+ all_chars.update(q)
42
+ all_chars.update(a)
43
+
44
+ if grid_size is None:
45
+ n = int(len(q) ** 0.5)
46
+ grid_size = (n, n)
47
+
48
+ inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
49
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
50
+
51
+ # If subsample_size is specified for the training set,
52
+ # randomly sample the desired number of examples.
53
+ if set_name == "train" and config.subsample_size is not None:
54
+ total_samples = len(inputs)
55
+ if config.subsample_size < total_samples:
56
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
57
+ inputs = [inputs[i] for i in indices]
58
+ labels = [labels[i] for i in indices]
59
+
60
+ # Generate dataset
61
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
62
+ puzzle_id = 0
63
+ example_id = 0
64
+
65
+ results["puzzle_indices"].append(0)
66
+ results["group_indices"].append(0)
67
+
68
+ for inp, out in zip(tqdm(inputs), labels):
69
+ # Dihedral transformations for augmentation
70
+ for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
71
+ results["inputs"].append(dihedral_transform(inp, aug_idx))
72
+ results["labels"].append(dihedral_transform(out, aug_idx))
73
+ example_id += 1
74
+ puzzle_id += 1
75
+
76
+ results["puzzle_indices"].append(example_id)
77
+ results["puzzle_identifiers"].append(0)
78
+
79
+ # Push group
80
+ results["group_indices"].append(puzzle_id)
81
+
82
+ # Char mappings
83
+ assert len(all_chars - set(CHARSET)) == 0
84
+
85
+ char2id = np.zeros(256, np.uint8)
86
+ char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
87
+
88
+ # To Numpy
89
+ def _seq_to_numpy(seq):
90
+ arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
91
+
92
+ return arr
93
+
94
+ results = {
95
+ "inputs": _seq_to_numpy(results["inputs"]),
96
+ "labels": _seq_to_numpy(results["labels"]),
97
+
98
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
99
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
100
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
101
+ }
102
+
103
+ # Metadata
104
+ metadata = PuzzleDatasetMetadata(
105
+ seq_len=int(math.prod(grid_size)), # type: ignore
106
+ vocab_size=len(CHARSET) + 1, # PAD + Charset
107
+
108
+ pad_id=0,
109
+ ignore_label_id=0,
110
+
111
+ blank_identifier_id=0,
112
+ num_puzzle_identifiers=1,
113
+
114
+ total_groups=len(results["group_indices"]) - 1,
115
+ mean_puzzle_examples=1,
116
+ sets=["all"]
117
+ )
118
+
119
+ # Save metadata as JSON.
120
+ save_dir = os.path.join(config.output_dir, set_name)
121
+ os.makedirs(save_dir, exist_ok=True)
122
+
123
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
124
+ json.dump(metadata.model_dump(), f)
125
+
126
+ # Save data
127
+ for k, v in results.items():
128
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
129
+
130
+ # Save IDs mapping (for visualization only)
131
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
132
+ json.dump(["<blank>"], f)
133
+
134
+
135
+ @cli.command(singleton=True)
136
+ def preprocess_data(config: DataProcessConfig):
137
+ convert_subset("train", config)
138
+ convert_subset("test", config)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ cli()
dataset/build_sudoku_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import os
3
+ import csv
4
+ import json
5
+ import numpy as np
6
+
7
+ from argdantic import ArgParser
8
+ from pydantic import BaseModel
9
+ from tqdm import tqdm
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from common import PuzzleDatasetMetadata
13
+
14
+
15
+ cli = ArgParser()
16
+
17
+
18
+ class DataProcessConfig(BaseModel):
19
+ source_repo: str = "imone/sudoku-hard-v2"
20
+ output_dir: str = "data/sudoku-extreme-full"
21
+
22
+ subsample_size: Optional[int] = None
23
+ min_difficulty: Optional[int] = None
24
+ num_aug: int = 0
25
+
26
+
27
+ def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
28
+ # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
29
+ digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
30
+
31
+ # Randomly decide whether to transpose.
32
+ transpose_flag = np.random.rand() < 0.5
33
+
34
+ # Generate a valid row permutation:
35
+ # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
36
+ bands = np.random.permutation(3)
37
+ row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
38
+
39
+ # Similarly for columns (stacks).
40
+ stacks = np.random.permutation(3)
41
+ col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
42
+
43
+ # Build an 81->81 mapping. For each new cell at (i, j)
44
+ # (row index = i // 9, col index = i % 9),
45
+ # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
46
+ mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
47
+
48
+ def apply_transformation(x: np.ndarray) -> np.ndarray:
49
+ # Apply transpose flag
50
+ if transpose_flag:
51
+ x = x.T
52
+ # Apply the position mapping.
53
+ new_board = x.flatten()[mapping].reshape(9, 9).copy()
54
+ # Apply digit mapping
55
+ return digit_map[new_board]
56
+
57
+ return apply_transformation(board), apply_transformation(solution)
58
+
59
+
60
+ def convert_subset(set_name: str, config: DataProcessConfig):
61
+ # Read CSV
62
+ inputs = []
63
+ labels = []
64
+
65
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
66
+ reader = csv.reader(csvfile)
67
+ next(reader) # Skip header
68
+ for source, q, a, rating in reader:
69
+ if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
70
+ assert len(q) == 81 and len(a) == 81
71
+
72
+ inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
73
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
74
+
75
+ # If subsample_size is specified for the training set,
76
+ # randomly sample the desired number of examples.
77
+ if set_name == "train" and config.subsample_size is not None:
78
+ total_samples = len(inputs)
79
+ if config.subsample_size < total_samples:
80
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
81
+ inputs = [inputs[i] for i in indices]
82
+ labels = [labels[i] for i in indices]
83
+
84
+ # Generate dataset
85
+ num_augments = config.num_aug if set_name == "train" else 0
86
+
87
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
88
+ puzzle_id = 0
89
+ example_id = 0
90
+
91
+ results["puzzle_indices"].append(0)
92
+ results["group_indices"].append(0)
93
+
94
+ for orig_inp, orig_out in zip(tqdm(inputs), labels):
95
+ for aug_idx in range(1 + num_augments):
96
+ # First index is not augmented
97
+ if aug_idx == 0:
98
+ inp, out = orig_inp, orig_out
99
+ else:
100
+ inp, out = shuffle_sudoku(orig_inp, orig_out)
101
+
102
+ # Push puzzle (only single example)
103
+ results["inputs"].append(inp)
104
+ results["labels"].append(out)
105
+ example_id += 1
106
+ puzzle_id += 1
107
+
108
+ results["puzzle_indices"].append(example_id)
109
+ results["puzzle_identifiers"].append(0)
110
+
111
+ # Push group
112
+ results["group_indices"].append(puzzle_id)
113
+
114
+ # To Numpy
115
+ def _seq_to_numpy(seq):
116
+ arr = np.concatenate(seq).reshape(len(seq), -1)
117
+
118
+ assert np.all((arr >= 0) & (arr <= 9))
119
+ return arr + 1
120
+
121
+ results = {
122
+ "inputs": _seq_to_numpy(results["inputs"]),
123
+ "labels": _seq_to_numpy(results["labels"]),
124
+
125
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
126
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
127
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
128
+ }
129
+
130
+ # Metadata
131
+ metadata = PuzzleDatasetMetadata(
132
+ seq_len=81,
133
+ vocab_size=10 + 1, # PAD + "0" ... "9"
134
+
135
+ pad_id=0,
136
+ ignore_label_id=0,
137
+
138
+ blank_identifier_id=0,
139
+ num_puzzle_identifiers=1,
140
+
141
+ total_groups=len(results["group_indices"]) - 1,
142
+ mean_puzzle_examples=1,
143
+ sets=["all"]
144
+ )
145
+
146
+ # Save metadata as JSON.
147
+ save_dir = os.path.join(config.output_dir, set_name)
148
+ os.makedirs(save_dir, exist_ok=True)
149
+
150
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
151
+ json.dump(metadata.model_dump(), f)
152
+
153
+ # Save data
154
+ for k, v in results.items():
155
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
156
+
157
+ # Save IDs mapping (for visualization only)
158
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
159
+ json.dump(["<blank>"], f)
160
+
161
+
162
+ @cli.command(singleton=True)
163
+ def preprocess_data(config: DataProcessConfig):
164
+ convert_subset("train", config)
165
+ convert_subset("test", config)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ cli()
dataset/common.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import pydantic
4
+ import numpy as np
5
+
6
+
7
+ # Global list mapping each dihedral transform id to its inverse.
8
+ # Index corresponds to the original tid, and the value is its inverse.
9
+ DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
10
+
11
+
12
+ class PuzzleDatasetMetadata(pydantic.BaseModel):
13
+ pad_id: int
14
+ ignore_label_id: Optional[int]
15
+ blank_identifier_id: int
16
+
17
+ vocab_size: int
18
+ seq_len: int
19
+ num_puzzle_identifiers: int
20
+
21
+ total_groups: int
22
+ mean_puzzle_examples: float
23
+
24
+ sets: List[str]
25
+
26
+
27
+ def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
28
+ """8 dihedral symmetries by rotate, flip and mirror"""
29
+
30
+ if tid == 0:
31
+ return arr # identity
32
+ elif tid == 1:
33
+ return np.rot90(arr, k=1)
34
+ elif tid == 2:
35
+ return np.rot90(arr, k=2)
36
+ elif tid == 3:
37
+ return np.rot90(arr, k=3)
38
+ elif tid == 4:
39
+ return np.fliplr(arr) # horizontal flip
40
+ elif tid == 5:
41
+ return np.flipud(arr) # vertical flip
42
+ elif tid == 6:
43
+ return arr.T # transpose (reflection along main diagonal)
44
+ elif tid == 7:
45
+ return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
46
+ else:
47
+ return arr
48
+
49
+
50
+ def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
51
+ return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
evaluate.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import yaml
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ import pydantic
9
+ from omegaconf import OmegaConf
10
+ from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
11
+
12
+
13
+ class EvalConfig(pydantic.BaseModel):
14
+ checkpoint: str
15
+
16
+ save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
17
+
18
+
19
+ def launch():
20
+ eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
21
+
22
+ RANK = 0
23
+ WORLD_SIZE = 1
24
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
25
+ if "LOCAL_RANK" in os.environ:
26
+ # Initialize distributed, default device and dtype
27
+ dist.init_process_group(backend="nccl")
28
+
29
+ RANK = dist.get_rank()
30
+ WORLD_SIZE = dist.get_world_size()
31
+
32
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
33
+
34
+ with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
35
+ config = PretrainConfig(**yaml.safe_load(f))
36
+
37
+ config.eval_save_outputs = eval_cfg.save_outputs
38
+ config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
39
+
40
+ # Dataloader
41
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
42
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, test_set_limit_examples=LIMIT_EXAMPLES, rank=RANK, world_size=WORLD_SIZE)
43
+
44
+ # Models
45
+ train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
46
+ # Try unwrap torch.compile
47
+ try:
48
+ train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
49
+ except:
50
+ train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
51
+
52
+ train_state.step = 0
53
+ ckpt_filename = os.path.basename(eval_cfg.checkpoint)
54
+ if ckpt_filename.startswith("step_"):
55
+ train_state.step = int(ckpt_filename.removeprefix("step_"))
56
+
57
+ # Evaluate
58
+ print ("Starting evaluation")
59
+
60
+ train_state.model.eval()
61
+ metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
62
+
63
+ if metrics is not None:
64
+ print (metrics)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ launch()
models/common.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
8
+ # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
9
+ # This function is a PyTorch version of jax truncated normal init (default init method in flax)
10
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
11
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
12
+
13
+ with torch.no_grad():
14
+ if std == 0:
15
+ tensor.zero_()
16
+ else:
17
+ sqrt2 = math.sqrt(2)
18
+ a = math.erf(lower / sqrt2)
19
+ b = math.erf(upper / sqrt2)
20
+ z = (b - a) / 2
21
+
22
+ c = (2 * math.pi) ** -0.5
23
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
24
+ pdf_l = c * math.exp(-0.5 * upper ** 2)
25
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
26
+
27
+ tensor.uniform_(a, b)
28
+ tensor.erfinv_()
29
+ tensor.mul_(sqrt2 * comp_std)
30
+ tensor.clip_(lower * comp_std, upper * comp_std)
31
+
32
+ return tensor
models/hrm/hrm_act_v1.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+
15
+ @dataclass
16
+ class HierarchicalReasoningModel_ACTV1InnerCarry:
17
+ z_H: torch.Tensor
18
+ z_L: torch.Tensor
19
+
20
+
21
+ @dataclass
22
+ class HierarchicalReasoningModel_ACTV1Carry:
23
+ inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
24
+
25
+ steps: torch.Tensor
26
+ halted: torch.Tensor
27
+
28
+ current_data: Dict[str, torch.Tensor]
29
+
30
+
31
+ class HierarchicalReasoningModel_ACTV1Config(BaseModel):
32
+ batch_size: int
33
+ seq_len: int
34
+ puzzle_emb_ndim: int = 0
35
+ num_puzzle_identifiers: int
36
+ vocab_size: int
37
+
38
+ H_cycles: int
39
+ L_cycles: int
40
+
41
+ H_layers: int
42
+ L_layers: int
43
+
44
+ # Transformer config
45
+ hidden_size: int
46
+ expansion: float
47
+ num_heads: int
48
+ pos_encodings: str
49
+
50
+ rms_norm_eps: float = 1e-5
51
+ rope_theta: float = 10000.0
52
+
53
+ # Halting Q-learning config
54
+ halt_max_steps: int
55
+ halt_exploration_prob: float
56
+
57
+ forward_dtype: str = "bfloat16"
58
+
59
+
60
+ class HierarchicalReasoningModel_ACTV1Block(nn.Module):
61
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
62
+ super().__init__()
63
+
64
+ self.self_attn = Attention(
65
+ hidden_size=config.hidden_size,
66
+ head_dim=config.hidden_size // config.num_heads,
67
+ num_heads=config.num_heads,
68
+ num_key_value_heads=config.num_heads,
69
+ causal=False
70
+ )
71
+ self.mlp = SwiGLU(
72
+ hidden_size=config.hidden_size,
73
+ expansion=config.expansion,
74
+ )
75
+ self.norm_eps = config.rms_norm_eps
76
+
77
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
78
+ # Post Norm
79
+ # Self Attention
80
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
81
+ # Fully Connected
82
+ hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
83
+ return hidden_states
84
+
85
+
86
+ class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
87
+ def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
88
+ super().__init__()
89
+
90
+ self.layers = torch.nn.ModuleList(layers)
91
+
92
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
93
+ # Input injection (add)
94
+ hidden_states = hidden_states + input_injection
95
+ # Layers
96
+ for layer in self.layers:
97
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
98
+
99
+ return hidden_states
100
+
101
+
102
+ class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
103
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
104
+ super().__init__()
105
+ self.config = config
106
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
107
+
108
+ # I/O
109
+ self.embed_scale = math.sqrt(self.config.hidden_size)
110
+ embed_init_std = 1.0 / self.embed_scale
111
+
112
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
113
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
114
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
115
+
116
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
117
+ if self.config.puzzle_emb_ndim > 0:
118
+ # Zero init puzzle embeddings
119
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
120
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
121
+
122
+ # LM Blocks
123
+ if self.config.pos_encodings == "rope":
124
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
125
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
126
+ base=self.config.rope_theta)
127
+ elif self.config.pos_encodings == "learned":
128
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
129
+ else:
130
+ raise NotImplementedError()
131
+
132
+ # Reasoning Layers
133
+ self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
134
+ self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
135
+
136
+ # Initial states
137
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
138
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
139
+
140
+ # Q head special init
141
+ # Init Q to (almost) zero for faster learning during bootstrapping
142
+ with torch.no_grad():
143
+ self.q_head.weight.zero_()
144
+ self.q_head.bias.fill_(-5) # type: ignore
145
+
146
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
147
+ # Token embedding
148
+ embedding = self.embed_tokens(input.to(torch.int32))
149
+
150
+ # Puzzle embeddings
151
+ if self.config.puzzle_emb_ndim > 0:
152
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
153
+
154
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
155
+ if pad_count > 0:
156
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
157
+
158
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
159
+
160
+ # Position embeddings
161
+ if self.config.pos_encodings == "learned":
162
+ # scale by 1/sqrt(2) to maintain forward variance
163
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
164
+
165
+ # Scale
166
+ return self.embed_scale * embedding
167
+
168
+ def empty_carry(self, batch_size: int):
169
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
170
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
171
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
172
+ )
173
+
174
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
175
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
176
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
177
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
178
+ )
179
+
180
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
181
+ seq_info = dict(
182
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
183
+ )
184
+
185
+ # Input encoding
186
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
187
+
188
+ # Forward iterations
189
+ with torch.no_grad():
190
+ z_H, z_L = carry.z_H, carry.z_L
191
+
192
+ for _H_step in range(self.config.H_cycles):
193
+ for _L_step in range(self.config.L_cycles):
194
+ if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
195
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
196
+
197
+ if not (_H_step == self.config.H_cycles - 1):
198
+ z_H = self.H_level(z_H, z_L, **seq_info)
199
+
200
+ assert not z_H.requires_grad and not z_L.requires_grad
201
+
202
+ # 1-step grad
203
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
204
+ z_H = self.H_level(z_H, z_L, **seq_info)
205
+
206
+ # LM Outputs
207
+ new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
208
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
209
+
210
+ # Q head
211
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
212
+
213
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
214
+
215
+
216
+ class HierarchicalReasoningModel_ACTV1(nn.Module):
217
+ """ACT wrapper."""
218
+
219
+ def __init__(self, config_dict: dict):
220
+ super().__init__()
221
+ self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
222
+ self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
223
+
224
+ @property
225
+ def puzzle_emb(self):
226
+ return self.inner.puzzle_emb
227
+
228
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
229
+ batch_size = batch["inputs"].shape[0]
230
+
231
+ return HierarchicalReasoningModel_ACTV1Carry(
232
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
233
+
234
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
235
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
236
+
237
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
238
+ )
239
+
240
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
241
+ # Update data, carry (removing halted sequences)
242
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
243
+
244
+ new_steps = torch.where(carry.halted, 0, carry.steps)
245
+
246
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
247
+
248
+ # Forward inner model
249
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
250
+
251
+ outputs = {
252
+ "logits": logits,
253
+ "q_halt_logits": q_halt_logits,
254
+ "q_continue_logits": q_continue_logits
255
+ }
256
+
257
+ with torch.no_grad():
258
+ # Step
259
+ new_steps = new_steps + 1
260
+ is_last_step = new_steps >= self.config.halt_max_steps
261
+
262
+ halted = is_last_step
263
+
264
+ # if training, and ACT is enabled
265
+ if self.training and (self.config.halt_max_steps > 1):
266
+ # Halt signal
267
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
268
+ halted = halted | (q_halt_logits > q_continue_logits)
269
+
270
+ # Exploration
271
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
272
+
273
+ halted = halted & (new_steps >= min_halt_steps)
274
+
275
+ # Compute target Q
276
+ # NOTE: No replay buffer and target networks for computing target Q-value.
277
+ # As batch_size is large, there're many parallel envs.
278
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
279
+ next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
280
+
281
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
282
+
283
+ return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
models/layers.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.common import trunc_normal_init_
8
+
9
+
10
+ CosSin = Tuple[torch.Tensor, torch.Tensor]
11
+
12
+
13
+ def _find_multiple(a, b):
14
+ return (-(a // -b)) * b
15
+
16
+
17
+ def rotate_half(x: torch.Tensor):
18
+ """Rotates half the hidden dims of the input."""
19
+ x1 = x[..., : x.shape[-1] // 2]
20
+ x2 = x[..., x.shape[-1] // 2 :]
21
+ return torch.cat((-x2, x1), dim=-1)
22
+
23
+
24
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
25
+ # q, k: [bs, num_heads, seq_len, head_dim]
26
+ # cos, sin: [seq_len, head_dim]
27
+ orig_dtype = q.dtype
28
+ q = q.to(cos.dtype)
29
+ k = k.to(cos.dtype)
30
+
31
+ q_embed = (q * cos) + (rotate_half(q) * sin)
32
+ k_embed = (k * cos) + (rotate_half(k) * sin)
33
+
34
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
35
+
36
+
37
+ class CastedLinear(nn.Module):
38
+ def __init__(self,
39
+ in_features: int,
40
+ out_features: int,
41
+ bias: bool):
42
+ super().__init__()
43
+ # Truncated LeCun normal init
44
+ self.weight = nn.Parameter(
45
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
46
+ )
47
+ self.bias = None
48
+ if bias:
49
+ # Zero init bias
50
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
51
+
52
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
53
+ return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
54
+
55
+
56
+ class CastedEmbedding(nn.Module):
57
+ def __init__(self,
58
+ num_embeddings: int,
59
+ embedding_dim: int,
60
+ init_std: float,
61
+ cast_to: torch.dtype):
62
+ super().__init__()
63
+ self.cast_to = cast_to
64
+
65
+ # Truncated LeCun normal init
66
+ self.embedding_weight = nn.Parameter(
67
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
68
+ )
69
+
70
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
71
+ return F.embedding(input, self.embedding_weight.to(self.cast_to))
72
+
73
+
74
+ class RotaryEmbedding(nn.Module):
75
+ def __init__(self, dim, max_position_embeddings, base, device=None):
76
+ super().__init__()
77
+
78
+ # RoPE
79
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
80
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
81
+ freqs = torch.outer(t, inv_freq)
82
+
83
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
84
+ emb = torch.cat((freqs, freqs), dim=-1)
85
+ self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
86
+ self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
87
+
88
+ def forward(self):
89
+ return self.cos_cached, self.sin_cached
90
+
91
+
92
+ class Attention(nn.Module):
93
+ def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
94
+ super().__init__()
95
+
96
+ self.hidden_size = hidden_size
97
+ self.head_dim = head_dim
98
+ self.output_size = head_dim * num_heads
99
+ self.num_heads = num_heads
100
+ self.num_key_value_heads = num_key_value_heads
101
+ self.causal = causal
102
+
103
+ self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
104
+ self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
105
+
106
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
107
+ batch_size, seq_len, _ = hidden_states.shape
108
+
109
+ # hidden_states: [bs, seq_len, num_heads, head_dim]
110
+ qkv = self.qkv_proj(hidden_states)
111
+
112
+ # Split head
113
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3)
114
+ query = qkv[:, :self.num_heads]
115
+ key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads]
116
+ value = qkv[:, self.num_heads + self.num_key_value_heads:]
117
+
118
+ # RoPE
119
+ if cos_sin is not None:
120
+ cos, sin = cos_sin
121
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
122
+
123
+ # flash attn
124
+ attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
125
+
126
+ # attn_output: [batch_size, num_heads, seq_len, head_dim]
127
+ attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore
128
+ return self.o_proj(attn_output)
129
+
130
+
131
+ class SwiGLU(nn.Module):
132
+ def __init__(self, hidden_size: int, expansion: float):
133
+ super().__init__()
134
+ inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
135
+
136
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
137
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
138
+
139
+ def forward(self, x):
140
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
141
+ return self.down_proj(F.silu(gate) * up)
142
+
143
+
144
+ def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
145
+ input_dtype = hidden_states.dtype
146
+ hidden_states = hidden_states.to(torch.float32)
147
+
148
+ variance = hidden_states.square().mean(-1, keepdim=True)
149
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
150
+ return hidden_states.to(input_dtype)
models/losses.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ valid_mask = labels != ignore_index
28
+ transformed_labels = torch.where(valid_mask, labels, 0)
29
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
30
+
31
+ return -torch.where(valid_mask, prediction_logprobs, 0)
32
+
33
+
34
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
35
+ # Cast logits to f32
36
+ # Flatten logits
37
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
38
+
39
+
40
+ class ACTLossHead(nn.Module):
41
+ def __init__(self, model: nn.Module, loss_type: str):
42
+ super().__init__()
43
+ self.model = model
44
+ self.loss_fn = globals()[loss_type]
45
+
46
+ def initial_carry(self, *args, **kwargs):
47
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
48
+
49
+ def forward(
50
+ self,
51
+ return_keys: Sequence[str],
52
+ # Model args
53
+ **model_kwargs,
54
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
55
+ # Model logits
56
+ # B x SeqLen x D
57
+ new_carry, outputs = self.model(**model_kwargs)
58
+ labels = new_carry.current_data["labels"]
59
+
60
+ # Correctness
61
+ with torch.no_grad():
62
+ mask = labels != IGNORE_LABEL_ID
63
+ loss_counts = mask.sum(-1)
64
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
65
+
66
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
67
+ seq_is_correct = is_correct.sum(-1) == loss_counts
68
+
69
+ # Metrics (halted)
70
+ valid_metrics = new_carry.halted & (loss_counts > 0)
71
+ metrics = {
72
+ "count": valid_metrics.sum(),
73
+
74
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
75
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
76
+
77
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
78
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
79
+ }
80
+
81
+ # Losses
82
+ # FIXME: Assuming the batch is always full
83
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
84
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
85
+
86
+ metrics.update({
87
+ "lm_loss": lm_loss.detach(),
88
+ "q_halt_loss": q_halt_loss.detach(),
89
+ })
90
+
91
+ # Q continue (bootstrapping target loss)
92
+ q_continue_loss = 0
93
+ if "target_q_continue" in outputs:
94
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
95
+
96
+ metrics["q_continue_loss"] = q_continue_loss.detach()
97
+
98
+ # Filter outputs for return
99
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
100
+
101
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
models/sparse_embedding.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+ from torch.optim.optimizer import Optimizer, ParamsT
7
+
8
+ from models.common import trunc_normal_init_
9
+
10
+
11
+ class CastedSparseEmbedding(nn.Module):
12
+ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
13
+ super().__init__()
14
+ self.cast_to = cast_to
15
+
16
+ # Real Weights
17
+ # Truncated LeCun normal init
18
+ self.weights = nn.Buffer(
19
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
20
+ )
21
+
22
+ # Local weights and IDs
23
+ # Local embeddings, with gradient, not persistent
24
+ self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
25
+ # Local embedding IDs, not persistent
26
+ self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
27
+
28
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
29
+ if not self.training:
30
+ # Test mode, no gradient
31
+ return self.weights[inputs].to(self.cast_to)
32
+
33
+ # Training mode, fill puzzle embedding from weights
34
+ with torch.no_grad():
35
+ self.local_weights.copy_(self.weights[inputs])
36
+ self.local_ids.copy_(inputs)
37
+
38
+ return self.local_weights.to(self.cast_to)
39
+
40
+
41
+ class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
42
+ def __init__(
43
+ self,
44
+ params: ParamsT,
45
+
46
+ world_size: int,
47
+ lr: Union[float, torch.Tensor] = 1e-3,
48
+ weight_decay: float = 1e-2,
49
+ ):
50
+ if not 0.0 <= lr:
51
+ raise ValueError(f"Invalid learning rate: {lr}")
52
+ if not 0.0 <= weight_decay:
53
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54
+
55
+ defaults = dict(
56
+ lr=lr,
57
+ weight_decay=weight_decay,
58
+ world_size=world_size
59
+ )
60
+ super().__init__(params, defaults)
61
+
62
+ @torch.no_grad
63
+ def step(self, closure=None): # type: ignore
64
+ for group in self.param_groups:
65
+ # Find the sparse embedding weights
66
+ local_weights_grad = None
67
+ local_ids = None
68
+ weights = None
69
+
70
+ assert len(group["params"]) == 3
71
+ for p in group["params"]:
72
+ if p.requires_grad:
73
+ local_weights_grad = p.grad
74
+ elif p.ndim == 1:
75
+ local_ids = p
76
+ elif p.ndim == 2:
77
+ weights = p
78
+ else:
79
+ assert False
80
+
81
+ assert local_weights_grad is not None
82
+ assert local_ids is not None
83
+ assert weights is not None
84
+
85
+ # Apply SignSGD
86
+ # Adam ≈ SignSGD if gradient is very sparse
87
+ _sparse_emb_signsgd_dist(
88
+ local_weights_grad,
89
+ local_ids,
90
+ weights,
91
+
92
+ lr=group["lr"],
93
+ weight_decay=group["weight_decay"],
94
+ world_size=group["world_size"]
95
+ )
96
+
97
+
98
+ def _sparse_emb_signsgd_dist(
99
+ local_weights_grad: torch.Tensor,
100
+ local_ids: torch.Tensor,
101
+ weights: torch.Tensor,
102
+
103
+ lr: float,
104
+ weight_decay: float,
105
+ world_size: int
106
+ ) -> None:
107
+ N, D = local_weights_grad.shape
108
+
109
+ # All-gather
110
+ all_weights_grad = local_weights_grad
111
+ all_ids = local_ids
112
+
113
+ if world_size > 1:
114
+ all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
115
+ all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
116
+
117
+ dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
118
+ dist.all_gather_into_tensor(all_ids, local_ids)
119
+
120
+ # Unique
121
+ grad_ids, inv = all_ids.unique(return_inverse=True)
122
+
123
+ grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
124
+ grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
125
+
126
+ # SignSGD with decoupled weight decay
127
+ p = weights[grad_ids]
128
+
129
+ p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
130
+
131
+ # Write updated slices back
132
+ weights[grad_ids] = p
pretrain.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Any, Sequence, List
2
+ from dataclasses import dataclass
3
+ import os
4
+ import math
5
+ import yaml
6
+ import shutil
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader
12
+
13
+ import tqdm
14
+ import wandb
15
+ import coolname
16
+ import hydra
17
+ import pydantic
18
+ from omegaconf import DictConfig
19
+ from wandb.util import make_artifact_name_safe
20
+ from adam_atan2 import AdamATan2
21
+
22
+ from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
23
+ from utils.functions import load_model_class, get_model_source_path
24
+ from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
25
+
26
+
27
+ class LossConfig(pydantic.BaseModel):
28
+ model_config = pydantic.ConfigDict(extra='allow')
29
+
30
+ name: str
31
+
32
+
33
+ class ArchConfig(pydantic.BaseModel):
34
+ model_config = pydantic.ConfigDict(extra='allow')
35
+
36
+ name: str
37
+ loss: LossConfig
38
+
39
+
40
+ class PretrainConfig(pydantic.BaseModel):
41
+ # Config
42
+ arch: ArchConfig
43
+ # Data
44
+ data_path: str
45
+
46
+ # Hyperparams
47
+ global_batch_size: int
48
+ epochs: int
49
+
50
+ lr: float
51
+ lr_min_ratio: float
52
+ lr_warmup_steps: int
53
+
54
+ weight_decay: float
55
+ beta1: float
56
+ beta2: float
57
+
58
+ # Puzzle embedding
59
+ puzzle_emb_lr: float
60
+ puzzle_emb_weight_decay: float
61
+
62
+ # Names
63
+ project_name: Optional[str] = None
64
+ run_name: Optional[str] = None
65
+ checkpoint_path: Optional[str] = None
66
+
67
+ # Extras
68
+ seed: int = 0
69
+ checkpoint_every_eval: bool = False
70
+ eval_interval: Optional[int] = None
71
+ eval_save_outputs: List[str] = []
72
+
73
+
74
+ @dataclass
75
+ class TrainState:
76
+ model: nn.Module
77
+ optimizers: Sequence[torch.optim.Optimizer]
78
+ optimizer_lrs: Sequence[float]
79
+ carry: Any
80
+
81
+ step: int
82
+ total_steps: int
83
+
84
+
85
+ def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
86
+ dataset = PuzzleDataset(PuzzleDatasetConfig(
87
+ seed=config.seed,
88
+
89
+ dataset_path=config.data_path,
90
+
91
+ rank=rank,
92
+ num_replicas=world_size,
93
+
94
+ **kwargs
95
+ ), split=split)
96
+ dataloader = DataLoader(
97
+ dataset,
98
+ batch_size=None,
99
+
100
+ num_workers=1,
101
+ prefetch_factor=8,
102
+
103
+ pin_memory=True,
104
+ persistent_workers=True
105
+ )
106
+ return dataloader, dataset.metadata
107
+
108
+
109
+ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
110
+ model_cfg = dict(
111
+ **config.arch.__pydantic_extra__, # type: ignore
112
+
113
+ batch_size=config.global_batch_size // world_size,
114
+
115
+ vocab_size=train_metadata.vocab_size,
116
+ seq_len=train_metadata.seq_len,
117
+ num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
118
+ causal=False # Non-autoregressive
119
+ )
120
+
121
+ # Instantiate model with loss head
122
+ model_cls = load_model_class(config.arch.name)
123
+ loss_head_cls = load_model_class(config.arch.loss.name)
124
+
125
+ with torch.device("cuda"):
126
+ model: nn.Module = model_cls(model_cfg)
127
+ model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
128
+ if "DISABLE_COMPILE" not in os.environ:
129
+ model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore
130
+
131
+ # Broadcast parameters from rank 0
132
+ if world_size > 1:
133
+ with torch.no_grad():
134
+ for param in list(model.parameters()) + list(model.buffers()):
135
+ dist.broadcast(param, src=0)
136
+
137
+ # Optimizers and lr
138
+ optimizers = [
139
+ CastedSparseEmbeddingSignSGD_Distributed(
140
+ model.model.puzzle_emb.buffers(), # type: ignore
141
+
142
+ lr=0, # Needs to be set by scheduler
143
+ weight_decay=config.puzzle_emb_weight_decay,
144
+
145
+ world_size=world_size
146
+ ),
147
+ AdamATan2(
148
+ model.parameters(),
149
+
150
+ lr=0, # Needs to be set by scheduler
151
+ weight_decay=config.weight_decay,
152
+ betas=(config.beta1, config.beta2)
153
+ )
154
+ ]
155
+ optimizer_lrs = [
156
+ config.puzzle_emb_lr,
157
+ config.lr
158
+ ]
159
+
160
+ return model, optimizers, optimizer_lrs
161
+
162
+
163
+ def cosine_schedule_with_warmup_lr_lambda(
164
+ current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
165
+ ):
166
+ if current_step < num_warmup_steps:
167
+ return base_lr * float(current_step) / float(max(1, num_warmup_steps))
168
+
169
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
170
+ return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
171
+
172
+
173
+ def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
174
+ # Estimated total training steps
175
+ total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
176
+
177
+ # Model
178
+ model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)
179
+
180
+ return TrainState(
181
+ step=0,
182
+ total_steps=total_steps,
183
+
184
+ model=model,
185
+ optimizers=optimizers,
186
+ optimizer_lrs=optimizer_lrs,
187
+ carry=None
188
+ )
189
+
190
+
191
+ def save_train_state(config: PretrainConfig, train_state: TrainState):
192
+ # FIXME: Only saved model.
193
+ if config.checkpoint_path is None:
194
+ return
195
+
196
+ os.makedirs(config.checkpoint_path, exist_ok=True)
197
+ torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
198
+
199
+
200
+ def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
201
+ return cosine_schedule_with_warmup_lr_lambda(
202
+ current_step=train_state.step,
203
+ base_lr=base_lr,
204
+ num_warmup_steps=round(config.lr_warmup_steps),
205
+ num_training_steps=train_state.total_steps,
206
+ min_ratio=config.lr_min_ratio
207
+ )
208
+
209
+
210
+ def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
211
+ train_state.step += 1
212
+ if train_state.step > train_state.total_steps: # At most train_total_steps
213
+ return
214
+
215
+ # To device
216
+ batch = {k: v.cuda() for k, v in batch.items()}
217
+
218
+ # Init carry if it is None
219
+ if train_state.carry is None:
220
+ with torch.device("cuda"):
221
+ train_state.carry = train_state.model.initial_carry(batch) # type: ignore
222
+
223
+ # Forward
224
+ train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
225
+
226
+ ((1 / global_batch_size) * loss).backward()
227
+
228
+ # Allreduce
229
+ if world_size > 1:
230
+ for param in train_state.model.parameters():
231
+ if param.grad is not None:
232
+ dist.all_reduce(param.grad)
233
+
234
+ # Apply optimizer
235
+ lr_this_step = None
236
+ for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
237
+ lr_this_step = compute_lr(base_lr, config, train_state)
238
+
239
+ for param_group in optim.param_groups:
240
+ param_group['lr'] = lr_this_step
241
+
242
+ optim.step()
243
+ optim.zero_grad()
244
+
245
+ # Reduce metrics
246
+ if len(metrics):
247
+ assert not any(v.requires_grad for v in metrics.values())
248
+
249
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
250
+ # Reduce and reconstruct
251
+ metric_values = torch.stack([metrics[k] for k in metric_keys])
252
+ if world_size > 1:
253
+ dist.reduce(metric_values, dst=0)
254
+
255
+ if rank == 0:
256
+ metric_values = metric_values.cpu().numpy()
257
+ reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
258
+
259
+ # Postprocess
260
+ count = max(reduced_metrics["count"], 1) # Avoid NaNs
261
+ reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
262
+
263
+ reduced_metrics["train/lr"] = lr_this_step
264
+ return reduced_metrics
265
+
266
+
267
+ def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
268
+ with torch.inference_mode():
269
+ set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
270
+
271
+ all_preds = {}
272
+
273
+ metric_keys = []
274
+ metric_values = None
275
+ metric_global_batch_size = [0 for _ in range(len(set_ids))]
276
+
277
+ carry = None
278
+ for set_name, batch, global_batch_size in eval_loader:
279
+ # To device
280
+ batch = {k: v.cuda() for k, v in batch.items()}
281
+ with torch.device("cuda"):
282
+ carry = train_state.model.initial_carry(batch) # type: ignore
283
+
284
+ # Forward
285
+ while True:
286
+ carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
287
+
288
+ if all_finish:
289
+ break
290
+
291
+ for collection in (batch, preds):
292
+ for k, v in collection.items():
293
+ if k in config.eval_save_outputs:
294
+ all_preds.setdefault(k, [])
295
+ all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
296
+
297
+ del carry, preds, batch, all_finish
298
+
299
+ # Aggregate
300
+ set_id = set_ids[set_name]
301
+
302
+ if metric_values is None:
303
+ metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
304
+ metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
305
+
306
+ metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
307
+ metric_global_batch_size[set_id] += global_batch_size
308
+
309
+ if len(all_preds) and config.checkpoint_path is not None:
310
+ all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
311
+
312
+ os.makedirs(config.checkpoint_path, exist_ok=True)
313
+ torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))
314
+
315
+ # Logging
316
+ # Reduce to rank 0
317
+ if metric_values is not None:
318
+ if world_size > 1:
319
+ dist.reduce(metric_values, dst=0)
320
+
321
+ if rank == 0:
322
+ reduced_metrics = metric_values.cpu().numpy()
323
+ reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
324
+ for set_id, set_name in enumerate(set_ids)}
325
+
326
+ # Postprocess
327
+ for set_name, metrics in reduced_metrics.items():
328
+ count = metrics.pop("count")
329
+ reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}
330
+
331
+ return reduced_metrics
332
+
333
+
334
+ def save_code_and_config(config: PretrainConfig):
335
+ if config.checkpoint_path is None or wandb.run is None:
336
+ return
337
+
338
+ os.makedirs(config.checkpoint_path, exist_ok=True)
339
+
340
+ # Copy code
341
+ code_list = [
342
+ get_model_source_path(config.arch.name),
343
+ get_model_source_path(config.arch.loss.name)
344
+ ]
345
+ for code_file in code_list:
346
+ if code_file is not None:
347
+ code_name = os.path.basename(code_file)
348
+
349
+ shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
350
+
351
+ # Dump config as yaml
352
+ config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
353
+ with open(config_file, "wt") as f:
354
+ yaml.dump(config.model_dump(), f)
355
+
356
+ # Log code
357
+ wandb.run.log_code(config.checkpoint_path)
358
+
359
+
360
+ def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
361
+ objects = [None]
362
+ if rank == 0:
363
+ config = PretrainConfig(**hydra_config) # type: ignore
364
+
365
+ # Naming
366
+ if config.project_name is None:
367
+ config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
368
+ if config.run_name is None:
369
+ config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
370
+ if config.checkpoint_path is None:
371
+ config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
372
+
373
+ objects = [config]
374
+
375
+ if world_size > 1:
376
+ dist.broadcast_object_list(objects, src=0)
377
+
378
+ return objects[0] # type: ignore
379
+
380
+
381
+ @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
382
+ def launch(hydra_config: DictConfig):
383
+ RANK = 0
384
+ WORLD_SIZE = 1
385
+
386
+ # Initialize distributed training if in distributed environment (e.g. torchrun)
387
+ if "LOCAL_RANK" in os.environ:
388
+ # Initialize distributed, default device and dtype
389
+ dist.init_process_group(backend="nccl")
390
+
391
+ RANK = dist.get_rank()
392
+ WORLD_SIZE = dist.get_world_size()
393
+
394
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
395
+
396
+ # Load sync'ed config
397
+ config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
398
+
399
+ # Seed RNGs to ensure consistency
400
+ torch.random.manual_seed(config.seed + RANK)
401
+
402
+ # Dataset
403
+ train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
404
+ total_iters = config.epochs // train_epochs_per_iter
405
+
406
+ assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
407
+
408
+ train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
409
+ eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
410
+
411
+ # Train state
412
+ train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
413
+
414
+ # Progress bar and logger
415
+ progress_bar = None
416
+ if RANK == 0:
417
+ progress_bar = tqdm.tqdm(total=train_state.total_steps)
418
+
419
+ wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
420
+ wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
421
+ save_code_and_config(config)
422
+
423
+ # Training Loop
424
+ for _iter_id in range(total_iters):
425
+ print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
426
+
427
+ ############ Train Iter
428
+ train_state.model.train()
429
+ for set_name, batch, global_batch_size in train_loader:
430
+ metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
431
+
432
+ if RANK == 0 and metrics is not None:
433
+ wandb.log(metrics, step=train_state.step)
434
+ progress_bar.update(train_state.step - progress_bar.n) # type: ignore
435
+
436
+ ############ Evaluation
437
+ train_state.model.eval()
438
+ metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
439
+
440
+ if RANK == 0 and metrics is not None:
441
+ wandb.log(metrics, step=train_state.step)
442
+
443
+ ############ Checkpointing
444
+ if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
445
+ save_train_state(config, train_state)
446
+
447
+ # finalize
448
+ if dist.is_initialized():
449
+ dist.destroy_process_group()
450
+ wandb.finish()
451
+
452
+
453
+ if __name__ == "__main__":
454
+ launch()
puzzle_dataset.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import numpy as np
5
+ import pydantic
6
+
7
+ import torch
8
+ from torch.utils.data import IterableDataset, get_worker_info
9
+
10
+ from models.losses import IGNORE_LABEL_ID
11
+ from dataset.common import PuzzleDatasetMetadata
12
+
13
+
14
+ def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
15
+ # Pack examples into a full batch
16
+ batch = []
17
+ batch_puzzle_indices = []
18
+ current_size = 0
19
+
20
+ while (start_index < group_order.size) and (current_size < global_batch_size):
21
+ # Pick a group and a puzzle from that group
22
+ group_id = group_order[start_index]
23
+ puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
24
+ start_index += 1
25
+
26
+ # Get range of the puzzle
27
+ puzzle_start = puzzle_indices[puzzle_id]
28
+ puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
29
+
30
+ append_size = min(puzzle_size, global_batch_size - current_size)
31
+
32
+ # Put into batch
33
+ batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
34
+ batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
35
+
36
+ current_size += append_size
37
+
38
+ return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
39
+
40
+
41
+ class PuzzleDatasetConfig(pydantic.BaseModel):
42
+ seed: int
43
+ dataset_path: str
44
+ global_batch_size: int
45
+ test_set_mode: bool
46
+
47
+ epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
48
+
49
+ rank: int
50
+ num_replicas: int
51
+
52
+
53
+ class PuzzleDataset(IterableDataset):
54
+ def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
55
+ super().__init__()
56
+ self.config = config
57
+ self.split = split
58
+ self.metadata = self._load_metadata()
59
+
60
+ # Checks
61
+ assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
62
+ self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
63
+
64
+ # State
65
+ self._data = None
66
+ self._iters = 0
67
+
68
+ def _load_metadata(self) -> PuzzleDatasetMetadata:
69
+ with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
70
+ return PuzzleDatasetMetadata(**json.load(f))
71
+
72
+ def _lazy_load_dataset(self):
73
+ if self._data is not None:
74
+ return
75
+
76
+ field_mmap_modes = {
77
+ "inputs": "r",
78
+ "labels": "r",
79
+
80
+ # Keep indices in memory
81
+ "puzzle_identifiers": None,
82
+ "puzzle_indices": None,
83
+ "group_indices": None
84
+ }
85
+
86
+ # Load data
87
+ self._data = {}
88
+ for set_name in self.metadata.sets:
89
+ # Load subset
90
+ self._data[set_name] = {
91
+ field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
92
+ for field_name, mmap_mode in field_mmap_modes.items()
93
+ }
94
+
95
+ def _collate_batch(self, batch):
96
+ # Convert dtype
97
+ batch = {k: v.astype(np.int32) for k, v in batch.items()}
98
+
99
+ # Convert ignore label IDs
100
+ if self.metadata.ignore_label_id is not None:
101
+ batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
102
+
103
+ # Pad
104
+ if batch["puzzle_identifiers"].size < self.local_batch_size:
105
+ pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
106
+
107
+ pad_values = {
108
+ "inputs": self.metadata.pad_id,
109
+ "labels": IGNORE_LABEL_ID,
110
+
111
+ "puzzle_identifiers": self.metadata.blank_identifier_id
112
+ }
113
+ batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
114
+
115
+ # To tensor
116
+ return {k: torch.from_numpy(v) for k, v in batch.items()}
117
+
118
+ def _iter_test(self):
119
+ for set_name, dataset in self._data.items(): # type: ignore
120
+ total_examples = len(dataset["inputs"])
121
+
122
+ # Load examples one by one
123
+ start_index = 0
124
+ while start_index < total_examples:
125
+ # Compute indices
126
+ end_index = min(total_examples, start_index + self.config.global_batch_size)
127
+
128
+ local_start = start_index + self.config.rank * self.local_batch_size
129
+ local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
130
+
131
+ # Get batch of examples, and also puzzle IDs
132
+ puzzle_indices = []
133
+ puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
134
+ for i in range(local_start, local_end):
135
+ while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
136
+ puzzle_index += 1
137
+
138
+ puzzle_indices.append(puzzle_index)
139
+
140
+ batch = self._collate_batch({
141
+ "inputs": dataset["inputs"][local_start: local_end],
142
+ "labels": dataset["labels"][local_start: local_end],
143
+ "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
144
+ })
145
+
146
+ yield set_name, batch, end_index - start_index
147
+
148
+ # Advance to next batch
149
+ start_index += self.config.global_batch_size
150
+
151
+ def _iter_train(self):
152
+ for set_name, dataset in self._data.items(): # type: ignore
153
+ # Increase epoch count
154
+ self._iters += 1
155
+
156
+ # Randomly shuffle groups
157
+ rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
158
+
159
+ group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
160
+ start_index = 0
161
+
162
+ while start_index < group_order.size:
163
+ start_index, batch_indices, batch_puzzle_indices = _sample_batch(
164
+ rng,
165
+ group_order=group_order,
166
+ puzzle_indices=dataset["puzzle_indices"],
167
+ group_indices=dataset["group_indices"],
168
+ start_index=start_index,
169
+ global_batch_size=self.config.global_batch_size,
170
+ )
171
+
172
+ # Select current rank and collate
173
+ global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
174
+
175
+ # Drop last batch
176
+ if global_effective_batch_size < self.config.global_batch_size:
177
+ break
178
+
179
+ batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
180
+ batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
181
+ batch = self._collate_batch({
182
+ "inputs": dataset["inputs"][batch_indices],
183
+ "labels": dataset["labels"][batch_indices],
184
+ "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
185
+ })
186
+
187
+ yield set_name, batch, global_effective_batch_size
188
+
189
+ def __iter__(self):
190
+ worker_info = get_worker_info()
191
+ assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
192
+
193
+ self._lazy_load_dataset()
194
+
195
+ # Iterate using specified mode
196
+ if self.config.test_set_mode:
197
+ yield from self._iter_test()
198
+ else:
199
+ yield from self._iter_train()
puzzle_visualizer.html ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>
6
+ <style>
7
+ body {
8
+ font-family: sans-serif;
9
+ margin: 16px;
10
+ }
11
+ .selector-area {
12
+ margin-bottom: 1rem;
13
+ }
14
+ .grid-canvas {
15
+ margin: 4px;
16
+ border: 1px solid #ccc;
17
+ }
18
+ .example-container {
19
+ display: inline-block;
20
+ margin: 0 16px 16px 0;
21
+ vertical-align: top;
22
+ }
23
+ .puzzle-display {
24
+ margin-top: 1rem;
25
+ }
26
+ .puzzle-id {
27
+ font-weight: bold;
28
+ margin-bottom: 0.5rem;
29
+ }
30
+ #groupList, #puzzleList {
31
+ margin: 1rem 0;
32
+ }
33
+ .group-item, .puzzle-item {
34
+ cursor: pointer;
35
+ margin: 4px 8px 4px 0;
36
+ padding: 2px 6px;
37
+ border: 1px solid #aaa;
38
+ display: inline-block;
39
+ }
40
+ .group-item:hover, .puzzle-item:hover {
41
+ background: #eef;
42
+ }
43
+ </style>
44
+ </head>
45
+ <body>
46
+ <h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>
47
+
48
+ <div class="selector-area">
49
+ <!-- 1) Directory input with webkitdirectory, mozdirectory -->
50
+ <label>Upload ARC Folder:</label>
51
+ <input type="file" id="folderInput"
52
+ webkitdirectory mozdirectory multiple
53
+ onchange="onFolderSelected(event)" />
54
+ <br><br>
55
+
56
+ <!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->
57
+ <label>Set:</label>
58
+ <select id="setSelect" disabled>
59
+ <option value="train">train</option>
60
+ <option value="test">test</option>
61
+ </select>
62
+
63
+ <label> Subset:</label>
64
+ <select id="subsetSelect" disabled>
65
+ <option value="all">all</option>
66
+ </select>
67
+
68
+ <button id="loadBtn" disabled>Load</button>
69
+ </div>
70
+
71
+ <div>
72
+ <div id="groupList"></div>
73
+ <div id="puzzleList"></div>
74
+ <div class="puzzle-display" id="puzzleView"></div>
75
+ </div>
76
+
77
+ <!--
78
+ 3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.
79
+ Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't
80
+ contain "import" statements.
81
+ -->
82
+ <script src="assets/npyjs.js"></script>
83
+
84
+ <script>
85
+ /***************************************************************************
86
+ * Global Maps & Variables
87
+ ***************************************************************************/
88
+
89
+ // Map from "train/all__inputs.npy" => File, etc.
90
+ let filesByPath = {};
91
+
92
+ // Once loaded, we store typed arrays for the chosen set/subset
93
+ let inputsArr, labelsArr;
94
+ let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;
95
+ let identifiersJson;
96
+
97
+ // The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize
98
+ let seqLen = 0;
99
+ let gridSize = 0;
100
+
101
+
102
+ /***************************************************************************
103
+ * 1) Handle folder selection: read all files, find identifiers.json,
104
+ * remove topmost folder from each file path, validate.
105
+ ***************************************************************************/
106
+ function onFolderSelected(event) {
107
+ filesByPath = {};
108
+ const fileList = event.target.files;
109
+ if (!fileList || fileList.length === 0) {
110
+ alert("No files selected!");
111
+ return;
112
+ }
113
+
114
+ // We'll gather all webkitRelativePaths
115
+ const paths = [];
116
+ for (let i = 0; i < fileList.length; i++) {
117
+ // Typically "arc-aug-10/train/all__inputs.npy", etc.
118
+ const file = fileList[i];
119
+ const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
120
+ paths.push(relPath);
121
+ }
122
+
123
+ // 1. Check if we have "identifiers.json" somewhere.
124
+ const idPath = paths.find(p => p.endsWith("identifiers.json"));
125
+ if (!idPath) {
126
+ alert("Error: No 'identifiers.json' found in the uploaded folder.");
127
+ return;
128
+ }
129
+
130
+ // 2. Derive the top-level directory from that file's path
131
+ // e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10"
132
+ // If there's no slash, topDir = "" => do nothing
133
+ let topDir = "";
134
+ const lastSlash = idPath.lastIndexOf("/");
135
+ if (lastSlash >= 0) {
136
+ topDir = idPath.substring(0, lastSlash);
137
+ }
138
+
139
+ // 3. Rebuild filesByPath with the top folder removed.
140
+ // For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy"
141
+ // becomes "train/all__inputs.npy"
142
+ for (let i = 0; i < fileList.length; i++) {
143
+ const file = fileList[i];
144
+ let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
145
+ // If relPath starts with "arc-aug-10/", remove that prefix
146
+ if (topDir && relPath.startsWith(topDir + "/")) {
147
+ relPath = relPath.substring(topDir.length + 1);
148
+ }
149
+ filesByPath[relPath] = file;
150
+ }
151
+
152
+ // Enable set/subset selection and "Load"
153
+ document.getElementById("setSelect").disabled = false;
154
+ document.getElementById("subsetSelect").disabled = false;
155
+ document.getElementById("loadBtn").disabled = false;
156
+ }
157
+
158
+ // When user clicks "Load," parse the .npy for the chosen set/subset
159
+ document.getElementById("loadBtn").addEventListener("click", async () => {
160
+ document.getElementById("groupList").innerHTML = "";
161
+ document.getElementById("puzzleList").innerHTML = "";
162
+ document.getElementById("puzzleView").innerHTML = "";
163
+
164
+ const setName = document.getElementById("setSelect").value; // e.g. "train"
165
+ const subsetName = document.getElementById("subsetSelect").value; // e.g. "all"
166
+
167
+ try {
168
+ await loadDataset(setName, subsetName);
169
+ buildGroupList(); // show groups
170
+ } catch (err) {
171
+ console.error(err);
172
+ alert("Error while loading dataset: " + err);
173
+ }
174
+ });
175
+
176
+
177
+ /***************************************************************************
178
+ * 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)
179
+ ***************************************************************************/
180
+ async function loadDataset(setName, subsetName) {
181
+ const prefix = `${setName}/${subsetName}__`;
182
+ // e.g. "train/all__inputs.npy"
183
+ const inputsPath = prefix + "inputs.npy";
184
+ const labelsPath = prefix + "labels.npy";
185
+ const pIdxPath = prefix + "puzzle_indices.npy";
186
+ const gIdxPath = prefix + "group_indices.npy";
187
+ const pIdsPath = prefix + "puzzle_identifiers.npy";
188
+ const identifiersPath = "identifiers.json";
189
+
190
+ // Check existence
191
+ const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];
192
+ for (const f of needed) {
193
+ if (!filesByPath[f]) {
194
+ throw new Error(`Missing file: ${f}`);
195
+ }
196
+ }
197
+
198
+ // parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array
199
+ const inputsNpy = await parseNpy(filesByPath[inputsPath]);
200
+ const labelsNpy = await parseNpy(filesByPath[labelsPath]);
201
+ const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);
202
+ const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);
203
+ const puzzleIdsNpy = await parseNpy(filesByPath[pIdsPath]);
204
+
205
+ inputsArr = inputsNpy.data;
206
+ labelsArr = labelsNpy.data;
207
+ puzzleIndicesArr = puzzleIndicesNpy.data;
208
+ groupIndicesArr = groupIndicesNpy.data;
209
+ puzzleIdentifiersArr = puzzleIdsNpy.data;
210
+
211
+ // shape e.g. [N_examples, seqLen]
212
+ seqLen = inputsNpy.shape[1];
213
+ gridSize = Math.sqrt(seqLen);
214
+
215
+ // read JSON
216
+ identifiersJson = await readJsonFile(filesByPath[identifiersPath]);
217
+ }
218
+
219
+ /***************************************************************************
220
+ * parseNpy => read a File as ArrayBuffer, parse with npyjs
221
+ ***************************************************************************/
222
+ function parseNpy(file) {
223
+ return new Promise((resolve, reject) => {
224
+ const reader = new FileReader();
225
+ reader.onload = async () => {
226
+ try {
227
+ const arrayBuffer = reader.result;
228
+ const npy = new npyjs();
229
+ resolve(await npy.parse(arrayBuffer));
230
+ } catch (err) {
231
+ reject(err);
232
+ }
233
+ };
234
+ reader.onerror = err => reject(err);
235
+ reader.readAsArrayBuffer(file);
236
+ });
237
+ }
238
+
239
+ /***************************************************************************
240
+ * readJsonFile => read a local JSON file into object
241
+ ***************************************************************************/
242
+ function readJsonFile(file) {
243
+ return new Promise((resolve, reject) => {
244
+ const reader = new FileReader();
245
+ reader.onload = () => {
246
+ try {
247
+ const obj = JSON.parse(reader.result);
248
+ resolve(obj);
249
+ } catch (err) {
250
+ reject(err);
251
+ }
252
+ };
253
+ reader.onerror = (err) => reject(err);
254
+ reader.readAsText(file);
255
+ });
256
+ }
257
+
258
+ /***************************************************************************
259
+ * 3) Build group list in UI
260
+ ***************************************************************************/
261
+ function buildGroupList() {
262
+ document.getElementById("groupList").innerHTML = "<h3>Groups</h3>";
263
+ const groupListDiv = document.getElementById("groupList");
264
+
265
+ const nGroups = groupIndicesArr.length - 1;
266
+ for (let g = 0; g < nGroups; g++) {
267
+ const div = document.createElement("span");
268
+ div.className = "group-item";
269
+ div.textContent = `Group ${g}`;
270
+ div.onclick = () => onSelectGroup(g);
271
+ groupListDiv.appendChild(div);
272
+ }
273
+ }
274
+
275
+ /***************************************************************************
276
+ * onSelectGroup => show puzzles in that group
277
+ ***************************************************************************/
278
+ function onSelectGroup(groupIndex) {
279
+ document.getElementById("puzzleList").innerHTML = "";
280
+ document.getElementById("puzzleView").innerHTML = "";
281
+
282
+ const puzzleListDiv = document.getElementById("puzzleList");
283
+ puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;
284
+
285
+ const firstPuzzle = groupIndicesArr[groupIndex];
286
+ const lastPuzzle = groupIndicesArr[groupIndex + 1];
287
+
288
+ for (let p = firstPuzzle; p < lastPuzzle; p++) {
289
+ const puzzleIntId = puzzleIdentifiersArr[p];
290
+ const puzzleStrId = (puzzleIntId < identifiersJson.length)
291
+ ? identifiersJson[puzzleIntId]
292
+ : "<unknown>";
293
+
294
+ const div = document.createElement("span");
295
+ div.className = "puzzle-item";
296
+ div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;
297
+ div.onclick = () => onSelectPuzzle(p);
298
+ puzzleListDiv.appendChild(div);
299
+ }
300
+ }
301
+
302
+ /***************************************************************************
303
+ * onSelectPuzzle => show each example
304
+ ***************************************************************************/
305
+ function onSelectPuzzle(puzzleIndex) {
306
+ const puzzleView = document.getElementById("puzzleView");
307
+ puzzleView.innerHTML = "";
308
+
309
+ // puzzle ID
310
+ const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];
311
+ const puzzleStrId = (puzzleIntId < identifiersJson.length)
312
+ ? identifiersJson[puzzleIntId]
313
+ : "<unknown>";
314
+
315
+ const titleDiv = document.createElement("div");
316
+ titleDiv.className = "puzzle-id";
317
+ titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;
318
+ puzzleView.appendChild(titleDiv);
319
+
320
+ // Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])
321
+ const firstExample = puzzleIndicesArr[puzzleIndex];
322
+ const lastExample = puzzleIndicesArr[puzzleIndex + 1];
323
+
324
+ for (let e = firstExample; e < lastExample; e++) {
325
+ const inputSeq = slice1D(inputsArr, e*seqLen, (e+1)*seqLen);
326
+ const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);
327
+
328
+ const inputGrid = decodeGrid(inputSeq);
329
+ const outputGrid = decodeGrid(outputSeq);
330
+
331
+ const exDiv = document.createElement("div");
332
+ exDiv.className = "example-container";
333
+ exDiv.appendChild(document.createTextNode(`Example ${e}`));
334
+ exDiv.appendChild(document.createElement("br"));
335
+
336
+ exDiv.appendChild(renderGrid(inputGrid));
337
+ exDiv.appendChild(renderGrid(outputGrid));
338
+
339
+ puzzleView.appendChild(exDiv);
340
+ }
341
+ }
342
+
343
+ /***************************************************************************
344
+ * slice1D => typed array slicing
345
+ ***************************************************************************/
346
+ function slice1D(arr, start, end) {
347
+ const result = new Uint32Array(end - start);
348
+ for (let i = start; i < end; i++) {
349
+ result[i - start] = Number(arr[i]);
350
+ }
351
+ return result;
352
+ }
353
+
354
+ /***************************************************************************
355
+ * decodeGrid => turn the flattened seq of length=gridSize^2 into 2D
356
+ ***************************************************************************/
357
+ function decodeGrid(seq) {
358
+ const grid = [];
359
+ let idx = 0;
360
+ for (let r = 0; r < gridSize; r++) {
361
+ const row = [];
362
+ for (let c = 0; c < gridSize; c++) {
363
+ row.push(seq[idx]);
364
+ idx++;
365
+ }
366
+ grid.push(row);
367
+ }
368
+ return grid;
369
+ }
370
+
371
+ /***************************************************************************
372
+ * renderGrid => draws a 2D grid to <canvas>
373
+ ***************************************************************************/
374
+ function renderGrid(grid2d) {
375
+ const rows = grid2d.length;
376
+ const cols = grid2d[0].length;
377
+ const scale = 10;
378
+
379
+ const canvas = document.createElement("canvas");
380
+ canvas.width = cols * scale;
381
+ canvas.height = rows * scale;
382
+ canvas.className = "grid-canvas";
383
+ const ctx = canvas.getContext("2d");
384
+
385
+ for (let r = 0; r < rows; r++) {
386
+ for (let c = 0; c < cols; c++) {
387
+ const val = grid2d[r][c];
388
+ ctx.fillStyle = indexToColor(val);
389
+ ctx.fillRect(c * scale, r * scale, scale, scale);
390
+ }
391
+ }
392
+ return canvas;
393
+ }
394
+
395
+ /***************************************************************************
396
+ * indexToColor => color palette:
397
+ * 0 => pad => white
398
+ * 1 => eos => light gray
399
+ * 2..11 => original color(0..9)
400
+ ***************************************************************************/
401
+ function indexToColor(value) {
402
+ if (value === 0) return "#FFFFFF"; // pad => white
403
+ if (value === 1) return "#DDDDDD"; // eos => light gray
404
+
405
+ // shift by 2 => original color in [0..9]
406
+ const colorIdx = value - 2;
407
+ const palette = [
408
+ "#000000", // color0 => black
409
+ "#FF0000", // color1 => red
410
+ "#00FF00", // color2 => green
411
+ "#0000FF", // color3 => blue
412
+ "#FFFF00", // color4 => yellow
413
+ "#FFA500", // color5 => orange
414
+ "#800080", // color6 => purple
415
+ "#00FFFF", // color7 => cyan
416
+ "#FFC0CB", // color8 => pink
417
+ "#808080" // color9 => gray
418
+ ];
419
+ if (colorIdx >= 0 && colorIdx < palette.length) {
420
+ return palette[colorIdx];
421
+ }
422
+ return "#FFFFFF"; // fallback
423
+ }
424
+ </script>
425
+ </body>
426
+ </html>
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ adam-atan2
3
+ einops
4
+ tqdm
5
+ coolname
6
+ pydantic
7
+ argdantic
8
+ wandb
9
+ omegaconf
10
+ hydra-core
11
+ huggingface_hub
utils/functions.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+
4
+
5
+ def load_model_class(identifier: str, prefix: str = "models."):
6
+ module_path, class_name = identifier.split('@')
7
+
8
+ # Import the module
9
+ module = importlib.import_module(prefix + module_path)
10
+ cls = getattr(module, class_name)
11
+
12
+ return cls
13
+
14
+
15
+ def get_model_source_path(identifier: str, prefix: str = "models."):
16
+ module_path, class_name = identifier.split('@')
17
+
18
+ module = importlib.import_module(prefix + module_path)
19
+ return inspect.getsourcefile(module)