Ali2206 commited on
Commit
47f0902
·
verified ·
1 Parent(s): 1155704

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -166
app.py CHANGED
@@ -1,56 +1,41 @@
1
- import random
2
  import os
3
- import datetime
4
- import sys
5
- from txagent import TxAgent
6
- import spaces
7
  import gradio as gr
 
 
 
8
 
9
- # Set environment variables
10
  current_dir = os.path.dirname(os.path.abspath(__file__))
11
  os.environ["MKL_THREADING_LAYER"] = "GNU"
12
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
 
15
  DESCRIPTION = '''
16
  <div>
17
  <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
18
  </div>
19
  '''
20
- INTRO = """
21
- Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations...
22
- """
23
- LICENSE = """
24
- DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE...
25
- """
26
 
27
- PLACEHOLDER = """
28
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
29
  <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
30
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
31
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
32
- (top-right) to remove previous context before submitting a new question.</p>
33
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
34
  </div>
35
- """
36
 
37
  css = """
38
- h1 {
39
- text-align: center;
40
- display: block;
41
- }
42
  #duplicate-button {
43
  margin: auto;
44
  color: white;
45
  background: #1565c0;
46
  border-radius: 100vh;
47
  }
48
- .small-button button {
49
- font-size: 12px !important;
50
- padding: 4px 8px !important;
51
- height: 6px !important;
52
- width: 4px !important;
53
- }
54
  .gradio-accordion {
55
  margin-top: 0px !important;
56
  margin-bottom: 0px !important;
@@ -62,167 +47,110 @@ chat_css = """
62
  .gr-button svg { width: 32px !important; height: 32px !important; }
