Lohia, Aditya commited on
Commit
f24a24a
·
1 Parent(s): e553ed7

update space

Browse files
Files changed (2) hide show
  1. app.py +41 -29
  2. gateway.py +84 -57
app.py CHANGED
@@ -1,15 +1,34 @@
1
  import os
 
2
  import gradio as gr
3
  from typing import Iterator
4
 
5
  from dialog import get_dialog_box
6
  from gateway import check_server_health, request_generation
7
 
 
 
 
8
  # CONSTANTS
9
- MAX_NEW_TOKENS: int = 2048
 
10
 
11
- # GET ENVIRONMENT VARIABLES
12
  CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def toggle_ui():
@@ -18,7 +37,7 @@ def toggle_ui():
18
  Returns:
19
  hide/show main ui/dialog
20
  """
21
- health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API)
22
  if health:
23
  return gr.update(visible=True), gr.update(
24
  visible=False
@@ -35,9 +54,8 @@ def generate(
35
  system_prompt: str,
36
  max_new_tokens: int = 1024,
37
  temperature: float = 0.6,
38
- top_p: float = 0.9,
39
- top_k: int = 50,
40
- repetition_penalty: float = 1.2,
41
  ) -> Iterator[str]:
42
  """Send a request to backend, fetch the streaming responses and emit to the UI.
43
 
@@ -61,14 +79,15 @@ def generate(
61
  # sample method to yield responses from the llm model
62
  outputs = []
63
  for text in request_generation(
 
64
  message=message,
65
  system_prompt=system_prompt,
66
  max_new_tokens=max_new_tokens,
67
  temperature=temperature,
68
- top_p=top_p,
69
- top_k=top_k,
70
- repetition_penalty=repetition_penalty,
71
  cloud_gateway_api=CLOUD_GATEWAY_API,
 
72
  ):
73
  outputs.append(text)
74
  yield "".join(outputs)
@@ -94,28 +113,21 @@ chat_interface = gr.ChatInterface(
94
  minimum=0.1,
95
  maximum=4.0,
96
  step=0.1,
97
- value=1.0,
98
- ),
99
- gr.Slider(
100
- label="Top-p (nucleus sampling)",
101
- minimum=0.05,
102
- maximum=1.0,
103
- step=0.05,
104
- value=0.95,
105
  ),
106
  gr.Slider(
107
- label="Top-k",
108
- minimum=1,
109
- maximum=1000,
110
- step=1,
111
- value=64,
112
  ),
113
  gr.Slider(
114
- label="Repetition penalty",
115
- minimum=1.0,
116
  maximum=2.0,
117
- step=0.05,
118
- value=1.0,
119
  ),
120
  ],
121
  stop_btn=None,
@@ -134,14 +146,14 @@ chat_interface = gr.ChatInterface(
134
 
135
  with gr.Blocks(css="style.css", fill_height=True) as demo:
136
  # Get the server status before displaying UI
137
- visibility = check_server_health(CLOUD_GATEWAY_API)
138
 
139
  # Container for the main interface
140
  with gr.Column(visible=visibility, elem_id="main_ui") as main_ui:
141
  gr.Markdown(
142
  f"""
143
- # Gemma-3 27B Chat
144
- This Space is an Alpha release that demonstrates [Gemma-3-27B-It](https://huggingface.co/google/gemma-3-27b-it) model running on AMD MI210 infrastructure. The space is built with Google Gemma 3 [License](https://ai.google.dev/gemma/terms). Feel free to play with it!
145
  """
146
  )
147
  chat_interface.render()
 
1
  import os
2
+ import logging
3
  import gradio as gr
4
  from typing import Iterator
5
 
6
  from dialog import get_dialog_box
7
  from gateway import check_server_health, request_generation
8
 
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
  # CONSTANTS
13
+ # Get max new tokens from environment variable, if it is not set, default to 2048
14
+ MAX_NEW_TOKENS: int = os.getenv("MAX_NEW_TOKENS", 2048)
15
 
16
+ # Validate environment variables
17
  CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
18
+ if not CLOUD_GATEWAY_API:
19
+ raise EnvironmentError("API_ENDPOINT is not set.")
20
+
21
+ MODEL_NAME: str = os.getenv("MODEL_NAME")
22
+ if not MODEL_NAME:
23
+ raise EnvironmentError("MODEL_NAME is not set.")
24
+
25
+ # Get API Key
26
+ API_KEY = os.getenv("API_KEY")
27
+ if not API_KEY: # simple check to validate API Key
28
+ raise Exception("API Key not valid.")
29
+
30
+ # Create a header, avoid declaring multiple times
31
+ HEADER = {"x-api-key": f"{API_KEY}"}
32
 
33
 
34
  def toggle_ui():
 
37
  Returns:
38
  hide/show main ui/dialog
39
  """
40
+ health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API, header=HEADER)
41
  if health:
42
  return gr.update(visible=True), gr.update(
43
  visible=False
 
54
  system_prompt: str,
55
  max_new_tokens: int = 1024,
56
  temperature: float = 0.6,
57
+ frequency_penalty: float = 0.0,
58
+ presence_penalty: float = 0.0,
 
59
  ) -> Iterator[str]:
60
  """Send a request to backend, fetch the streaming responses and emit to the UI.
61
 
 
79
  # sample method to yield responses from the llm model
80
  outputs = []
81
  for text in request_generation(
82
+ header=HEADER,
83
  message=message,
84
  system_prompt=system_prompt,
85
  max_new_tokens=max_new_tokens,
86
  temperature=temperature,
87
+ presence_penalty=presence_penalty,
88
+ frequency_penalty=frequency_penalty,
 
89
  cloud_gateway_api=CLOUD_GATEWAY_API,
90
+ model_name=MODEL_NAME,
91
  ):
92
  outputs.append(text)
93
  yield "".join(outputs)
 
113
  minimum=0.1,
114
  maximum=4.0,
115
  step=0.1,
116
+ value=0.3,
 
 
 
 
 
 
 
117
  ),
