matveymih commited on
Commit
c3d8da6
·
verified ·
1 Parent(s): d2ebc57

Create app.py

Browse files

Initial commit

Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ import io
5
+ from typing import Any
6
+ import os
7
+
8
+ class Client:
9
+ def __init__(self, server_url: str):
10
+ self.server_url = server_url
11
+
12
+ def send_request(self, model_name: str, text: str) -> Any:
13
+ response = requests.post(self.server_url, json={"model_name": model_name, "text": text})
14
+ if response.status_code == 200:
15
+ img_data = response.content
16
+ img = Image.open(io.BytesIO(img_data))
17
+ return img
18
+ else:
19
+ return "Error, please retry"
20
+
21
+ client = Client(f"http://{os.environ['SERVER']}/predict")
22
+
23
+ def get_layerwise_nonlinearity(model_name: str, text: str) -> Any:
24
+ return client.send_request(model_name, text)
25
+
26
+ with gr.Blocks() as demo:
27
+ with gr.Column():
28
+ model_selector = gr.Dropdown(choices=["mistralai/Mistral-7B-v0.1", "facebook/opt-125m"], label="Select Model")
29
+ text_message = gr.Textbox(label="Enter your request:")
30
+ submit = gr.Button("Submit")
31
+ box_for_plot = gr.Image(label="Layer wise non-linearity (with first layer)", type="pil")
32
+
33
+ submit.click(get_layerwise_nonlinearity, [model_selector, text_message], [box_for_plot])
34
+
35
+ if __name__ == "__main__":
36
+ demo.launch(share=True, server_port=7860, server_name="0.0.0.0")