Ali2206 commited on
Commit
0b9e159
·
verified ·
1 Parent(s): dec1312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -76
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
- from txagent import TxAgent
3
- import logging
4
  import os
 
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
@@ -18,11 +18,19 @@ class TxAgentApp:
18
  agent = TxAgent(
19
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
20
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
21
- # Remove unsupported parameters
22
- force_finish=True,
23
- enable_checker=True,
 
24
  step_rag_num=10,
 
 
 
 
 
25
  seed=42,
 
 
26
  additional_default_tools=["DirectResponse", "RequireClarification"]
27
  )
28
  logger.info("Model loading complete")
@@ -31,87 +39,117 @@ class TxAgentApp:
31
  logger.error(f"Initialization failed: {str(e)}")
32
  raise
33
 
34
- def respond(self, message, history):
35
- """Handle streaming responses with Gradio 5.23"""
36
  try:
 
 
 
 
37
  response_generator = self.agent.run_gradio_chat(
38
- message=message,
39
  history=history,
40
- temperature=0.3,
41
- max_new_tokens=2048,
42
- stream=True
 
 
 
 
43
  )
44
-
 
45
  for chunk in response_generator:
46
- if isinstance(chunk, dict):
47
- yield chunk.get("content", "")
48
  elif isinstance(chunk, str):
49
- yield chunk
 
 
 
 
 
 
50
  except Exception as e:
51
- logger.error(f"Generation error: {str(e)}")
52
  yield f"⚠️ Error: {str(e)}"
53
 
54
- # Initialize the app
55
- app = TxAgentApp()
56
-
57
- # Gradio 5.23 interface
58
- with gr.Blocks(
59
- title="TxAgent Medical AI",
60
- theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
61
- ) as demo:
62
- gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
63
 
64
- with gr.Row(equal_height=False):
65
- with gr.Column(scale=2):
66
- chatbot = gr.Chatbot(
67
- height=650,
68
- bubble_full_width=False
69
- )
70
- with gr.Column(scale=1):
71
- with gr.Accordion("⚙️ Parameters", open=False):
72
- temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
73
- max_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
74
- rag_toggle = gr.Checkbox(value=True, label="Enable RAG")
75
-
76
- msg = gr.Textbox(
77
- label="Your medical query",
78
- placeholder="Enter your biomedical question...",
79
- lines=5,
80
- max_lines=10
81
- )
82
- submit_btn = gr.Button("Submit", variant="primary")
83
- clear_btn = gr.Button("Clear History")
84
-
85
- # Chat interface
86
- msg.submit(
87
- app.respond,
88
- [msg, chatbot],
89
- chatbot,
90
- api_name="chat"
91
- ).then(
92
- lambda: "", None, msg
93
- )
94
-
95
- submit_btn.click(
96
- app.respond,
97
- [msg, chatbot],
98
- chatbot,
99
- api_name="chat"
100
- ).then(
101
- lambda: "", None, msg
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- clear_btn.click(
105
- lambda: [], None, chatbot
106
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # Launch configuration
109
  if __name__ == "__main__":
110
- demo.queue(
111
- concurrency_count=5,
112
- max_size=20
113
- ).launch(
114
- server_name="0.0.0.0",
115
- server_port=7860,
116
- share=False
117
- )
 
1
  import gradio as gr
 
 
2
  import os
3
+ import logging
4
+ from txagent import TxAgent
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
 
18
  agent = TxAgent(
19
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
20
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
21
+ enable_finish=True,
22
+ enable_rag=True,
23
+ enable_summary=False,
24
+ init_rag_num=0,
25
  step_rag_num=10,
26
+ summary_mode='step',
27
+ summary_skip_last_k=0,
28
+ summary_context_length=None,
29
+ force_finish=True,
30
+ avoid_repeat=True,
31
  seed=42,
32
+ enable_checker=True,
33
+ enable_chat=False,
34
  additional_default_tools=["DirectResponse", "RequireClarification"]
35
  )
36
  logger.info("Model loading complete")
 
39
  logger.error(f"Initialization failed: {str(e)}")
40
  raise
41
 
42
+ def respond(self, message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
43
+ """Handle streaming responses with Gradio"""
44
  try:
45
+ if not isinstance(message, str) or len(message.strip()) <= 10:
46
+ yield "Please provide a valid message longer than 10 characters."
47
+ return
48
+
49
  response_generator = self.agent.run_gradio_chat(
50
+ message=message.strip(),
51
  history=history,
52
+ temperature=temperature,
53
+ max_new_tokens=max_new_tokens,
54
+ max_token=max_tokens,
55
+ call_agent=multi_agent,
56
+ conversation=conversation_state,
57
+ max_round=max_round,
58
+ seed=42
59
  )
60
+
61
+ full_response = ""
62
  for chunk in response_generator:
63
+ if isinstance(chunk, dict) and "content" in chunk:
64
+ content = chunk["content"]
65
  elif isinstance(chunk, str):
66
+ content = chunk
67
+ else:
68
+ content = str(chunk)
69
+
70
+ full_response += content
71
+ yield full_response
72
+
73
  except Exception as e:
74
+ logger.error(f"Error in respond function: {str(e)}")
75
  yield f"⚠️ Error: {str(e)}"
76
 
77
+ def create_demo():
78
+ """Create and configure the Gradio interface"""
79
+ app = TxAgentApp()
 
 
 
 
 
 
80
 
81
+ with gr.Blocks(
82
+ title="TxAgent Medical AI",
83
+ theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
84
+ ) as demo:
85
+ gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
86
+
87
+ with gr.Row(equal_height=False):
88
+ with gr.Column(scale=2):
89
+ chatbot = gr.Chatbot(
90
+ height=650,
91
+ bubble_full_width=False,
92
+ render_markdown=True
93
+ )
94
+ msg = gr.Textbox(
95
+ label="Your medical query",
96
+ placeholder="Enter your biomedical question...",
97
+ lines=5,
98
+ max_lines=10
99
+ )
100
+
101
+ with gr.Column(scale=1):
102
+ with gr.Accordion("⚙️ Parameters", open=False):
103
+ temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
104
+ max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
105
+ max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
106
+ max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
107
+ multi_agent = gr.Checkbox(value=False, label="Multi-Agent Mode")
108
+
109
+ submit_btn = gr.Button("Submit", variant="primary")
110
+ clear_btn = gr.Button("Clear History")
111
+
112
+ conversation_state = gr.State([])
113
+
114
+ # Chat interface
115
+ msg.submit(
116
+ app.respond,
117
+ [msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
118
+ chatbot
119
+ ).then(lambda: "", None, msg)
120
+
121
+ submit_btn.click(
122
+ app.respond,
123
+ [msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
124
+ chatbot
125
+ ).then(lambda: "", None, msg)
126
+
127
+ clear_btn.click(
128
+ lambda: ([], []),
129
+ None,
130
+ [chatbot, conversation_state]
131
+ )
132
+
133
+ return demo
134
 
135
+ def main():
136
+ """Main entry point for the application"""
137
+ try:
138
+ logger.info("Starting TxAgent application...")
139
+ demo = create_demo()
140
+
141
+ logger.info("Launching Gradio interface...")
142
+ demo.queue(
143
+ concurrency_count=5,
144
+ max_size=20
145
+ ).launch(
146
+ server_name="0.0.0.0",
147
+ server_port=7860,
148
+ share=False
149
+ )
150
+ except Exception as e:
151
+ logger.error(f"Application failed to start: {str(e)}")
152
+ raise
153
 
 
154
  if __name__ == "__main__":
155
+ main()