118
  gr.Slider(
119
+ label="Frequency penalty",
120
+ minimum=-2.0,
121
+ maximum=2.0,
122
+ step=0.1,
123
+ value=0.0,
124
  ),
125
  gr.Slider(
126
+ label="Presence penalty",
127
+ minimum=-2.0,
128
  maximum=2.0,
129
+ step=0.1,
130
+ value=0.0,
131
  ),
132
  ],
133
  stop_btn=None,
 
146
 
147
  with gr.Blocks(css="style.css", fill_height=True) as demo:
148
  # Get the server status before displaying UI
149
+ visibility = check_server_health(CLOUD_GATEWAY_API, header=HEADER)
150
 
151
  # Container for the main interface
152
  with gr.Column(visible=visibility, elem_id="main_ui") as main_ui:
153
  gr.Markdown(
154
  f"""
155
+ # Gemma 3 27b Instruct
156
+ This Space is an Alpha release that demonstrates [Gemma-3-27B-It](https://huggingface.co/google/gemma-3-27b-it) model running on AMD MI300 infrastructure. The space is built with Google Gemma 3 [License](https://ai.google.dev/gemma/terms). Feel free to play with it!
157
  """
158
  )
159
  chat_interface.render()
gateway.py CHANGED
@@ -1,41 +1,54 @@
1
  import json
 
2
  import requests
 
3
 
 
4
 
5
- def check_server_health(cloud_gateway_api: str):
 
 
 
 
6
  """
7
  Use the appropriate API endpoint to check the server health.
8
  Args:
9
  cloud_gateway_api: API endpoint to probe.
 
10
 
11
  Returns:
12
  True if server is active, false otherwise.
13
  """
14
  try:
15
- response = requests.get(cloud_gateway_api + "/health")
16
- if response.status_code == 200:
17
- return True
18
- except requests.ConnectionError:
19
- print("Failed to establish connection to the server.")
20
-
21
- return False
 
 
 
22
 
23
 
24
  def request_generation(
 
25
  message: str,
26
  system_prompt: str,
27
  cloud_gateway_api: str,
 
28
  max_new_tokens: int = 1024,
29
- temperature: float = 0.6,
30
- top_p: float = 0.9,
31
- top_k: int = 50,
32
- repetition_penalty: float = 1.2,
33
  ):
34
  """
35
  Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
36
  token-by-token generation from LLM.
37
 
38
  Args:
 
39
  message: prompt from the user.
40
  system_prompt: system prompt to append.
41
  cloud_gateway_api (str): API endpoint to send the request.
@@ -43,7 +56,6 @@ def request_generation(
43
  temperature: the value used to module the next token probabilities.
44
  top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
45
  or higher are kept for generation.
46
- top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
47
  repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
48
 
49
  Returns:
@@ -51,54 +63,69 @@ def request_generation(
51
  """
52
 
53
  payload = {
54
- "model": "google/gemma-3-27b-it",
55
  "messages": [
56
- *(
57
- [
58
- {
59
- "role": "system",
60
- "content": [{"type": "text", "text": system_prompt}],
61
- }
62
- ]
63
- if system_prompt
64
- else []
65
- ),
66
- {"role": "user", "content": [{"type": "text", "text": message}]},
67
  ],
68
  "max_tokens": max_new_tokens,
69
  "temperature": temperature,
70
- "top_p": top_p,
71
- "repetition_penalty": repetition_penalty,
72
- "top_k": top_k,
73
  "stream": True, # Enable streaming
 
74
  }
75
 
