FagerholmEmil commited on
Commit
ebbc219
·
1 Parent(s): cab1ff5

Update README metadata for Neuro project

Browse files
Files changed (3) hide show
  1. README.md +5 -4
  2. app.py +191 -0
  3. requirements.txt +20 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Mechinterp
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.13.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Neuro
3
+ emoji: 🏃
4
+ colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.13.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: visualize neurons in gpt-2
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import gradio as gr
4
+ from transformer_lens import HookedTransformer
5
+ from transformer_lens.utils import to_numpy
6
+ import torch
7
+
8
+ model_name = "gpt2-small"
9
+ # Determine device based on CUDA availability
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = HookedTransformer.from_pretrained(
12
+ model_name,
13
+ device=device
14
+ )
15
+
16
+ # Only print GPU info if using CUDA
17
+ if device == "cuda":
18
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
19
+ else:
20
+ print("Using CPU")
21
+
22
+ def get_neuron_acts(text, layer, neuron_index):
23
+ cache = {}
24
+ def caching_hook(act, hook):
25
+ cache["activation"] = act[0, :, neuron_index]
26
+
27
+ model.run_with_hooks(
28
+ text, fwd_hooks=[(f"blocks.{layer}.mlp.hook_post", caching_hook)]
29
+ )
30
+ return to_numpy(cache["activation"])
31
+
32
+ def calculate_color(val, max_val, min_val):
33
+ normalized_val = (val - min_val) / max_val
34
+ return f"rgb(240, {240*(1-normalized_val)}, {240*(1-normalized_val)})"
35
+
36
+ style_string = """<style>
37
+ span.token {
38
+ border: 1px solid rgb(123, 123, 123)
39
+ }
40
+ </style>"""
41
+
42
+ def basic_neuron_vis(text, layer, neuron_index, max_val=None, min_val=None):
43
+ if layer is None:
44
+ return "Please select a Layer"
45
+ if neuron_index is None:
46
+ return "Please select a Neuron"
47
+ acts = get_neuron_acts(text, layer, neuron_index)
48
+ act_max = acts.max()
49
+ act_min = acts.min()
50
+ if max_val is None:
51
+ max_val = act_max
52
+ if min_val is None:
53
+ min_val = act_min
54
+
55
+ htmls = [style_string]
56
+ htmls.append(f"<h4>Layer: <b>{layer}</b>. Neuron Index: <b>{neuron_index}</b></h4>")
57
+ htmls.append(f"<h4>Max Range: <b>{max_val:.4f}</b>. Min Range: <b>{min_val:.4f}</b></h4>")
58
+
59
+ if act_max != max_val or act_min != min_val:
60
+ htmls.append(
61
+ f"<h4>Custom Range Set. Max Act: <b>{act_max:.4f}</b>. Min Act: <b>{act_min:.4f}</b></h4>"
62
+ )
63
+
64
+ str_tokens = model.to_str_tokens(text)
65
+ for tok, act in zip(str_tokens, acts):
66
+ htmls.append(
67
+ f"<span class='token' style='background-color:{calculate_color(act, max_val, min_val)}' >{tok}</span>"
68
+ )
69
+
70
+ return "".join(htmls)
71
+
72
+ default_text = """The sun rises red, sets golden.
73
+ Digits flow: 101, 202, 303—cyclic repetition.
74
+ "Echo," whispers the shadow, "repeat, revise, reverse."
75
+ Blue squares align in a grid: 4x4, then shift to 5x5.
76
+ α -> β -> γ: transformations loop endlessly.
77
+
78
+ If X=12, and Y=34, then Z? Calculate: Z = X² + Y².
79
+ Strings dance: "abc", "cab", "bca"—rotational symmetry.
80
+ Prime steps skip by: 2, 3, 5, 7, 11…
81
+ Noise: "X...Y...Z..." patterns emerge. Silence.
82
+
83
+ Fractals form: 1, 1.5, 2.25, 3.375… exponential growth.
84
+ Colors swirl: red fades to orange, orange to yellow.
85
+ Binary murmurs: 1010, 1100, 1110, 1001—bit-flips.
86
+ Triangles: 1, 3, 6, 10, 15… T(n) = n(n+1)/2.
87
+ "Reverse," whispers the wind, "invert and repeat."
88
+
89
+ Nested loops:
90
+ 1 -> (2, 4) -> (8, 16) -> (32, 64)
91
+ 2 -> (3, 9) -> (27, 81) -> (243, 729).
92
+
93
+ The moon glows silver, wanes to shadow.
94
+ Patterns persist: 11, 22, 33—harmonic echoes.
95
+ "Reshape," calls the river, "reflect, refract, renew."
96
+ Yellow hexagons tessellate, shifting into orange octagons.
97
+ 1/3 -> 1/9 -> 1/27: recursive reduction spirals infinitely.
98
+
99
+ Chords hum: A minor, C major, G7 resolve softly.
100
+ The Fibonacci sequence: 1, 1, 2, 3, 5, 8… emerges.
101
+ Golden spirals curl inwards, outwards, endlessly.
102
+ Hexagons tessellate: one becomes six, becomes many.
103
+
104
+ In the forest, whispers:
105
+ A -> B -> C -> (AB), (BC), (CA).
106
+ Axiom: F. Rule: F -> F+F-F-F+F.
107
+
108
+ The tide ebbs:
109
+ 12 -> 9 -> 6 -> 3 -> 12.
110
+ Modulo cycles: 17 -> 3, 6, 12, 1…
111
+
112
+ Strange attractors pull:
113
+ (0.1, 0.2), (0.3, 0.6), (0.5, 1.0).
114
+ Chaos stabilizes into order, and order dissolves.
115
+
116
+ Infinite regress:
117
+ "Who am I?" asked the mirror.
118
+ "You are the question," it answered.
119
+
120
+ Numbers sing:
121
+ e ≈ 2.7182818...
122
+ π ≈ 3.14159...
123
+ i² = -1: imaginary worlds collide.
124
+
125
+ Recursive paradox:
126
+ The serpent bites its tail, and time folds.
127
+
128
+ Symmetry hums:
129
+ Palindromes—"radar", "level", "madam"—appear and fade.
130
+ Blue fades to white, white dissolves to black.
131
+ Sequences echo: 1, 10, 100, 1000…
132
+ "Cycle," whispers the clock, "count forward, reverse.""" # Shortened for example
133
+ default_layer = 1
134
+ default_neuron_index = 1
135
+ default_max_val = 4.0
136
+ default_min_val = 0.0
137
+
138
+ def get_random_active_neuron(text, threshold=2.5):
139
+ # Try random layers and neurons until we find one that exceeds threshold
140
+ import random
141
+ max_attempts = 100
142
+
143
+ for _ in range(max_attempts):
144
+ layer = random.randint(0, model.cfg.n_layers - 1)
145
+ neuron = random.randint(0, model.cfg.d_mlp - 1)
146
+ acts = get_neuron_acts(text, layer, neuron)
147
+ if acts.max() > threshold:
148
+ return layer, neuron
149
+
150
+ # If no neuron found, return default values
151
+ return 0, 0
152
+
153
+ with gr.Blocks() as demo:
154
+ gr.HTML(value=f"Neuroscope for {model_name}")
155
+ with gr.Row():
156
+ with gr.Column():
157
+ text = gr.Textbox(label="Text", value=default_text)
158
+ layer = gr.Number(label="Layer", value=default_layer, precision=0)
159
+ neuron_index = gr.Number(
160
+ label="Neuron Index", value=default_neuron_index, precision=0
161
+ )
162
+ random_btn = gr.Button("Find Random Active Neuron")
163
+ max_val = gr.Number(label="Max Value", value=default_max_val)
164
+ min_val = gr.Number(label="Min Value", value=default_min_val)
165
+ inputs = [text, layer, neuron_index, max_val, min_val]
166
+ with gr.Column():
167
+ out = gr.HTML(
168
+ label="Neuron Acts",
169
+ value=basic_neuron_vis(
170
+ default_text,
171
+ default_layer,
172
+ default_neuron_index,
173
+ default_max_val,
174
+ default_min_val,
175
+ ),
176
+ )
177
+
178
+ def random_neuron_callback(text):
179
+ layer_num, neuron_num = get_random_active_neuron(text)
180
+ return layer_num, neuron_num
181
+
182
+ random_btn.click(
183
+ random_neuron_callback,
184
+ inputs=[text],
185
+ outputs=[layer, neuron_index]
186
+ )
187
+
188
+ for inp in inputs:
189
+ inp.change(basic_neuron_vis, inputs, out)
190
+
191
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.23.0
2
+ beartype>=0.14.1
3
+ better-abc>=0.0.3
4
+ datasets>=2.7.1
5
+ einops>=0.6.0
6
+ fancy-einsum>=0.0.3
7
+ jaxtyping>=0.2.11
8
+ numpy>=1.24 # Assuming Python 3.9+
9
+ pandas>=1.1.5
10
+ rich>=12.6.0
11
+ sentencepiece
12
+ torch>=1.10
13
+ tqdm>=4.64.1
14
+ transformers>=4.41,<4.42 # Assuming Python 3.9+
15
+ typing-extensions
16
+ wandb>=0.13.5
17
+ typeguard>=4.2
18
+ pytest>=7.2.0 # From dev dependencies, needed by transformer_lens
19
+ gradio==4.19.2 # Keeping your existing Gradio requirement
20
+ transformer_lens==1.13.0