Ali2206 commited on
Commit
849209d
·
verified ·
1 Parent(s): d206f24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -94
app.py CHANGED
@@ -7,10 +7,10 @@ import gradio as gr
7
  from huggingface_hub import snapshot_download
8
  from tooluniverse import ToolUniverse
9
  import time
10
- from functools import partial
11
  from requests.adapters import HTTPAdapter
12
  from requests import Session
13
  from urllib3.util.retry import Retry
 
14
 
15
  # Configuration
16
  CONFIG = {
@@ -21,30 +21,42 @@ CONFIG = {
21
  "tool_files": {
22
  "new_tool": "./data/new_tool.json"
23
  },
24
- "download_timeout": 300, # 5 minutes timeout
25
- "max_retries": 3,
26
- "retry_delay": 10 # seconds between retries
 
 
 
 
27
  }
28
 
29
  # Logging setup
30
- logging.basicConfig(level=logging.INFO)
 
 
 
31
  logger = logging.getLogger(__name__)
32
 
33
- def create_custom_session():
34
- """Create a session with custom timeout and retry settings"""
35
  session = Session()
36
- retries = Retry(
37
- total=CONFIG["max_retries"],
 
38
  backoff_factor=1,
39
- status_forcelist=[500, 502, 503, 504]
40
  )
 
41
  adapter = HTTPAdapter(
42
- max_retries=retries,
43
  pool_connections=10,
44
- pool_maxsize=10
 
45
  )
46
- session.mount("http://", adapter)
47
  session.mount("https://", adapter)
 
 
48
  return session
49
 
50
  def prepare_tool_files():
@@ -57,49 +69,66 @@ def prepare_tool_files():
57
  json.dump(tools, f, indent=2)
58
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
59
 
60
- def download_with_retry(repo_id, local_dir):
61
- retry_count = 0
62
- custom_session = create_custom_session()
63
 
64
- while retry_count < CONFIG["max_retries"]:
65
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  snapshot_download(
67
  repo_id=repo_id,
68
  local_dir=local_dir,
69
  resume_download=True,
70
  local_dir_use_symlinks=False,
71
  use_auth_token=True,
 
 
72
  session=custom_session
73
  )
 
 
74
  return True
 
75
  except Exception as e:
76
- retry_count += 1
77
- logger.error(f"Attempt {retry_count} failed for {repo_id}: {str(e)}")
78
- if retry_count < CONFIG["max_retries"]:
79
- wait_time = CONFIG["retry_delay"] * retry_count
80
  logger.info(f"Waiting {wait_time} seconds before retry...")
81
  time.sleep(wait_time)
82
- return False
 
 
83
 
84
  def download_model_files():
85
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
86
- logger.info("Downloading model files...")
87
 
88
  # Download main model
89
- logger.info(f"Downloading {CONFIG['model_name']}...")
90
- if not download_with_retry(
91
  CONFIG["model_name"],
92
  os.path.join(CONFIG["local_dir"], CONFIG["model_name"])
93
  ):
94
- raise RuntimeError(f"Failed to download {CONFIG['model_name']} after {CONFIG['max_retries']} attempts")
95
 
96
  # Download RAG model
97
- logger.info(f"Downloading {CONFIG['rag_model_name']}...")
98
- if not download_with_retry(
99
  CONFIG["rag_model_name"],
100
  os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"])
101
  ):
102
- raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']} after {CONFIG['max_retries']} attempts")
103
 
104
  logger.info("All model files downloaded successfully")
105
 
@@ -113,8 +142,7 @@ def load_embeddings(agent):
113
  return
114
  except Exception as e:
115
  logger.error(f"Failed to load embeddings: {e}")
116
- # Fall through to generate new embeddings
117
-
118
  logger.info("Generating tool embeddings...")
119
  try:
120
  tools = agent.tooluniverse.get_all_tools()
@@ -134,26 +162,35 @@ class TxAgentApp:
134
 
135
  def initialize(self):
136
  if self.is_initialized:
137
- return "Already initialized"
138
-
139
  try:
140
- logger.info("Initializing TxAgent...")
141
- self.agent = TxAgent(
142
- CONFIG["model_name"],
143
- CONFIG["rag_model_name"],
144
- tool_files_dict=CONFIG["tool_files"],
145
- force_finish=True,
146
- enable_checker=True,
147
- step_rag_num=10,
148
- seed=100,
149
- additional_default_tools=["DirectResponse", "RequireClarification"]
150
- )
151
- logger.info("Initializing models...")
152
- self.agent.init_model()
153
- logger.info("Loading embeddings...")
154
- load_embeddings(self.agent)
155
- self.is_initialized = True
156
- logger.info("✅ TxAgent initialized successfully")
 
 
 
 
 
 
 
 
 
157
  return "✅ TxAgent initialized successfully"
158
  except Exception as e:
159
  logger.error(f"Initialization failed: {str(e)}")
@@ -161,8 +198,8 @@ class TxAgentApp:
161
 
162
  def chat(self, message, history):
163
  if not self.is_initialized:
164
- return history + [(message, "⚠️ Error: Model not initialized. Please click 'Initialize Model' first.")]
165
-
166
  try:
167
  response = ""
168
  for chunk in self.agent.run_gradio_chat(
@@ -176,77 +213,85 @@ class TxAgentApp:
176
  max_round=30
177
  ):
178
  response += chunk
179
-
180
- return history + [(message, response)]
181
  except Exception as e:
182
  logger.error(f"Chat error: {str(e)}")
183
- return history + [(message, f"Error: {str(e)}")]
184
 
185
  def create_interface():
186
  app = TxAgentApp()
187
- with gr.Blocks(title="TxAgent", css=".gradio-container {max-width: 900px !important}") as demo:
 
 
 
 
 
 
 
188
  gr.Markdown("""
189
- # 🧠 TxAgent: Therapeutic Reasoning AI
190
- ### A specialized AI for clinical decision support and therapeutic reasoning
191
  """)
192
-
 
193
  with gr.Row():
194
  init_btn = gr.Button("Initialize Model", variant="primary")
195
- init_status = gr.Textbox(label="Initialization Status", interactive=False)
196
-
197
- with gr.Row():
198
- with gr.Column(scale=3):
199
- chatbot = gr.Chatbot(height=600, label="Conversation", bubble_full_width=False)
200
- msg = gr.Textbox(label="Your Question", placeholder="Enter your clinical question here...")
201
- submit_btn = gr.Button("Submit", variant="primary")
202
- with gr.Column(scale=1):
203
- gr.Markdown("### Example Questions:")
204
- gr.Examples(
205
- examples=[
206
- "How to adjust Journavx dosage for hepatic impairment?",
207
- "Is Xolremdi safe with Prozac for WHIM syndrome?",
208
- "Warfarin-Amiodarone contraindications?",
209
- "Alternative treatments for EGFR-positive NSCLC?"
210
- ],
211
- inputs=msg,
212
- label="Click to try"
213
- )
214
-
215
  init_btn.click(
216
  fn=app.initialize,
217
- outputs=init_status,
218
- api_name="initialize"
219
  )
 
220
  msg.submit(
221
- fn=app.chat,
222
- inputs=[msg, chatbot],
223
- outputs=chatbot,
224
- api_name="chat"
225
- )
226
- submit_btn.click(
227
  fn=app.chat,
228
  inputs=[msg, chatbot],
229
  outputs=chatbot
230
  )
231
-
 
 
 
 
 
232
  return demo
233
 
234
  if __name__ == "__main__":
235
  try:
236
- logger.info("Preparing tool files...")
 
 
237
  prepare_tool_files()
238
 
239
- logger.info("Downloading model files (if needed)...")
240
  download_model_files()
241
 
242
- logger.info("Launching interface...")
243
  interface = create_interface()
244
  interface.launch(
245
  server_name="0.0.0.0",
246
  server_port=7860,
247
- share=False,
248
- show_error=True
249
  )
250
  except Exception as e:
251
- logger.error(f"Application failed to start: {str(e)}")
252
  raise
 
7
  from huggingface_hub import snapshot_download
8
  from tooluniverse import ToolUniverse
9
  import time
 
10
  from requests.adapters import HTTPAdapter
11
  from requests import Session
12
  from urllib3.util.retry import Retry
13
+ from tqdm import tqdm
14
 
15
  # Configuration
16
  CONFIG = {
 
21
  "tool_files": {
22
  "new_tool": "./data/new_tool.json"
23
  },
24
+ "download_settings": {
25
+ "timeout": 600, # 10 minutes per request
26
+ "max_retries": 5,
27
+ "retry_delay": 30, # seconds between retries
28
+ "chunk_size": 1024 * 1024 * 10, # 10MB chunks
29
+ "max_concurrent": 2 # concurrent downloads
30
+ }
31
  }
32
 
33
  # Logging setup
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
37
+ )
38
  logger = logging.getLogger(__name__)
