Ali2206 commited on
Commit
25e2c05
·
verified ·
1 Parent(s): 9baeed7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -276
app.py CHANGED
@@ -1,281 +1,18 @@
1
- import sys
2
  import os
3
- import pandas as pd
4
- import pdfplumber
5
- import json
6
  import gradio as gr
7
- from typing import List
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
- import hashlib
10
- import shutil
11
- import re
12
- import psutil
13
- import subprocess
14
- import traceback
15
- import torch
16
- import copy
17
- from gradio import ChatMessage
18
-
19
- os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
20
- if not torch.cuda.is_available():
21
- print("No GPU detected. Forcing CPU mode by setting CUDA_VISIBLE_DEVICES to an empty string.")
22
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
23
-
24
- persistent_dir = "/data/hf_cache"
25
- os.makedirs(persistent_dir, exist_ok=True)
26
-
27
- model_cache_dir = os.path.join(persistent_dir, "txagent_models")
28
- tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
29
- file_cache_dir = os.path.join(persistent_dir, "cache")
30
- report_dir = os.path.join(persistent_dir, "reports")
31
- vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
32
-
33
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
34
- os.makedirs(directory, exist_ok=True)
35
-
36
- os.environ["HF_HOME"] = model_cache_dir
37
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
38
- os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
39
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
41
-
42
- current_dir = os.path.dirname(os.path.abspath(__file__))
43
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
44
- sys.path.insert(0, src_path)
45
-
46
- from txagent.txagent import TxAgent
47
-
48
- MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
49
- 'allergies', 'summary', 'impression', 'findings', 'recommendations'}
50
-
51
- def sanitize_utf8(text: str) -> str:
52
- return text.encode("utf-8", "ignore").decode("utf-8")
53
-
54
- def file_hash(path: str) -> str:
55
- with open(path, "rb") as f:
56
- return hashlib.md5(f.read()).hexdigest()
57
-
58
- def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
59
- try:
60
- text_chunks = []
61
- with pdfplumber.open(file_path) as pdf:
62
- for i, page in enumerate(pdf.pages[:3]):
63
- text = page.extract_text() or ""
64
- text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
65
- for i, page in enumerate(pdf.pages[3:max_pages], start=4):
66
- page_text = page.extract_text() or ""
67
- if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
68
- text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
69
- return "\n\n".join(text_chunks)
70
- except Exception as e:
71
- print("PDF processing error:", str(e))
72
- traceback.print_exc()
73
- return str(e)
74
-
75
- def convert_file_to_json(file_path: str, file_type: str) -> str:
76
- try:
77
- h = file_hash(file_path)
78
- cache_path = os.path.join(file_cache_dir, f"{h}.json")
79
- if os.path.exists(cache_path):
80
- with open(cache_path, "r", encoding="utf-8") as f:
81
- return f.read()
82
-
83
- if file_type == "pdf":
84
- text = extract_priority_pages(file_path)
85
- result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
86
- elif file_type == "csv":
87
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
88
- content = df.fillna("").astype(str).values.tolist()
89
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
90
- elif file_type in ["xls", "xlsx"]:
91
- try:
92
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
93
- except Exception:
94
- df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
95
- content = df.fillna("").astype(str).values.tolist()
96
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
97
- else:
98
- result = json.dumps({"error": f"Unsupported file type: {file_type}"})
99
-
100
- with open(cache_path, "w", encoding="utf-8") as f:
101
- f.write(result)
102
- return result
103
- except Exception as e:
104
- print("Error processing", file_path, str(e))
105
- traceback.print_exc()
106
- return json.dumps({"error": str(e)})
107
-
108
- def log_system_usage(tag=""):
109
- try:
110
- cpu = psutil.cpu_percent(interval=1)
111
- mem = psutil.virtual_memory()
112
- print(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
113
- result = subprocess.run(
114
- ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
115
- capture_output=True, text=True
116
- )
117
- if result.returncode == 0:
118
- used, total, util = result.stdout.strip().split(", ")
119
- print(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
120
- except Exception as e:
121
- print(f"[{tag}] GPU/CPU monitor failed: {e}")
122
- traceback.print_exc()
123
-
124
- def init_agent():
125
- try:
126
- print("🔁 Initializing model...")
127
- log_system_usage("Before Load")
128
- default_tool_path = os.path.abspath("data/new_tool.json")
129
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
130
- if not os.path.exists(target_tool_path):
131
- shutil.copy(default_tool_path, target_tool_path)
132
-
133
- agent = TxAgent(
134
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
135
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
136
- tool_files_dict={"new_tool": target_tool_path},
137
- enable_finish=True,
138
- enable_rag=True,
139
- enable_summary=False,
140
- init_rag_num=0,
141
- step_rag_num=8,
142
- summary_mode='step',
143
- summary_skip_last_k=0,
144
- summary_context_length=None,
145
- force_finish=True,
146
- avoid_repeat=True,
147
- seed=100,
148
- enable_checker=True,
149
- enable_chat=False,
150
- additional_default_tools=[],
151
- )
152
- agent.init_model()
153
- log_system_usage("After Load")
154
- print("✅ Agent Ready")
155
- return agent
156
- except Exception as e:
157
- print("❌ Error initializing agent:", str(e))
158
- traceback.print_exc()
159
- raise e
160
-
161
- def create_ui(agent):
162
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
163
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
164
- chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
165
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
166
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
167
- send_btn = gr.Button("Analyze", variant="primary")
168
- download_output = gr.File(label="Download Full Report")
169
-
170
- def analyze(message: str, history: List[List[str]], files: list):
171
- try:
172
- # Initialize with loading message
173
- history.append([message, None])
174
- yield history, None
175
-
176
- extracted = ""
177
- file_hash_value = ""
178
- if files:
179
- with ThreadPoolExecutor(max_workers=4) as executor:
180
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
181
- results = []
182
- for future in as_completed(futures):
183
- try:
184
- res = future.result()
185
- results.append(sanitize_utf8(res))
186
- except Exception as e:
187
- print("❌ Error in file processing:", str(e))
188
- traceback.print_exc()
189
- extracted = "\n".join(results)
190
- file_hash_value = file_hash(files[0].name)
191
-
192
- prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
193
- 1. List potential missed diagnoses
194
- 2. Flag any medication conflicts
195
- 3. Note incomplete assessments
196
- 4. Highlight abnormal results needing follow-up
197
-
198
- Medical Records:
199
- {extracted[:8000]}
200
-
201
- ### Potential Oversights:
202
- """
203
- print("🔎 Generated prompt:")
204
- print(prompt)
205
-
206
- # Remove loading message before streaming actual response
207
- history[-1][1] = "⏳ Analyzing records..."
208
- yield history, None
209
-
210
- # Initialize conversation state
211
- conversation = []
212
- full_response = ""
213
-
214
- # Stream responses from the agent
215
- for chunk in agent.run_gradio_chat(
216
- message=prompt,
217
- history=[],
218
- temperature=0.2,
219
- max_new_tokens=2048,
220
- max_token=4096,
221
- call_agent=False,
222
- conversation=conversation
223
- ):
224
- if isinstance(chunk, str):
225
- # Update the last message in history with the new chunk
226
- full_response = chunk
227
- history[-1][1] = full_response
228
- yield history, None
229
- elif isinstance(chunk, list):
230
- # Handle tool calls or other structured responses
231
- for item in chunk:
232
- if isinstance(item, ChatMessage):
233
- # Add tool call messages to history
234
- if item.role == "assistant":
235
- history.append([None, item.content])
236
- else:
237
- history.append([None, f"⚒️ {item.content}"])
238
- yield history, None
239
-
240
- # Final cleanup and report generation
241
- full_response = full_response.replace('[TxAgent]', '').strip()
242
- full_response = re.sub(r"\[TOOL_CALLS\].*?\n*", "", full_response, flags=re.DOTALL).strip()
243
-
244
- # Update the final response
245
- history[-1][1] = full_response
246
-
247
- # Generate report file if we have files
248
- report_path = None
249
- if file_hash_value:
250
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
251
- with open(report_path, "w", encoding="utf-8") as f:
252
- f.write(full_response)
253
-
254
- yield history, report_path if report_path and os.path.exists(report_path) else None
255
-
256
- except Exception as e:
257
- error_msg = f"❌ An error occurred: {str(e)}"
258
- print(error_msg)
259
- traceback.print_exc()
260
- history[-1][1] = error_msg
261
- yield history, None
262
 
263
- send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
264
- msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
265
- return demo
266
 
267
  if __name__ == "__main__":
268
- try:
269
- print("🚀 Launching app...")
270
- agent = init_agent()
271
- demo = create_ui(agent)
272
- demo.queue(api_open=False).launch(
273
- server_name="0.0.0.0",
274
- server_port=7860,
275
- show_error=True,
276
- allowed_paths=[report_dir],
277
- share=False
278
- )
279
- except Exception as e:
280
- print("❌ Fatal error during launch:", str(e))
281
- traceback.print_exc()
 
 
1
  import os
2
+ import sys
 
 
3
  import gradio as gr
4
+ from multiprocessing import freeze_support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from ui.ui_core import create_ui
7
+ from backend.agent_instance import init_agent
 
8
 
9
  if __name__ == "__main__":
10
+ freeze_support()
11
+ agent = init_agent()
12
+ demo = create_ui(agent)
13
+ demo.queue().launch(
14
+ server_name="0.0.0.0",
15
+ server_port=7860,
16
+ show_error=True,
17
+ share=True
18
+ )