Release
Browse files- .gitignore +169 -0
- .vscode/launch.json +26 -0
- .vscode/settings.json +3 -0
- README.md +169 -0
- arc_eval.ipynb +252 -0
- assets/hrm.png +0 -0
- assets/npyjs.js +176 -0
- config/arch/hrm_v1.yaml +21 -0
- config/cfg_pretrain.yaml +31 -0
- dataset/build_arc_dataset.py +291 -0
- dataset/build_maze_dataset.py +142 -0
- dataset/build_sudoku_dataset.py +169 -0
- dataset/common.py +51 -0
- evaluate.py +68 -0
- models/common.py +32 -0
- models/hrm/hrm_act_v1.py +283 -0
- models/layers.py +150 -0
- models/losses.py +101 -0
- models/sparse_embedding.py +132 -0
- pretrain.py +454 -0
- puzzle_dataset.py +199 -0
- puzzle_visualizer.html +426 -0
- requirements.txt +11 -0
- utils/functions.py +19 -0
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# WandB
|
2 |
+
/wandb/
|
3 |
+
# checkpoints
|
4 |
+
/checkpoints/
|
5 |
+
# cache
|
6 |
+
/cache/
|
7 |
+
# data
|
8 |
+
/data/
|
9 |
+
|
10 |
+
# Byte-compiled / optimized / DLL files
|
11 |
+
__pycache__/
|
12 |
+
*.py[cod]
|
13 |
+
*$py.class
|
14 |
+
|
15 |
+
# C extensions
|
16 |
+
*.so
|
17 |
+
|
18 |
+
# Distribution / packaging
|
19 |
+
.Python
|
20 |
+
build/
|
21 |
+
develop-eggs/
|
22 |
+
dist/
|
23 |
+
downloads/
|
24 |
+
eggs/
|
25 |
+
.eggs/
|
26 |
+
lib/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
share/python-wheels/
|
33 |
+
*.egg-info/
|
34 |
+
.installed.cfg
|
35 |
+
*.egg
|
36 |
+
MANIFEST
|
37 |
+
|
38 |
+
# PyInstaller
|
39 |
+
# Usually these files are written by a python script from a template
|
40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
41 |
+
*.manifest
|
42 |
+
*.spec
|
43 |
+
|
44 |
+
# Installer logs
|
45 |
+
pip-log.txt
|
46 |
+
pip-delete-this-directory.txt
|
47 |
+
|
48 |
+
# Unit test / coverage reports
|
49 |
+
htmlcov/
|
50 |
+
.tox/
|
51 |
+
.nox/
|
52 |
+
.coverage
|
53 |
+
.coverage.*
|
54 |
+
.cache
|
55 |
+
nosetests.xml
|
56 |
+
coverage.xml
|
57 |
+
*.cover
|
58 |
+
*.py,cover
|
59 |
+
.hypothesis/
|
60 |
+
.pytest_cache/
|
61 |
+
cover/
|
62 |
+
|
63 |
+
# Translations
|
64 |
+
*.mo
|
65 |
+
*.pot
|
66 |
+
|
67 |
+
# Django stuff:
|
68 |
+
*.log
|
69 |
+
local_settings.py
|
70 |
+
db.sqlite3
|
71 |
+
db.sqlite3-journal
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
.pybuilder/
|
85 |
+
target/
|
86 |
+
|
87 |
+
# Jupyter Notebook
|
88 |
+
.ipynb_checkpoints
|
89 |
+
|
90 |
+
# IPython
|
91 |
+
profile_default/
|
92 |
+
ipython_config.py
|
93 |
+
|
94 |
+
# pyenv
|
95 |
+
# For a library or package, you might want to ignore these files since the code is
|
96 |
+
# intended to run in multiple environments; otherwise, check them in:
|
97 |
+
# .python-version
|
98 |
+
|
99 |
+
# pipenv
|
100 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
101 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
102 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
103 |
+
# install all needed dependencies.
|
104 |
+
#Pipfile.lock
|
105 |
+
|
106 |
+
# poetry
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
109 |
+
# commonly ignored for libraries.
|
110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
111 |
+
#poetry.lock
|
112 |
+
|
113 |
+
# pdm
|
114 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
115 |
+
#pdm.lock
|
116 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
117 |
+
# in version control.
|
118 |
+
# https://pdm.fming.dev/#use-with-ide
|
119 |
+
.pdm.toml
|
120 |
+
|
121 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
122 |
+
__pypackages__/
|
123 |
+
|
124 |
+
# Celery stuff
|
125 |
+
celerybeat-schedule
|
126 |
+
celerybeat.pid
|
127 |
+
|
128 |
+
# SageMath parsed files
|
129 |
+
*.sage.py
|
130 |
+
|
131 |
+
# Environments
|
132 |
+
.env
|
133 |
+
.venv
|
134 |
+
env/
|
135 |
+
venv/
|
136 |
+
ENV/
|
137 |
+
env.bak/
|
138 |
+
venv.bak/
|
139 |
+
|
140 |
+
# Spyder project settings
|
141 |
+
.spyderproject
|
142 |
+
.spyproject
|
143 |
+
|
144 |
+
# Rope project settings
|
145 |
+
.ropeproject
|
146 |
+
|
147 |
+
# mkdocs documentation
|
148 |
+
/site
|
149 |
+
|
150 |
+
# mypy
|
151 |
+
.mypy_cache/
|
152 |
+
.dmypy.json
|
153 |
+
dmypy.json
|
154 |
+
|
155 |
+
# Pyre type checker
|
156 |
+
.pyre/
|
157 |
+
|
158 |
+
# pytype static type analyzer
|
159 |
+
.pytype/
|
160 |
+
|
161 |
+
# Cython debug symbols
|
162 |
+
cython_debug/
|
163 |
+
|
164 |
+
# PyCharm
|
165 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
166 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
167 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
168 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
169 |
+
#.idea/
|
.vscode/launch.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python Debugger: Current File",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "${file}",
|
12 |
+
"console": "integratedTerminal"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"name": "Debug: Single GPU",
|
16 |
+
"type": "debugpy",
|
17 |
+
"request": "launch",
|
18 |
+
"program": "pretrain.py",
|
19 |
+
"args": [],
|
20 |
+
"env": {
|
21 |
+
"OMP_NUM_THREADS": "1",
|
22 |
+
"DISABLE_COMPILE": "true"
|
23 |
+
}
|
24 |
+
}
|
25 |
+
]
|
26 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.analysis.typeCheckingMode": "standard"
|
3 |
+
}
|
README.md
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Hierarchical Reasoning Model
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
|
6 |
+
Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
|
7 |
+
HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
|
8 |
+
Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
|
9 |
+
These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.
|
10 |
+
|
11 |
+
## Quick Start Guide 🚀
|
12 |
+
|
13 |
+
### Prerequisites ⚙️
|
14 |
+
|
15 |
+
Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
|
16 |
+
|
17 |
+
```bash
|
18 |
+
# Install CUDA 12.4
|
19 |
+
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run
|
20 |
+
|
21 |
+
wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
|
22 |
+
sudo sh cuda_installer.run --silent --toolkit --override
|
23 |
+
|
24 |
+
export CUDA_HOME=/usr/local/cuda-12.4
|
25 |
+
|
26 |
+
# Install PyTorch with CUDA 12.4
|
27 |
+
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu124
|
28 |
+
|
29 |
+
pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
|
30 |
+
|
31 |
+
# Additional packages for building extensions
|
32 |
+
pip3 install packaging ninja wheel setuptools setuptools-scm
|
33 |
+
```
|
34 |
+
|
35 |
+
## Install Python Dependencies 🐍
|
36 |
+
|
37 |
+
```bash
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
## W&B Integration 📈
|
42 |
+
|
43 |
+
This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
wandb login
|
47 |
+
```
|
48 |
+
|
49 |
+
## Run Experiments
|
50 |
+
|
51 |
+
### Quick Demo: Sudoku Solver 💻🗲
|
52 |
+
|
53 |
+
Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩
|
54 |
+
|
55 |
+
```bash
|
56 |
+
# Download and build Sudoku dataset
|
57 |
+
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
|
58 |
+
|
59 |
+
# Start training (single GPU, smaller batch size)
|
60 |
+
OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
|
61 |
+
```
|
62 |
+
|
63 |
+
Runtime: ~10 hours on a RTX 4070 laptop GPU
|
64 |
+
|
65 |
+
## Full-scale Experiments 🔵
|
66 |
+
|
67 |
+
Experiments below assume an 8-GPU setup.
|
68 |
+
|
69 |
+
### Dataset Preparation
|
70 |
+
|
71 |
+
```bash
|
72 |
+
# Initialize submodules
|
73 |
+
git submodule update --init --recursive
|
74 |
+
|
75 |
+
# ARC-1
|
76 |
+
python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples
|
77 |
+
# ARC-2
|
78 |
+
python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples
|
79 |
+
|
80 |
+
# Sudoku-Extreme
|
81 |
+
python dataset/build_sudoku_dataset.py # Full version
|
82 |
+
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples
|
83 |
+
|
84 |
+
# Maze
|
85 |
+
python dataset/build_maze_dataset.py # 1000 examples
|
86 |
+
```
|
87 |
+
|
88 |
+
### Dataset Visualization
|
89 |
+
|
90 |
+
Explore the puzzles visually:
|
91 |
+
|
92 |
+
* Open `puzzle_visualizer.html` in your browser.
|
93 |
+
* Upload the generated dataset folder located in `data/...`.
|
94 |
+
|
95 |
+
## Launch experiments
|
96 |
+
|
97 |
+
### Small-sample (1K)
|
98 |
+
|
99 |
+
ARC-1:
|
100 |
+
|
101 |
+
```bash
|
102 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
|
103 |
+
```
|
104 |
+
|
105 |
+
*Runtime:* ~24 hours
|
106 |
+
|
107 |
+
ARC-2:
|
108 |
+
|
109 |
+
```bash
|
110 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
|
111 |
+
```
|
112 |
+
|
113 |
+
*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)
|
114 |
+
|
115 |
+
Sudoku Extreme (1k):
|
116 |
+
|
117 |
+
```bash
|
118 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
|
119 |
+
```
|
120 |
+
|
121 |
+
*Runtime:* ~10 minutes
|
122 |
+
|
123 |
+
Maze 30x30 Hard (1k):
|
124 |
+
|
125 |
+
```bash
|
126 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
|
127 |
+
```
|
128 |
+
|
129 |
+
*Runtime:* ~1 hour
|
130 |
+
|
131 |
+
### Full Sudoku-Hard
|
132 |
+
|
133 |
+
```bash
|
134 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
|
135 |
+
```
|
136 |
+
|
137 |
+
*Runtime:* ~2 hours
|
138 |
+
|
139 |
+
## Evaluation
|
140 |
+
|
141 |
+
Evaluate your trained models:
|
142 |
+
|
143 |
+
* Check `eval/exact_accuracy` in W&B.
|
144 |
+
* For ARC-AGI, follow these additional steps:
|
145 |
+
|
146 |
+
```bash
|
147 |
+
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
|
148 |
+
```
|
149 |
+
|
150 |
+
* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.
|
151 |
+
|
152 |
+
## Notes
|
153 |
+
|
154 |
+
- Small-sample learning typically exhibits accuracy variance of around ±2 points.
|
155 |
+
- For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.
|
156 |
+
|
157 |
+
## Citation 📜
|
158 |
+
|
159 |
+
```
|
160 |
+
@misc{wang2025hierarchicalreasoningmodel,
|
161 |
+
title={Hierarchical Reasoning Model},
|
162 |
+
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
|
163 |
+
year={2025},
|
164 |
+
eprint={2506.21734},
|
165 |
+
archivePrefix={arXiv},
|
166 |
+
primaryClass={cs.AI},
|
167 |
+
url={https://arxiv.org/abs/2506.21734},
|
168 |
+
}
|
169 |
+
```
|
arc_eval.ipynb
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import json\n",
|
11 |
+
"from glob import glob\n",
|
12 |
+
"import hashlib\n",
|
13 |
+
"import matplotlib.pyplot as plt\n",
|
14 |
+
"import matplotlib.colors as mcolors\n",
|
15 |
+
"\n",
|
16 |
+
"import torch\n",
|
17 |
+
"import torch.nn.functional as F\n",
|
18 |
+
"import numpy as np\n",
|
19 |
+
"from numba import njit\n",
|
20 |
+
"\n",
|
21 |
+
"from dataset.common import inverse_dihedral_transform\n",
|
22 |
+
"\n",
|
23 |
+
"\n",
|
24 |
+
"DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n",
|
25 |
+
"# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n",
|
26 |
+
"\n",
|
27 |
+
"CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
|
28 |
+
"\n",
|
29 |
+
"\n",
|
30 |
+
"PAD_PUZZLE_IDENTIFIER = 0\n",
|
31 |
+
"\n",
|
32 |
+
"# Visualization\n",
|
33 |
+
"ARC_COLOR_MAP = mcolors.ListedColormap([\n",
|
34 |
+
" \"#000000\", # symbol_0: black\n",
|
35 |
+
" \"#0074D9\", # symbol_1: blue\n",
|
36 |
+
" \"#FF4136\", # symbol_2: red\n",
|
37 |
+
" \"#2ECC40\", # symbol_3: green\n",
|
38 |
+
" \"#FFDC00\", # symbol_4: yellow\n",
|
39 |
+
" \"#AAAAAA\", # symbol_5: grey\n",
|
40 |
+
" \"#F012BE\", # symbol_6: fuschia\n",
|
41 |
+
" \"#FF851B\", # symbol_7: orange\n",
|
42 |
+
" \"#7FDBFF\", # symbol_8: teal\n",
|
43 |
+
" \"#870C25\" # symbol_9: brown\n",
|
44 |
+
"])"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
|
54 |
+
" # Load puzzle identifiers\n",
|
55 |
+
" with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
|
56 |
+
" identifier_map = json.load(f)\n",
|
57 |
+
" \n",
|
58 |
+
" # Load preds\n",
|
59 |
+
" all_preds = {}\n",
|
60 |
+
" for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
|
61 |
+
" preds = torch.load(filename)\n",
|
62 |
+
" for k, v in preds.items():\n",
|
63 |
+
" all_preds.setdefault(k, [])\n",
|
64 |
+
" all_preds[k].append(v)\n",
|
65 |
+
" \n",
|
66 |
+
" del preds\n",
|
67 |
+
"\n",
|
68 |
+
" all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
|
69 |
+
" \n",
|
70 |
+
" # Remove paddings\n",
|
71 |
+
" mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
|
72 |
+
" all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
|
73 |
+
"\n",
|
74 |
+
" return identifier_map, all_preds\n",
|
75 |
+
"\n",
|
76 |
+
"\n",
|
77 |
+
"def inverse_aug(name: str, grid: np.ndarray):\n",
|
78 |
+
" if \"_\" not in name:\n",
|
79 |
+
" return grid\n",
|
80 |
+
"\n",
|
81 |
+
" trans_id, perm = name.split(\"_\")[-2:]\n",
|
82 |
+
" trans_id = int(trans_id[1:]) # Remove \"t\" letter\n",
|
83 |
+
" inv_perm = np.argsort(list(perm))\n",
|
84 |
+
" \n",
|
85 |
+
" return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
|
86 |
+
"\n",
|
87 |
+
"\n",
|
88 |
+
"def grid_hash(grid: np.ndarray):\n",
|
89 |
+
" return hash((grid.tobytes(), grid.shape))\n",
|
90 |
+
"\n",
|
91 |
+
"\n",
|
92 |
+
"@njit\n",
|
93 |
+
"def crop(grid: np.ndarray):\n",
|
94 |
+
" # Find maximum-sized rectangle without any EOS token inside.\n",
|
95 |
+
" grid = grid.reshape(30, 30)\n",
|
96 |
+
"\n",
|
97 |
+
" max_area = 0\n",
|
98 |
+
" max_size = (0, 0)\n",
|
99 |
+
" nr, nc = grid.shape\n",
|
100 |
+
" \n",
|
101 |
+
" num_c = nc\n",
|
102 |
+
" for num_r in range(1, nr + 1):\n",
|
103 |
+
" # Scan for maximum c\n",
|
104 |
+
" for c in range(1, num_c + 1):\n",
|
105 |
+
" x = grid[num_r - 1, c - 1]\n",
|
106 |
+
" if (x < 2) | (x > 11):\n",
|
107 |
+
" num_c = c - 1\n",
|
108 |
+
" break\n",
|
109 |
+
" \n",
|
110 |
+
" area = num_r * num_c\n",
|
111 |
+
" if area > max_area:\n",
|
112 |
+
" max_area = area\n",
|
113 |
+
" max_size = (num_r, num_c)\n",
|
114 |
+
"\n",
|
115 |
+
" return grid[:max_size[0], :max_size[1]] - 2\n",
|
116 |
+
"\n",
|
117 |
+
"\n",
|
118 |
+
"def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
|
119 |
+
" identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
|
120 |
+
" \n",
|
121 |
+
" global_hmap = {}\n",
|
122 |
+
" \n",
|
123 |
+
" # Get puzzles and corresponding answers\n",
|
124 |
+
" puzzle_labels = {}\n",
|
125 |
+
" for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
|
126 |
+
" name = identifier_map[identifier]\n",
|
127 |
+
" if \"_\" not in name: # Not-augmented\n",
|
128 |
+
" puzzle_labels.setdefault(name, {})\n",
|
129 |
+
" \n",
|
130 |
+
" input = crop(input.numpy())\n",
|
131 |
+
" label = crop(label.numpy())\n",
|
132 |
+
"\n",
|
133 |
+
" input_hash = grid_hash(input)\n",
|
134 |
+
" label_hash = grid_hash(label)\n",
|
135 |
+
"\n",
|
136 |
+
" global_hmap[input_hash] = input\n",
|
137 |
+
" global_hmap[label_hash] = label\n",
|
138 |
+
"\n",
|
139 |
+
" assert input_hash not in puzzle_labels[name]\n",
|
140 |
+
" puzzle_labels[name][input_hash] = label_hash\n",
|
141 |
+
" \n",
|
142 |
+
" print (\"Number of puzzles\", len(puzzle_labels))\n",
|
143 |
+
" \n",
|
144 |
+
" # Argmax prediction\n",
|
145 |
+
" preds = all_preds[\"logits\"].argmax(-1)\n",
|
146 |
+
"\n",
|
147 |
+
" # Collate\n",
|
148 |
+
" pred_answers = {}\n",
|
149 |
+
" for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
|
150 |
+
" name = identifier_map[identifier]\n",
|
151 |
+
" orig_name = name.split(\"_\")[0]\n",
|
152 |
+
" \n",
|
153 |
+
" input = input.numpy()\n",
|
154 |
+
" input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
|
155 |
+
" assert input_hash in puzzle_labels[orig_name]\n",
|
156 |
+
" \n",
|
157 |
+
" pred = inverse_aug(name, crop(pred.numpy()))\n",
|
158 |
+
" pred_hash = grid_hash(pred)\n",
|
159 |
+
" global_hmap[pred_hash] = pred\n",
|
160 |
+
" \n",
|
161 |
+
" pred_answers.setdefault(orig_name, {})\n",
|
162 |
+
" pred_answers[orig_name].setdefault(input_hash, [])\n",
|
163 |
+
" pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
|
164 |
+
"\n",
|
165 |
+
" # test-1\n",
|
166 |
+
" if visualize:\n",
|
167 |
+
" num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
|
168 |
+
" fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
|
169 |
+
" \n",
|
170 |
+
" fig_id = 0\n",
|
171 |
+
" \n",
|
172 |
+
" correct = [0 for _ in range(len(Ks))]\n",
|
173 |
+
" for name, tests in puzzle_labels.items():\n",
|
174 |
+
" num_test_correct = [0 for _ in range(len(Ks))]\n",
|
175 |
+
" for input_hash, label_hash in tests.items():\n",
|
176 |
+
" p = pred_answers[name][input_hash]\n",
|
177 |
+
" p_map = {}\n",
|
178 |
+
" \n",
|
179 |
+
" for h, q in p:\n",
|
180 |
+
" p_map.setdefault(h, [0, 0])\n",
|
181 |
+
" p_map[h][0] += 1\n",
|
182 |
+
" p_map[h][1] += q\n",
|
183 |
+
" \n",
|
184 |
+
" for h, stats in p_map.items():\n",
|
185 |
+
" stats[1] /= stats[0]\n",
|
186 |
+
" \n",
|
187 |
+
" p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
|
188 |
+
"\n",
|
189 |
+
" # 2-vote\n",
|
190 |
+
" for i, k in enumerate(Ks):\n",
|
191 |
+
" ok = False\n",
|
192 |
+
" for h, stats in p_map[:k]:\n",
|
193 |
+
" ok |= h == label_hash\n",
|
194 |
+
" \n",
|
195 |
+
" num_test_correct[i] += ok\n",
|
196 |
+
"\n",
|
197 |
+
" if visualize:\n",
|
198 |
+
" # Show input and ground truth\n",
|
199 |
+
" axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
|
200 |
+
" axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
|
201 |
+
" axes[fig_id, 0].axis('off')\n",
|
202 |
+
" \n",
|
203 |
+
" axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
|
204 |
+
" axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
|
205 |
+
" axes[fig_id, 1].axis('off')\n",
|
206 |
+
" \n",
|
207 |
+
" trial_id = 2\n",
|
208 |
+
" for h, stats in p_map[:2]:\n",
|
209 |
+
" ans = global_hmap[h]\n",
|
210 |
+
" \n",
|
211 |
+
" axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
|
212 |
+
" axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
|
213 |
+
" axes[fig_id, trial_id].axis('off')\n",
|
214 |
+
" \n",
|
215 |
+
" trial_id += 1\n",
|
216 |
+
" \n",
|
217 |
+
" fig_id += 1\n",
|
218 |
+
" \n",
|
219 |
+
" # Total correctness\n",
|
220 |
+
" for i in range(len(Ks)):\n",
|
221 |
+
" correct[i] += num_test_correct[i] == len(tests)\n",
|
222 |
+
"\n",
|
223 |
+
" for i, k in enumerate(Ks):\n",
|
224 |
+
" print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
|
225 |
+
"\n",
|
226 |
+
"\n",
|
227 |
+
"test(visualize=False)"
|
228 |
+
]
|
229 |
+
}
|
230 |
+
],
|
231 |
+
"metadata": {
|
232 |
+
"kernelspec": {
|
233 |
+
"display_name": "Python 3",
|
234 |
+
"language": "python",
|
235 |
+
"name": "python3"
|
236 |
+
},
|
237 |
+
"language_info": {
|
238 |
+
"codemirror_mode": {
|
239 |
+
"name": "ipython",
|
240 |
+
"version": 3
|
241 |
+
},
|
242 |
+
"file_extension": ".py",
|
243 |
+
"mimetype": "text/x-python",
|
244 |
+
"name": "python",
|
245 |
+
"nbconvert_exporter": "python",
|
246 |
+
"pygments_lexer": "ipython3",
|
247 |
+
"version": "3.12.10"
|
248 |
+
}
|
249 |
+
},
|
250 |
+
"nbformat": 4,
|
251 |
+
"nbformat_minor": 2
|
252 |
+
}
|
assets/hrm.png
ADDED
![]() |
assets/npyjs.js
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class npyjs {
|
2 |
+
|
3 |
+
constructor(opts) {
|
4 |
+
if (opts && !('convertFloat16' in opts)) {
|
5 |
+
console.warn([
|
6 |
+
"npyjs constructor now accepts {convertFloat16?: boolean}.",
|
7 |
+
"For usage, go to https://github.com/jhuapl-boss/npyjs."
|
8 |
+
].join(" "));
|
9 |
+
}
|
10 |
+
|
11 |
+
this.convertFloat16 = opts?.convertFloat16 ?? true;
|
12 |
+
|
13 |
+
this.dtypes = {
|
14 |
+
"<u1": {
|
15 |
+
name: "uint8",
|
16 |
+
size: 8,
|
17 |
+
arrayConstructor: Uint8Array,
|
18 |
+
},
|
19 |
+
"|u1": {
|
20 |
+
name: "uint8",
|
21 |
+
size: 8,
|
22 |
+
arrayConstructor: Uint8Array,
|
23 |
+
},
|
24 |
+
"<u2": {
|
25 |
+
name: "uint16",
|
26 |
+
size: 16,
|
27 |
+
arrayConstructor: Uint16Array,
|
28 |
+
},
|
29 |
+
"|i1": {
|
30 |
+
name: "int8",
|
31 |
+
size: 8,
|
32 |
+
arrayConstructor: Int8Array,
|
33 |
+
},
|
34 |
+
"<i2": {
|
35 |
+
name: "int16",
|
36 |
+
size: 16,
|
37 |
+
arrayConstructor: Int16Array,
|
38 |
+
},
|
39 |
+
"<u4": {
|
40 |
+
name: "uint32",
|
41 |
+
size: 32,
|
42 |
+
arrayConstructor: Uint32Array,
|
43 |
+
},
|
44 |
+
"<i4": {
|
45 |
+
name: "int32",
|
46 |
+
size: 32,
|
47 |
+
arrayConstructor: Int32Array,
|
48 |
+
},
|
49 |
+
"<u8": {
|
50 |
+
name: "uint64",
|
51 |
+
size: 64,
|
52 |
+
arrayConstructor: BigUint64Array,
|
53 |
+
},
|
54 |
+
"<i8": {
|
55 |
+
name: "int64",
|
56 |
+
size: 64,
|
57 |
+
arrayConstructor: BigInt64Array,
|
58 |
+
},
|
59 |
+
"<f4": {
|
60 |
+
name: "float32",
|
61 |
+
size: 32,
|
62 |
+
arrayConstructor: Float32Array
|
63 |
+
},
|
64 |
+
"<f8": {
|
65 |
+
name: "float64",
|
66 |
+
size: 64,
|
67 |
+
arrayConstructor: Float64Array
|
68 |
+
},
|
69 |
+
"<f2": {
|
70 |
+
name: "float16",
|
71 |
+
size: 16,
|
72 |
+
arrayConstructor: Uint16Array,
|
73 |
+
converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined
|
74 |
+
},
|
75 |
+
};
|
76 |
+
}
|
77 |
+
|
78 |
+
float16ToFloat32Array(float16Array) {
|
79 |
+
const length = float16Array.length;
|
80 |
+
const float32Array = new Float32Array(length);
|
81 |
+
|
82 |
+
for (let i = 0; i < length; i++) {
|
83 |
+
float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);
|
84 |
+
}
|
85 |
+
|
86 |
+
return float32Array;
|
87 |
+
}
|
88 |
+
|
89 |
+
static float16ToFloat32(float16) {
|
90 |
+
// Extract the parts of the float16
|
91 |
+
const sign = (float16 >> 15) & 0x1;
|
92 |
+
const exponent = (float16 >> 10) & 0x1f;
|
93 |
+
const fraction = float16 & 0x3ff;
|
94 |
+
|
95 |
+
// Handle special cases
|
96 |
+
if (exponent === 0) {
|
97 |
+
if (fraction === 0) {
|
98 |
+
// Zero
|
99 |
+
return sign ? -0 : 0;
|
100 |
+
}
|
101 |
+
// Denormalized number
|
102 |
+
return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);
|
103 |
+
} else if (exponent === 0x1f) {
|
104 |
+
if (fraction === 0) {
|
105 |
+
// Infinity
|
106 |
+
return sign ? -Infinity : Infinity;
|
107 |
+
}
|
108 |
+
// NaN
|
109 |
+
return NaN;
|
110 |
+
}
|
111 |
+
|
112 |
+
// Normalized number
|
113 |
+
return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);
|
114 |
+
}
|
115 |
+
|
116 |
+
parse(arrayBufferContents) {
|
117 |
+
// const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
|
118 |
+
const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
|
119 |
+
const offsetBytes = 10 + headerLength;
|
120 |
+
|
121 |
+
const hcontents = new TextDecoder("utf-8").decode(
|
122 |
+
new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
|
123 |
+
);
|
124 |
+
const header = JSON.parse(
|
125 |
+
hcontents
|
126 |
+
.toLowerCase() // True -> true
|
127 |
+
.replace(/'/g, '"')
|
128 |
+
.replace("(", "[")
|
129 |
+
.replace(/,*\),*/g, "]")
|
130 |
+
);
|
131 |
+
const shape = header.shape;
|
132 |
+
const dtype = this.dtypes[header.descr];
|
133 |
+
|
134 |
+
if (!dtype) {
|
135 |
+
console.error(`Unsupported dtype: ${header.descr}`);
|
136 |
+
return null;
|
137 |
+
}
|
138 |
+
|
139 |
+
const nums = new dtype.arrayConstructor(
|
140 |
+
arrayBufferContents,
|
141 |
+
offsetBytes
|
142 |
+
);
|
143 |
+
|
144 |
+
// Convert float16 to float32 if converter exists
|
145 |
+
const data = dtype.converter ? dtype.converter.call(this, nums) : nums;
|
146 |
+
|
147 |
+
return {
|
148 |
+
dtype: dtype.name,
|
149 |
+
data: data,
|
150 |
+
shape,
|
151 |
+
fortranOrder: header.fortran_order
|
152 |
+
};
|
153 |
+
}
|
154 |
+
|
155 |
+
async load(filename, callback, fetchArgs) {
|
156 |
+
/*
|
157 |
+
Loads an array from a stream of bytes.
|
158 |
+
*/
|
159 |
+
fetchArgs = fetchArgs || {};
|
160 |
+
let arrayBuf;
|
161 |
+
// If filename is ArrayBuffer
|
162 |
+
if (filename instanceof ArrayBuffer) {
|
163 |
+
arrayBuf = filename;
|
164 |
+
}
|
165 |
+
// If filename is a file path
|
166 |
+
else {
|
167 |
+
const resp = await fetch(filename, { ...fetchArgs });
|
168 |
+
arrayBuf = await resp.arrayBuffer();
|
169 |
+
}
|
170 |
+
const result = this.parse(arrayBuf);
|
171 |
+
if (callback) {
|
172 |
+
return callback(result);
|
173 |
+
}
|
174 |
+
return result;
|
175 |
+
}
|
176 |
+
}
|
config/arch/hrm_v1.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
|
2 |
+
loss:
|
3 |
+
name: losses@ACTLossHead
|
4 |
+
loss_type: stablemax_cross_entropy
|
5 |
+
|
6 |
+
halt_exploration_prob: 0.1
|
7 |
+
halt_max_steps: 16
|
8 |
+
|
9 |
+
H_cycles: 2
|
10 |
+
L_cycles: 2
|
11 |
+
|
12 |
+
H_layers: 4
|
13 |
+
L_layers: 4
|
14 |
+
|
15 |
+
hidden_size: 512
|
16 |
+
num_heads: 8 # min(2, hidden_size // 64)
|
17 |
+
expansion: 4
|
18 |
+
|
19 |
+
puzzle_emb_ndim: ${.hidden_size}
|
20 |
+
|
21 |
+
pos_encodings: rope
|
config/cfg_pretrain.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ARC training config
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- arch: hrm_v1
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
hydra:
|
8 |
+
output_subdir: null
|
9 |
+
|
10 |
+
# Data path
|
11 |
+
data_path: data/arc-aug-1000
|
12 |
+
|
13 |
+
# Hyperparams - Training
|
14 |
+
global_batch_size: 768
|
15 |
+
|
16 |
+
epochs: 100000
|
17 |
+
eval_interval: 10000
|
18 |
+
checkpoint_every_eval: True
|
19 |
+
|
20 |
+
lr: 1e-4
|
21 |
+
lr_min_ratio: 1.0
|
22 |
+
lr_warmup_steps: 2000
|
23 |
+
|
24 |
+
# Standard hyperparameter settings for LM, as used in Llama
|
25 |
+
beta1: 0.9
|
26 |
+
beta2: 0.95
|
27 |
+
weight_decay: 0.1
|
28 |
+
puzzle_emb_weight_decay: 0.1
|
29 |
+
|
30 |
+
# Hyperparams - Puzzle embeddings training
|
31 |
+
puzzle_emb_lr: 1e-2
|
dataset/build_arc_dataset.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Dict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import hashlib
|
7 |
+
import numpy as np
|
8 |
+
from glob import glob
|
9 |
+
|
10 |
+
from argdantic import ArgParser
|
11 |
+
from pydantic import BaseModel
|
12 |
+
|
13 |
+
from common import PuzzleDatasetMetadata, dihedral_transform
|
14 |
+
|
15 |
+
|
16 |
+
cli = ArgParser()
|
17 |
+
|
18 |
+
|
19 |
+
class DataProcessConfig(BaseModel):
|
20 |
+
# ARC-1
|
21 |
+
dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
|
22 |
+
output_dir: str = "data/arc-aug-1000"
|
23 |
+
|
24 |
+
# ARC-2
|
25 |
+
# dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
|
26 |
+
# output_dir: str = "data/arc-2-aug-1000"
|
27 |
+
|
28 |
+
seed: int = 42
|
29 |
+
num_aug: int = 1000
|
30 |
+
|
31 |
+
|
32 |
+
ARCMaxGridSize = 30
|
33 |
+
ARCAugmentRetriesFactor = 5
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class ARCPuzzle:
|
38 |
+
id: str
|
39 |
+
|
40 |
+
examples: List[Tuple[np.ndarray, np.ndarray]]
|
41 |
+
|
42 |
+
|
43 |
+
def arc_grid_to_np(grid: List[List[int]]):
|
44 |
+
arr = np.array(grid)
|
45 |
+
|
46 |
+
# Shape check
|
47 |
+
assert arr.ndim == 2
|
48 |
+
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
|
49 |
+
# Element check
|
50 |
+
assert np.all((arr >= 0) & (arr <= 9))
|
51 |
+
return arr.astype(np.uint8)
|
52 |
+
|
53 |
+
|
54 |
+
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
|
55 |
+
# PAD: 0, <eos>: 1, digits: 2 ... 11
|
56 |
+
# Compute random top-left pad
|
57 |
+
if do_translation:
|
58 |
+
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
|
59 |
+
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
|
60 |
+
else:
|
61 |
+
pad_r = pad_c = 0
|
62 |
+
|
63 |
+
# Pad grid
|
64 |
+
result = []
|
65 |
+
for grid in [inp, out]:
|
66 |
+
nrow, ncol = grid.shape
|
67 |
+
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
|
68 |
+
|
69 |
+
# Add <eos>
|
70 |
+
eos_row, eos_col = pad_r + nrow, pad_c + ncol
|
71 |
+
if eos_row < ARCMaxGridSize:
|
72 |
+
grid[eos_row, pad_c:eos_col] = 1
|
73 |
+
if eos_col < ARCMaxGridSize:
|
74 |
+
grid[pad_r:eos_row, eos_col] = 1
|
75 |
+
|
76 |
+
result.append(grid.flatten())
|
77 |
+
|
78 |
+
return result
|
79 |
+
|
80 |
+
|
81 |
+
def puzzle_hash(puzzle: dict):
|
82 |
+
# Hash the puzzle for checking equivalence
|
83 |
+
def _grid_hash(grid: np.ndarray):
|
84 |
+
buffer = [x.to_bytes(1) for x in grid.shape]
|
85 |
+
buffer.append(grid.tobytes())
|
86 |
+
|
87 |
+
return hashlib.sha256(b"".join(buffer)).hexdigest()
|
88 |
+
|
89 |
+
hashes = []
|
90 |
+
for example_type, example in puzzle.items():
|
91 |
+
for input, label in example.examples:
|
92 |
+
hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
|
93 |
+
|
94 |
+
hashes.sort()
|
95 |
+
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
|
96 |
+
|
97 |
+
|
98 |
+
def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
|
99 |
+
# Remove "name"
|
100 |
+
name = puzzle.pop("name", default_name)
|
101 |
+
|
102 |
+
# Convert
|
103 |
+
dests = set(dest_mapping.values())
|
104 |
+
converted = {dest: ARCPuzzle(name, []) for dest in dests}
|
105 |
+
for example_type, examples in puzzle.items():
|
106 |
+
dest = dest_mapping[example_type]
|
107 |
+
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
|
108 |
+
|
109 |
+
group = [converted]
|
110 |
+
|
111 |
+
# Augment
|
112 |
+
if aug_count > 0:
|
113 |
+
hashes = {puzzle_hash(converted)}
|
114 |
+
|
115 |
+
for _trial in range(ARCAugmentRetriesFactor * aug_count):
|
116 |
+
# Augment plan
|
117 |
+
trans_id = np.random.randint(0, 8)
|
118 |
+
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
|
119 |
+
|
120 |
+
aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"
|
121 |
+
|
122 |
+
def _map_grid(grid: np.ndarray):
|
123 |
+
return dihedral_transform(mapping[grid], trans_id)
|
124 |
+
|
125 |
+
# Check duplicate
|
126 |
+
augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
|
127 |
+
h = puzzle_hash(augmented)
|
128 |
+
if h not in hashes:
|
129 |
+
hashes.add(h)
|
130 |
+
group.append(augmented)
|
131 |
+
|
132 |
+
if len(group) >= aug_count + 1:
|
133 |
+
break
|
134 |
+
|
135 |
+
if len(group) < aug_count + 1:
|
136 |
+
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
|
137 |
+
|
138 |
+
# Append
|
139 |
+
for dest in dests:
|
140 |
+
# Convert the examples
|
141 |
+
dest_split, dest_set = dest
|
142 |
+
|
143 |
+
results.setdefault(dest_split, {})
|
144 |
+
results[dest_split].setdefault(dest_set, [])
|
145 |
+
results[dest_split][dest_set].append([converted[dest] for converted in group])
|
146 |
+
|
147 |
+
|
148 |
+
def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
|
149 |
+
train_examples_dest = ("train", "all")
|
150 |
+
test_examples_map = {
|
151 |
+
"evaluation": [(1.0, ("test", "all"))],
|
152 |
+
"_default": [(1.0, ("train", "all"))]
|
153 |
+
}
|
154 |
+
|
155 |
+
total_puzzles = 0
|
156 |
+
for subdir in os.scandir(dataset_path):
|
157 |
+
if subdir.is_dir():
|
158 |
+
# Load all puzzles in this directory
|
159 |
+
puzzles = []
|
160 |
+
for filename in glob(os.path.join(subdir.path, "*.json")):
|
161 |
+
with open(filename, "r") as f:
|
162 |
+
puzzles.append((Path(filename).stem, json.load(f)))
|
163 |
+
|
164 |
+
# Shuffle puzzles
|
165 |
+
np.random.shuffle(puzzles)
|
166 |
+
|
167 |
+
# Assign by fraction
|
168 |
+
for idx, (default_name, puzzle) in enumerate(puzzles):
|
169 |
+
fraction = idx / len(puzzles)
|
170 |
+
test_examples_dest = None
|
171 |
+
for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
|
172 |
+
if fraction < f:
|
173 |
+
test_examples_dest = dest
|
174 |
+
break
|
175 |
+
|
176 |
+
assert test_examples_dest is not None
|
177 |
+
|
178 |
+
convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
|
179 |
+
total_puzzles += 1
|
180 |
+
|
181 |
+
print (f"[{dataset_path}] total puzzles: {total_puzzles}")
|
182 |
+
|
183 |
+
|
184 |
+
def convert_dataset(config: DataProcessConfig):
|
185 |
+
np.random.seed(config.seed)
|
186 |
+
|
187 |
+
# Read dataset
|
188 |
+
data = {}
|
189 |
+
for dataset_dir in config.dataset_dirs:
|
190 |
+
load_puzzles_arcagi(data, dataset_dir, config)
|
191 |
+
|
192 |
+
# Map global puzzle identifiers
|
193 |
+
num_identifiers = 1 # 0 is blank
|
194 |
+
identifier_map = {}
|
195 |
+
for split_name, split in data.items():
|
196 |
+
for subset_name, subset in split.items():
|
197 |
+
for group in subset:
|
198 |
+
for puzzle in group:
|
199 |
+
if puzzle.id not in identifier_map:
|
200 |
+
identifier_map[puzzle.id] = num_identifiers
|
201 |
+
num_identifiers += 1
|
202 |
+
|
203 |
+
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
|
204 |
+
|
205 |
+
# Save
|
206 |
+
for split_name, split in data.items():
|
207 |
+
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
|
208 |
+
|
209 |
+
# Translational augmentations
|
210 |
+
enable_translational_augment = split_name == "train"
|
211 |
+
|
212 |
+
# Statistics
|
213 |
+
total_examples = 0
|
214 |
+
total_puzzles = 0
|
215 |
+
total_groups = 0
|
216 |
+
|
217 |
+
for subset_name, subset in split.items():
|
218 |
+
# Construct subset
|
219 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
220 |
+
results["puzzle_indices"].append(0)
|
221 |
+
results["group_indices"].append(0)
|
222 |
+
|
223 |
+
example_id = 0
|
224 |
+
puzzle_id = 0
|
225 |
+
|
226 |
+
for group in subset:
|
227 |
+
for puzzle in group:
|
228 |
+
# Push puzzle
|
229 |
+
no_aug_id = np.random.randint(0, len(puzzle.examples))
|
230 |
+
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
|
231 |
+
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
|
232 |
+
|
233 |
+
results["inputs"].append(inp)
|
234 |
+
results["labels"].append(out)
|
235 |
+
example_id += 1
|
236 |
+
|
237 |
+
total_examples += 1
|
238 |
+
|
239 |
+
results["puzzle_indices"].append(example_id)
|
240 |
+
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
|
241 |
+
|
242 |
+
puzzle_id += 1
|
243 |
+
|
244 |
+
total_puzzles += 1
|
245 |
+
|
246 |
+
# Push group
|
247 |
+
results["group_indices"].append(puzzle_id)
|
248 |
+
total_groups += 1
|
249 |
+
|
250 |
+
for k, v in results.items():
|
251 |
+
if k in {"inputs", "labels"}:
|
252 |
+
v = np.stack(v, 0)
|
253 |
+
else:
|
254 |
+
v = np.array(v, dtype=np.int32)
|
255 |
+
|
256 |
+
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
|
257 |
+
|
258 |
+
# Metadata
|
259 |
+
metadata = PuzzleDatasetMetadata(
|
260 |
+
seq_len=ARCMaxGridSize * ARCMaxGridSize,
|
261 |
+
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
|
262 |
+
|
263 |
+
pad_id=0,
|
264 |
+
ignore_label_id=0,
|
265 |
+
|
266 |
+
blank_identifier_id=0,
|
267 |
+
num_puzzle_identifiers=num_identifiers,
|
268 |
+
|
269 |
+
total_groups=total_groups,
|
270 |
+
mean_puzzle_examples=total_examples / total_puzzles,
|
271 |
+
sets=list(split.keys())
|
272 |
+
)
|
273 |
+
|
274 |
+
# Save metadata as JSON.
|
275 |
+
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
|
276 |
+
json.dump(metadata.model_dump(), f)
|
277 |
+
|
278 |
+
# Save IDs mapping
|
279 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
280 |
+
ids_mapping = {v: k for k, v in identifier_map.items()}
|
281 |
+
|
282 |
+
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
|
283 |
+
|
284 |
+
|
285 |
+
@cli.command(singleton=True)
|
286 |
+
def main(config: DataProcessConfig):
|
287 |
+
convert_dataset(config)
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == "__main__":
|
291 |
+
cli()
|
dataset/build_maze_dataset.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import csv
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from argdantic import ArgParser
|
9 |
+
from pydantic import BaseModel
|
10 |
+
from tqdm import tqdm
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
from common import PuzzleDatasetMetadata, dihedral_transform
|
14 |
+
|
15 |
+
|
16 |
+
CHARSET = "# SGo"
|
17 |
+
|
18 |
+
|
19 |
+
cli = ArgParser()
|
20 |
+
|
21 |
+
|
22 |
+
class DataProcessConfig(BaseModel):
|
23 |
+
source_repo: str = "imone/small-sample-challenge-maze-30x30-hard"
|
24 |
+
output_dir: str = "data/maze-30x30-hard-1k"
|
25 |
+
|
26 |
+
subsample_size: Optional[int] = None
|
27 |
+
aug: bool = False
|
28 |
+
|
29 |
+
|
30 |
+
def convert_subset(set_name: str, config: DataProcessConfig):
|
31 |
+
# Read CSV
|
32 |
+
all_chars = set()
|
33 |
+
grid_size = None
|
34 |
+
inputs = []
|
35 |
+
labels = []
|
36 |
+
|
37 |
+
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
|
38 |
+
reader = csv.reader(csvfile)
|
39 |
+
next(reader) # Skip header
|
40 |
+
for source, q, a, rating in reader:
|
41 |
+
all_chars.update(q)
|
42 |
+
all_chars.update(a)
|
43 |
+
|
44 |
+
if grid_size is None:
|
45 |
+
n = int(len(q) ** 0.5)
|
46 |
+
grid_size = (n, n)
|
47 |
+
|
48 |
+
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
|
49 |
+
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
|
50 |
+
|
51 |
+
# If subsample_size is specified for the training set,
|
52 |
+
# randomly sample the desired number of examples.
|
53 |
+
if set_name == "train" and config.subsample_size is not None:
|
54 |
+
total_samples = len(inputs)
|
55 |
+
if config.subsample_size < total_samples:
|
56 |
+
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
57 |
+
inputs = [inputs[i] for i in indices]
|
58 |
+
labels = [labels[i] for i in indices]
|
59 |
+
|
60 |
+
# Generate dataset
|
61 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
62 |
+
puzzle_id = 0
|
63 |
+
example_id = 0
|
64 |
+
|
65 |
+
results["puzzle_indices"].append(0)
|
66 |
+
results["group_indices"].append(0)
|
67 |
+
|
68 |
+
for inp, out in zip(tqdm(inputs), labels):
|
69 |
+
# Dihedral transformations for augmentation
|
70 |
+
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
|
71 |
+
results["inputs"].append(dihedral_transform(inp, aug_idx))
|
72 |
+
results["labels"].append(dihedral_transform(out, aug_idx))
|
73 |
+
example_id += 1
|
74 |
+
puzzle_id += 1
|
75 |
+
|
76 |
+
results["puzzle_indices"].append(example_id)
|
77 |
+
results["puzzle_identifiers"].append(0)
|
78 |
+
|
79 |
+
# Push group
|
80 |
+
results["group_indices"].append(puzzle_id)
|
81 |
+
|
82 |
+
# Char mappings
|
83 |
+
assert len(all_chars - set(CHARSET)) == 0
|
84 |
+
|
85 |
+
char2id = np.zeros(256, np.uint8)
|
86 |
+
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
|
87 |
+
|
88 |
+
# To Numpy
|
89 |
+
def _seq_to_numpy(seq):
|
90 |
+
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
|
91 |
+
|
92 |
+
return arr
|
93 |
+
|
94 |
+
results = {
|
95 |
+
"inputs": _seq_to_numpy(results["inputs"]),
|
96 |
+
"labels": _seq_to_numpy(results["labels"]),
|
97 |
+
|
98 |
+
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
99 |
+
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
100 |
+
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
101 |
+
}
|
102 |
+
|
103 |
+
# Metadata
|
104 |
+
metadata = PuzzleDatasetMetadata(
|
105 |
+
seq_len=int(math.prod(grid_size)), # type: ignore
|
106 |
+
vocab_size=len(CHARSET) + 1, # PAD + Charset
|
107 |
+
|
108 |
+
pad_id=0,
|
109 |
+
ignore_label_id=0,
|
110 |
+
|
111 |
+
blank_identifier_id=0,
|
112 |
+
num_puzzle_identifiers=1,
|
113 |
+
|
114 |
+
total_groups=len(results["group_indices"]) - 1,
|
115 |
+
mean_puzzle_examples=1,
|
116 |
+
sets=["all"]
|
117 |
+
)
|
118 |
+
|
119 |
+
# Save metadata as JSON.
|
120 |
+
save_dir = os.path.join(config.output_dir, set_name)
|
121 |
+
os.makedirs(save_dir, exist_ok=True)
|
122 |
+
|
123 |
+
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
124 |
+
json.dump(metadata.model_dump(), f)
|
125 |
+
|
126 |
+
# Save data
|
127 |
+
for k, v in results.items():
|
128 |
+
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
129 |
+
|
130 |
+
# Save IDs mapping (for visualization only)
|
131 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
132 |
+
json.dump(["<blank>"], f)
|
133 |
+
|
134 |
+
|
135 |
+
@cli.command(singleton=True)
|
136 |
+
def preprocess_data(config: DataProcessConfig):
|
137 |
+
convert_subset("train", config)
|
138 |
+
convert_subset("test", config)
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
cli()
|
dataset/build_sudoku_dataset.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import os
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from argdantic import ArgParser
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from tqdm import tqdm
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
+
from common import PuzzleDatasetMetadata
|
13 |
+
|
14 |
+
|
15 |
+
cli = ArgParser()
|
16 |
+
|
17 |
+
|
18 |
+
class DataProcessConfig(BaseModel):
|
19 |
+
source_repo: str = "imone/sudoku-hard-v2"
|
20 |
+
output_dir: str = "data/sudoku-extreme-full"
|
21 |
+
|
22 |
+
subsample_size: Optional[int] = None
|
23 |
+
min_difficulty: Optional[int] = None
|
24 |
+
num_aug: int = 0
|
25 |
+
|
26 |
+
|
27 |
+
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
|
28 |
+
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
|
29 |
+
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
|
30 |
+
|
31 |
+
# Randomly decide whether to transpose.
|
32 |
+
transpose_flag = np.random.rand() < 0.5
|
33 |
+
|
34 |
+
# Generate a valid row permutation:
|
35 |
+
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
|
36 |
+
bands = np.random.permutation(3)
|
37 |
+
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
|
38 |
+
|
39 |
+
# Similarly for columns (stacks).
|
40 |
+
stacks = np.random.permutation(3)
|
41 |
+
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
|
42 |
+
|
43 |
+
# Build an 81->81 mapping. For each new cell at (i, j)
|
44 |
+
# (row index = i // 9, col index = i % 9),
|
45 |
+
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
|
46 |
+
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
|
47 |
+
|
48 |
+
def apply_transformation(x: np.ndarray) -> np.ndarray:
|
49 |
+
# Apply transpose flag
|
50 |
+
if transpose_flag:
|
51 |
+
x = x.T
|
52 |
+
# Apply the position mapping.
|
53 |
+
new_board = x.flatten()[mapping].reshape(9, 9).copy()
|
54 |
+
# Apply digit mapping
|
55 |
+
return digit_map[new_board]
|
56 |
+
|
57 |
+
return apply_transformation(board), apply_transformation(solution)
|
58 |
+
|
59 |
+
|
60 |
+
def convert_subset(set_name: str, config: DataProcessConfig):
|
61 |
+
# Read CSV
|
62 |
+
inputs = []
|
63 |
+
labels = []
|
64 |
+
|
65 |
+
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
|
66 |
+
reader = csv.reader(csvfile)
|
67 |
+
next(reader) # Skip header
|
68 |
+
for source, q, a, rating in reader:
|
69 |
+
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
|
70 |
+
assert len(q) == 81 and len(a) == 81
|
71 |
+
|
72 |
+
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
73 |
+
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
74 |
+
|
75 |
+
# If subsample_size is specified for the training set,
|
76 |
+
# randomly sample the desired number of examples.
|
77 |
+
if set_name == "train" and config.subsample_size is not None:
|
78 |
+
total_samples = len(inputs)
|
79 |
+
if config.subsample_size < total_samples:
|
80 |
+
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
81 |
+
inputs = [inputs[i] for i in indices]
|
82 |
+
labels = [labels[i] for i in indices]
|
83 |
+
|
84 |
+
# Generate dataset
|
85 |
+
num_augments = config.num_aug if set_name == "train" else 0
|
86 |
+
|
87 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
88 |
+
puzzle_id = 0
|
89 |
+
example_id = 0
|
90 |
+
|
91 |
+
results["puzzle_indices"].append(0)
|
92 |
+
results["group_indices"].append(0)
|
93 |
+
|
94 |
+
for orig_inp, orig_out in zip(tqdm(inputs), labels):
|
95 |
+
for aug_idx in range(1 + num_augments):
|
96 |
+
# First index is not augmented
|
97 |
+
if aug_idx == 0:
|
98 |
+
inp, out = orig_inp, orig_out
|
99 |
+
else:
|
100 |
+
inp, out = shuffle_sudoku(orig_inp, orig_out)
|
101 |
+
|
102 |
+
# Push puzzle (only single example)
|
103 |
+
results["inputs"].append(inp)
|
104 |
+
results["labels"].append(out)
|
105 |
+
example_id += 1
|
106 |
+
puzzle_id += 1
|
107 |
+
|
108 |
+
results["puzzle_indices"].append(example_id)
|
109 |
+
results["puzzle_identifiers"].append(0)
|
110 |
+
|
111 |
+
# Push group
|
112 |
+
results["group_indices"].append(puzzle_id)
|
113 |
+
|
114 |
+
# To Numpy
|
115 |
+
def _seq_to_numpy(seq):
|
116 |
+
arr = np.concatenate(seq).reshape(len(seq), -1)
|
117 |
+
|
118 |
+
assert np.all((arr >= 0) & (arr <= 9))
|
119 |
+
return arr + 1
|
120 |
+
|
121 |
+
results = {
|
122 |
+
"inputs": _seq_to_numpy(results["inputs"]),
|
123 |
+
"labels": _seq_to_numpy(results["labels"]),
|
124 |
+
|
125 |
+
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
126 |
+
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
127 |
+
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
128 |
+
}
|
129 |
+
|
130 |
+
# Metadata
|
131 |
+
metadata = PuzzleDatasetMetadata(
|
132 |
+
seq_len=81,
|
133 |
+
vocab_size=10 + 1, # PAD + "0" ... "9"
|
134 |
+
|
135 |
+
pad_id=0,
|
136 |
+
ignore_label_id=0,
|
137 |
+
|
138 |
+
blank_identifier_id=0,
|
139 |
+
num_puzzle_identifiers=1,
|
140 |
+
|
141 |
+
total_groups=len(results["group_indices"]) - 1,
|
142 |
+
mean_puzzle_examples=1,
|
143 |
+
sets=["all"]
|
144 |
+
)
|
145 |
+
|
146 |
+
# Save metadata as JSON.
|
147 |
+
save_dir = os.path.join(config.output_dir, set_name)
|
148 |
+
os.makedirs(save_dir, exist_ok=True)
|
149 |
+
|
150 |
+
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
151 |
+
json.dump(metadata.model_dump(), f)
|
152 |
+
|
153 |
+
# Save data
|
154 |
+
for k, v in results.items():
|
155 |
+
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
156 |
+
|
157 |
+
# Save IDs mapping (for visualization only)
|
158 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
159 |
+
json.dump(["<blank>"], f)
|
160 |
+
|
161 |
+
|
162 |
+
@cli.command(singleton=True)
|
163 |
+
def preprocess_data(config: DataProcessConfig):
|
164 |
+
convert_subset("train", config)
|
165 |
+
convert_subset("test", config)
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
cli()
|
dataset/common.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import pydantic
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
# Global list mapping each dihedral transform id to its inverse.
|
8 |
+
# Index corresponds to the original tid, and the value is its inverse.
|
9 |
+
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
|
10 |
+
|
11 |
+
|
12 |
+
class PuzzleDatasetMetadata(pydantic.BaseModel):
|
13 |
+
pad_id: int
|
14 |
+
ignore_label_id: Optional[int]
|
15 |
+
blank_identifier_id: int
|
16 |
+
|
17 |
+
vocab_size: int
|
18 |
+
seq_len: int
|
19 |
+
num_puzzle_identifiers: int
|
20 |
+
|
21 |
+
total_groups: int
|
22 |
+
mean_puzzle_examples: float
|
23 |
+
|
24 |
+
sets: List[str]
|
25 |
+
|
26 |
+
|
27 |
+
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
28 |
+
"""8 dihedral symmetries by rotate, flip and mirror"""
|
29 |
+
|
30 |
+
if tid == 0:
|
31 |
+
return arr # identity
|
32 |
+
elif tid == 1:
|
33 |
+
return np.rot90(arr, k=1)
|
34 |
+
elif tid == 2:
|
35 |
+
return np.rot90(arr, k=2)
|
36 |
+
elif tid == 3:
|
37 |
+
return np.rot90(arr, k=3)
|
38 |
+
elif tid == 4:
|
39 |
+
return np.fliplr(arr) # horizontal flip
|
40 |
+
elif tid == 5:
|
41 |
+
return np.flipud(arr) # vertical flip
|
42 |
+
elif tid == 6:
|
43 |
+
return arr.T # transpose (reflection along main diagonal)
|
44 |
+
elif tid == 7:
|
45 |
+
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
|
46 |
+
else:
|
47 |
+
return arr
|
48 |
+
|
49 |
+
|
50 |
+
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
51 |
+
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
|
evaluate.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import yaml
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
import pydantic
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
|
11 |
+
|
12 |
+
|
13 |
+
class EvalConfig(pydantic.BaseModel):
|
14 |
+
checkpoint: str
|
15 |
+
|
16 |
+
save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
|
17 |
+
|
18 |
+
|
19 |
+
def launch():
|
20 |
+
eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
|
21 |
+
|
22 |
+
RANK = 0
|
23 |
+
WORLD_SIZE = 1
|
24 |
+
# Initialize distributed training if in distributed environment (e.g. torchrun)
|
25 |
+
if "LOCAL_RANK" in os.environ:
|
26 |
+
# Initialize distributed, default device and dtype
|
27 |
+
dist.init_process_group(backend="nccl")
|
28 |
+
|
29 |
+
RANK = dist.get_rank()
|
30 |
+
WORLD_SIZE = dist.get_world_size()
|
31 |
+
|
32 |
+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
33 |
+
|
34 |
+
with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
|
35 |
+
config = PretrainConfig(**yaml.safe_load(f))
|
36 |
+
|
37 |
+
config.eval_save_outputs = eval_cfg.save_outputs
|
38 |
+
config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
|
39 |
+
|
40 |
+
# Dataloader
|
41 |
+
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
42 |
+
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, test_set_limit_examples=LIMIT_EXAMPLES, rank=RANK, world_size=WORLD_SIZE)
|
43 |
+
|
44 |
+
# Models
|
45 |
+
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
|
46 |
+
# Try unwrap torch.compile
|
47 |
+
try:
|
48 |
+
train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
|
49 |
+
except:
|
50 |
+
train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
|
51 |
+
|
52 |
+
train_state.step = 0
|
53 |
+
ckpt_filename = os.path.basename(eval_cfg.checkpoint)
|
54 |
+
if ckpt_filename.startswith("step_"):
|
55 |
+
train_state.step = int(ckpt_filename.removeprefix("step_"))
|
56 |
+
|
57 |
+
# Evaluate
|
58 |
+
print ("Starting evaluation")
|
59 |
+
|
60 |
+
train_state.model.eval()
|
61 |
+
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
|
62 |
+
|
63 |
+
if metrics is not None:
|
64 |
+
print (metrics)
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
launch()
|
models/common.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
8 |
+
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
9 |
+
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
10 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
11 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
12 |
+
|
13 |
+
with torch.no_grad():
|
14 |
+
if std == 0:
|
15 |
+
tensor.zero_()
|
16 |
+
else:
|
17 |
+
sqrt2 = math.sqrt(2)
|
18 |
+
a = math.erf(lower / sqrt2)
|
19 |
+
b = math.erf(upper / sqrt2)
|
20 |
+
z = (b - a) / 2
|
21 |
+
|
22 |
+
c = (2 * math.pi) ** -0.5
|
23 |
+
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
24 |
+
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
25 |
+
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
26 |
+
|
27 |
+
tensor.uniform_(a, b)
|
28 |
+
tensor.erfinv_()
|
29 |
+
tensor.mul_(sqrt2 * comp_std)
|
30 |
+
tensor.clip_(lower * comp_std, upper * comp_std)
|
31 |
+
|
32 |
+
return tensor
|
models/hrm/hrm_act_v1.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List, Dict, Optional
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from models.common import trunc_normal_init_
|
11 |
+
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class HierarchicalReasoningModel_ACTV1InnerCarry:
|
17 |
+
z_H: torch.Tensor
|
18 |
+
z_L: torch.Tensor
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class HierarchicalReasoningModel_ACTV1Carry:
|
23 |
+
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
|
24 |
+
|
25 |
+
steps: torch.Tensor
|
26 |
+
halted: torch.Tensor
|
27 |
+
|
28 |
+
current_data: Dict[str, torch.Tensor]
|
29 |
+
|
30 |
+
|
31 |
+
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
|
32 |
+
batch_size: int
|
33 |
+
seq_len: int
|
34 |
+
puzzle_emb_ndim: int = 0
|
35 |
+
num_puzzle_identifiers: int
|
36 |
+
vocab_size: int
|
37 |
+
|
38 |
+
H_cycles: int
|
39 |
+
L_cycles: int
|
40 |
+
|
41 |
+
H_layers: int
|
42 |
+
L_layers: int
|
43 |
+
|
44 |
+
# Transformer config
|
45 |
+
hidden_size: int
|
46 |
+
expansion: float
|
47 |
+
num_heads: int
|
48 |
+
pos_encodings: str
|
49 |
+
|
50 |
+
rms_norm_eps: float = 1e-5
|
51 |
+
rope_theta: float = 10000.0
|
52 |
+
|
53 |
+
# Halting Q-learning config
|
54 |
+
halt_max_steps: int
|
55 |
+
halt_exploration_prob: float
|
56 |
+
|
57 |
+
forward_dtype: str = "bfloat16"
|
58 |
+
|
59 |
+
|
60 |
+
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
|
61 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
self.self_attn = Attention(
|
65 |
+
hidden_size=config.hidden_size,
|
66 |
+
head_dim=config.hidden_size // config.num_heads,
|
67 |
+
num_heads=config.num_heads,
|
68 |
+
num_key_value_heads=config.num_heads,
|
69 |
+
causal=False
|
70 |
+
)
|
71 |
+
self.mlp = SwiGLU(
|
72 |
+
hidden_size=config.hidden_size,
|
73 |
+
expansion=config.expansion,
|
74 |
+
)
|
75 |
+
self.norm_eps = config.rms_norm_eps
|
76 |
+
|
77 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
78 |
+
# Post Norm
|
79 |
+
# Self Attention
|
80 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
81 |
+
# Fully Connected
|
82 |
+
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
|
83 |
+
return hidden_states
|
84 |
+
|
85 |
+
|
86 |
+
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
|
87 |
+
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.layers = torch.nn.ModuleList(layers)
|
91 |
+
|
92 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
93 |
+
# Input injection (add)
|
94 |
+
hidden_states = hidden_states + input_injection
|
95 |
+
# Layers
|
96 |
+
for layer in self.layers:
|
97 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
98 |
+
|
99 |
+
return hidden_states
|
100 |
+
|
101 |
+
|
102 |
+
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
|
103 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
104 |
+
super().__init__()
|
105 |
+
self.config = config
|
106 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
107 |
+
|
108 |
+
# I/O
|
109 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
110 |
+
embed_init_std = 1.0 / self.embed_scale
|
111 |
+
|
112 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
113 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
114 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
115 |
+
|
116 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
117 |
+
if self.config.puzzle_emb_ndim > 0:
|
118 |
+
# Zero init puzzle embeddings
|
119 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
120 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
121 |
+
|
122 |
+
# LM Blocks
|
123 |
+
if self.config.pos_encodings == "rope":
|
124 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
125 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
126 |
+
base=self.config.rope_theta)
|
127 |
+
elif self.config.pos_encodings == "learned":
|
128 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
129 |
+
else:
|
130 |
+
raise NotImplementedError()
|
131 |
+
|
132 |
+
# Reasoning Layers
|
133 |
+
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
|
134 |
+
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
135 |
+
|
136 |
+
# Initial states
|
137 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
138 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
139 |
+
|
140 |
+
# Q head special init
|
141 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
142 |
+
with torch.no_grad():
|
143 |
+
self.q_head.weight.zero_()
|
144 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
145 |
+
|
146 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
147 |
+
# Token embedding
|
148 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
149 |
+
|
150 |
+
# Puzzle embeddings
|
151 |
+
if self.config.puzzle_emb_ndim > 0:
|
152 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
153 |
+
|
154 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
155 |
+
if pad_count > 0:
|
156 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
157 |
+
|
158 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
159 |
+
|
160 |
+
# Position embeddings
|
161 |
+
if self.config.pos_encodings == "learned":
|
162 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
163 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
164 |
+
|
165 |
+
# Scale
|
166 |
+
return self.embed_scale * embedding
|
167 |
+
|
168 |
+
def empty_carry(self, batch_size: int):
|
169 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
170 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
171 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
172 |
+
)
|
173 |
+
|
174 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
|
175 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
176 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
177 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
178 |
+
)
|
179 |
+
|
180 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
181 |
+
seq_info = dict(
|
182 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
183 |
+
)
|
184 |
+
|
185 |
+
# Input encoding
|
186 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
187 |
+
|
188 |
+
# Forward iterations
|
189 |
+
with torch.no_grad():
|
190 |
+
z_H, z_L = carry.z_H, carry.z_L
|
191 |
+
|
192 |
+
for _H_step in range(self.config.H_cycles):
|
193 |
+
for _L_step in range(self.config.L_cycles):
|
194 |
+
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
|
195 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
196 |
+
|
197 |
+
if not (_H_step == self.config.H_cycles - 1):
|
198 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
199 |
+
|
200 |
+
assert not z_H.requires_grad and not z_L.requires_grad
|
201 |
+
|
202 |
+
# 1-step grad
|
203 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
204 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
205 |
+
|
206 |
+
# LM Outputs
|
207 |
+
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
208 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
209 |
+
|
210 |
+
# Q head
|
211 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
212 |
+
|
213 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
214 |
+
|
215 |
+
|
216 |
+
class HierarchicalReasoningModel_ACTV1(nn.Module):
|
217 |
+
"""ACT wrapper."""
|
218 |
+
|
219 |
+
def __init__(self, config_dict: dict):
|
220 |
+
super().__init__()
|
221 |
+
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
|
222 |
+
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
|
223 |
+
|
224 |
+
@property
|
225 |
+
def puzzle_emb(self):
|
226 |
+
return self.inner.puzzle_emb
|
227 |
+
|
228 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
229 |
+
batch_size = batch["inputs"].shape[0]
|
230 |
+
|
231 |
+
return HierarchicalReasoningModel_ACTV1Carry(
|
232 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
233 |
+
|
234 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
235 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
236 |
+
|
237 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
238 |
+
)
|
239 |
+
|
240 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
241 |
+
# Update data, carry (removing halted sequences)
|
242 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
243 |
+
|
244 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
245 |
+
|
246 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
247 |
+
|
248 |
+
# Forward inner model
|
249 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
250 |
+
|
251 |
+
outputs = {
|
252 |
+
"logits": logits,
|
253 |
+
"q_halt_logits": q_halt_logits,
|
254 |
+
"q_continue_logits": q_continue_logits
|
255 |
+
}
|
256 |
+
|
257 |
+
with torch.no_grad():
|
258 |
+
# Step
|
259 |
+
new_steps = new_steps + 1
|
260 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
261 |
+
|
262 |
+
halted = is_last_step
|
263 |
+
|
264 |
+
# if training, and ACT is enabled
|
265 |
+
if self.training and (self.config.halt_max_steps > 1):
|
266 |
+
# Halt signal
|
267 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
268 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
269 |
+
|
270 |
+
# Exploration
|
271 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
272 |
+
|
273 |
+
halted = halted & (new_steps >= min_halt_steps)
|
274 |
+
|
275 |
+
# Compute target Q
|
276 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
277 |
+
# As batch_size is large, there're many parallel envs.
|
278 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
279 |
+
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
|
280 |
+
|
281 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
282 |
+
|
283 |
+
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
models/layers.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.common import trunc_normal_init_
|
8 |
+
|
9 |
+
|
10 |
+
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
11 |
+
|
12 |
+
|
13 |
+
def _find_multiple(a, b):
|
14 |
+
return (-(a // -b)) * b
|
15 |
+
|
16 |
+
|
17 |
+
def rotate_half(x: torch.Tensor):
|
18 |
+
"""Rotates half the hidden dims of the input."""
|
19 |
+
x1 = x[..., : x.shape[-1] // 2]
|
20 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
21 |
+
return torch.cat((-x2, x1), dim=-1)
|
22 |
+
|
23 |
+
|
24 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
25 |
+
# q, k: [bs, num_heads, seq_len, head_dim]
|
26 |
+
# cos, sin: [seq_len, head_dim]
|
27 |
+
orig_dtype = q.dtype
|
28 |
+
q = q.to(cos.dtype)
|
29 |
+
k = k.to(cos.dtype)
|
30 |
+
|
31 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
32 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
33 |
+
|
34 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
35 |
+
|
36 |
+
|
37 |
+
class CastedLinear(nn.Module):
|
38 |
+
def __init__(self,
|
39 |
+
in_features: int,
|
40 |
+
out_features: int,
|
41 |
+
bias: bool):
|
42 |
+
super().__init__()
|
43 |
+
# Truncated LeCun normal init
|
44 |
+
self.weight = nn.Parameter(
|
45 |
+
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
46 |
+
)
|
47 |
+
self.bias = None
|
48 |
+
if bias:
|
49 |
+
# Zero init bias
|
50 |
+
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
51 |
+
|
52 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
53 |
+
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
54 |
+
|
55 |
+
|
56 |
+
class CastedEmbedding(nn.Module):
|
57 |
+
def __init__(self,
|
58 |
+
num_embeddings: int,
|
59 |
+
embedding_dim: int,
|
60 |
+
init_std: float,
|
61 |
+
cast_to: torch.dtype):
|
62 |
+
super().__init__()
|
63 |
+
self.cast_to = cast_to
|
64 |
+
|
65 |
+
# Truncated LeCun normal init
|
66 |
+
self.embedding_weight = nn.Parameter(
|
67 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
71 |
+
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
72 |
+
|
73 |
+
|
74 |
+
class RotaryEmbedding(nn.Module):
|
75 |
+
def __init__(self, dim, max_position_embeddings, base, device=None):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
# RoPE
|
79 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
80 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
81 |
+
freqs = torch.outer(t, inv_freq)
|
82 |
+
|
83 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
84 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
85 |
+
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
|
86 |
+
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
|
87 |
+
|
88 |
+
def forward(self):
|
89 |
+
return self.cos_cached, self.sin_cached
|
90 |
+
|
91 |
+
|
92 |
+
class Attention(nn.Module):
|
93 |
+
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
self.hidden_size = hidden_size
|
97 |
+
self.head_dim = head_dim
|
98 |
+
self.output_size = head_dim * num_heads
|
99 |
+
self.num_heads = num_heads
|
100 |
+
self.num_key_value_heads = num_key_value_heads
|
101 |
+
self.causal = causal
|
102 |
+
|
103 |
+
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
104 |
+
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
105 |
+
|
106 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
107 |
+
batch_size, seq_len, _ = hidden_states.shape
|
108 |
+
|
109 |
+
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
110 |
+
qkv = self.qkv_proj(hidden_states)
|
111 |
+
|
112 |
+
# Split head
|
113 |
+
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim).transpose(-2, -3)
|
114 |
+
query = qkv[:, :self.num_heads]
|
115 |
+
key = qkv[:, self.num_heads: self.num_heads + self.num_key_value_heads]
|
116 |
+
value = qkv[:, self.num_heads + self.num_key_value_heads:]
|
117 |
+
|
118 |
+
# RoPE
|
119 |
+
if cos_sin is not None:
|
120 |
+
cos, sin = cos_sin
|
121 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
122 |
+
|
123 |
+
# flash attn
|
124 |
+
attn_output = F.scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
|
125 |
+
|
126 |
+
# attn_output: [batch_size, num_heads, seq_len, head_dim]
|
127 |
+
attn_output = attn_output.transpose(-2, -3).view(batch_size, seq_len, self.output_size) # type: ignore
|
128 |
+
return self.o_proj(attn_output)
|
129 |
+
|
130 |
+
|
131 |
+
class SwiGLU(nn.Module):
|
132 |
+
def __init__(self, hidden_size: int, expansion: float):
|
133 |
+
super().__init__()
|
134 |
+
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
135 |
+
|
136 |
+
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
137 |
+
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
141 |
+
return self.down_proj(F.silu(gate) * up)
|
142 |
+
|
143 |
+
|
144 |
+
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
145 |
+
input_dtype = hidden_states.dtype
|
146 |
+
hidden_states = hidden_states.to(torch.float32)
|
147 |
+
|
148 |
+
variance = hidden_states.square().mean(-1, keepdim=True)
|
149 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
150 |
+
return hidden_states.to(input_dtype)
|
models/losses.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
IGNORE_LABEL_ID = -100
|
9 |
+
|
10 |
+
|
11 |
+
def s(x, epsilon=1e-30):
|
12 |
+
return torch.where(
|
13 |
+
x<0,
|
14 |
+
1/(1-x+ epsilon),
|
15 |
+
x + 1
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def log_stablemax(x, dim=-1):
|
20 |
+
s_x = s(x)
|
21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
22 |
+
|
23 |
+
|
24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
|
25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
26 |
+
|
27 |
+
valid_mask = labels != ignore_index
|
28 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
29 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
30 |
+
|
31 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
32 |
+
|
33 |
+
|
34 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
35 |
+
# Cast logits to f32
|
36 |
+
# Flatten logits
|
37 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
38 |
+
|
39 |
+
|
40 |
+
class ACTLossHead(nn.Module):
|
41 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
42 |
+
super().__init__()
|
43 |
+
self.model = model
|
44 |
+
self.loss_fn = globals()[loss_type]
|
45 |
+
|
46 |
+
def initial_carry(self, *args, **kwargs):
|
47 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
48 |
+
|
49 |
+
def forward(
|
50 |
+
self,
|
51 |
+
return_keys: Sequence[str],
|
52 |
+
# Model args
|
53 |
+
**model_kwargs,
|
54 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
55 |
+
# Model logits
|
56 |
+
# B x SeqLen x D
|
57 |
+
new_carry, outputs = self.model(**model_kwargs)
|
58 |
+
labels = new_carry.current_data["labels"]
|
59 |
+
|
60 |
+
# Correctness
|
61 |
+
with torch.no_grad():
|
62 |
+
mask = labels != IGNORE_LABEL_ID
|
63 |
+
loss_counts = mask.sum(-1)
|
64 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
65 |
+
|
66 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
67 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
68 |
+
|
69 |
+
# Metrics (halted)
|
70 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
71 |
+
metrics = {
|
72 |
+
"count": valid_metrics.sum(),
|
73 |
+
|
74 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
75 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
76 |
+
|
77 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
78 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
79 |
+
}
|
80 |
+
|
81 |
+
# Losses
|
82 |
+
# FIXME: Assuming the batch is always full
|
83 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
|
84 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
85 |
+
|
86 |
+
metrics.update({
|
87 |
+
"lm_loss": lm_loss.detach(),
|
88 |
+
"q_halt_loss": q_halt_loss.detach(),
|
89 |
+
})
|
90 |
+
|
91 |
+
# Q continue (bootstrapping target loss)
|
92 |
+
q_continue_loss = 0
|
93 |
+
if "target_q_continue" in outputs:
|
94 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
95 |
+
|
96 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
97 |
+
|
98 |
+
# Filter outputs for return
|
99 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
100 |
+
|
101 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
models/sparse_embedding.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.optim.optimizer import Optimizer, ParamsT
|
7 |
+
|
8 |
+
from models.common import trunc_normal_init_
|
9 |
+
|
10 |
+
|
11 |
+
class CastedSparseEmbedding(nn.Module):
|
12 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
|
13 |
+
super().__init__()
|
14 |
+
self.cast_to = cast_to
|
15 |
+
|
16 |
+
# Real Weights
|
17 |
+
# Truncated LeCun normal init
|
18 |
+
self.weights = nn.Buffer(
|
19 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
|
20 |
+
)
|
21 |
+
|
22 |
+
# Local weights and IDs
|
23 |
+
# Local embeddings, with gradient, not persistent
|
24 |
+
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
|
25 |
+
# Local embedding IDs, not persistent
|
26 |
+
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
|
27 |
+
|
28 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
29 |
+
if not self.training:
|
30 |
+
# Test mode, no gradient
|
31 |
+
return self.weights[inputs].to(self.cast_to)
|
32 |
+
|
33 |
+
# Training mode, fill puzzle embedding from weights
|
34 |
+
with torch.no_grad():
|
35 |
+
self.local_weights.copy_(self.weights[inputs])
|
36 |
+
self.local_ids.copy_(inputs)
|
37 |
+
|
38 |
+
return self.local_weights.to(self.cast_to)
|
39 |
+
|
40 |
+
|
41 |
+
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
params: ParamsT,
|
45 |
+
|
46 |
+
world_size: int,
|
47 |
+
lr: Union[float, torch.Tensor] = 1e-3,
|
48 |
+
weight_decay: float = 1e-2,
|
49 |
+
):
|
50 |
+
if not 0.0 <= lr:
|
51 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
52 |
+
if not 0.0 <= weight_decay:
|
53 |
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
54 |
+
|
55 |
+
defaults = dict(
|
56 |
+
lr=lr,
|
57 |
+
weight_decay=weight_decay,
|
58 |
+
world_size=world_size
|
59 |
+
)
|
60 |
+
super().__init__(params, defaults)
|
61 |
+
|
62 |
+
@torch.no_grad
|
63 |
+
def step(self, closure=None): # type: ignore
|
64 |
+
for group in self.param_groups:
|
65 |
+
# Find the sparse embedding weights
|
66 |
+
local_weights_grad = None
|
67 |
+
local_ids = None
|
68 |
+
weights = None
|
69 |
+
|
70 |
+
assert len(group["params"]) == 3
|
71 |
+
for p in group["params"]:
|
72 |
+
if p.requires_grad:
|
73 |
+
local_weights_grad = p.grad
|
74 |
+
elif p.ndim == 1:
|
75 |
+
local_ids = p
|
76 |
+
elif p.ndim == 2:
|
77 |
+
weights = p
|
78 |
+
else:
|
79 |
+
assert False
|
80 |
+
|
81 |
+
assert local_weights_grad is not None
|
82 |
+
assert local_ids is not None
|
83 |
+
assert weights is not None
|
84 |
+
|
85 |
+
# Apply SignSGD
|
86 |
+
# Adam ≈ SignSGD if gradient is very sparse
|
87 |
+
_sparse_emb_signsgd_dist(
|
88 |
+
local_weights_grad,
|
89 |
+
local_ids,
|
90 |
+
weights,
|
91 |
+
|
92 |
+
lr=group["lr"],
|
93 |
+
weight_decay=group["weight_decay"],
|
94 |
+
world_size=group["world_size"]
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def _sparse_emb_signsgd_dist(
|
99 |
+
local_weights_grad: torch.Tensor,
|
100 |
+
local_ids: torch.Tensor,
|
101 |
+
weights: torch.Tensor,
|
102 |
+
|
103 |
+
lr: float,
|
104 |
+
weight_decay: float,
|
105 |
+
world_size: int
|
106 |
+
) -> None:
|
107 |
+
N, D = local_weights_grad.shape
|
108 |
+
|
109 |
+
# All-gather
|
110 |
+
all_weights_grad = local_weights_grad
|
111 |
+
all_ids = local_ids
|
112 |
+
|
113 |
+
if world_size > 1:
|
114 |
+
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
|
115 |
+
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
|
116 |
+
|
117 |
+
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
|
118 |
+
dist.all_gather_into_tensor(all_ids, local_ids)
|
119 |
+
|
120 |
+
# Unique
|
121 |
+
grad_ids, inv = all_ids.unique(return_inverse=True)
|
122 |
+
|
123 |
+
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
|
124 |
+
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
|
125 |
+
|
126 |
+
# SignSGD with decoupled weight decay
|
127 |
+
p = weights[grad_ids]
|
128 |
+
|
129 |
+
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
|
130 |
+
|
131 |
+
# Write updated slices back
|
132 |
+
weights[grad_ids] = p
|
pretrain.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Any, Sequence, List
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import yaml
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
import tqdm
|
14 |
+
import wandb
|
15 |
+
import coolname
|
16 |
+
import hydra
|
17 |
+
import pydantic
|
18 |
+
from omegaconf import DictConfig
|
19 |
+
from wandb.util import make_artifact_name_safe
|
20 |
+
from adam_atan2 import AdamATan2
|
21 |
+
|
22 |
+
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
|
23 |
+
from utils.functions import load_model_class, get_model_source_path
|
24 |
+
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
|
25 |
+
|
26 |
+
|
27 |
+
class LossConfig(pydantic.BaseModel):
|
28 |
+
model_config = pydantic.ConfigDict(extra='allow')
|
29 |
+
|
30 |
+
name: str
|
31 |
+
|
32 |
+
|
33 |
+
class ArchConfig(pydantic.BaseModel):
|
34 |
+
model_config = pydantic.ConfigDict(extra='allow')
|
35 |
+
|
36 |
+
name: str
|
37 |
+
loss: LossConfig
|
38 |
+
|
39 |
+
|
40 |
+
class PretrainConfig(pydantic.BaseModel):
|
41 |
+
# Config
|
42 |
+
arch: ArchConfig
|
43 |
+
# Data
|
44 |
+
data_path: str
|
45 |
+
|
46 |
+
# Hyperparams
|
47 |
+
global_batch_size: int
|
48 |
+
epochs: int
|
49 |
+
|
50 |
+
lr: float
|
51 |
+
lr_min_ratio: float
|
52 |
+
lr_warmup_steps: int
|
53 |
+
|
54 |
+
weight_decay: float
|
55 |
+
beta1: float
|
56 |
+
beta2: float
|
57 |
+
|
58 |
+
# Puzzle embedding
|
59 |
+
puzzle_emb_lr: float
|
60 |
+
puzzle_emb_weight_decay: float
|
61 |
+
|
62 |
+
# Names
|
63 |
+
project_name: Optional[str] = None
|
64 |
+
run_name: Optional[str] = None
|
65 |
+
checkpoint_path: Optional[str] = None
|
66 |
+
|
67 |
+
# Extras
|
68 |
+
seed: int = 0
|
69 |
+
checkpoint_every_eval: bool = False
|
70 |
+
eval_interval: Optional[int] = None
|
71 |
+
eval_save_outputs: List[str] = []
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass
|
75 |
+
class TrainState:
|
76 |
+
model: nn.Module
|
77 |
+
optimizers: Sequence[torch.optim.Optimizer]
|
78 |
+
optimizer_lrs: Sequence[float]
|
79 |
+
carry: Any
|
80 |
+
|
81 |
+
step: int
|
82 |
+
total_steps: int
|
83 |
+
|
84 |
+
|
85 |
+
def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
|
86 |
+
dataset = PuzzleDataset(PuzzleDatasetConfig(
|
87 |
+
seed=config.seed,
|
88 |
+
|
89 |
+
dataset_path=config.data_path,
|
90 |
+
|
91 |
+
rank=rank,
|
92 |
+
num_replicas=world_size,
|
93 |
+
|
94 |
+
**kwargs
|
95 |
+
), split=split)
|
96 |
+
dataloader = DataLoader(
|
97 |
+
dataset,
|
98 |
+
batch_size=None,
|
99 |
+
|
100 |
+
num_workers=1,
|
101 |
+
prefetch_factor=8,
|
102 |
+
|
103 |
+
pin_memory=True,
|
104 |
+
persistent_workers=True
|
105 |
+
)
|
106 |
+
return dataloader, dataset.metadata
|
107 |
+
|
108 |
+
|
109 |
+
def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
|
110 |
+
model_cfg = dict(
|
111 |
+
**config.arch.__pydantic_extra__, # type: ignore
|
112 |
+
|
113 |
+
batch_size=config.global_batch_size // world_size,
|
114 |
+
|
115 |
+
vocab_size=train_metadata.vocab_size,
|
116 |
+
seq_len=train_metadata.seq_len,
|
117 |
+
num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
|
118 |
+
causal=False # Non-autoregressive
|
119 |
+
)
|
120 |
+
|
121 |
+
# Instantiate model with loss head
|
122 |
+
model_cls = load_model_class(config.arch.name)
|
123 |
+
loss_head_cls = load_model_class(config.arch.loss.name)
|
124 |
+
|
125 |
+
with torch.device("cuda"):
|
126 |
+
model: nn.Module = model_cls(model_cfg)
|
127 |
+
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
|
128 |
+
if "DISABLE_COMPILE" not in os.environ:
|
129 |
+
model = torch.compile(model, dynamic=False, fullgraph=True) # type: ignore
|
130 |
+
|
131 |
+
# Broadcast parameters from rank 0
|
132 |
+
if world_size > 1:
|
133 |
+
with torch.no_grad():
|
134 |
+
for param in list(model.parameters()) + list(model.buffers()):
|
135 |
+
dist.broadcast(param, src=0)
|
136 |
+
|
137 |
+
# Optimizers and lr
|
138 |
+
optimizers = [
|
139 |
+
CastedSparseEmbeddingSignSGD_Distributed(
|
140 |
+
model.model.puzzle_emb.buffers(), # type: ignore
|
141 |
+
|
142 |
+
lr=0, # Needs to be set by scheduler
|
143 |
+
weight_decay=config.puzzle_emb_weight_decay,
|
144 |
+
|
145 |
+
world_size=world_size
|
146 |
+
),
|
147 |
+
AdamATan2(
|
148 |
+
model.parameters(),
|
149 |
+
|
150 |
+
lr=0, # Needs to be set by scheduler
|
151 |
+
weight_decay=config.weight_decay,
|
152 |
+
betas=(config.beta1, config.beta2)
|
153 |
+
)
|
154 |
+
]
|
155 |
+
optimizer_lrs = [
|
156 |
+
config.puzzle_emb_lr,
|
157 |
+
config.lr
|
158 |
+
]
|
159 |
+
|
160 |
+
return model, optimizers, optimizer_lrs
|
161 |
+
|
162 |
+
|
163 |
+
def cosine_schedule_with_warmup_lr_lambda(
|
164 |
+
current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
|
165 |
+
):
|
166 |
+
if current_step < num_warmup_steps:
|
167 |
+
return base_lr * float(current_step) / float(max(1, num_warmup_steps))
|
168 |
+
|
169 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
170 |
+
return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
|
171 |
+
|
172 |
+
|
173 |
+
def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
|
174 |
+
# Estimated total training steps
|
175 |
+
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
|
176 |
+
|
177 |
+
# Model
|
178 |
+
model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)
|
179 |
+
|
180 |
+
return TrainState(
|
181 |
+
step=0,
|
182 |
+
total_steps=total_steps,
|
183 |
+
|
184 |
+
model=model,
|
185 |
+
optimizers=optimizers,
|
186 |
+
optimizer_lrs=optimizer_lrs,
|
187 |
+
carry=None
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
def save_train_state(config: PretrainConfig, train_state: TrainState):
|
192 |
+
# FIXME: Only saved model.
|
193 |
+
if config.checkpoint_path is None:
|
194 |
+
return
|
195 |
+
|
196 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
197 |
+
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
|
198 |
+
|
199 |
+
|
200 |
+
def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
|
201 |
+
return cosine_schedule_with_warmup_lr_lambda(
|
202 |
+
current_step=train_state.step,
|
203 |
+
base_lr=base_lr,
|
204 |
+
num_warmup_steps=round(config.lr_warmup_steps),
|
205 |
+
num_training_steps=train_state.total_steps,
|
206 |
+
min_ratio=config.lr_min_ratio
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
|
211 |
+
train_state.step += 1
|
212 |
+
if train_state.step > train_state.total_steps: # At most train_total_steps
|
213 |
+
return
|
214 |
+
|
215 |
+
# To device
|
216 |
+
batch = {k: v.cuda() for k, v in batch.items()}
|
217 |
+
|
218 |
+
# Init carry if it is None
|
219 |
+
if train_state.carry is None:
|
220 |
+
with torch.device("cuda"):
|
221 |
+
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
|
222 |
+
|
223 |
+
# Forward
|
224 |
+
train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
|
225 |
+
|
226 |
+
((1 / global_batch_size) * loss).backward()
|
227 |
+
|
228 |
+
# Allreduce
|
229 |
+
if world_size > 1:
|
230 |
+
for param in train_state.model.parameters():
|
231 |
+
if param.grad is not None:
|
232 |
+
dist.all_reduce(param.grad)
|
233 |
+
|
234 |
+
# Apply optimizer
|
235 |
+
lr_this_step = None
|
236 |
+
for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
|
237 |
+
lr_this_step = compute_lr(base_lr, config, train_state)
|
238 |
+
|
239 |
+
for param_group in optim.param_groups:
|
240 |
+
param_group['lr'] = lr_this_step
|
241 |
+
|
242 |
+
optim.step()
|
243 |
+
optim.zero_grad()
|
244 |
+
|
245 |
+
# Reduce metrics
|
246 |
+
if len(metrics):
|
247 |
+
assert not any(v.requires_grad for v in metrics.values())
|
248 |
+
|
249 |
+
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
|
250 |
+
# Reduce and reconstruct
|
251 |
+
metric_values = torch.stack([metrics[k] for k in metric_keys])
|
252 |
+
if world_size > 1:
|
253 |
+
dist.reduce(metric_values, dst=0)
|
254 |
+
|
255 |
+
if rank == 0:
|
256 |
+
metric_values = metric_values.cpu().numpy()
|
257 |
+
reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
|
258 |
+
|
259 |
+
# Postprocess
|
260 |
+
count = max(reduced_metrics["count"], 1) # Avoid NaNs
|
261 |
+
reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
|
262 |
+
|
263 |
+
reduced_metrics["train/lr"] = lr_this_step
|
264 |
+
return reduced_metrics
|
265 |
+
|
266 |
+
|
267 |
+
def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
|
268 |
+
with torch.inference_mode():
|
269 |
+
set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
|
270 |
+
|
271 |
+
all_preds = {}
|
272 |
+
|
273 |
+
metric_keys = []
|
274 |
+
metric_values = None
|
275 |
+
metric_global_batch_size = [0 for _ in range(len(set_ids))]
|
276 |
+
|
277 |
+
carry = None
|
278 |
+
for set_name, batch, global_batch_size in eval_loader:
|
279 |
+
# To device
|
280 |
+
batch = {k: v.cuda() for k, v in batch.items()}
|
281 |
+
with torch.device("cuda"):
|
282 |
+
carry = train_state.model.initial_carry(batch) # type: ignore
|
283 |
+
|
284 |
+
# Forward
|
285 |
+
while True:
|
286 |
+
carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
|
287 |
+
|
288 |
+
if all_finish:
|
289 |
+
break
|
290 |
+
|
291 |
+
for collection in (batch, preds):
|
292 |
+
for k, v in collection.items():
|
293 |
+
if k in config.eval_save_outputs:
|
294 |
+
all_preds.setdefault(k, [])
|
295 |
+
all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
|
296 |
+
|
297 |
+
del carry, preds, batch, all_finish
|
298 |
+
|
299 |
+
# Aggregate
|
300 |
+
set_id = set_ids[set_name]
|
301 |
+
|
302 |
+
if metric_values is None:
|
303 |
+
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
|
304 |
+
metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
|
305 |
+
|
306 |
+
metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
|
307 |
+
metric_global_batch_size[set_id] += global_batch_size
|
308 |
+
|
309 |
+
if len(all_preds) and config.checkpoint_path is not None:
|
310 |
+
all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
|
311 |
+
|
312 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
313 |
+
torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))
|
314 |
+
|
315 |
+
# Logging
|
316 |
+
# Reduce to rank 0
|
317 |
+
if metric_values is not None:
|
318 |
+
if world_size > 1:
|
319 |
+
dist.reduce(metric_values, dst=0)
|
320 |
+
|
321 |
+
if rank == 0:
|
322 |
+
reduced_metrics = metric_values.cpu().numpy()
|
323 |
+
reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
|
324 |
+
for set_id, set_name in enumerate(set_ids)}
|
325 |
+
|
326 |
+
# Postprocess
|
327 |
+
for set_name, metrics in reduced_metrics.items():
|
328 |
+
count = metrics.pop("count")
|
329 |
+
reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}
|
330 |
+
|
331 |
+
return reduced_metrics
|
332 |
+
|
333 |
+
|
334 |
+
def save_code_and_config(config: PretrainConfig):
|
335 |
+
if config.checkpoint_path is None or wandb.run is None:
|
336 |
+
return
|
337 |
+
|
338 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
339 |
+
|
340 |
+
# Copy code
|
341 |
+
code_list = [
|
342 |
+
get_model_source_path(config.arch.name),
|
343 |
+
get_model_source_path(config.arch.loss.name)
|
344 |
+
]
|
345 |
+
for code_file in code_list:
|
346 |
+
if code_file is not None:
|
347 |
+
code_name = os.path.basename(code_file)
|
348 |
+
|
349 |
+
shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
|
350 |
+
|
351 |
+
# Dump config as yaml
|
352 |
+
config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
|
353 |
+
with open(config_file, "wt") as f:
|
354 |
+
yaml.dump(config.model_dump(), f)
|
355 |
+
|
356 |
+
# Log code
|
357 |
+
wandb.run.log_code(config.checkpoint_path)
|
358 |
+
|
359 |
+
|
360 |
+
def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
|
361 |
+
objects = [None]
|
362 |
+
if rank == 0:
|
363 |
+
config = PretrainConfig(**hydra_config) # type: ignore
|
364 |
+
|
365 |
+
# Naming
|
366 |
+
if config.project_name is None:
|
367 |
+
config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
|
368 |
+
if config.run_name is None:
|
369 |
+
config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
|
370 |
+
if config.checkpoint_path is None:
|
371 |
+
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
|
372 |
+
|
373 |
+
objects = [config]
|
374 |
+
|
375 |
+
if world_size > 1:
|
376 |
+
dist.broadcast_object_list(objects, src=0)
|
377 |
+
|
378 |
+
return objects[0] # type: ignore
|
379 |
+
|
380 |
+
|
381 |
+
@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
|
382 |
+
def launch(hydra_config: DictConfig):
|
383 |
+
RANK = 0
|
384 |
+
WORLD_SIZE = 1
|
385 |
+
|
386 |
+
# Initialize distributed training if in distributed environment (e.g. torchrun)
|
387 |
+
if "LOCAL_RANK" in os.environ:
|
388 |
+
# Initialize distributed, default device and dtype
|
389 |
+
dist.init_process_group(backend="nccl")
|
390 |
+
|
391 |
+
RANK = dist.get_rank()
|
392 |
+
WORLD_SIZE = dist.get_world_size()
|
393 |
+
|
394 |
+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
395 |
+
|
396 |
+
# Load sync'ed config
|
397 |
+
config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
|
398 |
+
|
399 |
+
# Seed RNGs to ensure consistency
|
400 |
+
torch.random.manual_seed(config.seed + RANK)
|
401 |
+
|
402 |
+
# Dataset
|
403 |
+
train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
|
404 |
+
total_iters = config.epochs // train_epochs_per_iter
|
405 |
+
|
406 |
+
assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
|
407 |
+
|
408 |
+
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
409 |
+
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
410 |
+
|
411 |
+
# Train state
|
412 |
+
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
|
413 |
+
|
414 |
+
# Progress bar and logger
|
415 |
+
progress_bar = None
|
416 |
+
if RANK == 0:
|
417 |
+
progress_bar = tqdm.tqdm(total=train_state.total_steps)
|
418 |
+
|
419 |
+
wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
|
420 |
+
wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
|
421 |
+
save_code_and_config(config)
|
422 |
+
|
423 |
+
# Training Loop
|
424 |
+
for _iter_id in range(total_iters):
|
425 |
+
print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
|
426 |
+
|
427 |
+
############ Train Iter
|
428 |
+
train_state.model.train()
|
429 |
+
for set_name, batch, global_batch_size in train_loader:
|
430 |
+
metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
431 |
+
|
432 |
+
if RANK == 0 and metrics is not None:
|
433 |
+
wandb.log(metrics, step=train_state.step)
|
434 |
+
progress_bar.update(train_state.step - progress_bar.n) # type: ignore
|
435 |
+
|
436 |
+
############ Evaluation
|
437 |
+
train_state.model.eval()
|
438 |
+
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
|
439 |
+
|
440 |
+
if RANK == 0 and metrics is not None:
|
441 |
+
wandb.log(metrics, step=train_state.step)
|
442 |
+
|
443 |
+
############ Checkpointing
|
444 |
+
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
|
445 |
+
save_train_state(config, train_state)
|
446 |
+
|
447 |
+
# finalize
|
448 |
+
if dist.is_initialized():
|
449 |
+
dist.destroy_process_group()
|
450 |
+
wandb.finish()
|
451 |
+
|
452 |
+
|
453 |
+
if __name__ == "__main__":
|
454 |
+
launch()
|
puzzle_dataset.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pydantic
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import IterableDataset, get_worker_info
|
9 |
+
|
10 |
+
from models.losses import IGNORE_LABEL_ID
|
11 |
+
from dataset.common import PuzzleDatasetMetadata
|
12 |
+
|
13 |
+
|
14 |
+
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
|
15 |
+
# Pack examples into a full batch
|
16 |
+
batch = []
|
17 |
+
batch_puzzle_indices = []
|
18 |
+
current_size = 0
|
19 |
+
|
20 |
+
while (start_index < group_order.size) and (current_size < global_batch_size):
|
21 |
+
# Pick a group and a puzzle from that group
|
22 |
+
group_id = group_order[start_index]
|
23 |
+
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
|
24 |
+
start_index += 1
|
25 |
+
|
26 |
+
# Get range of the puzzle
|
27 |
+
puzzle_start = puzzle_indices[puzzle_id]
|
28 |
+
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
|
29 |
+
|
30 |
+
append_size = min(puzzle_size, global_batch_size - current_size)
|
31 |
+
|
32 |
+
# Put into batch
|
33 |
+
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
|
34 |
+
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
|
35 |
+
|
36 |
+
current_size += append_size
|
37 |
+
|
38 |
+
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
|
39 |
+
|
40 |
+
|
41 |
+
class PuzzleDatasetConfig(pydantic.BaseModel):
|
42 |
+
seed: int
|
43 |
+
dataset_path: str
|
44 |
+
global_batch_size: int
|
45 |
+
test_set_mode: bool
|
46 |
+
|
47 |
+
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
|
48 |
+
|
49 |
+
rank: int
|
50 |
+
num_replicas: int
|
51 |
+
|
52 |
+
|
53 |
+
class PuzzleDataset(IterableDataset):
|
54 |
+
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
|
55 |
+
super().__init__()
|
56 |
+
self.config = config
|
57 |
+
self.split = split
|
58 |
+
self.metadata = self._load_metadata()
|
59 |
+
|
60 |
+
# Checks
|
61 |
+
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
|
62 |
+
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
|
63 |
+
|
64 |
+
# State
|
65 |
+
self._data = None
|
66 |
+
self._iters = 0
|
67 |
+
|
68 |
+
def _load_metadata(self) -> PuzzleDatasetMetadata:
|
69 |
+
with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
|
70 |
+
return PuzzleDatasetMetadata(**json.load(f))
|
71 |
+
|
72 |
+
def _lazy_load_dataset(self):
|
73 |
+
if self._data is not None:
|
74 |
+
return
|
75 |
+
|
76 |
+
field_mmap_modes = {
|
77 |
+
"inputs": "r",
|
78 |
+
"labels": "r",
|
79 |
+
|
80 |
+
# Keep indices in memory
|
81 |
+
"puzzle_identifiers": None,
|
82 |
+
"puzzle_indices": None,
|
83 |
+
"group_indices": None
|
84 |
+
}
|
85 |
+
|
86 |
+
# Load data
|
87 |
+
self._data = {}
|
88 |
+
for set_name in self.metadata.sets:
|
89 |
+
# Load subset
|
90 |
+
self._data[set_name] = {
|
91 |
+
field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
|
92 |
+
for field_name, mmap_mode in field_mmap_modes.items()
|
93 |
+
}
|
94 |
+
|
95 |
+
def _collate_batch(self, batch):
|
96 |
+
# Convert dtype
|
97 |
+
batch = {k: v.astype(np.int32) for k, v in batch.items()}
|
98 |
+
|
99 |
+
# Convert ignore label IDs
|
100 |
+
if self.metadata.ignore_label_id is not None:
|
101 |
+
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
|
102 |
+
|
103 |
+
# Pad
|
104 |
+
if batch["puzzle_identifiers"].size < self.local_batch_size:
|
105 |
+
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
|
106 |
+
|
107 |
+
pad_values = {
|
108 |
+
"inputs": self.metadata.pad_id,
|
109 |
+
"labels": IGNORE_LABEL_ID,
|
110 |
+
|
111 |
+
"puzzle_identifiers": self.metadata.blank_identifier_id
|
112 |
+
}
|
113 |
+
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
|
114 |
+
|
115 |
+
# To tensor
|
116 |
+
return {k: torch.from_numpy(v) for k, v in batch.items()}
|
117 |
+
|
118 |
+
def _iter_test(self):
|
119 |
+
for set_name, dataset in self._data.items(): # type: ignore
|
120 |
+
total_examples = len(dataset["inputs"])
|
121 |
+
|
122 |
+
# Load examples one by one
|
123 |
+
start_index = 0
|
124 |
+
while start_index < total_examples:
|
125 |
+
# Compute indices
|
126 |
+
end_index = min(total_examples, start_index + self.config.global_batch_size)
|
127 |
+
|
128 |
+
local_start = start_index + self.config.rank * self.local_batch_size
|
129 |
+
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
|
130 |
+
|
131 |
+
# Get batch of examples, and also puzzle IDs
|
132 |
+
puzzle_indices = []
|
133 |
+
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
|
134 |
+
for i in range(local_start, local_end):
|
135 |
+
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
|
136 |
+
puzzle_index += 1
|
137 |
+
|
138 |
+
puzzle_indices.append(puzzle_index)
|
139 |
+
|
140 |
+
batch = self._collate_batch({
|
141 |
+
"inputs": dataset["inputs"][local_start: local_end],
|
142 |
+
"labels": dataset["labels"][local_start: local_end],
|
143 |
+
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
|
144 |
+
})
|
145 |
+
|
146 |
+
yield set_name, batch, end_index - start_index
|
147 |
+
|
148 |
+
# Advance to next batch
|
149 |
+
start_index += self.config.global_batch_size
|
150 |
+
|
151 |
+
def _iter_train(self):
|
152 |
+
for set_name, dataset in self._data.items(): # type: ignore
|
153 |
+
# Increase epoch count
|
154 |
+
self._iters += 1
|
155 |
+
|
156 |
+
# Randomly shuffle groups
|
157 |
+
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
|
158 |
+
|
159 |
+
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
|
160 |
+
start_index = 0
|
161 |
+
|
162 |
+
while start_index < group_order.size:
|
163 |
+
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
|
164 |
+
rng,
|
165 |
+
group_order=group_order,
|
166 |
+
puzzle_indices=dataset["puzzle_indices"],
|
167 |
+
group_indices=dataset["group_indices"],
|
168 |
+
start_index=start_index,
|
169 |
+
global_batch_size=self.config.global_batch_size,
|
170 |
+
)
|
171 |
+
|
172 |
+
# Select current rank and collate
|
173 |
+
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
|
174 |
+
|
175 |
+
# Drop last batch
|
176 |
+
if global_effective_batch_size < self.config.global_batch_size:
|
177 |
+
break
|
178 |
+
|
179 |
+
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
180 |
+
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
181 |
+
batch = self._collate_batch({
|
182 |
+
"inputs": dataset["inputs"][batch_indices],
|
183 |
+
"labels": dataset["labels"][batch_indices],
|
184 |
+
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
|
185 |
+
})
|
186 |
+
|
187 |
+
yield set_name, batch, global_effective_batch_size
|
188 |
+
|
189 |
+
def __iter__(self):
|
190 |
+
worker_info = get_worker_info()
|
191 |
+
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
|
192 |
+
|
193 |
+
self._lazy_load_dataset()
|
194 |
+
|
195 |
+
# Iterate using specified mode
|
196 |
+
if self.config.test_set_mode:
|
197 |
+
yield from self._iter_test()
|
198 |
+
else:
|
199 |
+
yield from self._iter_train()
|
puzzle_visualizer.html
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8" />
|
5 |
+
<title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>
|
6 |
+
<style>
|
7 |
+
body {
|
8 |
+
font-family: sans-serif;
|
9 |
+
margin: 16px;
|
10 |
+
}
|
11 |
+
.selector-area {
|
12 |
+
margin-bottom: 1rem;
|
13 |
+
}
|
14 |
+
.grid-canvas {
|
15 |
+
margin: 4px;
|
16 |
+
border: 1px solid #ccc;
|
17 |
+
}
|
18 |
+
.example-container {
|
19 |
+
display: inline-block;
|
20 |
+
margin: 0 16px 16px 0;
|
21 |
+
vertical-align: top;
|
22 |
+
}
|
23 |
+
.puzzle-display {
|
24 |
+
margin-top: 1rem;
|
25 |
+
}
|
26 |
+
.puzzle-id {
|
27 |
+
font-weight: bold;
|
28 |
+
margin-bottom: 0.5rem;
|
29 |
+
}
|
30 |
+
#groupList, #puzzleList {
|
31 |
+
margin: 1rem 0;
|
32 |
+
}
|
33 |
+
.group-item, .puzzle-item {
|
34 |
+
cursor: pointer;
|
35 |
+
margin: 4px 8px 4px 0;
|
36 |
+
padding: 2px 6px;
|
37 |
+
border: 1px solid #aaa;
|
38 |
+
display: inline-block;
|
39 |
+
}
|
40 |
+
.group-item:hover, .puzzle-item:hover {
|
41 |
+
background: #eef;
|
42 |
+
}
|
43 |
+
</style>
|
44 |
+
</head>
|
45 |
+
<body>
|
46 |
+
<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>
|
47 |
+
|
48 |
+
<div class="selector-area">
|
49 |
+
<!-- 1) Directory input with webkitdirectory, mozdirectory -->
|
50 |
+
<label>Upload ARC Folder:</label>
|
51 |
+
<input type="file" id="folderInput"
|
52 |
+
webkitdirectory mozdirectory multiple
|
53 |
+
onchange="onFolderSelected(event)" />
|
54 |
+
<br><br>
|
55 |
+
|
56 |
+
<!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->
|
57 |
+
<label>Set:</label>
|
58 |
+
<select id="setSelect" disabled>
|
59 |
+
<option value="train">train</option>
|
60 |
+
<option value="test">test</option>
|
61 |
+
</select>
|
62 |
+
|
63 |
+
<label> Subset:</label>
|
64 |
+
<select id="subsetSelect" disabled>
|
65 |
+
<option value="all">all</option>
|
66 |
+
</select>
|
67 |
+
|
68 |
+
<button id="loadBtn" disabled>Load</button>
|
69 |
+
</div>
|
70 |
+
|
71 |
+
<div>
|
72 |
+
<div id="groupList"></div>
|
73 |
+
<div id="puzzleList"></div>
|
74 |
+
<div class="puzzle-display" id="puzzleView"></div>
|
75 |
+
</div>
|
76 |
+
|
77 |
+
<!--
|
78 |
+
3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.
|
79 |
+
Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't
|
80 |
+
contain "import" statements.
|
81 |
+
-->
|
82 |
+
<script src="assets/npyjs.js"></script>
|
83 |
+
|
84 |
+
<script>
|
85 |
+
/***************************************************************************
|
86 |
+
* Global Maps & Variables
|
87 |
+
***************************************************************************/
|
88 |
+
|
89 |
+
// Map from "train/all__inputs.npy" => File, etc.
|
90 |
+
let filesByPath = {};
|
91 |
+
|
92 |
+
// Once loaded, we store typed arrays for the chosen set/subset
|
93 |
+
let inputsArr, labelsArr;
|
94 |
+
let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;
|
95 |
+
let identifiersJson;
|
96 |
+
|
97 |
+
// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize
|
98 |
+
let seqLen = 0;
|
99 |
+
let gridSize = 0;
|
100 |
+
|
101 |
+
|
102 |
+
/***************************************************************************
|
103 |
+
* 1) Handle folder selection: read all files, find identifiers.json,
|
104 |
+
* remove topmost folder from each file path, validate.
|
105 |
+
***************************************************************************/
|
106 |
+
function onFolderSelected(event) {
|
107 |
+
filesByPath = {};
|
108 |
+
const fileList = event.target.files;
|
109 |
+
if (!fileList || fileList.length === 0) {
|
110 |
+
alert("No files selected!");
|
111 |
+
return;
|
112 |
+
}
|
113 |
+
|
114 |
+
// We'll gather all webkitRelativePaths
|
115 |
+
const paths = [];
|
116 |
+
for (let i = 0; i < fileList.length; i++) {
|
117 |
+
// Typically "arc-aug-10/train/all__inputs.npy", etc.
|
118 |
+
const file = fileList[i];
|
119 |
+
const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
|
120 |
+
paths.push(relPath);
|
121 |
+
}
|
122 |
+
|
123 |
+
// 1. Check if we have "identifiers.json" somewhere.
|
124 |
+
const idPath = paths.find(p => p.endsWith("identifiers.json"));
|
125 |
+
if (!idPath) {
|
126 |
+
alert("Error: No 'identifiers.json' found in the uploaded folder.");
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
// 2. Derive the top-level directory from that file's path
|
131 |
+
// e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10"
|
132 |
+
// If there's no slash, topDir = "" => do nothing
|
133 |
+
let topDir = "";
|
134 |
+
const lastSlash = idPath.lastIndexOf("/");
|
135 |
+
if (lastSlash >= 0) {
|
136 |
+
topDir = idPath.substring(0, lastSlash);
|
137 |
+
}
|
138 |
+
|
139 |
+
// 3. Rebuild filesByPath with the top folder removed.
|
140 |
+
// For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy"
|
141 |
+
// becomes "train/all__inputs.npy"
|
142 |
+
for (let i = 0; i < fileList.length; i++) {
|
143 |
+
const file = fileList[i];
|
144 |
+
let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
|
145 |
+
// If relPath starts with "arc-aug-10/", remove that prefix
|
146 |
+
if (topDir && relPath.startsWith(topDir + "/")) {
|
147 |
+
relPath = relPath.substring(topDir.length + 1);
|
148 |
+
}
|
149 |
+
filesByPath[relPath] = file;
|
150 |
+
}
|
151 |
+
|
152 |
+
// Enable set/subset selection and "Load"
|
153 |
+
document.getElementById("setSelect").disabled = false;
|
154 |
+
document.getElementById("subsetSelect").disabled = false;
|
155 |
+
document.getElementById("loadBtn").disabled = false;
|
156 |
+
}
|
157 |
+
|
158 |
+
// When user clicks "Load," parse the .npy for the chosen set/subset
|
159 |
+
document.getElementById("loadBtn").addEventListener("click", async () => {
|
160 |
+
document.getElementById("groupList").innerHTML = "";
|
161 |
+
document.getElementById("puzzleList").innerHTML = "";
|
162 |
+
document.getElementById("puzzleView").innerHTML = "";
|
163 |
+
|
164 |
+
const setName = document.getElementById("setSelect").value; // e.g. "train"
|
165 |
+
const subsetName = document.getElementById("subsetSelect").value; // e.g. "all"
|
166 |
+
|
167 |
+
try {
|
168 |
+
await loadDataset(setName, subsetName);
|
169 |
+
buildGroupList(); // show groups
|
170 |
+
} catch (err) {
|
171 |
+
console.error(err);
|
172 |
+
alert("Error while loading dataset: " + err);
|
173 |
+
}
|
174 |
+
});
|
175 |
+
|
176 |
+
|
177 |
+
/***************************************************************************
|
178 |
+
* 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)
|
179 |
+
***************************************************************************/
|
180 |
+
async function loadDataset(setName, subsetName) {
|
181 |
+
const prefix = `${setName}/${subsetName}__`;
|
182 |
+
// e.g. "train/all__inputs.npy"
|
183 |
+
const inputsPath = prefix + "inputs.npy";
|
184 |
+
const labelsPath = prefix + "labels.npy";
|
185 |
+
const pIdxPath = prefix + "puzzle_indices.npy";
|
186 |
+
const gIdxPath = prefix + "group_indices.npy";
|
187 |
+
const pIdsPath = prefix + "puzzle_identifiers.npy";
|
188 |
+
const identifiersPath = "identifiers.json";
|
189 |
+
|
190 |
+
// Check existence
|
191 |
+
const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];
|
192 |
+
for (const f of needed) {
|
193 |
+
if (!filesByPath[f]) {
|
194 |
+
throw new Error(`Missing file: ${f}`);
|
195 |
+
}
|
196 |
+
}
|
197 |
+
|
198 |
+
// parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array
|
199 |
+
const inputsNpy = await parseNpy(filesByPath[inputsPath]);
|
200 |
+
const labelsNpy = await parseNpy(filesByPath[labelsPath]);
|
201 |
+
const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);
|
202 |
+
const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);
|
203 |
+
const puzzleIdsNpy = await parseNpy(filesByPath[pIdsPath]);
|
204 |
+
|
205 |
+
inputsArr = inputsNpy.data;
|
206 |
+
labelsArr = labelsNpy.data;
|
207 |
+
puzzleIndicesArr = puzzleIndicesNpy.data;
|
208 |
+
groupIndicesArr = groupIndicesNpy.data;
|
209 |
+
puzzleIdentifiersArr = puzzleIdsNpy.data;
|
210 |
+
|
211 |
+
// shape e.g. [N_examples, seqLen]
|
212 |
+
seqLen = inputsNpy.shape[1];
|
213 |
+
gridSize = Math.sqrt(seqLen);
|
214 |
+
|
215 |
+
// read JSON
|
216 |
+
identifiersJson = await readJsonFile(filesByPath[identifiersPath]);
|
217 |
+
}
|
218 |
+
|
219 |
+
/***************************************************************************
|
220 |
+
* parseNpy => read a File as ArrayBuffer, parse with npyjs
|
221 |
+
***************************************************************************/
|
222 |
+
function parseNpy(file) {
|
223 |
+
return new Promise((resolve, reject) => {
|
224 |
+
const reader = new FileReader();
|
225 |
+
reader.onload = async () => {
|
226 |
+
try {
|
227 |
+
const arrayBuffer = reader.result;
|
228 |
+
const npy = new npyjs();
|
229 |
+
resolve(await npy.parse(arrayBuffer));
|
230 |
+
} catch (err) {
|
231 |
+
reject(err);
|
232 |
+
}
|
233 |
+
};
|
234 |
+
reader.onerror = err => reject(err);
|
235 |
+
reader.readAsArrayBuffer(file);
|
236 |
+
});
|
237 |
+
}
|
238 |
+
|
239 |
+
/***************************************************************************
|
240 |
+
* readJsonFile => read a local JSON file into object
|
241 |
+
***************************************************************************/
|
242 |
+
function readJsonFile(file) {
|
243 |
+
return new Promise((resolve, reject) => {
|
244 |
+
const reader = new FileReader();
|
245 |
+
reader.onload = () => {
|
246 |
+
try {
|
247 |
+
const obj = JSON.parse(reader.result);
|
248 |
+
resolve(obj);
|
249 |
+
} catch (err) {
|
250 |
+
reject(err);
|
251 |
+
}
|
252 |
+
};
|
253 |
+
reader.onerror = (err) => reject(err);
|
254 |
+
reader.readAsText(file);
|
255 |
+
});
|
256 |
+
}
|
257 |
+
|
258 |
+
/***************************************************************************
|
259 |
+
* 3) Build group list in UI
|
260 |
+
***************************************************************************/
|
261 |
+
function buildGroupList() {
|
262 |
+
document.getElementById("groupList").innerHTML = "<h3>Groups</h3>";
|
263 |
+
const groupListDiv = document.getElementById("groupList");
|
264 |
+
|
265 |
+
const nGroups = groupIndicesArr.length - 1;
|
266 |
+
for (let g = 0; g < nGroups; g++) {
|
267 |
+
const div = document.createElement("span");
|
268 |
+
div.className = "group-item";
|
269 |
+
div.textContent = `Group ${g}`;
|
270 |
+
div.onclick = () => onSelectGroup(g);
|
271 |
+
groupListDiv.appendChild(div);
|
272 |
+
}
|
273 |
+
}
|
274 |
+
|
275 |
+
/***************************************************************************
|
276 |
+
* onSelectGroup => show puzzles in that group
|
277 |
+
***************************************************************************/
|
278 |
+
function onSelectGroup(groupIndex) {
|
279 |
+
document.getElementById("puzzleList").innerHTML = "";
|
280 |
+
document.getElementById("puzzleView").innerHTML = "";
|
281 |
+
|
282 |
+
const puzzleListDiv = document.getElementById("puzzleList");
|
283 |
+
puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;
|
284 |
+
|
285 |
+
const firstPuzzle = groupIndicesArr[groupIndex];
|
286 |
+
const lastPuzzle = groupIndicesArr[groupIndex + 1];
|
287 |
+
|
288 |
+
for (let p = firstPuzzle; p < lastPuzzle; p++) {
|
289 |
+
const puzzleIntId = puzzleIdentifiersArr[p];
|
290 |
+
const puzzleStrId = (puzzleIntId < identifiersJson.length)
|
291 |
+
? identifiersJson[puzzleIntId]
|
292 |
+
: "<unknown>";
|
293 |
+
|
294 |
+
const div = document.createElement("span");
|
295 |
+
div.className = "puzzle-item";
|
296 |
+
div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;
|
297 |
+
div.onclick = () => onSelectPuzzle(p);
|
298 |
+
puzzleListDiv.appendChild(div);
|
299 |
+
}
|
300 |
+
}
|
301 |
+
|
302 |
+
/***************************************************************************
|
303 |
+
* onSelectPuzzle => show each example
|
304 |
+
***************************************************************************/
|
305 |
+
function onSelectPuzzle(puzzleIndex) {
|
306 |
+
const puzzleView = document.getElementById("puzzleView");
|
307 |
+
puzzleView.innerHTML = "";
|
308 |
+
|
309 |
+
// puzzle ID
|
310 |
+
const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];
|
311 |
+
const puzzleStrId = (puzzleIntId < identifiersJson.length)
|
312 |
+
? identifiersJson[puzzleIntId]
|
313 |
+
: "<unknown>";
|
314 |
+
|
315 |
+
const titleDiv = document.createElement("div");
|
316 |
+
titleDiv.className = "puzzle-id";
|
317 |
+
titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;
|
318 |
+
puzzleView.appendChild(titleDiv);
|
319 |
+
|
320 |
+
// Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])
|
321 |
+
const firstExample = puzzleIndicesArr[puzzleIndex];
|
322 |
+
const lastExample = puzzleIndicesArr[puzzleIndex + 1];
|
323 |
+
|
324 |
+
for (let e = firstExample; e < lastExample; e++) {
|
325 |
+
const inputSeq = slice1D(inputsArr, e*seqLen, (e+1)*seqLen);
|
326 |
+
const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);
|
327 |
+
|
328 |
+
const inputGrid = decodeGrid(inputSeq);
|
329 |
+
const outputGrid = decodeGrid(outputSeq);
|
330 |
+
|
331 |
+
const exDiv = document.createElement("div");
|
332 |
+
exDiv.className = "example-container";
|
333 |
+
exDiv.appendChild(document.createTextNode(`Example ${e}`));
|
334 |
+
exDiv.appendChild(document.createElement("br"));
|
335 |
+
|
336 |
+
exDiv.appendChild(renderGrid(inputGrid));
|
337 |
+
exDiv.appendChild(renderGrid(outputGrid));
|
338 |
+
|
339 |
+
puzzleView.appendChild(exDiv);
|
340 |
+
}
|
341 |
+
}
|
342 |
+
|
343 |
+
/***************************************************************************
|
344 |
+
* slice1D => typed array slicing
|
345 |
+
***************************************************************************/
|
346 |
+
function slice1D(arr, start, end) {
|
347 |
+
const result = new Uint32Array(end - start);
|
348 |
+
for (let i = start; i < end; i++) {
|
349 |
+
result[i - start] = Number(arr[i]);
|
350 |
+
}
|
351 |
+
return result;
|
352 |
+
}
|
353 |
+
|
354 |
+
/***************************************************************************
|
355 |
+
* decodeGrid => turn the flattened seq of length=gridSize^2 into 2D
|
356 |
+
***************************************************************************/
|
357 |
+
function decodeGrid(seq) {
|
358 |
+
const grid = [];
|
359 |
+
let idx = 0;
|
360 |
+
for (let r = 0; r < gridSize; r++) {
|
361 |
+
const row = [];
|
362 |
+
for (let c = 0; c < gridSize; c++) {
|
363 |
+
row.push(seq[idx]);
|
364 |
+
idx++;
|
365 |
+
}
|
366 |
+
grid.push(row);
|
367 |
+
}
|
368 |
+
return grid;
|
369 |
+
}
|
370 |
+
|
371 |
+
/***************************************************************************
|
372 |
+
* renderGrid => draws a 2D grid to <canvas>
|
373 |
+
***************************************************************************/
|
374 |
+
function renderGrid(grid2d) {
|
375 |
+
const rows = grid2d.length;
|
376 |
+
const cols = grid2d[0].length;
|
377 |
+
const scale = 10;
|
378 |
+
|
379 |
+
const canvas = document.createElement("canvas");
|
380 |
+
canvas.width = cols * scale;
|
381 |
+
canvas.height = rows * scale;
|
382 |
+
canvas.className = "grid-canvas";
|
383 |
+
const ctx = canvas.getContext("2d");
|
384 |
+
|
385 |
+
for (let r = 0; r < rows; r++) {
|
386 |
+
for (let c = 0; c < cols; c++) {
|
387 |
+
const val = grid2d[r][c];
|
388 |
+
ctx.fillStyle = indexToColor(val);
|
389 |
+
ctx.fillRect(c * scale, r * scale, scale, scale);
|
390 |
+
}
|
391 |
+
}
|
392 |
+
return canvas;
|
393 |
+
}
|
394 |
+
|
395 |
+
/***************************************************************************
|
396 |
+
* indexToColor => color palette:
|
397 |
+
* 0 => pad => white
|
398 |
+
* 1 => eos => light gray
|
399 |
+
* 2..11 => original color(0..9)
|
400 |
+
***************************************************************************/
|
401 |
+
function indexToColor(value) {
|
402 |
+
if (value === 0) return "#FFFFFF"; // pad => white
|
403 |
+
if (value === 1) return "#DDDDDD"; // eos => light gray
|
404 |
+
|
405 |
+
// shift by 2 => original color in [0..9]
|
406 |
+
const colorIdx = value - 2;
|
407 |
+
const palette = [
|
408 |
+
"#000000", // color0 => black
|
409 |
+
"#FF0000", // color1 => red
|
410 |
+
"#00FF00", // color2 => green
|
411 |
+
"#0000FF", // color3 => blue
|
412 |
+
"#FFFF00", // color4 => yellow
|
413 |
+
"#FFA500", // color5 => orange
|
414 |
+
"#800080", // color6 => purple
|
415 |
+
"#00FFFF", // color7 => cyan
|
416 |
+
"#FFC0CB", // color8 => pink
|
417 |
+
"#808080" // color9 => gray
|
418 |
+
];
|
419 |
+
if (colorIdx >= 0 && colorIdx < palette.length) {
|
420 |
+
return palette[colorIdx];
|
421 |
+
}
|
422 |
+
return "#FFFFFF"; // fallback
|
423 |
+
}
|
424 |
+
</script>
|
425 |
+
</body>
|
426 |
+
</html>
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
adam-atan2
|
3 |
+
einops
|
4 |
+
tqdm
|
5 |
+
coolname
|
6 |
+
pydantic
|
7 |
+
argdantic
|
8 |
+
wandb
|
9 |
+
omegaconf
|
10 |
+
hydra-core
|
11 |
+
huggingface_hub
|
utils/functions.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import inspect
|
3 |
+
|
4 |
+
|
5 |
+
def load_model_class(identifier: str, prefix: str = "models."):
|
6 |
+
module_path, class_name = identifier.split('@')
|
7 |
+
|
8 |
+
# Import the module
|
9 |
+
module = importlib.import_module(prefix + module_path)
|
10 |
+
cls = getattr(module, class_name)
|
11 |
+
|
12 |
+
return cls
|
13 |
+
|
14 |
+
|
15 |
+
def get_model_source_path(identifier: str, prefix: str = "models."):
|
16 |
+
module_path, class_name = identifier.split('@')
|
17 |
+
|
18 |
+
module = importlib.import_module(prefix + module_path)
|
19 |
+
return inspect.getsourcefile(module)
|