39
 
40
+ def create_optimized_session():
41
+ """Create a session optimized for large file downloads"""
42
  session = Session()
43
+
44
+ retry_strategy = Retry(
45
+ total=CONFIG["download_settings"]["max_retries"],
46
  backoff_factor=1,
47
+ status_forcelist=[408, 429, 500, 502, 503, 504]
48
  )
49
+
50
  adapter = HTTPAdapter(
51
+ max_retries=retry_strategy,
52
  pool_connections=10,
53
+ pool_maxsize=10,
54
+ pool_block=True
55
  )
56
+
57
  session.mount("https://", adapter)
58
+ session.mount("http://", adapter)
59
+
60
  return session
61
 
62
  def prepare_tool_files():
 
69
  json.dump(tools, f, indent=2)
70
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
71
 
72
+ def download_model_with_progress(repo_id, local_dir):
73
+ custom_session = create_optimized_session()
 
74
 
75
+ for attempt in range(CONFIG["download_settings"]["max_retries"] + 1):
76
  try:
77
+ logger.info(f"Download attempt {attempt + 1} for {repo_id}")
78
+
79
+ # Create progress bar
80
+ progress = tqdm(
81
+ unit="B",
82
+ unit_scale=True,
83
+ unit_divisor=1024,
84
+ miniters=1,
85
+ desc=f"Downloading {repo_id.split('/')[-1]}"
86
+ )
87
+
88
+ def update_progress(monitor):
89
+ progress.update(monitor.bytes_read - progress.n)
90
+
91
  snapshot_download(
92
  repo_id=repo_id,
93
  local_dir=local_dir,
94
  resume_download=True,
95
  local_dir_use_symlinks=False,
96
  use_auth_token=True,
97
+ max_workers=CONFIG["download_settings"]["max_concurrent"],
98
+ tqdm_class=None, # We handle progress ourselves
99
  session=custom_session
100
  )
