Spaces:
Running
Running
import os | |
import gradio as gr | |
from typing import Iterator | |
from dialog import get_dialog_box | |
from gateway import check_server_health, request_generation | |
# CONSTANTS | |
MAX_NEW_TOKENS: int = 4096 | |
# GET ENVIRONMENT VARIABLES | |
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") | |
def toggle_ui(): | |
""" | |
Function to toggle the visibility of the UI based on the server health | |
Returns: | |
hide/show main ui/dialog | |
""" | |
health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API) | |
if health: | |
return gr.update(visible=True), gr.update( | |
visible=False | |
) # Show main UI, hide dialog | |
else: | |
return gr.update(visible=False), gr.update( | |
visible=True | |
) # Hide main UI, show dialog | |
def generate( | |
message: str, | |
chat_history: list, | |
system_prompt: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
"""Send a request to backend, fetch the streaming responses and emit to the UI. | |
Args: | |
message (str): input message from the user | |
chat_history (list[tuple[str, str]]): entire chat history of the session | |
system_prompt (str): system prompt | |
max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the | |
prompt. Defaults to 1024. | |
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. | |
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities | |
that add up to top_p or higher are kept for generation. Defaults to 0.9. | |
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. | |
Defaults to 50. | |
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. | |
Defaults to 1.2. | |
Yields: | |
Iterator[str]: Streaming responses to the UI | |
""" | |
# sample method to yield responses from the llm model | |
outputs = [] | |
for text in request_generation( | |
message=message, | |
system_prompt=system_prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
cloud_gateway_api=CLOUD_GATEWAY_API, | |
): | |
outputs.append(text) | |
yield "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Textbox(label="System prompt", lines=6), | |
gr.Slider( | |
label="Max New Tokens", | |
minimum=1, | |
maximum=MAX_NEW_TOKENS, | |
step=1, | |
value=1024, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=1.0, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.95, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=64, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
[ | |
"I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka for Cherry Blossom. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations." | |
], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'."], | |
], | |
cache_examples=False, | |
) | |
with gr.Blocks(css="style.css", fill_height=True) as demo: | |
# Get the server status before displaying UI | |
visibility = check_server_health(CLOUD_GATEWAY_API) | |
# Container for the main interface | |
with gr.Column(visible=visibility, elem_id="main_ui") as main_ui: | |
gr.Markdown( | |
f""" | |
# Gemma-3 27B Chat | |
This Space is an Alpha release that demonstrates [Gemma-3-27B-It](https://huggingface.co/google/gemma-3-27b-it) model running on AMD MI210 infrastructure. The space is built with Google Gemma 3 [License](https://ai.google.dev/gemma/terms). Feel free to play with it! | |
""" | |
) | |
chat_interface.render() | |
# Dialog box using Markdown for the error message | |
with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box: | |
# Add spinner and message | |
get_dialog_box() | |
# Timer to check server health every 5 seconds and update UI | |
timer = gr.Timer(value=10) | |
timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box]) | |
if __name__ == "__main__": | |
demo.queue( | |
max_size=int(os.getenv("QUEUE")), | |
default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")), | |
).launch() | |