Ali2206 commited on
Commit
63950ea
·
verified ·
1 Parent(s): 849209d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -144
app.py CHANGED
@@ -4,30 +4,19 @@ import logging
4
  import torch
5
  from txagent import TxAgent
6
  import gradio as gr
7
- from huggingface_hub import snapshot_download
8
  from tooluniverse import ToolUniverse
9
- import time
10
- from requests.adapters import HTTPAdapter
11
- from requests import Session
12
- from urllib3.util.retry import Retry
13
  from tqdm import tqdm
14
 
15
- # Configuration
16
  CONFIG = {
17
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
18
- "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
19
  "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
20
- "local_dir": "./models",
21
  "tool_files": {
22
  "new_tool": "./data/new_tool.json"
23
  },
24
- "download_settings": {
25
- "timeout": 600, # 10 minutes per request
26
- "max_retries": 5,
27
- "retry_delay": 30, # seconds between retries
28
- "chunk_size": 1024 * 1024 * 10, # 10MB chunks
29
- "max_concurrent": 2 # concurrent downloads
30
- }
31
  }
32
 
33
  # Logging setup
@@ -37,28 +26,6 @@ logging.basicConfig(
37
  )
38
  logger = logging.getLogger(__name__)
39
 
40
- def create_optimized_session():
41
- """Create a session optimized for large file downloads"""
42
- session = Session()
43
-
44
- retry_strategy = Retry(
45
- total=CONFIG["download_settings"]["max_retries"],
46
- backoff_factor=1,
47
- status_forcelist=[408, 429, 500, 502, 503, 504]
48
- )
49
-
50
- adapter = HTTPAdapter(
51
- max_retries=retry_strategy,
52
- pool_connections=10,
53
- pool_maxsize=10,
54
- pool_block=True
55
- )
56
-
57
- session.mount("https://", adapter)
58
- session.mount("http://", adapter)
59
-
60
- return session
61
-
62
  def prepare_tool_files():
63
  os.makedirs("./data", exist_ok=True)
64
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
@@ -69,69 +36,6 @@ def prepare_tool_files():
69
  json.dump(tools, f, indent=2)
70
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
71
 
72
- def download_model_with_progress(repo_id, local_dir):
73
- custom_session = create_optimized_session()
74
-
75
- for attempt in range(CONFIG["download_settings"]["max_retries"] + 1):
76
- try:
77
- logger.info(f"Download attempt {attempt + 1} for {repo_id}")
78
-
79
- # Create progress bar
80
- progress = tqdm(
81
- unit="B",
82
- unit_scale=True,
83
- unit_divisor=1024,
84
- miniters=1,
85
- desc=f"Downloading {repo_id.split('/')[-1]}"
86
- )
87
-
88
- def update_progress(monitor):
89
- progress.update(monitor.bytes_read - progress.n)
90
-
91
- snapshot_download(
92
- repo_id=repo_id,
93
- local_dir=local_dir,
94
- resume_download=True,
95
- local_dir_use_symlinks=False,
96
- use_auth_token=True,
97
- max_workers=CONFIG["download_settings"]["max_concurrent"],
98
- tqdm_class=None, # We handle progress ourselves
99
- session=custom_session
100
- )
101
-
102
- progress.close()
103
- return True
104
-
105
- except Exception as e:
106
- logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
107
- if attempt < CONFIG["download_settings"]["max_retries"]:
108
- wait_time = CONFIG["download_settings"]["retry_delay"] * (attempt + 1)
109
- logger.info(f"Waiting {wait_time} seconds before retry...")
110
- time.sleep(wait_time)
111
- else:
112
- progress.close()
113
- return False
114
-
115
- def download_model_files():
116
- os.makedirs(CONFIG["local_dir"], exist_ok=True)
117
- logger.info("Starting model downloads...")
118
-
119
- # Download main model
120
- if not download_model_with_progress(
121
- CONFIG["model_name"],
122
- os.path.join(CONFIG["local_dir"], CONFIG["model_name"])
123
- ):
124
- raise RuntimeError(f"Failed to download {CONFIG['model_name']}")
125
-
126
- # Download RAG model
127
- if not download_model_with_progress(
128
- CONFIG["rag_model_name"],
129
- os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"])
130
- ):
131
- raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']}")
132
-
133
- logger.info("All model files downloaded successfully")
134
-
135
  def load_embeddings(agent):
