neuro / app.py
FagerholmEmil
No changes
511e2ba
raw
history blame
3.14 kB
# app.py
import os
import gradio as gr
from transformer_lens import HookedTransformer
from transformer_lens.utils import to_numpy
model_name = "gpt2-small"
model = HookedTransformer.from_pretrained(model_name)
def get_neuron_acts(text, layer, neuron_index):
cache = {}
def caching_hook(act, hook):
cache["activation"] = act[0, :, neuron_index]
model.run_with_hooks(
text, fwd_hooks=[(f"blocks.{layer}.mlp.hook_post", caching_hook)]
)
return to_numpy(cache["activation"])
def calculate_color(val, max_val, min_val):
normalized_val = (val - min_val) / max_val
return f"rgb(240, {240*(1-normalized_val)}, {240*(1-normalized_val)})"
style_string = """<style>
span.token {
border: 1px solid rgb(123, 123, 123)
}
</style>"""
def basic_neuron_vis(text, layer, neuron_index, max_val=None, min_val=None):
if layer is None:
return "Please select a Layer"
if neuron_index is None:
return "Please select a Neuron"
acts = get_neuron_acts(text, layer, neuron_index)
act_max = acts.max()
act_min = acts.min()
if max_val is None:
max_val = act_max
if min_val is None:
min_val = act_min
htmls = [style_string]
htmls.append(f"<h4>Layer: <b>{layer}</b>. Neuron Index: <b>{neuron_index}</b></h4>")
htmls.append(f"<h4>Max Range: <b>{max_val:.4f}</b>. Min Range: <b>{min_val:.4f}</b></h4>")
if act_max != max_val or act_min != min_val:
htmls.append(
f"<h4>Custom Range Set. Max Act: <b>{act_max:.4f}</b>. Min Act: <b>{act_min:.4f}</b></h4>"
)
str_tokens = model.to_str_tokens(text)
for tok, act in zip(str_tokens, acts):
htmls.append(
f"<span class='token' style='background-color:{calculate_color(act, max_val, min_val)}' >{tok}</span>"
)
return "".join(htmls)
default_text = """The sun rises red, sets golden.
Digits flow: 101, 202, 303—cyclic repetition.""" # Shortened for example
default_layer = 9
default_neuron_index = 652
default_max_val = 4.0
default_min_val = 0.0
with gr.Blocks() as demo:
gr.HTML(value=f"Neuroscope for {model_name}")
with gr.Row():
with gr.Column():
text = gr.Textbox(label="Text", value=default_text)
layer = gr.Number(label="Layer", value=default_layer, precision=0)
neuron_index = gr.Number(
label="Neuron Index", value=default_neuron_index, precision=0
)
max_val = gr.Number(label="Max Value", value=default_max_val)
min_val = gr.Number(label="Min Value", value=default_min_val)
inputs = [text, layer, neuron_index, max_val, min_val]
with gr.Column():
out = gr.HTML(
label="Neuron Acts",
value=basic_neuron_vis(
default_text,
default_layer,
default_neuron_index,
default_max_val,
default_min_val,
),
)
for inp in inputs:
inp.change(basic_neuron_vis, inputs, out)
demo.launch()