Tarive commited on
Commit
f80b41f
Β·
verified Β·
1 Parent(s): a18dd4d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -191
README.md CHANGED
@@ -1,191 +1,9 @@
1
- # Hierarchical Reasoning Model
2
-
3
- ![](./assets/hrm.png)
4
-
5
- Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
6
- Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
7
- HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
8
- Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
9
- These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.
10
-
11
- ## Quick Start Guide πŸš€
12
-
13
- ### Prerequisites βš™οΈ
14
-
15
- Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
16
-
17
- ```bash
18
- # Install CUDA 12.6
19
- CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run
20
-
21
- wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
22
- sudo sh cuda_installer.run --silent --toolkit --override
23
-
24
- export CUDA_HOME=/usr/local/cuda-12.6
25
-
26
- # Install PyTorch with CUDA 12.6
27
- PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126
28
-
29
- pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
30
-
31
- # Additional packages for building extensions
32
- pip3 install packaging ninja wheel setuptools setuptools-scm
33
- ```
34
-
35
- Then install FlashAttention. For Hopper GPUs, install FlashAttention 3
36
-
37
- ```bash
38
- git clone [email protected]:Dao-AILab/flash-attention.git
39
- cd flash-attention/hopper
40
- python setup.py install
41
- ```
42
-
43
- For Ampere or earlier GPUs, install FlashAttention 2
44
-
45
- ```bash
46
- pip3 install flash-attn
47
- ```
48
-
49
- ## Install Python Dependencies 🐍
50
-
51
- ```bash
52
- pip install -r requirements.txt
53
- ```
54
-
55
- ## W&B Integration πŸ“ˆ
56
-
57
- This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:
58
-
59
- ```bash
60
- wandb login
61
- ```
62
-
63
- ## Run Experiments
64
-
65
- ### Quick Demo: Sudoku Solver πŸ’»πŸ—²
66
-
67
- Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩
68
-
69
- ```bash
70
- # Download and build Sudoku dataset
71
- python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
72
-
73
- # Start training (single GPU, smaller batch size)
74
- OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
75
- ```
76
-
77
- Runtime: ~10 hours on a RTX 4070 laptop GPU
78
-
79
- ## Trained Checkpoints 🚧
80
-
81
- - [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2)
82
- - [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme)
83
- - [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard)
84
-
85
- To use the checkpoints, see Evaluation section below.
86
-
87
- ## Full-scale Experiments πŸ”΅
88
-
89
- Experiments below assume an 8-GPU setup.
90
-
91
- ### Dataset Preparation
92
-
93
- ```bash
94
- # Initialize submodules
95
- git submodule update --init --recursive
96
-
97
- # ARC-1
98
- python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples
99
- # ARC-2
100
- python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples
101
-
102
- # Sudoku-Extreme
103
- python dataset/build_sudoku_dataset.py # Full version
104
- python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples
105
-
106
- # Maze
107
- python dataset/build_maze_dataset.py # 1000 examples
108
- ```
109
-
110
- ### Dataset Visualization
111
-
112
- Explore the puzzles visually:
113
-
114
- * Open `puzzle_visualizer.html` in your browser.
115
- * Upload the generated dataset folder located in `data/...`.
116
-
117
- ## Launch experiments
118
-
119
- ### Small-sample (1K)
120
-
121
- ARC-1:
122
-
123
- ```bash
124
- OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
125
- ```
126
-
127
- *Runtime:* ~24 hours
128
-
129
- ARC-2:
130
-
131
- ```bash
132
- OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
133
- ```
134
-
135
- *Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)
136
-
137
- Sudoku Extreme (1k):
138
-
139
- ```bash
140
- OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
141
- ```
142
-
143
- *Runtime:* ~10 minutes
144
-
145
- Maze 30x30 Hard (1k):
146
-
147
- ```bash
148
- OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
149
- ```
150
-
151
- *Runtime:* ~1 hour
152
-
153
- ### Full Sudoku-Hard
154
-
155
- ```bash
156
- OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
157
- ```
158
-
159
- *Runtime:* ~2 hours
160
-
161
- ## Evaluation
162
-
163
- Evaluate your trained models:
164
-
165
- * Check `eval/exact_accuracy` in W&B.
166
- * For ARC-AGI, follow these additional steps:
167
-
168
- ```bash
169
- OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
170
- ```
171
-
172
- * Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.
173
-
174
- ## Notes
175
-
176
- - Small-sample learning typically exhibits accuracy variance of around Β±2 points.
177
- - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.
178
-
179
- ## Citation πŸ“œ
180
-
181
- ```bibtex
182
- @misc{wang2025hierarchicalreasoningmodel,
183
- title={Hierarchical Reasoning Model},
184
- author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
185
- year={2025},
186
- eprint={2506.21734},
187
- archivePrefix={arXiv},
188
- primaryClass={cs.AI},
189
- url={https://arxiv.org/abs/2506.21734},
190
- }
191
- ```
 
1
+ ---
2
+ title: HRM Abstract Finetuning
3
+ emoji: πŸš€
4
+ colorFrom: blue
5
+ sdk: docker
6
+ secrets:
7
+ - WANDB_API_KEY
8
+ - GROQ_API_KEY
9
+ ---