136
  embedding_path = CONFIG["embedding_filename"]
137
  if os.path.exists(embedding_path):
@@ -143,7 +47,7 @@ def load_embeddings(agent):
143
  except Exception as e:
144
  logger.error(f"Failed to load embeddings: {e}")
145
 
146
- logger.info("Generating tool embeddings...")
147
  try:
148
  tools = agent.tooluniverse.get_all_tools()
149
  descriptions = [tool["description"] for tool in tools]
@@ -165,33 +69,29 @@ class TxAgentApp:
165
  return "✅ Already initialized"
166
 
167
  try:
168
- # Initialize with progress tracking
169
- with tqdm(total=4, desc="Initializing TxAgent") as pbar:
170
- logger.info("Creating TxAgent instance...")
171
- self.agent = TxAgent(
172
- CONFIG["model_name"],
173
- CONFIG["rag_model_name"],
174
- tool_files_dict=CONFIG["tool_files"],
175
- force_finish=True,
176
- enable_checker=True,
177
- step_rag_num=10,
178
- seed=100,
179
- additional_default_tools=["DirectResponse", "RequireClarification"]
180
- )
181
- pbar.update(1)
182
-
183
- logger.info("Initializing models...")
184
- self.agent.init_model()
185
- pbar.update(1)
186
-
187
- logger.info("Loading embeddings...")
188
- load_embeddings(self.agent)
189
- pbar.update(1)
190
-
191
- self.is_initialized = True
192
- pbar.update(1)
193
-
194
- return "✅ TxAgent initialized successfully"
195
  except Exception as e:
196
  logger.error(f"Initialization failed: {str(e)}")
197
  return f"❌ Initialization failed: {str(e)}"
@@ -226,37 +126,30 @@ def create_interface():
226
  title="TxAgent",
227
  css="""
228
  .gradio-container {max-width: 900px !important}
229
- .progress-bar {height: 20px !important}
230
  """
231
  ) as demo:
232
  gr.Markdown("""
233
- # TxAgent: Therapeutic Reasoning AI
234
- ### Specialized for clinical decision support
235
  """)
236
 
237
- # Initialization section
238
  with gr.Row():
239
  init_btn = gr.Button("Initialize Model", variant="primary")
240
  init_status = gr.Textbox(label="Status", interactive=False)
241
- download_progress = gr.Textbox(visible=False)
242
 
243
- # Chat interface
244
  chatbot = gr.Chatbot(height=500, label="Conversation")
245
- msg = gr.Textbox(label="Your clinical question", placeholder="Ask about drug interactions, dosing, etc...")
246
  clear_btn = gr.Button("Clear Chat")
247
 
248
- # Examples
249
  gr.Examples(
250
  examples=[
251
  "How to adjust Journavx for renal impairment?",
252
  "Xolremdi and Prozac interaction in WHIM syndrome?",
253
  "Alternative to Warfarin for patient with amiodarone?"
254
  ],
255
- inputs=msg,
256
- label="Example Questions"
257
  )
258
 
