Update app.py
Browse files
app.py
CHANGED
@@ -39,12 +39,22 @@ chat_css = """
|
|
39 |
.gr-button svg { width: 32px !important; height: 32px !important; }
|
40 |
"""
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def safe_load_embeddings(filepath: str) -> any:
|
43 |
try:
|
|
|
44 |
return torch.load(filepath, weights_only=True)
|
45 |
except Exception as e:
|
46 |
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
|
47 |
try:
|
|
|
48 |
return torch.load(filepath, weights_only=False)
|
49 |
except Exception as e:
|
50 |
logger.error(f"Failed to load embeddings: {str(e)}")
|
@@ -139,22 +149,18 @@ def create_agent():
|
|
139 |
logger.error(f"Failed to create agent: {str(e)}")
|
140 |
raise
|
141 |
|
142 |
-
def respond(
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
if not isinstance(message, str) or len(message.strip()) <= 10:
|
149 |
-
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "Please provide a valid message with a string longer than 10 characters."}]
|
150 |
|
151 |
updated_history = history + [{"role": "user", "content": message}]
|
152 |
-
|
153 |
-
print("User Message:", message)
|
154 |
-
print("Full History:", updated_history)
|
155 |
-
print("================\n")
|
156 |
|
157 |
try:
|
|
|
158 |
formatted_history = [(m["role"], m["content"]) for m in updated_history]
|
159 |
|
160 |
response_generator = agent.run_gradio_chat(
|
@@ -167,22 +173,29 @@ def respond(chat_history, history, temperature, max_new_tokens, max_tokens, mult
|
|
167 |
max_round
|
168 |
)
|
169 |
except Exception as e:
|
170 |
-
|
|
|
|
|
171 |
|
172 |
collected = ""
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
return
|
180 |
|
181 |
def create_demo(agent):
|
182 |
with gr.Blocks(css=chat_css) as demo:
|
183 |
chatbot = gr.Chatbot(label="TxAgent", type="messages")
|
184 |
with gr.Row():
|
185 |
-
msg = gr.Textbox(label="Your question"
|
186 |
with gr.Row():
|
187 |
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
|
188 |
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
|
@@ -194,10 +207,9 @@ def create_demo(agent):
|
|
194 |
|
195 |
submit.click(
|
196 |
respond,
|
197 |
-
inputs=[
|
198 |
outputs=[chatbot]
|
199 |
)
|
200 |
-
|
201 |
return demo
|
202 |
|
203 |
def main():
|
@@ -205,10 +217,10 @@ def main():
|
|
205 |
global agent
|
206 |
agent = create_agent()
|
207 |
demo = create_demo(agent)
|
208 |
-
demo.launch()
|
209 |
except Exception as e:
|
210 |
logger.error(f"Application failed to start: {str(e)}")
|
211 |
raise
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
-
main()
|
|
|
39 |
.gr-button svg { width: 32px !important; height: 32px !important; }
|
40 |
"""
|
41 |
|
42 |
+
def validate_message(message: str) -> bool:
|
43 |
+
"""Validate that the message meets minimum requirements."""
|
44 |
+
if not message or not isinstance(message, str):
|
45 |
+
return False
|
46 |
+
# Remove whitespace and check length
|
47 |
+
clean_msg = message.strip()
|
48 |
+
return len(clean_msg) >= 10
|
49 |
+
|
50 |
def safe_load_embeddings(filepath: str) -> any:
|
51 |
try:
|
52 |
+
# First try with weights_only=True (secure mode)
|
53 |
return torch.load(filepath, weights_only=True)
|
54 |
except Exception as e:
|
55 |
logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
|
56 |
try:
|
57 |
+
# If that fails, try with weights_only=False (less secure)
|
58 |
return torch.load(filepath, weights_only=False)
|
59 |
except Exception as e:
|
60 |
logger.error(f"Failed to load embeddings: {str(e)}")
|
|
|
149 |
logger.error(f"Failed to create agent: {str(e)}")
|
150 |
raise
|
151 |
|
152 |
+
def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
153 |
+
# Validate the message first
|
154 |
+
if not validate_message(message):
|
155 |
+
error_msg = "Please provide a valid message with a string longer than 10 characters."
|
156 |
+
logger.warning(f"Message validation failed: {message}")
|
157 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
|
|
|
|
|
158 |
|
159 |
updated_history = history + [{"role": "user", "content": message}]
|
160 |
+
logger.debug(f"\n==== DEBUG ====\nUser Message: {message}\nFull History: {updated_history}\n================\n")
|
|
|
|
|
|
|
161 |
|
162 |
try:
|
163 |
+
# Ensure correct format for run_gradio_chat
|
164 |
formatted_history = [(m["role"], m["content"]) for m in updated_history]
|
165 |
|
166 |
response_generator = agent.run_gradio_chat(
|
|
|
173 |
max_round
|
174 |
)
|
175 |
except Exception as e:
|
176 |
+
error_msg = f"Error processing your request: {str(e)}"
|
177 |
+
logger.error(f"Error in respond function: {str(e)}")
|
178 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
|
179 |
|
180 |
collected = ""
|
181 |
+
try:
|
182 |
+
for chunk in response_generator:
|
183 |
+
if isinstance(chunk, dict):
|
184 |
+
collected += chunk.get("content", "")
|
185 |
+
else:
|
186 |
+
collected += str(chunk)
|
187 |
+
except Exception as e:
|
188 |
+
error_msg = f"Error generating response: {str(e)}"
|
189 |
+
logger.error(f"Error in response generation: {str(e)}")
|
190 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
|
191 |
|
192 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": collected}]
|
193 |
|
194 |
def create_demo(agent):
|
195 |
with gr.Blocks(css=chat_css) as demo:
|
196 |
chatbot = gr.Chatbot(label="TxAgent", type="messages")
|
197 |
with gr.Row():
|
198 |
+
msg = gr.Textbox(label="Your question", placeholder="Enter your biomedical question here (minimum 10 characters)...")
|
199 |
with gr.Row():
|
200 |
temp = gr.Slider(0, 1, value=0.3, label="Temperature")
|
201 |
max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
|
|
|
207 |
|
208 |
submit.click(
|
209 |
respond,
|
210 |
+
inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
|
211 |
outputs=[chatbot]
|
212 |
)
|
|
|
213 |
return demo
|
214 |
|
215 |
def main():
|
|
|
217 |
global agent
|
218 |
agent = create_agent()
|
219 |
demo = create_demo(agent)
|
220 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
221 |
except Exception as e:
|
222 |
logger.error(f"Application failed to start: {str(e)}")
|
223 |
raise
|
224 |
|
225 |
if __name__ == "__main__":
|
226 |
+
main()
|