Lohia, Aditya commited on
Commit
7faf2cf
·
1 Parent(s): bf9d9a5
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from typing import Iterator
4
+
5
+ from dialog import get_dialog_box
6
+ from gateway import check_server_health, request_generation
7
+
8
+ # CONSTANTS
9
+ MAX_NEW_TOKENS: int = 2048
10
+
11
+ # GET ENVIRONMENT VARIABLES
12
+ CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
13
+
14
+
15
+ def toggle_ui():
16
+ """
17
+ Function to toggle the visibility of the UI based on the server health
18
+ Returns:
19
+ hide/show main ui/dialog
20
+ """
21
+ health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API)
22
+ if health:
23
+ return gr.update(visible=True), gr.update(
24
+ visible=False
25
+ ) # Show main UI, hide dialog
26
+ else:
27
+ return gr.update(visible=False), gr.update(
28
+ visible=True
29
+ ) # Hide main UI, show dialog
30
+
31
+
32
+ def generate(
33
+ message: str,
34
+ chat_history: list,
35
+ system_prompt: str,
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.6,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50,
40
+ repetition_penalty: float = 1.2,
41
+ ) -> Iterator[str]:
42
+ """Send a request to backend, fetch the streaming responses and emit to the UI.
43
+
44
+ Args:
45
+ message (str): input message from the user
46
+ chat_history (list[tuple[str, str]]): entire chat history of the session
47
+ system_prompt (str): system prompt
48
+ max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
49
+ prompt. Defaults to 1024.
50
+ temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
51
+ top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
52
+ that add up to top_p or higher are kept for generation. Defaults to 0.9.
53
+ top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
54
+ Defaults to 50.
55
+ repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
56
+ Defaults to 1.2.
57
+
58
+ Yields:
59
+ Iterator[str]: Streaming responses to the UI
60
+ """
61
+ # sample method to yield responses from the llm model
62
+ outputs = []
63
+ for text in request_generation(
64
+ message=message,
65
+ system_prompt=system_prompt,
66
+ max_new_tokens=max_new_tokens,
67
+ temperature=temperature,
68
+ top_p=top_p,
69
+ top_k=top_k,
70
+ repetition_penalty=repetition_penalty,
71
+ cloud_gateway_api=CLOUD_GATEWAY_API,
72
+ ):
73
+ outputs.append(text)
74
+ yield "".join(outputs)
75
+
76
+
77
+ chat_interface = gr.ChatInterface(
78
+ fn=generate,
79
+ additional_inputs=[
80
+ gr.Textbox(label="System prompt", lines=6),
81
+ gr.Slider(
82
+ label="Max New Tokens",
83
+ minimum=1,
84
+ maximum=MAX_NEW_TOKENS,
85
+ step=1,
86
+ value=1024,
87
+ ),
88
+ gr.Slider(
89
+ label="Temperature",
90
+ minimum=0.1,
91
+ maximum=4.0,
92
+ step=0.1,
93
+ value=0.1,
94
+ ),
95
+ gr.Slider(
96
+ label="Top-p (nucleus sampling)",
97
+ minimum=0.05,
98
+ maximum=1.0,
99
+ step=0.05,
100
+ value=0.95,
101
+ ),
102
+ gr.Slider(
103
+ label="Top-k",
104
+ minimum=1,
105
+ maximum=1000,
106
+ step=1,
107
+ value=50,
108
+ ),
109
+ gr.Slider(
110
+ label="Repetition penalty",
111
+ minimum=1.0,
112
+ maximum=2.0,
113
+ step=0.05,
114
+ value=1.2,
115
+ ),
116
+ ],
117
+ stop_btn=None,
118
+ examples=[
119
+ ["Hello there! How are you doing?"],
120
+ ["Can you explain briefly to me what is the Python programming language?"],
121
+ ["Explain the plot of Cinderella in a sentence."],
122
+ ["How many hours does it take a man to eat a Helicopter?"],
123
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'."],
124
+ ],
125
+ cache_examples=False,
126
+ )
127
+
128
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
129
+ # Get the server status before displaying UI
130
+ visibility = check_server_health(CLOUD_GATEWAY_API)
131
+
132
+ # Container for the main interface
133
+ with gr.Column(visible=visibility, elem_id="main_ui") as main_ui:
134
+ gr.Markdown(
135
+ f"""
136
+ # Gemma-3 27B Chat
137
+ This Space is an Alpha release that demonstrates [Gemma-3-27B-It](https://huggingface.co/google/gemma-3-27b-it) model running on AMD MI210 infrastructure. The space is built with Google Gemma 3 [License](https://ai.google.dev/gemma/terms). Feel free to play with it!
138
+ """
139
+ )
140
+ chat_interface.render()
141
+
142
+ # Dialog box using Markdown for the error message
143
+ with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box:
144
+ # Add spinner and message
145
+ get_dialog_box()
146
+
147
+ # Timer to check server health every 5 seconds and update UI
148
+ timer = gr.Timer(value=10)
149
+ timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box])
150
+
151
+
152
+ if __name__ == "__main__":
153
+ demo.queue(
154
+ max_size=int(os.getenv("QUEUE")),
155
+ default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
156
+ ).launch()