matveymih commited on
Commit
86ce226
·
verified ·
1 Parent(s): b6e9c51

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ import io
5
+ from typing import Any, Tuple
6
+ import os
7
+
8
+
9
+ class Client:
10
+ def __init__(self, server_url: str):
11
+ self.server_url = server_url
12
+
13
+ def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
14
+ response = requests.post(
15
+ self.server_url,
16
+ json={
17
+ "task_name": task_name,
18
+ "model_name": model_name,
19
+ "text": text,
20
+ "normalization_type": normalization_type
21
+ },
22
+ timeout=60
23
+ )
24
+ if response.status_code == 200:
25
+ response_data = response.json()
26
+ img_data = bytes.fromhex(response_data["image"])
27
+ img = Image.open(io.BytesIO(img_data))
28
+ return img, "OK"
29
+ else:
30
+ return "Error, please retry", "Error: Could not get response from server"
31
+
32
+ client = Client(f"http://{os.environ['SERVER']}/predict")
33
+
34
+ def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
35
+ return client.send_request(task_name, model_name, text, normalization_type)
36
+
37
+
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown("# 🔬 LLM-Microscope — Understanding Token Representations in Transformers")
40
+ gr.Markdown("Select a model, a mode of analysis, and a sentence. The tool will visualize what’s happening **inside** the language model — layer by layer, token by token.")
41
+
42
+ with gr.Row():
43
+ model_selector = gr.Dropdown(
44
+ choices=[
45
+ "facebook/opt-1.3b",
46
+ "TheBloke/Llama-2-7B-fp16"
47
+ ],
48
+ value="facebook/opt-1.3b",
49
+ label="Select Model"
50
+ )
51
+ task_selector = gr.Dropdown(
52
+ choices=[
53
+ "Layer wise non-linearity",
54
+ "Next-token prediction from intermediate representations",
55
+ "Contextualization measurement",
56
+ "Layerwise predictions (logit lens)",
57
+ "Tokenwise loss without i-th layer"
58
+ ],
59
+ value="Layer wise non-linearity",
60
+ label="Select Mode"
61
+ )
62
+ normalization_selector = gr.Dropdown(
63
+ choices=["global", "token-wise"],
64
+ value="token-wise",
65
+ label="Select Normalization"
66
+ )
67
+
68
+ with gr.Column():
69
+ text_message = gr.Textbox(label="Enter your input text:", value="I love to live my life")
70
+ submit = gr.Button("Submit")
71
+ box_for_plot = gr.Image(label="Visualization", type="pil")
72
+
73
+ # 💬 Explanation below the visualization
74
+ explanation_text = gr.Markdown("""
75
+ ### 📘 Legend and Interpretation
76
+
77
+ This heatmap shows **how each token is processed** across layers of a language model. Here's how to read it:
78
+
79
+ - **Rows**: layers of the model (bottom = deeper)
80
+ - **Columns**: input tokens
81
+ - **Colors**: intensity of effect (depends on the selected metric)
82
+
83
+ **Metrics explained:**
84
+
85
+ - `Layer wise non-linearity`: how nonlinear the transformation is at each layer (red = more nonlinear).
86
+ - `Next-token prediction from intermediate representations`: shows which layers begin to make good predictions.
87
+ - `Contextualization measurement`: tokens with more context info get lower scores (green = more context).
88
+ - `Layerwise predictions (logit lens)`: tracks how the model’s guesses evolve at each layer.
89
+ - `Tokenwise loss without i-th layer`: shows how much each token depends on a specific layer. Red means performance drops if we skip this layer.
90
+
91
+ Use this tool to **peek inside the black box** — it reveals which layers matter most, which tokens carry the most memory, and how LLMs evolve their predictions.
92
+ """)
93
+
94
+ def update_output(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any]:
95
+ img, _ = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
96
+ return img
97
+
98
+ def set_default(task_name: str) -> str:
99
+ if task_name in ["Layer wise non-linearity", "Next-token prediction from intermediate representations", "Tokenwise loss without i-th layer"]:
100
+ return "token-wise"
101
+ return "global"
102
+
103
+ def check_normalization(task_name: str, normalization_name) -> Tuple[str]:
104
+ if task_name == "Contextualization measurement" and normalization_name == "token-wise":
105
+ return "global"
106
+ return normalization_name
107
+
108
+ task_selector.select(set_default, [task_selector], [normalization_selector])
109
+ normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector])
110
+ submit.click(
111
+ fn=update_output,
112
+ inputs=[task_selector, model_selector, text_message, normalization_selector],
113
+ outputs=[box_for_plot]
114
+ )
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch(share=True, server_port=7860, server_name="0.0.0.0")