File size: 1,300 Bytes
c3d8da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gradio as gr
import requests
from PIL import Image
import io
from typing import Any
import os

class Client:
    def __init__(self, server_url: str):
        self.server_url = server_url

    def send_request(self, model_name: str, text: str) -> Any:
        response = requests.post(self.server_url, json={"model_name": model_name, "text": text})
        if response.status_code == 200:
            img_data = response.content
            img = Image.open(io.BytesIO(img_data))
            return img
        else:
            return "Error, please retry"

client = Client(f"http://{os.environ['SERVER']}/predict")

def get_layerwise_nonlinearity(model_name: str, text: str) -> Any:
    return client.send_request(model_name, text)

with gr.Blocks() as demo:
    with gr.Column():
        model_selector = gr.Dropdown(choices=["mistralai/Mistral-7B-v0.1", "facebook/opt-125m"], label="Select Model")
        text_message = gr.Textbox(label="Enter your request:")
        submit = gr.Button("Submit")
        box_for_plot = gr.Image(label="Layer wise non-linearity (with first layer)", type="pil")

        submit.click(get_layerwise_nonlinearity, [model_selector, text_message], [box_for_plot])

if __name__ == "__main__":
    demo.launch(share=True, server_port=7860, server_name="0.0.0.0")