matveymih's picture
Update app.py
07a26be verified
raw
history blame
6.23 kB
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any, Tuple
import os
class Client:
def __init__(self, server_url: str):
self.server_url = server_url
def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
response = requests.post(
self.server_url,
json={
"task_name": task_name,
"model_name": model_name,
"text": text,
"normalization_type": normalization_type
},
timeout=60
)
if response.status_code == 200:
response_data = response.json()
img_data = bytes.fromhex(response_data["image"])
img = Image.open(io.BytesIO(img_data))
return img, "OK"
else:
return "Error, please retry", "Error: Could not get response from server"
client = Client(f"http://{os.environ['SERVER']}/predict")
def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
return client.send_request(task_name, model_name, text, normalization_type)
def update_output(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any]:
img, _ = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
return img
def set_default(task_name: str) -> str:
if task_name in ["Layer wise non-linearity", "Next-token prediction from intermediate representations", "Tokenwise loss without i-th layer"]:
return "token-wise"
return "global"
def check_normalization(task_name: str, normalization_name) -> Tuple[str]:
if task_name == "Contextualization measurement" and normalization_name == "token-wise":
return "global"
return normalization_name
def update_description(task_name: str) -> str:
descriptions = {
"Layer wise non-linearity": "Non-linearity per layer: shows how complex each layer's transformation is. Red = more nonlinear.",
"Next-token prediction from intermediate representations": "Layerwise token prediction: when does the model start guessing correctly?",
"Contextualization measurement": "Context stored in each token: how well can the model reconstruct the previous context?",
"Layerwise predictions (logit lens)": "Logit lens: what does each layer believe the next token should be?",
"Tokenwise loss without i-th layer": "Layer ablation: how much does performance drop if a layer is removed?"
}
return descriptions.get(task_name, "ℹ️ No description available.")
with gr.Blocks() as demo:
# gr.Markdown("# πŸ”¬ LLM-Microscope β€” Understanding Token Representations in Transformers")
gr.Markdown("# πŸ”¬ LLM-Microscope β€” A Look Inside the Black Box")
gr.Markdown("Select a model, analysis mode, and input β€” then peek inside the black box of an LLM to see which layers matter most, which tokens carry the most memory, and how predictions evolve.")
with gr.Row():
model_selector = gr.Dropdown(
choices=[
"facebook/opt-1.3b",
"TheBloke/Llama-2-7B-fp16",
"Qwen/Qwen3-8B"
],
value="facebook/opt-1.3b",
label="Select Model"
)
task_selector = gr.Dropdown(
choices=[
"Layer wise non-linearity",
"Next-token prediction from intermediate representations",
"Contextualization measurement",
"Layerwise predictions (logit lens)",
"Tokenwise loss without i-th layer"
],
value="Layer wise non-linearity",
label="Select Mode"
)
normalization_selector = gr.Dropdown(
choices=["global", "token-wise"],
value="token-wise",
label="Select Normalization"
)
task_description = gr.Markdown("ℹ️ Choose a mode to see what it does.")
with gr.Column():
text_message = gr.Textbox(label="Enter your input text:", value="I love to live my life")
submit = gr.Button("Submit")
box_for_plot = gr.Image(label="Visualization", type="pil")
with gr.Accordion("πŸ“˜ More Info and Explanation", open=False):
gr.Markdown("""
This heatmap shows **how each token is processed** across layers of a language model. Here's how to read it:
- **Rows**: layers of the model (bottom = deeper)
- **Columns**: input tokens
- **Colors**: intensity of effect (depends on the selected metric)
---
### Metrics explained:
- `Layer wise non-linearity`: how nonlinear the transformation is at each layer (red = more nonlinear).
- `Next-token prediction from intermediate representations`: shows which layers begin to make good predictions.
- `Contextualization measurement`: tokens with more context info get lower scores (green = more context).
- `Layerwise predictions (logit lens)`: tracks how the model’s guesses evolve at each layer.
- `Tokenwise loss without i-th layer`: shows how much each token depends on a specific layer. Red means performance drops if we skip this layer.
Use this tool to **peek inside the black box** β€” it reveals which layers matter most, which tokens carry the most memory, and how LLMs evolve their predictions.
---
You can also use `llm-microscope` as a Python library to run these analyses on **your own models and data**.
Just install it with: `pip install llm-microscope`
More details provided in [GitHub repo](https://github.com/AIRI-Institute/LLM-Microscope).
""")
task_selector.change(fn=update_description, inputs=[task_selector], outputs=[task_description])
task_selector.select(set_default, [task_selector], [normalization_selector])
normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector])
submit.click(
fn=update_output,
inputs=[task_selector, model_selector, text_message, normalization_selector],
outputs=[box_for_plot]
)
if __name__ == "__main__":
demo.launch(share=True, server_port=7860, server_name="0.0.0.0")