63
  """
64
 
65
- model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
66
- rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
67
-
68
- question_examples = [
69
- ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
70
- ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
71
- ['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
72
- ]
73
-
74
  new_tool_files = {
75
- 'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
76
  }
77
 
78
- agent = TxAgent(model_name,
79
- rag_model_name,
80
- tool_files_dict=new_tool_files,
81
- force_finish=True,
82
- enable_checker=True,
83
- step_rag_num=10,
84
- seed=100,
85
- additional_default_tools=['DirectResponse', 'RequireClarification'])
86
- agent.init_model()
87
-
88
- def update_model_parameters(enable_finish, enable_rag, enable_summary,
89
- init_rag_num, step_rag_num, skip_last_k,
90
- summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
91
- return agent.update_parameters(
92
- enable_finish=enable_finish,
93
- enable_rag=enable_rag,
94
- enable_summary=enable_summary,
95
- init_rag_num=init_rag_num,
96
- step_rag_num=step_rag_num,
97
- skip_last_k=skip_last_k,
98
- summary_mode=summary_mode,
99
- summary_skip_last_k=summary_skip_last_k,
100
- summary_context_length=summary_context_length,
101
- force_finish=force_finish,
102
- seed=seed,
103
- )
104
-
105
- def update_seed():
106
- seed = random.randint(0, 10000)
107
- return agent.update_parameters(seed=seed)
108
-
109
- def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
110
- update_seed()
111
- new_history = history[:retry_data.index]
112
- previous_prompt = history[retry_data.index]['content']
113
- yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
114
-
115
- PASSWORD = "mypassword"
116
-
117
- def check_password(input_password):
118
- if input_password == PASSWORD:
119
- return gr.update(visible=True), ""
120
- else:
121
- return gr.update(visible=False), "Incorrect password, try again!"
122
-
123
 
 
124
  if __name__ == "__main__":
125
- conversation_state = gr.State([])
126
-
127
- chatbot = gr.Chatbot(
128
- height=800, placeholder=PLACEHOLDER, label='TxAgent',
129
- type="messages", show_copy_button=True
 
 
 
 
 
130
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
132
  with gr.Blocks(css=css) as demo:
133
  gr.Markdown(DESCRIPTION)
134
  gr.Markdown(INTRO)
135
 
136
- default_temperature = 0.3
137
- default_max_new_tokens = 1024
138
- default_max_tokens = 81920
139
- default_max_round = 30
140
-
141
- temperature_state = gr.State(value=default_temperature)
142
- max_new_tokens_state = gr.State(value=default_max_new_tokens)
143
- max_tokens_state = gr.State(value=default_max_tokens)
144
- max_round_state = gr.State(value=default_max_round)
 
 
 
 
 
145
 
146
  chatbot.retry(
147
  handle_retry,
148
  chatbot, chatbot,
149
- temperature_state, max_new_tokens_state,
150
- max_tokens_state,
151
- gr.Checkbox(value=False, render=False),
152
- conversation_state,
153
- max_round_state
154
  )
155
 
156
  gr.ChatInterface(
157
- fn=agent.run_gradio_chat,
158
  chatbot=chatbot,
159
- fill_height=True, fill_width=True, stop_btn=True,
160
- additional_inputs_accordion=gr.Accordion(label="⚙️ Inference Parameters", open=False, render=False),
161
  additional_inputs=[
162
- temperature_state, max_new_tokens_state, max_tokens_state,
163
- gr.Checkbox(label="Activate multi-agent reasoning mode", value=False, render=False),
164
- conversation_state,
165
- max_round_state,
166
- gr.Number(label="Seed", value=100, render=False)
167
  ],
168
  examples=question_examples,
169
- cache_examples=False,
170
  css=chat_css,
 
 
 
 
171
  )
172
 
173
- with gr.Accordion("Settings", open=False):
174
- temperature_slider = gr.Slider(0, 1, step=0.1, value=default_temperature, label="Temperature")
175
- max_new_tokens_slider = gr.Slider(128, 4096, step=1, value=default_max_new_tokens, label="Max new tokens")
176
- max_tokens_slider = gr.Slider(128, 32000, step=1, value=default_max_tokens, label="Max tokens")
177
- max_round_slider = gr.Slider(0, 50, step=1, value=default_max_round, label="Max round")
178
-
179
- temperature_slider.change(lambda x: x, inputs=temperature_slider, outputs=temperature_state)
180
- max_new_tokens_slider.change(lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
181
- max_tokens_slider.change(lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
182
- max_round_slider.change(lambda x: x, inputs=max_round_slider, outputs=max_round_state)
183
-
184
- password_input = gr.Textbox(label="Enter Password for More Settings", type="password")
185
- incorrect_message = gr.Textbox(visible=False, interactive=False)
186
-
187
- with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
188
- with gr.Row():
189
- with gr.Column(scale=1):
190
- with gr.Accordion("⚙️ Model Loading", open=False):
191
- model_name_input = gr.Textbox(label="Enter model path", value=model_name)
192
- load_model_btn = gr.Button(value="Load Model")
193
- load_model_btn.click(agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
194
- with gr.Column(scale=1):
195
- with gr.Accordion("⚙️ Functional Parameters", open=False):
196
- enable_finish = gr.Checkbox(label="Enable Finish", value=True)
197
- enable_rag = gr.Checkbox(label="Enable RAG", value=True)
198
- enable_summary = gr.Checkbox(label="Enable Summary", value=False)
199
- init_rag_num = gr.Number(label="Initial RAG Num", value=0)
200
- step_rag_num = gr.Number(label="Step RAG Num", value=10)
201
- skip_last_k = gr.Number(label="Skip Last K", value=0)
202
- summary_mode = gr.Textbox(label="Summary Mode", value='step')
203
- summary_skip_last_k = gr.Number(label="Summary Skip Last K", value=0)
204
- summary_context_length = gr.Number(label="Summary Context Length", value=None)
205
- force_finish = gr.Checkbox(label="Force FinalAnswer", value=True)
206
- seed = gr.Number(label="Seed", value=100)
207
-
208
- submit_btn = gr.Button("Update Parameters")
209
- updated_parameters_output = gr.JSON()
210
- submit_btn.click(
211
- fn=update_model_parameters,
212
- inputs=[
213
- enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
214
- summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed
215
- ],
216
- outputs=updated_parameters_output
217
- )
218
-
219
- submit_button = gr.Button("Submit")
220
- submit_button.click(
221
- check_password,
222
- inputs=password_input,
223
- outputs=[protected_accordion, incorrect_message]
224
- )
225
-
226
  gr.Markdown(LICENSE)
227
 
228
- demo.launch(share=True)
 
 
1
  import os
2
+ import random
 
 
 
3
  import gradio as gr
4
+ from datetime import datetime
5
+
6
+ from txagent import TxAgent
7
 
8
+ # ==== Environment Setup ====
9
  current_dir = os.path.dirname(os.path.abspath(__file__))
10
  os.environ["MKL_THREADING_LAYER"] = "GNU"
11
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
 
14
+ # ==== UI Content ====
15
  DESCRIPTION = '''
16
  <div>
17
  <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
18
  </div>
19
  '''
20
+ INTRO = "Precision therapeutics require multimodal adaptive models..."
21
+ LICENSE = "DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE..."
 
 
 
 
22
 
23
+ PLACEHOLDER = '''
24
+ <div style="padding: 30px; text-align: center;">
25
  <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
26
+ <p style="font-size: 18px;">Click clear 🗑️ before asking a new question.</p>
27
+ <p style="font-size: 18px;">Click retry 🔄 to see another answer.</p>
 
 
28
  </div>
29
+ '''
30
 
31
  css = """