101
+
102
+ progress.close()
103
  return True
104
+
105
  except Exception as e:
106
+ logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
107
+ if attempt < CONFIG["download_settings"]["max_retries"]:
108
+ wait_time = CONFIG["download_settings"]["retry_delay"] * (attempt + 1)
 
109
  logger.info(f"Waiting {wait_time} seconds before retry...")
110
  time.sleep(wait_time)
111
+ else:
112
+ progress.close()
113
+ return False
114
 
115
  def download_model_files():
116
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
117
+ logger.info("Starting model downloads...")
118
 
119
  # Download main model
120
+ if not download_model_with_progress(
 
121
  CONFIG["model_name"],
122
  os.path.join(CONFIG["local_dir"], CONFIG["model_name"])
123
  ):
124
+ raise RuntimeError(f"Failed to download {CONFIG['model_name']}")
125
 
126
  # Download RAG model
127
+ if not download_model_with_progress(
 
128
  CONFIG["rag_model_name"],
129
  os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"])
130
  ):
131
+ raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']}")
132
 
133
  logger.info("All model files downloaded successfully")
134
 
 
142
  return
143
  except Exception as e:
144
  logger.error(f"Failed to load embeddings: {e}")
145
+
 
146
  logger.info("Generating tool embeddings...")
147
  try:
148
  tools = agent.tooluniverse.get_all_tools()
 
162
 
163
  def initialize(self):
164
  if self.is_initialized:
165
+ return "Already initialized"
166
+
167
  try:
