Ali2206 commited on
Commit
a6968c2
·
verified ·
1 Parent(s): 25e2c05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -13
app.py CHANGED
@@ -1,18 +1,284 @@
1
- import os
2
  import sys
 
 
 
 
3
  import gradio as gr
4
- from multiprocessing import freeze_support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- from ui.ui_core import create_ui
7
- from backend.agent_instance import init_agent
 
8
 
9
  if __name__ == "__main__":
10
- freeze_support()
11
- agent = init_agent()
12
- demo = create_ui(agent)
13
- demo.queue().launch(
14
- server_name="0.0.0.0",
15
- server_port=7860,
16
- show_error=True,
17
- share=True
18
- )
 
 
 
 
 
 
 
1
  import sys
2
+ import os
3
+ import pandas as pd
4
+ import pdfplumber
5
+ import json
6
  import gradio as gr
7
+ from typing import List
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ import hashlib
10
+ import shutil
11
+ import re
12
+ import psutil
13
+ import subprocess
14
+ import traceback
15
+ import torch
16
+
17
+ os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
18
+ if not torch.cuda.is_available():
19
+ print("No GPU detected. Forcing CPU mode by setting CUDA_VISIBLE_DEVICES to an empty string.")
20
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
21
+
22
+ persistent_dir = "/data/hf_cache"
23
+ os.makedirs(persistent_dir, exist_ok=True)
24
+ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
25
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
26
+ file_cache_dir = os.path.join(persistent_dir, "cache")
27
+ report_dir = os.path.join(persistent_dir, "reports")
28
+ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
29
+
30
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
31
+ os.makedirs(directory, exist_ok=True)
32
+
33
+ os.environ["HF_HOME"] = model_cache_dir
34
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
+ os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
36
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
37
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
38
+
39
+ current_dir = os.path.dirname(os.path.abspath(__file__))
40
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
41
+ sys.path.insert(0, src_path)
42
+
43
+ from txagent.txagent import TxAgent
44
+
45
+ MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
46
+ 'allergies', 'summary', 'impression', 'findings', 'recommendations'}
47
+
48
+ def sanitize_utf8(text: str) -> str:
49
+ return text.encode("utf-8", "ignore").decode("utf-8")
50
+
51
+ def file_hash(path: str) -> str:
52
+ with open(path, "rb") as f:
53
+ return hashlib.md5(f.read()).hexdigest()
54
+
55
+ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
56
+ try:
57
+ text_chunks = []
58
+ with pdfplumber.open(file_path) as pdf:
59
+ for i, page in enumerate(pdf.pages[:3]):
60
+ text = page.extract_text() or ""
61
+ text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
62
+ for i, page in enumerate(pdf.pages[3:max_pages], start=4):
63
+ page_text = page.extract_text() or ""
64
+ if any(re.search(rf'\\b{kw}\\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
65
+ text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
66
+ return "\n\n".join(text_chunks)
67
+ except Exception as e:
68
+ debug_msg = f"PDF processing error: {str(e)}"
69
+ print(debug_msg)
70
+ traceback.print_exc()
71
+ return debug_msg
72
+
73
+ def convert_file_to_json(file_path: str, file_type: str) -> str:
74
+ try:
75
+ h = file_hash(file_path)
76
+ cache_path = os.path.join(file_cache_dir, f"{h}.json")
77
+ if os.path.exists(cache_path):
78
+ with open(cache_path, "r", encoding="utf-8") as f:
79
+ return f.read()
80
+
81
+ if file_type == "pdf":
82
+ text = extract_priority_pages(file_path)
83
+ result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
84
+ elif file_type == "csv":
85
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
86
+ skip_blank_lines=False, on_bad_lines="skip")
87
+ content = df.fillna("").astype(str).values.tolist()
88
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
89
+ elif file_type in ["xls", "xlsx"]:
90
+ try:
91
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
92
+ except Exception:
93
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
94
+ content = df.fillna("").astype(str).values.tolist()
95
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
96
+ else:
97
+ result = json.dumps({"error": f"Unsupported file type: {file_type}"})
98
+ with open(cache_path, "w", encoding="utf-8") as f:
99
+ f.write(result)
100
+ return result
101
+ except Exception as e:
102
+ error_msg = f"Error processing {os.path.basename(file_path)}: {str(e)}"
103
+ print(error_msg)
104
+ traceback.print_exc()
105
+ return json.dumps({"error": error_msg})
106
+
107
+ def log_system_usage(tag=""):
108
+ try:
109
+ cpu = psutil.cpu_percent(interval=1)
110
+ mem = psutil.virtual_memory()
111
+ print(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
112
+ result = subprocess.run(
113
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
114
+ capture_output=True, text=True
115
+ )
116
+ if result.returncode == 0:
117
+ used, total, util = result.stdout.strip().split(", ")
118
+ print(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
119
+ except Exception as e:
120
+ print(f"[{tag}] GPU/CPU monitor failed: {e}")
121
+ traceback.print_exc()
122
+
123
+ def init_agent():
124
+ try:
125
+ print("\U0001F501 Initializing model...")
126
+ log_system_usage("Before Load")
127
+ default_tool_path = os.path.abspath("data/new_tool.json")
128
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
129
+ if not os.path.exists(target_tool_path):
130
+ shutil.copy(default_tool_path, target_tool_path)
131
+
132
+ agent = TxAgent(
133
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
134
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
135
+ tool_files_dict={"new_tool": target_tool_path},
136
+ force_finish=True,
137
+ enable_checker=True,
138
+ step_rag_num=8,
139
+ seed=100,
140
+ additional_default_tools=[],
141
+ )
142
+ agent.init_model()
143
+ log_system_usage("After Load")
144
+ print("✅ Agent Ready")
145
+ return agent
146
+ except Exception as e:
147
+ print("❌ Error initializing agent:", str(e))
148
+ traceback.print_exc()
149
+ raise e
150
+
151
+ def create_ui(agent):
152
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
153
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
154
+ chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
155
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
156
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
157
+ send_btn = gr.Button("Analyze", variant="primary")
158
+ download_output = gr.File(label="Download Full Report")
159
+
160
+ def analyze(message: str, history: list, files: list):
161
+ try:
162
+ history.append({"role": "user", "content": message})
163
+ history.append({"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."})
164
+ yield history, None
165
+
166
+ extracted = ""
167
+ file_hash_value = ""
168
+ if files:
169
+ with ThreadPoolExecutor(max_workers=4) as executor:
170
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
171
+ results = []
172
+ for future in as_completed(futures):
173
+ try:
174
+ results.append(sanitize_utf8(future.result()))
175
+ except Exception as e:
176
+ print("❌ Error in file processing:", str(e))
177
+ traceback.print_exc()
178
+ extracted = "\n".join(results)
179
+ file_hash_value = file_hash(files[0].name)
180
+
181
+ max_content_length = 8000
182
+ prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
183
+ 1. List potential missed diagnoses
184
+ 2. Flag any medication conflicts
185
+ 3. Note incomplete assessments
186
+ 4. Highlight abnormal results needing follow-up
187
+ Medical Records:
188
+ {extracted[:max_content_length]}
189
+ ### Potential Oversights:
190
+ """
191
+
192
+ full_response = ""
193
+ response_chunks = []
194
+
195
+ for chunk in agent.run_gradio_chat(
196
+ message=prompt,
197
+ history=[],
198
+ temperature=0.2,
199
+ max_new_tokens=2048,
200
+ max_token=4096,
201
+ call_agent=False,
202
+ conversation=[]
203
+ ):
204
+ try:
205
+ chunk_content = ""
206
+ if isinstance(chunk, str):
207
+ chunk_content = chunk
208
+ elif hasattr(chunk, 'content'):
209
+ chunk_content = chunk.content
210
+ elif isinstance(chunk, list):
211
+ chunk_content = "".join([c.content for c in chunk if hasattr(c, "content") and c.content])
212
+
213
+ if not chunk_content:
214
+ continue
215
+
216
+ response_chunks.append(chunk_content)
217
+ full_response = "".join(response_chunks)
218
+
219
+ display_response = re.split(r"\\[TOOL_CALLS\\].*?$", full_response, flags=re.DOTALL)[0].strip()
220
+ display_response = display_response.replace('[TxAgent]', '').strip()
221
+
222
+ if len(history) > 1 and history[-2]["role"] == "assistant" and history[-2]["content"] == display_response:
223
+ pass
224
+ else:
225
+ if len(history) > 0 and history[-1]["role"] == "assistant":
226
+ history[-1]["content"] = display_response
227
+ else:
228
+ history.append({"role": "assistant", "content": display_response})
229
+
230
+ yield history, None
231
+ except Exception as e:
232
+ print("❌ Error processing chunk:", str(e))
233
+ traceback.print_exc()
234
+ continue
235
+
236
+ if not full_response:
237
+ full_response = "⚠️ No clear oversights identified or model output was invalid."
238
+ else:
239
+ full_response = re.split(r"\\[TOOL_CALLS\\].*?$", full_response, flags=re.DOTALL)[0].strip()
240
+ full_response = full_response.replace('[TxAgent]', '').strip()
241
+
242
+ report_path = None
243
+ if file_hash_value:
244
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
245
+ try:
246
+ with open(report_path, "w", encoding="utf-8") as f:
247
+ f.write(full_response)
248
+ except Exception as e:
249
+ print("❌ Error saving report:", str(e))
250
+ traceback.print_exc()
251
+
252
+ if len(history) > 0 and history[-1]["role"] == "assistant":
253
+ history[-1]["content"] = full_response
254
+ else:
255
+ history.append({"role": "assistant", "content": full_response})
256
+
257
+ yield history, report_path if report_path and os.path.exists(report_path) else None
258
+
259
+ except Exception as e:
260
+ error_message = f"❌ An error occurred in analyze: {str(e)}"
261
+ print(error_message)
262
+ traceback.print_exc()
263
+ history.append({"role": "assistant", "content": error_message})
264
+ yield history, None
265
 
266
+ send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
267
+ msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
268
+ return demo
269
 
270
  if __name__ == "__main__":
271
+ try:
272
+ print("🚀 Launching app...")
273
+ agent = init_agent()
274
+ demo = create_ui(agent)
275
+ demo.queue(api_open=False).launch(
276
+ server_name="0.0.0.0",
277
+ server_port=7860,
278
+ show_error=True,
279
+ allowed_paths=[report_dir],
280
+ share=False
281
+ )
282
+ except Exception as e:
283
+ print("❌ Fatal error during launch:", str(e))
284
+ traceback.print_exc()