32
+ h1 { text-align: center; }
 
 
 
33
  #duplicate-button {
34
  margin: auto;
35
  color: white;
36
  background: #1565c0;
37
  border-radius: 100vh;
38
  }
 
 
 
 
 
 
39
  .gradio-accordion {
40
  margin-top: 0px !important;
41
  margin-bottom: 0px !important;
 
47
  .gr-button svg { width: 32px !important; height: 32px !important; }
48
  """
49
 
50
+ # ==== Model Settings ====
51
+ model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
52
+ rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
 
 
 
 
 
 
53
  new_tool_files = {
54
+ "new_tool": os.path.join(current_dir, "data", "new_tool.json")
55
  }
56
 
57
+ question_examples = [
58
+ ["Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering moderate hepatic impairment?"],
59
+ ["A 30-year-old patient is on Prozac for depression and now diagnosed with WHIM syndrome. Is Xolremdi suitable?"]
60
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # ====== Main Application Entrypoint ======
63
  if __name__ == "__main__":
64
+ # === Initialize the model (inside __main__) ===
65
+ agent = TxAgent(
66
+ model_name,
67
+ rag_model_name,
68
+ tool_files_dict=new_tool_files,
69
+ force_finish=True,
70
+ enable_checker=True,
71
+ step_rag_num=10,
72
+ seed=100,
73
+ additional_default_tools=["DirectResponse", "RequireClarification"]
74
  )
75
+ agent.init_model()
76
+
77
+ # === Gradio interface logic ===
78
+ def handle_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
79
+ return agent.run_gradio_chat(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
80
+
81
+ def update_model_parameters(enable_finish, enable_rag, enable_summary,
82
+ init_rag_num, step_rag_num, skip_last_k,
83
+ summary_mode, summary_skip_last_k, summary_context_length,
84
+ force_finish, seed):
85
+ return agent.update_parameters(
86
+ enable_finish=enable_finish,
87
+ enable_rag=enable_rag,
88
+ enable_summary=enable_summary,
89
+ init_rag_num=init_rag_num,
90
+ step_rag_num=step_rag_num,
91
+ skip_last_k=skip_last_k,
92
+ summary_mode=summary_mode,
93
+ summary_skip_last_k=summary_skip_last_k,
94
+ summary_context_length=summary_context_length,
95
+ force_finish=force_finish,
96
+ seed=seed
97
+ )
98
+
99
+ def update_seed():
100
+ seed = random.randint(0, 10000)
101
+ return agent.update_parameters(seed=seed)
102
+
103
+ def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
104
+ update_seed()
105
+ new_history = history[:retry_data.index]
106
+ previous_prompt = history[retry_data.index]["content"]
107
+ yield from agent.run_gradio_chat(
108
+ new_history + [{"role": "user", "content": previous_prompt}],
109
+ temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round
110
+ )
111
 
112
+ # ===== Build Gradio Interface =====
113
  with gr.Blocks(css=css) as demo:
114
  gr.Markdown(DESCRIPTION)
115
  gr.Markdown(INTRO)
116
 
117
+ temperature = gr.Slider(0, 1, step=0.1, value=0.3, label="Temperature")
118
+ max_new_tokens = gr.Slider(128, 4096, step=1, value=1024, label="Max New Tokens")
119
+ max_tokens = gr.Slider(128, 32000, step=1, value=8192, label="Max Total Tokens")
120
+ max_round = gr.Slider(1, 50, step=1, value=30, label="Max Rounds")
121
+ multi_agent = gr.Checkbox(label="Enable Multi-agent Reasoning", value=False)
122
+ conversation_state = gr.State([])
123
+
124
+ chatbot = gr.Chatbot(
125
+ label="TxAgent",
126
+ placeholder=PLACEHOLDER,
127
+ height=700,
128
+ type="messages",
129
+ show_copy_button=True
130
+ )
131
 
132
  chatbot.retry(
133
  handle_retry,
134
  chatbot, chatbot,
135
+ temperature, max_new_tokens, max_tokens,
136
+ multi_agent, conversation_state, max_round
 
 
 
137
  )
138
 
139
  gr.ChatInterface(
140
+ fn=handle_chat,
141
  chatbot=chatbot,
 
 
142
  additional_inputs=[
143
+ temperature, max_new_tokens, max_tokens,
144
+ multi_agent, conversation_state, max_round
 
 
 
145
  ],
146
  examples=question_examples,
 
147
  css=chat_css,
148
+ cache_examples=False,
149
+ fill_height=True,
150
+ fill_width=True,
151
+ stop_btn=True
152
  )
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  gr.Markdown(LICENSE)
155
 
156
+ demo.launch()