Spaces:
Runtime error
Runtime error
# app.py | |
import os | |
import gradio as gr | |
from transformer_lens import HookedTransformer | |
from transformer_lens.utils import to_numpy | |
import torch | |
model_name = "gpt2-small" | |
# Determine device based on CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = HookedTransformer.from_pretrained( | |
model_name, | |
device=device | |
) | |
# Only print GPU info if using CUDA | |
if device == "cuda": | |
print(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
else: | |
print("Using CPU") | |
def get_neuron_acts(text, layer, neuron_index): | |
cache = {} | |
def caching_hook(act, hook): | |
cache["activation"] = act[0, :, neuron_index] | |
model.run_with_hooks( | |
text, fwd_hooks=[(f"blocks.{layer}.mlp.hook_post", caching_hook)] | |
) | |
return to_numpy(cache["activation"]) | |
def calculate_color(val, max_val, min_val): | |
normalized_val = (val - min_val) / max_val | |
return f"rgb(240, {240*(1-normalized_val)}, {240*(1-normalized_val)})" | |
style_string = """<style> | |
span.token { | |
border: 1px solid rgb(123, 123, 123) | |
} | |
</style>""" | |
def basic_neuron_vis(text, layer, neuron_index, max_val=None, min_val=None): | |
if layer is None: | |
return "Please select a Layer" | |
if neuron_index is None: | |
return "Please select a Neuron" | |
acts = get_neuron_acts(text, layer, neuron_index) | |
act_max = acts.max() | |
act_min = acts.min() | |
if max_val is None: | |
max_val = act_max | |
if min_val is None: | |
min_val = act_min | |
htmls = [style_string] | |
htmls.append(f"<h4>Layer: <b>{layer}</b>. Neuron Index: <b>{neuron_index}</b></h4>") | |
htmls.append(f"<h4>Max Range: <b>{max_val:.4f}</b>. Min Range: <b>{min_val:.4f}</b></h4>") | |
if act_max != max_val or act_min != min_val: | |
htmls.append( | |
f"<h4>Custom Range Set. Max Act: <b>{act_max:.4f}</b>. Min Act: <b>{act_min:.4f}</b></h4>" | |
) | |
str_tokens = model.to_str_tokens(text) | |
for tok, act in zip(str_tokens, acts): | |
htmls.append( | |
f"<span class='token' style='background-color:{calculate_color(act, max_val, min_val)}' >{tok}</span>" | |
) | |
return "".join(htmls) | |
default_text = """The sun rises red, sets golden. | |
Digits flow: 101, 202, 303—cyclic repetition. | |
"Echo," whispers the shadow, "repeat, revise, reverse." | |
Blue squares align in a grid: 4x4, then shift to 5x5. | |
α -> β -> γ: transformations loop endlessly. | |
If X=12, and Y=34, then Z? Calculate: Z = X² + Y². | |
Strings dance: "abc", "cab", "bca"—rotational symmetry. | |
Prime steps skip by: 2, 3, 5, 7, 11… | |
Noise: "X...Y...Z..." patterns emerge. Silence. | |
Fractals form: 1, 1.5, 2.25, 3.375… exponential growth. | |
Colors swirl: red fades to orange, orange to yellow. | |
Binary murmurs: 1010, 1100, 1110, 1001—bit-flips. | |
Triangles: 1, 3, 6, 10, 15… T(n) = n(n+1)/2. | |
"Reverse," whispers the wind, "invert and repeat." | |
Nested loops: | |
1 -> (2, 4) -> (8, 16) -> (32, 64) | |
2 -> (3, 9) -> (27, 81) -> (243, 729). | |
The moon glows silver, wanes to shadow. | |
Patterns persist: 11, 22, 33—harmonic echoes. | |
"Reshape," calls the river, "reflect, refract, renew." | |
Yellow hexagons tessellate, shifting into orange octagons. | |
1/3 -> 1/9 -> 1/27: recursive reduction spirals infinitely. | |
Chords hum: A minor, C major, G7 resolve softly. | |
The Fibonacci sequence: 1, 1, 2, 3, 5, 8… emerges. | |
Golden spirals curl inwards, outwards, endlessly. | |
Hexagons tessellate: one becomes six, becomes many. | |
In the forest, whispers: | |
A -> B -> C -> (AB), (BC), (CA). | |
Axiom: F. Rule: F -> F+F-F-F+F. | |
The tide ebbs: | |
12 -> 9 -> 6 -> 3 -> 12. | |
Modulo cycles: 17 -> 3, 6, 12, 1… | |
Strange attractors pull: | |
(0.1, 0.2), (0.3, 0.6), (0.5, 1.0). | |
Chaos stabilizes into order, and order dissolves. | |
Infinite regress: | |
"Who am I?" asked the mirror. | |
"You are the question," it answered. | |
Numbers sing: | |
e ≈ 2.7182818... | |
π ≈ 3.14159... | |
i² = -1: imaginary worlds collide. | |
Recursive paradox: | |
The serpent bites its tail, and time folds. | |
Symmetry hums: | |
Palindromes—"radar", "level", "madam"—appear and fade. | |
Blue fades to white, white dissolves to black. | |
Sequences echo: 1, 10, 100, 1000… | |
"Cycle," whispers the clock, "count forward, reverse.""" # Shortened for example | |
default_layer = 1 | |
default_neuron_index = 1 | |
default_max_val = 4.0 | |
default_min_val = 0.0 | |
def get_random_active_neuron(text, threshold=2.5): | |
# Try random layers and neurons until we find one that exceeds threshold | |
import random | |
max_attempts = 100 | |
for _ in range(max_attempts): | |
layer = random.randint(0, model.cfg.n_layers - 1) | |
neuron = random.randint(0, model.cfg.d_mlp - 1) | |
acts = get_neuron_acts(text, layer, neuron) | |
if acts.max() > threshold: | |
return layer, neuron | |
# If no neuron found, return default values | |
return 0, 0 | |
with gr.Blocks() as demo: | |
gr.HTML(value=f"Neuroscope for {model_name}") | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Textbox(label="Text", value=default_text) | |
layer = gr.Number(label="Layer", value=default_layer, precision=0) | |
neuron_index = gr.Number( | |
label="Neuron Index", value=default_neuron_index, precision=0 | |
) | |
random_btn = gr.Button("Find Random Active Neuron") | |
max_val = gr.Number(label="Max Value", value=default_max_val) | |
min_val = gr.Number(label="Min Value", value=default_min_val) | |
inputs = [text, layer, neuron_index, max_val, min_val] | |
with gr.Column(): | |
out = gr.HTML( | |
label="Neuron Acts", | |
value=basic_neuron_vis( | |
default_text, | |
default_layer, | |
default_neuron_index, | |
default_max_val, | |
default_min_val, | |
), | |
) | |
def random_neuron_callback(text): | |
layer_num, neuron_num = get_random_active_neuron(text) | |
return layer_num, neuron_num | |
random_btn.click( | |
random_neuron_callback, | |
inputs=[text], | |
outputs=[layer, neuron_index] | |
) | |
for inp in inputs: | |
inp.change(basic_neuron_vis, inputs, out) | |
demo.launch() |