Ali2206 commited on
Commit
742b026
·
verified ·
1 Parent(s): 517d789

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -39,12 +39,22 @@ chat_css = """
39
  .gr-button svg { width: 32px !important; height: 32px !important; }
40
  """
41
 
 
 
 
 
 
 
 
 
42
  def safe_load_embeddings(filepath: str) -> any:
43
  try:
 
44
  return torch.load(filepath, weights_only=True)
45
  except Exception as e:
46
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
47
  try:
 
48
  return torch.load(filepath, weights_only=False)
49
  except Exception as e:
50
  logger.error(f"Failed to load embeddings: {str(e)}")
@@ -139,22 +149,18 @@ def create_agent():
139
  logger.error(f"Failed to create agent: {str(e)}")
140
  raise
141
 
142
- def respond(chat_history, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
143
- if not chat_history:
144
- return history + [{"role": "assistant", "content": "Please provide a message."}]
145
-
146
- message = chat_history[-1][1] if isinstance(chat_history[-1], (list, tuple)) else chat_history[-1]
147
-
148
- if not isinstance(message, str) or len(message.strip()) <= 10:
149
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "Please provide a valid message with a string longer than 10 characters."}]
150
 
151
  updated_history = history + [{"role": "user", "content": message}]
152
- print("\n==== DEBUG ====")
153
- print("User Message:", message)
154
- print("Full History:", updated_history)
155
- print("================\n")
156
 
157
  try:
 
158
  formatted_history = [(m["role"], m["content"]) for m in updated_history]
159
 
160
  response_generator = agent.run_gradio_chat(
@@ -167,22 +173,29 @@ def respond(chat_history, history, temperature, max_new_tokens, max_tokens, mult
167
  max_round
168
  )
169
  except Exception as e:
170
- return updated_history + [{"role": "assistant", "content": f"Error: {str(e)}"}]
 
 
171
 
172
  collected = ""
173
- for chunk in response_generator:
174
- if isinstance(chunk, dict):
175
- collected += chunk.get("content", "")
176
- else:
177
- collected += str(chunk)
 
 
 
 
 
178
 
179
- return updated_history + [{"role": "assistant", "content": collected}]
180
 
181
  def create_demo(agent):
182
  with gr.Blocks(css=chat_css) as demo:
183
  chatbot = gr.Chatbot(label="TxAgent", type="messages")
184
  with gr.Row():
185
- msg = gr.Textbox(label="Your question") # not used in inputs anymore
186
  with gr.Row():
187
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
188
  max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
@@ -194,10 +207,9 @@ def create_demo(agent):
194
 
195
  submit.click(
196
  respond,
197
- inputs=[chatbot, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
198
  outputs=[chatbot]
199
  )
200
-
201
  return demo
202
 
203
  def main():
@@ -205,10 +217,10 @@ def main():
205
  global agent
206
  agent = create_agent()
207
  demo = create_demo(agent)
208
- demo.launch()
209
  except Exception as e:
210
  logger.error(f"Application failed to start: {str(e)}")
211
  raise
212
 
213
  if __name__ == "__main__":
214
- main()
 
39
  .gr-button svg { width: 32px !important; height: 32px !important; }
40
  """
41
 
42
+ def validate_message(message: str) -> bool:
43
+ """Validate that the message meets minimum requirements."""
44
+ if not message or not isinstance(message, str):
45
+ return False
46
+ # Remove whitespace and check length
47
+ clean_msg = message.strip()
48
+ return len(clean_msg) >= 10
49
+
50
  def safe_load_embeddings(filepath: str) -> any:
51
  try:
52
+ # First try with weights_only=True (secure mode)
53
  return torch.load(filepath, weights_only=True)
54
  except Exception as e:
55
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
56
  try:
57
+ # If that fails, try with weights_only=False (less secure)
58
  return torch.load(filepath, weights_only=False)
59
  except Exception as e:
60
  logger.error(f"Failed to load embeddings: {str(e)}")
 
149
  logger.error(f"Failed to create agent: {str(e)}")
150
  raise
151
 
152
+ def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
153
+ # Validate the message first
154
+ if not validate_message(message):
155
+ error_msg = "Please provide a valid message with a string longer than 10 characters."
156
+ logger.warning(f"Message validation failed: {message}")
157
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
 
 
158
 
159
  updated_history = history + [{"role": "user", "content": message}]
160
+ logger.debug(f"\n==== DEBUG ====\nUser Message: {message}\nFull History: {updated_history}\n================\n")
 
 
 
161
 
162
  try:
163
+ # Ensure correct format for run_gradio_chat
164
  formatted_history = [(m["role"], m["content"]) for m in updated_history]
165
 
166
  response_generator = agent.run_gradio_chat(
 
173
  max_round
174
  )
175
  except Exception as e:
176
+ error_msg = f"Error processing your request: {str(e)}"
177
+ logger.error(f"Error in respond function: {str(e)}")
178
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
179
 
180
  collected = ""
181
+ try:
182
+ for chunk in response_generator:
183
+ if isinstance(chunk, dict):
184
+ collected += chunk.get("content", "")
185
+ else:
186
+ collected += str(chunk)
187
+ except Exception as e:
188
+ error_msg = f"Error generating response: {str(e)}"
189
+ logger.error(f"Error in response generation: {str(e)}")
190
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
191
 
192
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": collected}]
193
 
194
  def create_demo(agent):
195
  with gr.Blocks(css=chat_css) as demo:
196
  chatbot = gr.Chatbot(label="TxAgent", type="messages")
197
  with gr.Row():
198
+ msg = gr.Textbox(label="Your question", placeholder="Enter your biomedical question here (minimum 10 characters)...")
199
  with gr.Row():
200
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
201
  max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
 
207
 
208
  submit.click(
209
  respond,
210
+ inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
211
  outputs=[chatbot]
212
  )
 
213
  return demo
214
 
215
  def main():
 
217
  global agent
218
  agent = create_agent()
219
  demo = create_demo(agent)
220
+ demo.launch(server_name="0.0.0.0", server_port=7860)
221
  except Exception as e:
222
  logger.error(f"Application failed to start: {str(e)}")
223
  raise
224
 
225
  if __name__ == "__main__":
226
+ main()