168
+ # Initialize with progress tracking
169
+ with tqdm(total=4, desc="Initializing TxAgent") as pbar:
170
+ logger.info("Creating TxAgent instance...")
171
+ self.agent = TxAgent(
172
+ CONFIG["model_name"],
173
+ CONFIG["rag_model_name"],
174
+ tool_files_dict=CONFIG["tool_files"],
175
+ force_finish=True,
176
+ enable_checker=True,
177
+ step_rag_num=10,
178
+ seed=100,
179
+ additional_default_tools=["DirectResponse", "RequireClarification"]
180
+ )
181
+ pbar.update(1)
182
+
183
+ logger.info("Initializing models...")
184
+ self.agent.init_model()
185
+ pbar.update(1)
186
+
187
+ logger.info("Loading embeddings...")
188
+ load_embeddings(self.agent)
189
+ pbar.update(1)
190
+
191
+ self.is_initialized = True
192
+ pbar.update(1)
193
+
194
  return "✅ TxAgent initialized successfully"
195
  except Exception as e:
196
  logger.error(f"Initialization failed: {str(e)}")
 
198
 
199
  def chat(self, message, history):
200
  if not self.is_initialized:
201
+ return history + [(message, "⚠️ Please initialize the model first")]
202
+
203
  try:
204
  response = ""
205
  for chunk in self.agent.run_gradio_chat(
 
213
  max_round=30
214
  ):
215
  response += chunk
216
+ yield history + [(message, response)]
217
+
218
  except Exception as e:
219
  logger.error(f"Chat error: {str(e)}")
220
+ yield history + [(message, f"Error: {str(e)}")]
221
 
222
  def create_interface():
223
  app = TxAgentApp()
224
+
225
+ with gr.Blocks(
226
+ title="TxAgent",
227
+ css="""
228
+ .gradio-container {max-width: 900px !important}
229
+ .progress-bar {height: 20px !important}
230
+ """
231
+ ) as demo:
232
  gr.Markdown("""
233
+ # TxAgent: Therapeutic Reasoning AI
234
+ ### Specialized for clinical decision support
235
  """)
236
+
237
+ # Initialization section
238
  with gr.Row():
239
  init_btn = gr.Button("Initialize Model", variant="primary")
240
+ init_status = gr.Textbox(label="Status", interactive=False)
241
+ download_progress = gr.Textbox(visible=False)
242
+
243
+ # Chat interface
244
+ chatbot = gr.Chatbot(height=500, label="Conversation")
245
+ msg = gr.Textbox(label="Your clinical question", placeholder="Ask about drug interactions, dosing, etc...")
246
+ clear_btn = gr.Button("Clear Chat")
247
+
248
+ # Examples
249
+ gr.Examples(
250
+ examples=[
251
+ "How to adjust Journavx for renal impairment?",
252
+ "Xolremdi and Prozac interaction in WHIM syndrome?",
253
+ "Alternative to Warfarin for patient with amiodarone?"
254
+ ],
255
+ inputs=msg,
256
+ label="Example Questions"
257
+ )
258
+
259
+ # Event handlers
260
  init_btn.click(
261
  fn=app.initialize,
262
+ outputs=init_status
 
263
  )
264
+
265
  msg.submit(
 
 
 
 
 
 
266
  fn=app.chat,
267
  inputs=[msg, chatbot],
268
  outputs=chatbot
269
  )
270
+
271
+ clear_btn.click(
272
+ fn=lambda: ([], ""),
273
+ outputs=[chatbot, msg]
274
+ )
275
+
276
  return demo
277
 
278
  if __name__ == "__main__":
279
  try:
280
+ logger.info("Starting application setup...")
281
+
282
+ # Prepare files
283
  prepare_tool_files()
284
 
285
+ # Download models with progress tracking
286
  download_model_files()
287
 
288
+ # Launch interface
289
  interface = create_interface()
290
  interface.launch(
291
  server_name="0.0.0.0",
292
  server_port=7860,
293
+ share=False
 
294
  )
295
  except Exception as e:
296
+ logger.error(f"Fatal error: {str(e)}")
297
  raise