File size: 4,395 Bytes
c3d8da6
 
 
 
d00e7d9
c3d8da6
 
07b8186
c3d8da6
 
 
 
d00e7d9
07b8186
 
 
 
 
 
 
 
 
 
c3d8da6
d00e7d9
 
 
c3d8da6
d00e7d9
c3d8da6
d00e7d9
c3d8da6
 
 
d00e7d9
 
c3d8da6
 
d00e7d9
 
 
 
 
07b8186
d00e7d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3d8da6
d00e7d9
c3d8da6
d00e7d9
 
c3d8da6
d00e7d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07b8186
 
 
 
 
d00e7d9
07b8186
d00e7d9
 
 
 
 
c3d8da6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any, Tuple
import os


class Client:
    def __init__(self, server_url: str):
        self.server_url = server_url

    def send_request(self, task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
        response = requests.post(
            self.server_url, 
            json={
                "task_name": task_name, 
                "model_name": model_name, 
                "text": text, 
                "normalization_type": normalization_type
            }, 
            timeout=60
        )
        if response.status_code == 200:
            response_data = response.json()
            img_data = bytes.fromhex(response_data["image"])
            log_info = response_data["log"]
            img = Image.open(io.BytesIO(img_data))
            return img, log_info
        else:
            return "Error, please retry", "Error: Could not get response from server"

client = Client(f"http://{os.environ['SERVER']}/predict")

def get_layerwise_nonlinearity(task_name: str, model_name: str, text: str, normalization_type: str) -> Tuple[Any, str]:
    return client.send_request(task_name, model_name, text, normalization_type)

with gr.Blocks() as demo:
    with gr.Row():
        model_selector = gr.Dropdown(
            choices=[
                "facebook/opt-1.3b",
                "facebook/opt-2.7b",
                # "microsoft/Phi-3-mini-128k-instruct"
            ],
            value="facebook/opt-1.3b",
            label="Select Model"
        )
        task_selector = gr.Dropdown(
            choices=[
                "Layer wise non-linearity (with first layer)", 
                "Next-token prediction from intermediate representations", 
                "Contextualization mesurment",
                "Layerwise predictions and losses",
                "Tokenwise loss without i-th layer"
            ], 
            value="Layer wise non-linearity (with first layer)",
            label="Select Mode"
        )
        normalization_selector = gr.Dropdown(
            choices=["global", "token-wise"], #, "sentence-wise"],
            value="token-wise",
            label="Select Normalization"
        )
    with gr.Column():
        text_message = gr.Textbox(label="Enter your request:", value="I love to live my life")
        submit = gr.Button("Submit")
        box_for_plot = gr.Image(label="Visualization", type="pil")
        log_output = gr.Textbox(label="Log Output", lines=10, interactive=False, value="")

        def update_output(task_name: str, model_name: str, text: str, normalization_type: str, existing_log: str) -> Tuple[Any, str]:
            img, new_log = get_layerwise_nonlinearity(task_name, model_name, text, normalization_type)
            combined_log = existing_log + "---\n" + new_log + "\n"
            return img, combined_log
        
        def set_default(task_name: str) -> str:
            if task_name == "Layer wise non-linearity (with first layer)":
                return "token-wise"
            if task_name == "Next-token prediction from intermediate representations":
                return "token-wise"
            if task_name == "Contextualization mesurment":
                return "global"
            if task_name == "Layerwise predictions and losses":
                return "global"
            if task_name == "Tokenwise loss without i-th layer":
                return "token-wise"

        def check_normalization(task_name: str, normalization_name) -> Tuple[str, str]:
            if task_name == "Contextualization mesurment" and normalization_name == "token-wise":
                return ("global", "\nALERT: Cannot apply token-wise normalization to one sentence, setting global normalization\n")
            return (normalization_name, "")
        
        task_selector.select(set_default, [task_selector], [normalization_selector])
        normalization_selector.select(check_normalization, [task_selector, normalization_selector], [normalization_selector, log_output])
        submit.click(
            fn=update_output, 
            inputs=[task_selector, model_selector, text_message, normalization_selector, log_output], 
            outputs=[box_for_plot, log_output]
        )

if __name__ == "__main__":
    demo.launch(share=True, server_port=7860, server_name="0.0.0.0")