259
- # Event handlers
260
  init_btn.click(
261
  fn=app.initialize,
262
  outputs=init_status
@@ -277,14 +170,11 @@ def create_interface():
277
 
278
  if __name__ == "__main__":
279
  try:
280
- logger.info("Starting application setup...")
281
 
282
- # Prepare files
283
  prepare_tool_files()
284
 
285
- # Download models with progress tracking
286
- download_model_files()
287
-
288
  # Launch interface
289
  interface = create_interface()
290
  interface.launch(
 
4
  import torch
5
  from txagent import TxAgent
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
  from tooluniverse import ToolUniverse
 
 
 
 
9
  from tqdm import tqdm
10
 
11
+ # Configuration - Now using remote Hugging Face models
12
  CONFIG = {
13
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
14
+ "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
15
  "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
 
16
  "tool_files": {
17
  "new_tool": "./data/new_tool.json"
18
  },
19
+ "load_from_hub": True # Flag to load directly from Hugging Face
 
 
 
 
 
 
20
  }
21
 
22
  # Logging setup
 
26
  )
27
  logger = logging.getLogger(__name__)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def prepare_tool_files():
30
  os.makedirs("./data", exist_ok=True)
31
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
 
36
  json.dump(tools, f, indent=2)
37
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def load_embeddings(agent):
40
  embedding_path = CONFIG["embedding_filename"]
41
  if os.path.exists(embedding_path):
 
47
  except Exception as e:
48
  logger.error(f"Failed to load embeddings: {e}")
49
 
50
+ logger.info("Generating tool embeddings from remote model...")
51
  try:
52
  tools = agent.tooluniverse.get_all_tools()
53
  descriptions = [tool["description"] for tool in tools]
 
69
  return "✅ Already initialized"
70
 
71
  try:
72
+ logger.info("Initializing TxAgent with remote models...")
73
+
74
+ # Initialize with remote models
75
+ self.agent = TxAgent(
76
+ CONFIG["model_name"],
77
+ CONFIG["rag_model_name"],
78
+ tool_files_dict=CONFIG["tool_files"],
79
+ force_finish=True,
80
+ enable_checker=True,
81
+ step_rag_num=10,
82
+ seed=100,
83
+ additional_default_tools=["DirectResponse", "RequireClarification"],
84
+ local_files_only=False # Force loading from Hugging Face Hub
85
+ )
86
+
87
+ logger.info("Loading remote models...")
88
+ self.agent.init_model()
89
+
90
+ logger.info("Preparing embeddings...")
91
+ load_embeddings(self.agent)
92
+
93
+ self.is_initialized = True
94
+ return "✅ TxAgent initialized successfully (using remote models)"
 
 
 
 
95
  except Exception as e:
96
  logger.error(f"Initialization failed: {str(e)}")
97
  return f"❌ Initialization failed: {str(e)}"
 
126
  title="TxAgent",
127
  css="""
128
  .gradio-container {max-width: 900px !important}
 
129
  """
130
  ) as demo:
131
  gr.Markdown("""
132
+ # 🧠 TxAgent: Therapeutic Reasoning AI
133
+ ### (Running with remote Hugging Face models)
134
  """)
135
 
 
136
  with gr.Row():
137
  init_btn = gr.Button("Initialize Model", variant="primary")
138
  init_status = gr.Textbox(label="Status", interactive=False)
 
139
 
 
140
  chatbot = gr.Chatbot(height=500, label="Conversation")
141
+ msg = gr.Textbox(label="Your clinical question")
142
  clear_btn = gr.Button("Clear Chat")
143
 
 
144
  gr.Examples(
145
  examples=[
146
  "How to adjust Journavx for renal impairment?",
147
  "Xolremdi and Prozac interaction in WHIM syndrome?",
148
  "Alternative to Warfarin for patient with amiodarone?"
149
  ],
150
+ inputs=msg
 
151
  )
152
 
 
153
  init_btn.click(
154
  fn=app.initialize,
155
  outputs=init_status
 
170
 
171
  if __name__ == "__main__":
172
  try:
173
+ logger.info("Starting application...")
174
 
175
+ # Prepare local tool files
176
  prepare_tool_files()
177
 
 
 
 
178
  # Launch interface
179
  interface = create_interface()
180
  interface.launch(