File size: 5,844 Bytes
86ce226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c7eb2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86ce226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c7eb2d
 
86ce226
 
 
 
 
6c7eb2d
 
86ce226
 
 
 
 
 
6c7eb2d
 
 
86ce226
 
 
 
 
 
 
 
6c7eb2d
 
 
 
 
 
 
 
 
 
86ce226
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any, Tuple
import os


class Client:
    def __init__(self, server_url: str):
        self.server_url = server_url

    def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
        response = requests.post(
            self.server_url, 
            json={
                "task_name": task_name, 
                "model_name": model_name, 
                "text": text, 
                "normalization_type": normalization_type
            }, 
            timeout=60
        )
        if response.status_code == 200:
            response_data = response.json()
            img_data = bytes.fromhex(response_data["image"])
            img = Image.open(io.BytesIO(img_data))
            return img, "OK"
        else:
            return "Error, please retry", "Error: Could not get response from server"

client = Client(f"http://{os.environ['SERVER']}/predict")

def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
    return client.send_request(task_name, model_name, text, normalization_type)

def update_output(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any]:
    img, _ = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
    return img

def set_default(task_name: str) -> str:
    if task_name in ["Layer wise non-linearity", "Next-token prediction from intermediate representations", "Tokenwise loss without i-th layer"]:
        return "token-wise"
    return "global"

def check_normalization(task_name: str, normalization_name) -> Tuple[str]:
    if task_name == "Contextualization measurement" and normalization_name == "token-wise":
        return "global"
    return normalization_name

def update_description(task_name: str) -> str:
    descriptions = {
        "Layer wise non-linearity": "Non-linearity per layer: shows how complex each layer's transformation is. Red = more nonlinear.",
        "Next-token prediction from intermediate representations": "Layerwise token prediction: when does the model start guessing correctly?",
        "Contextualization measurement": "Context stored in each token: how well can the model reconstruct the previous context?",
        "Layerwise predictions (logit lens)": "Logit lens: what does each layer believe the next token should be?",
        "Tokenwise loss without i-th layer": "Layer ablation: how much does performance drop if a layer is removed?"
    }
    return descriptions.get(task_name, "ℹ️ No description available.")

with gr.Blocks() as demo:
    gr.Markdown("# 🔬 LLM-Microscope — Understanding Token Representations in Transformers")
    gr.Markdown("Select a model, a mode of analysis, and a sentence. The tool will visualize what’s happening **inside** the language model — layer by layer, token by token.")

    with gr.Row():
        model_selector = gr.Dropdown(
            choices=[
                "facebook/opt-1.3b",
                "TheBloke/Llama-2-7B-fp16"
            ],
            value="facebook/opt-1.3b",
            label="Select Model"
        )
        task_selector = gr.Dropdown(
            choices=[
                "Layer wise non-linearity", 
                "Next-token prediction from intermediate representations", 
                "Contextualization measurement",
                "Layerwise predictions (logit lens)",
                "Tokenwise loss without i-th layer"
            ], 
            value="Layer wise non-linearity",
            label="Select Mode"
        )
        normalization_selector = gr.Dropdown(
            choices=["global", "token-wise"],
            value="token-wise",
            label="Select Normalization"
        )

    task_description = gr.Markdown("ℹ️ Choose a mode to see what it does.")
    
    with gr.Column():
        text_message = gr.Textbox(label="Enter your input text:", value="I love to live my life")
        submit = gr.Button("Submit")
        box_for_plot = gr.Image(label="Visualization", type="pil")

        with gr.Accordion("📘 Full Legend and Interpretation", open=False):
            gr.Markdown("""
This heatmap shows **how each token is processed** across layers of a language model. Here's how to read it:

- **Rows**: layers of the model (bottom = deeper)
- **Columns**: input tokens
- **Colors**: intensity of effect (depends on the selected metric)

---

### Metrics explained:

- `Layer wise non-linearity`: how nonlinear the transformation is at each layer (red = more nonlinear).
- `Next-token prediction from intermediate representations`: shows which layers begin to make good predictions.
- `Contextualization measurement`: tokens with more context info get lower scores (green = more context).
- `Layerwise predictions (logit lens)`: tracks how the model’s guesses evolve at each layer.
- `Tokenwise loss without i-th layer`: shows how much each token depends on a specific layer. Red means performance drops if we skip this layer.

Use this tool to **peek inside the black box** — it reveals which layers matter most, which tokens carry the most memory, and how LLMs evolve their predictions.
""")

    task_selector.change(fn=update_description, inputs=[task_selector], outputs=[task_description])
    task_selector.select(set_default, [task_selector], [normalization_selector])
    normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector])
    submit.click(
        fn=update_output, 
        inputs=[task_selector, model_selector, text_message, normalization_selector], 
        outputs=[box_for_plot]
    )

if __name__ == "__main__":
    demo.launch(share=True, server_port=7860, server_name="0.0.0.0")