76
- with requests.post(
77
- cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True
78
- ) as response:
79
- for chunk in response.iter_lines():
80
- if chunk:
81
- # Convert the chunk from bytes to a string and then parse it as json
82
- chunk_str = chunk.decode("utf-8")
83
-
84
- # Remove the `data: ` prefix from the chunk if it exists
85
- if chunk_str.startswith("data: "):
86
- chunk_str = chunk_str[len("data: ") :]
87
-
88
- # Skip empty chunks
89
- if chunk_str.strip() == "[DONE]":
90
- break
91
-
92
- # Parse the chunk into a JSON object
93
- try:
94
- chunk_json = json.loads(chunk_str)
95
-
96
- # Extract the "content" field from the choices
97
- content = chunk_json["choices"][0]["delta"].get("content", "")
98
-
99
- # Print the generated content as it's streamed
100
- if content:
101
- yield content
102
- except json.JSONDecodeError:
103
- # Handle any potential errors in decoding
104
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import logging
3
  import requests
4
+ import urllib3
5
 
6
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
7
 
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+
12
+ def check_server_health(cloud_gateway_api: str, header: dict) -> bool:
13
  """
14
  Use the appropriate API endpoint to check the server health.
15
  Args:
16
  cloud_gateway_api: API endpoint to probe.
17
+ header: Header for Authorization.
18
 
19
  Returns:
20
  True if server is active, false otherwise.
21
  """
22
  try:
23
+ response = requests.get(
24
+ cloud_gateway_api + "model/info",
25
+ headers=header,
26
+ verify=False,
27
+ )
28
+ response.raise_for_status()
29
+ return True
30
+ except requests.RequestException as e:
31
+ logging.error(f"Failed to check server health: {e}")
32
+ return False
33
 
34
 
35
  def request_generation(
36
+ header: dict,
37
  message: str,
38
  system_prompt: str,
39
  cloud_gateway_api: str,
40
+ model_name: str,
41
  max_new_tokens: int = 1024,
42
+ temperature: float = 0.3,
43
+ frequency_penalty: float = 0.0,
44
+ presence_penalty: float = 0.0,
 
45
  ):
46
  """
47
  Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
48
  token-by-token generation from LLM.
49
 
50
  Args:
51
+ header: authorization header for the API.
52
  message: prompt from the user.
53
  system_prompt: system prompt to append.
54
  cloud_gateway_api (str): API endpoint to send the request.
 
56
  temperature: the value used to module the next token probabilities.
57
  top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
58
  or higher are kept for generation.
 
59
  repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
60
 
61
  Returns:
 
63
  """
64
 
65
  payload = {
66
+ "model": model_name,
67
  "messages": [
68
+ {"role": "system", "content": system_prompt},
69
+ {"role": "user", "content": message},
 
 
 
 
 
 
 
 
 
70
  ],
71
  "max_tokens": max_new_tokens,
72
  "temperature": temperature,
73
+ "frequency_penalty": frequency_penalty,
74
+ "presence_penalty": presence_penalty,
 
75
  "stream": True, # Enable streaming
76
+ "serving_runtime": "vllm",
77
  }
78
 
79
+ try:
80
+ response = requests.post(
81
+ cloud_gateway_api + "chat/conversation",
82
+ headers=header,
83
+ json=payload,
84
+ verify=False,
85
+ )
86
+ response.raise_for_status()
87
+
88
+ # Append the conversation ID with the key X-Conversation-ID to the header
89
+ header["X-Conversation-ID"] = response.json()["conversationId"]
90
+
91
+ with requests.get(
92
+ cloud_gateway_api + f"conversation/stream",
93
+ headers=header,
94
+ verify=False,
95
+ stream=True,
96
+ ) as response:
97
+ for chunk in response.iter_lines():
98
+ if chunk:
99
+ # Convert the chunk from bytes to a string and then parse it as json
100
+ chunk_str = chunk.decode("utf-8")
101
+
102
+ # Remove the `data: ` prefix from the chunk if it exists
103
+ for _ in range(2):
104
+ if chunk_str.startswith("data: "):
105
+ chunk_str = chunk_str[len("data: ") :]
106
+
107
+ # Skip empty chunks
108
+ if chunk_str.strip() == "[DONE]":
109
+ break
110
+
111
+ # Parse the chunk into a JSON object
112
+ try:
113
+ chunk_json = json.loads(chunk_str)
114
+
115
+ # Extract the "content" field from the choices
116
+ if "choices" in chunk_json and chunk_json["choices"]:
117
+ content = chunk_json["choices"][0]["delta"].get(
118
+ "content", ""
119
+ )
120
+ else:
121
+ content = ""
122
+
123
+ # Print the generated content as it's streamed
124
+ if content:
125
+ yield content
126
+ except json.JSONDecodeError:
127
+ # Handle any potential errors in decoding
128
+ continue
129
+ except requests.RequestException as e:
130
+ logging.error(f"Failed to generate response: {e}")
131
+ yield "Server not responding. Please try again later."