LLM-Microscope / app.py
matveymih's picture
Update app.py
07b8186 verified
raw
history blame
4.4 kB
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any, Tuple
import os
class Client:
def __init__(self, server_url: str):
self.server_url = server_url
def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
response = requests.post(
self.server_url,
json={
"task_name": task_name,
"model_name": model_name,
"text": text,
"normalization_type": normalization_type
},
timeout=60
)
if response.status_code == 200:
response_data = response.json()
img_data = bytes.fromhex(response_data["image"])
log_info = response_data["log"]
img = Image.open(io.BytesIO(img_data))
return img, log_info
else:
return "Error, please retry", "Error: Could not get response from server"
client = Client(f"http://{os.environ['SERVER']}/predict")
def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
return client.send_request(task_name, model_name, text, normalization_type)
with gr.Blocks() as demo:
with gr.Row():
model_selector = gr.Dropdown(
choices=[
"facebook/opt-1.3b",
"facebook/opt-2.7b",
# "microsoft/Phi-3-mini-128k-instruct"
],
value="facebook/opt-1.3b",
label="Select Model"
)
task_selector = gr.Dropdown(
choices=[
"Layer wise non-linearity (with first layer)",
"Next-token prediction from intermediate representations",
"Contextualization mesurment",
"Layerwise predictions and losses",
"Tokenwise loss without i-th layer"
],
value="Layer wise non-linearity (with first layer)",
label="Select Mode"
)
normalization_selector = gr.Dropdown(
choices=["global", "token-wise"], #, "sentence-wise"],
value="token-wise",
label="Select Normalization"
)
with gr.Column():
text_message = gr.Textbox(label="Enter your request:", value="I love to live my life")
submit = gr.Button("Submit")
box_for_plot = gr.Image(label="Visualization", type="pil")
log_output = gr.Textbox(label="Log Output", lines=10, interactive=False, value="")
def update_output(task_name: str, model_name: str, text: str, normalization_type: str, existing_log: str) -> Tuple[Any, str]:
img, new_log = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
combined_log = existing_log + "---\n" + new_log + "\n"
return img, combined_log
def set_default(task_name: str) -> str:
if task_name == "Layer wise non-linearity (with first layer)":
return "token-wise"
if task_name == "Next-token prediction from intermediate representations":
return "token-wise"
if task_name == "Contextualization mesurment":
return "global"
if task_name == "Layerwise predictions and losses":
return "global"
if task_name == "Tokenwise loss without i-th layer":
return "token-wise"
def check_normalization(task_name: str, normalization_name) -> Tuple[str, str]:
if task_name == "Contextualization mesurment" and normalization_name == "token-wise":
return ("global", "\nALERT: Cannot apply token-wise normalization to one sentence, setting global normalization\n")
return (normalization_name, "")
task_selector.select(set_default, [task_selector], [normalization_selector])
normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector, log_output])
submit.click(
fn=update_output,
inputs=[task_selector, model_selector, text_message, normalization_selector, log_output],
outputs=[box_for_plot, log_output]
)
if __name__ == "__main__":
demo.launch(share=True, server_port=7860, server_name="0.0.0.0")