Ali2206 commited on
Commit
35da672
·
verified ·
1 Parent(s): 4b5755b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -51
app.py CHANGED
@@ -6,6 +6,8 @@ from txagent import TxAgent
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download, snapshot_download
8
  from tooluniverse import ToolUniverse
 
 
9
 
10
  # Configuration
11
  CONFIG = {
@@ -15,7 +17,9 @@ CONFIG = {
15
  "local_dir": "./models",
16
  "tool_files": {
17
  "new_tool": "./data/new_tool.json"
18
- }
 
 
19
  }
20
 
21
  # Logging setup
@@ -32,42 +36,71 @@ def prepare_tool_files():
32
  json.dump(tools, f, indent=2)
33
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def download_model_files():
36
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
37
- print("Downloading model files...")
38
-
39
- snapshot_download(
40
- repo_id=CONFIG["model_name"],
41
- local_dir=os.path.join(CONFIG["local_dir"], CONFIG["model_name"]),
42
- resume_download=True
43
- )
44
-
45
- snapshot_download(
46
- repo_id=CONFIG["rag_model_name"],
47
- local_dir=os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"]),
48
- resume_download=True
49
- )
50
-
51
- # Skip hf_hub_download for embedding since it's already uploaded manually
52
- print("Embedding file assumed to be pre-uploaded, skipping download.")
53
-
54
- def generate_embeddings(agent):
55
- embedding_path = os.path.join(CONFIG["local_dir"], CONFIG["embedding_filename"])
56
-
 
 
57
  if os.path.exists(embedding_path):
58
- print("Embeddings file already exists")
59
- return
 
 
 
 
 
 
60
 
61
- print("Generating missing tool embeddings...")
62
  try:
63
  tools = agent.tooluniverse.get_all_tools()
64
  descriptions = [tool["description"] for tool in tools]
65
  embeddings = agent.rag_model.generate_embeddings(descriptions)
66
  torch.save(embeddings, embedding_path)
67
  agent.rag_model.tool_desc_embedding = embeddings
68
- print(f"Embeddings saved to {embedding_path}")
69
  except Exception as e:
70
- print(f"Failed to generate embeddings: {e}")
71
  raise
72
 
73
  class TxAgentApp:
@@ -80,6 +113,7 @@ class TxAgentApp:
80
  return "Already initialized"
81
 
82
  try:
 
83
  self.agent = TxAgent(
84
  CONFIG["model_name"],
85
  CONFIG["rag_model_name"],
@@ -90,16 +124,20 @@ class TxAgentApp:
90
  seed=100,
91
  additional_default_tools=["DirectResponse", "RequireClarification"]
92
  )
 
93
  self.agent.init_model()
94
- generate_embeddings(self.agent)
 
95
  self.is_initialized = True
 
96
  return "✅ TxAgent initialized successfully"
97
  except Exception as e:
 
98
  return f"❌ Initialization failed: {str(e)}"
99
 
100
  def chat(self, message, history):
101
  if not self.is_initialized:
102
- return history + [(message, "⚠️ Error: Model not initialized")]
103
 
104
  try:
105
  response = ""
@@ -117,38 +155,74 @@ class TxAgentApp:
117
 
118
  return history + [(message, response)]
119
  except Exception as e:
 
120
  return history + [(message, f"Error: {str(e)}")]
121
 
122
  def create_interface():
123
  app = TxAgentApp()
124
- with gr.Blocks(title="TxAgent") as demo:
125
- gr.Markdown("# 🧠 TxAgent: Therapeutic Reasoning AI")
 
 
 
126
 
127
  with gr.Row():
128
  init_btn = gr.Button("Initialize Model", variant="primary")
129
- init_status = gr.Textbox(label="Initialization Status")
130
-
131
- chatbot = gr.Chatbot(height=600, label="Conversation")
132
- msg = gr.Textbox(label="Your Question")
133
- submit_btn = gr.Button("Submit")
134
-
135
- gr.Examples(
136
- examples=[
137
- "How to adjust Journavx dosage for hepatic impairment?",
138
- "Is Xolremdi safe with Prozac for WHIM syndrome?",
139
- "Warfarin-Amiodarone contraindications?"
140
- ],
141
- inputs=msg
142
- )
143
 
144
- init_btn.click(fn=app.initialize, outputs=init_status)
145
- msg.submit(fn=app.chat, inputs=[msg, chatbot], outputs=chatbot)
146
- submit_btn.click(fn=app.chat, inputs=[msg, chatbot], outputs=chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  return demo
149
 
150
  if __name__ == "__main__":
151
- prepare_tool_files()
152
- download_model_files()
153
- interface = create_interface()
154
- interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download, snapshot_download
8
  from tooluniverse import ToolUniverse
9
+ from tqdm import tqdm
10
+ import time
11
 
12
  # Configuration
13
  CONFIG = {
 
17
  "local_dir": "./models",
18
  "tool_files": {
19
  "new_tool": "./data/new_tool.json"
20
+ },
21
+ "download_timeout": 300, # Increased timeout to 5 minutes
22
+ "max_retries": 3
23
  }
24
 
25
  # Logging setup
 
36
  json.dump(tools, f, indent=2)
37
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
38
 
39
+ def download_with_retry(repo_id, local_dir):
40
+ retry_count = 0
41
+ while retry_count < CONFIG["max_retries"]:
42
+ try:
43
+ snapshot_download(
44
+ repo_id=repo_id,
45
+ local_dir=local_dir,
46
+ resume_download=True,
47
+ local_dir_use_symlinks=False,
48
+ timeout=CONFIG["download_timeout"]
49
+ )
50
+ return True
51
+ except Exception as e:
52
+ retry_count += 1
53
+ logger.error(f"Attempt {retry_count} failed for {repo_id}: {str(e)}")
54
+ if retry_count < CONFIG["max_retries"]:
55
+ wait_time = 10 * retry_count
56
+ logger.info(f"Waiting {wait_time} seconds before retry...")
57
+ time.sleep(wait_time)
58
+ return False
59
+
60
  def download_model_files():
61
  os.makedirs(CONFIG["local_dir"], exist_ok=True)
62
+ logger.info("Downloading model files...")
63
+
64
+ # Download main model
65
+ logger.info(f"Downloading {CONFIG['model_name']}...")
66
+ if not download_with_retry(
67
+ CONFIG["model_name"],
68
+ os.path.join(CONFIG["local_dir"], CONFIG["model_name"])
69
+ ):
70
+ raise RuntimeError(f"Failed to download {CONFIG['model_name']} after {CONFIG['max_retries']} attempts")
71
+
72
+ # Download RAG model
73
+ logger.info(f"Downloading {CONFIG['rag_model_name']}...")
74
+ if not download_with_retry(
75
+ CONFIG["rag_model_name"],
76
+ os.path.join(CONFIG["local_dir"], CONFIG["rag_model_name"])
77
+ ):
78
+ raise RuntimeError(f"Failed to download {CONFIG['rag_model_name']} after {CONFIG['max_retries']} attempts")
79
+
80
+ logger.info("All model files downloaded successfully")
81
+
82
+ def load_embeddings(agent):
83
+ embedding_path = CONFIG["embedding_filename"]
84
  if os.path.exists(embedding_path):
85
+ logger.info(" Loading pre-generated embeddings file")
86
+ try:
87
+ embeddings = torch.load(embedding_path)
88
+ agent.rag_model.tool_desc_embedding = embeddings
89
+ return
90
+ except Exception as e:
91
+ logger.error(f"Failed to load embeddings: {e}")
92
+ # Fall through to generate new embeddings
93
 
94
+ logger.info("Generating tool embeddings...")
95
  try:
96
  tools = agent.tooluniverse.get_all_tools()
97
  descriptions = [tool["description"] for tool in tools]
98
  embeddings = agent.rag_model.generate_embeddings(descriptions)
99
  torch.save(embeddings, embedding_path)
100
  agent.rag_model.tool_desc_embedding = embeddings
101
+ logger.info(f"Embeddings saved to {embedding_path}")
102
  except Exception as e:
103
+ logger.error(f"Failed to generate embeddings: {e}")
104
  raise
105
 
106
  class TxAgentApp:
 
113
  return "Already initialized"
114
 
115
  try:
116
+ logger.info("Initializing TxAgent...")
117
  self.agent = TxAgent(
118
  CONFIG["model_name"],
119
  CONFIG["rag_model_name"],
 
124
  seed=100,
125
  additional_default_tools=["DirectResponse", "RequireClarification"]
126
  )
127
+ logger.info("Initializing models...")
128
  self.agent.init_model()
129
+ logger.info("Loading embeddings...")
130
+ load_embeddings(self.agent)
131
  self.is_initialized = True
132
+ logger.info("✅ TxAgent initialized successfully")
133
  return "✅ TxAgent initialized successfully"
134
  except Exception as e:
135
+ logger.error(f"Initialization failed: {str(e)}")
136
  return f"❌ Initialization failed: {str(e)}"
137
 
138
  def chat(self, message, history):
139
  if not self.is_initialized:
140
+ return history + [(message, "⚠️ Error: Model not initialized. Please click 'Initialize Model' first.")]
141
 
142
  try:
143
  response = ""
 
155
 
156
  return history + [(message, response)]
157
  except Exception as e:
158
+ logger.error(f"Chat error: {str(e)}")
159
  return history + [(message, f"Error: {str(e)}")]
160
 
161
  def create_interface():
162
  app = TxAgentApp()
163
+ with gr.Blocks(title="TxAgent", css=".gradio-container {max-width: 900px !important}") as demo:
164
+ gr.Markdown("""
165
+ # 🧠 TxAgent: Therapeutic Reasoning AI
166
+ ### A specialized AI for clinical decision support and therapeutic reasoning
167
+ """)
168
 
169
  with gr.Row():
170
  init_btn = gr.Button("Initialize Model", variant="primary")
171
+ init_status = gr.Textbox(label="Initialization Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ chatbot = gr.Chatbot(height=600, label="Conversation", bubble_full_width=False)
176
+ msg = gr.Textbox(label="Your Question", placeholder="Enter your clinical question here...")
177
+ submit_btn = gr.Button("Submit", variant="primary")
178
+ with gr.Column(scale=1):
179
+ gr.Markdown("### Example Questions:")
180
+ gr.Examples(
181
+ examples=[
182
+ "How to adjust Journavx dosage for hepatic impairment?",
183
+ "Is Xolremdi safe with Prozac for WHIM syndrome?",
184
+ "Warfarin-Amiodarone contraindications?",
185
+ "Alternative treatments for EGFR-positive NSCLC?"
186
+ ],
187
+ inputs=msg,
188
+ label="Click to try"
189
+ )
190
+
191
+ init_btn.click(
192
+ fn=app.initialize,
193
+ outputs=init_status,
194
+ api_name="initialize"
195
+ )
196
+ msg.submit(
197
+ fn=app.chat,
198
+ inputs=[msg, chatbot],
199
+ outputs=chatbot,
200
+ api_name="chat"
201
+ )
202
+ submit_btn.click(
203
+ fn=app.chat,
204
+ inputs=[msg, chatbot],
205
+ outputs=chatbot
206
+ )
207
 
208
  return demo
209
 
210
  if __name__ == "__main__":
211
+ try:
212
+ logger.info("Preparing tool files...")
213
+ prepare_tool_files()
214
+
215
+ logger.info("Downloading model files (if needed)...")
216
+ download_model_files()
217
+
218
+ logger.info("Launching interface...")
219
+ interface = create_interface()
220
+ interface.launch(
221
+ server_name="0.0.0.0",
222
+ server_port=7860,
223
+ share=False,
224
+ show_error=True
225
+ )
226
+ except Exception as e:
227
+ logger.error(f"Application failed to start: {str(e)}")
228
+ raise