Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
#
|
|
|
|
|
2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3 |
from threading import Thread
|
4 |
|
5 |
-
#
|
6 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
-
src_path = os.path.join(current_dir, "src")
|
8 |
-
sys.path.insert(0, src_path)
|
9 |
-
|
10 |
base_dir = "/data"
|
11 |
model_cache_dir = os.path.join(base_dir, "txagent_models")
|
12 |
tool_cache_dir = os.path.join(base_dir, "tool_cache")
|
@@ -17,7 +15,7 @@ vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
|
|
17 |
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
|
18 |
os.makedirs(d, exist_ok=True)
|
19 |
|
20 |
-
#
|
21 |
os.environ.update({
|
22 |
"HF_HOME": model_cache_dir,
|
23 |
"TRANSFORMERS_CACHE": model_cache_dir,
|
@@ -26,13 +24,21 @@ os.environ.update({
|
|
26 |
"CUDA_LAUNCH_BLOCKING": "1"
|
27 |
})
|
28 |
|
29 |
-
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
|
35 |
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
|
|
|
|
|
36 |
|
37 |
def extract_priority_pages(file_path, max_pages=20):
|
38 |
try:
|
@@ -42,7 +48,7 @@ def extract_priority_pages(file_path, max_pages=20):
|
|
42 |
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
|
43 |
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
|
44 |
text = page.extract_text() or ""
|
45 |
-
if any(re.search(rf'
|
46 |
pages.append(f"=== Page {i} ===\n{text.strip()}")
|
47 |
return "\n\n".join(pages)
|
48 |
except Exception as e:
|
@@ -59,43 +65,36 @@ def convert_file_to_json(file_path, file_type):
|
|
59 |
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
60 |
Thread(target=full_pdf_processing, args=(file_path, h)).start()
|
61 |
elif file_type == "csv":
|
62 |
-
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str
|
63 |
-
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna(
|
64 |
elif file_type in ["xls", "xlsx"]:
|
65 |
-
|
66 |
-
|
67 |
-
except:
|
68 |
-
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
69 |
-
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
|
70 |
else:
|
71 |
return json.dumps({"error": f"Unsupported file type: {file_type}"})
|
72 |
|
73 |
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
|
74 |
return result
|
75 |
except Exception as e:
|
76 |
-
return json.dumps({"error":
|
77 |
|
78 |
-
def full_pdf_processing(file_path,
|
79 |
try:
|
80 |
-
cache_path = os.path.join(file_cache_dir, f"{
|
81 |
if os.path.exists(cache_path): return
|
82 |
with pdfplumber.open(file_path) as pdf:
|
83 |
-
full_text = "\n".join([f"=== Page {i+1} ===\n{(
|
84 |
-
|
85 |
-
|
86 |
-
with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out: out.write(full_text)
|
87 |
-
except Exception as e:
|
88 |
-
print("PDF processing error:", e)
|
89 |
|
90 |
def init_agent():
|
91 |
-
default_tool_path = os.path.abspath("data/new_tool.json")
|
92 |
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
93 |
if not os.path.exists(target_tool_path):
|
94 |
-
shutil.copy(
|
95 |
|
96 |
agent = TxAgent(
|
97 |
-
model_name=
|
98 |
-
rag_model_name=
|
99 |
tool_files_dict={"new_tool": target_tool_path},
|
100 |
force_finish=True,
|
101 |
enable_checker=True,
|
@@ -105,82 +104,59 @@ def init_agent():
|
|
105 |
agent.init_model()
|
106 |
return agent
|
107 |
|
108 |
-
# Lazy load
|
109 |
agent_container = {"agent": None}
|
110 |
def get_agent():
|
111 |
if agent_container["agent"] is None:
|
112 |
agent_container["agent"] = init_agent()
|
113 |
return agent_container["agent"]
|
114 |
|
115 |
-
def create_ui(
|
116 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
117 |
-
gr.Markdown("<h1 style='text-align:
|
118 |
-
|
119 |
-
|
120 |
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
121 |
-
msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
|
122 |
send_btn = gr.Button("Analyze", variant="primary")
|
123 |
state = gr.State([])
|
124 |
-
download_output = gr.File(label="Download Report")
|
125 |
|
126 |
def analyze(message, history, conversation, files):
|
127 |
try:
|
128 |
-
|
129 |
if files:
|
130 |
-
with ThreadPoolExecutor(max_workers=
|
131 |
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
|
136 |
-
1. List potential missed diagnoses
|
137 |
-
2. Flag any medication conflicts
|
138 |
-
3. Note incomplete assessments
|
139 |
-
4. Highlight abnormal results needing follow-up
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
144 |
|
|
|
|
|
145 |
final_response = ""
|
146 |
-
for chunk in
|
147 |
-
|
148 |
-
|
149 |
-
temperature=0.2,
|
150 |
-
max_new_tokens=1024,
|
151 |
-
max_token=4096,
|
152 |
-
call_agent=False,
|
153 |
-
conversation=conversation
|
154 |
-
):
|
155 |
-
if isinstance(chunk, str):
|
156 |
-
final_response += chunk
|
157 |
-
elif isinstance(chunk, list):
|
158 |
-
final_response += "".join([c.content for c in chunk if hasattr(c, "content")])
|
159 |
-
|
160 |
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
|
161 |
-
if not cleaned:
|
162 |
-
cleaned = "No oversights found. Consider further review."
|
163 |
-
|
164 |
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
|
165 |
-
|
166 |
-
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value and os.path.exists(os.path.join(report_dir, f"{file_hash_value}_report.txt")) else None
|
167 |
-
yield updated_history, report_path
|
168 |
except Exception as e:
|
169 |
-
|
170 |
-
yield updated_history, None
|
171 |
-
|
172 |
-
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
|
173 |
-
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
|
174 |
|
|
|
|
|
175 |
return demo
|
176 |
|
177 |
if __name__ == "__main__":
|
178 |
-
|
179 |
-
ui = create_ui(get_agent)
|
180 |
ui.queue(api_open=False).launch(
|
181 |
server_name="0.0.0.0",
|
182 |
server_port=7860,
|
183 |
show_error=True,
|
184 |
allowed_paths=["/data/reports"],
|
185 |
share=False
|
186 |
-
)
|
|
|
1 |
+
# ✅ Fully optimized app.py for Hugging Face Space with persistent 150GB storage
|
2 |
+
|
3 |
+
import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
|
4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5 |
from threading import Thread
|
6 |
|
7 |
+
# Use /data for persistent HF storage
|
|
|
|
|
|
|
|
|
8 |
base_dir = "/data"
|
9 |
model_cache_dir = os.path.join(base_dir, "txagent_models")
|
10 |
tool_cache_dir = os.path.join(base_dir, "tool_cache")
|
|
|
15 |
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
|
16 |
os.makedirs(d, exist_ok=True)
|
17 |
|
18 |
+
# Set persistent HF + VLLM cache
|
19 |
os.environ.update({
|
20 |
"HF_HOME": model_cache_dir,
|
21 |
"TRANSFORMERS_CACHE": model_cache_dir,
|
|
|
24 |
"CUDA_LAUNCH_BLOCKING": "1"
|
25 |
})
|
26 |
|
27 |
+
# Force local loading only
|
28 |
+
LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
|
29 |
+
LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")
|
30 |
|
31 |
+
# Manual download using snapshot_download (only if needed)
|
32 |
+
# from huggingface_hub import snapshot_download
|
33 |
+
# snapshot_download("mims-harvard/TxAgent-T1-Llama-3.1-8B", local_dir=LOCAL_TXAGENT_PATH, local_dir_use_symlinks=False)
|
34 |
+
# snapshot_download("mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", local_dir=LOCAL_RAG_PATH, local_dir_use_symlinks=False)
|
35 |
+
|
36 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
|
37 |
+
from txagent.txagent import TxAgent
|
38 |
|
|
|
39 |
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
|
40 |
+
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
|
41 |
+
MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}
|
42 |
|
43 |
def extract_priority_pages(file_path, max_pages=20):
|
44 |
try:
|
|
|
48 |
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
|
49 |
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
|
50 |
text = page.extract_text() or ""
|
51 |
+
if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
|
52 |
pages.append(f"=== Page {i} ===\n{text.strip()}")
|
53 |
return "\n\n".join(pages)
|
54 |
except Exception as e:
|
|
|
65 |
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
66 |
Thread(target=full_pdf_processing, args=(file_path, h)).start()
|
67 |
elif file_type == "csv":
|
68 |
+
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
|
69 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
|
70 |
elif file_type in ["xls", "xlsx"]:
|
71 |
+
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
|
72 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
|
|
|
|
|
|
|
73 |
else:
|
74 |
return json.dumps({"error": f"Unsupported file type: {file_type}"})
|
75 |
|
76 |
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
|
77 |
return result
|
78 |
except Exception as e:
|
79 |
+
return json.dumps({"error": str(e)})
|
80 |
|
81 |
+
def full_pdf_processing(file_path, h):
|
82 |
try:
|
83 |
+
cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
|
84 |
if os.path.exists(cache_path): return
|
85 |
with pdfplumber.open(file_path) as pdf:
|
86 |
+
full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
|
87 |
+
with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
|
88 |
+
except: pass
|
|
|
|
|
|
|
89 |
|
90 |
def init_agent():
|
|
|
91 |
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
92 |
if not os.path.exists(target_tool_path):
|
93 |
+
shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)
|
94 |
|
95 |
agent = TxAgent(
|
96 |
+
model_name=LOCAL_TXAGENT_PATH,
|
97 |
+
rag_model_name=LOCAL_RAG_PATH,
|
98 |
tool_files_dict={"new_tool": target_tool_path},
|
99 |
force_finish=True,
|
100 |
enable_checker=True,
|
|
|
104 |
agent.init_model()
|
105 |
return agent
|
106 |
|
107 |
+
# Lazy load
|
108 |
agent_container = {"agent": None}
|
109 |
def get_agent():
|
110 |
if agent_container["agent"] is None:
|
111 |
agent_container["agent"] = init_agent()
|
112 |
return agent_container["agent"]
|
113 |
|
114 |
+
def create_ui():
|
115 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
116 |
+
gr.Markdown("""<h1 style='text-align:center;'>🩺 Clinical Oversight Assistant</h1>""")
|
117 |
+
chatbot = gr.Chatbot(label="Analysis", height=600)
|
118 |
+
msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
|
119 |
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
|
|
120 |
send_btn = gr.Button("Analyze", variant="primary")
|
121 |
state = gr.State([])
|
|
|
122 |
|
123 |
def analyze(message, history, conversation, files):
|
124 |
try:
|
125 |
+
extracted, hval = "", ""
|
126 |
if files:
|
127 |
+
with ThreadPoolExecutor(max_workers=3) as pool:
|
128 |
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
|
129 |
+
extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
|
130 |
+
hval = file_hash(files[0].name)
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
prompt = f"""Review these medical records and identify exactly what might have been missed:
|
133 |
+
1. Missed diagnoses
|
134 |
+
2. Medication conflicts
|
135 |
+
3. Incomplete assessments
|
136 |
+
4. Abnormal results needing follow-up
|
137 |
|
138 |
+
Medical Records:\n{extracted[:15000]}
|
139 |
+
"""
|
140 |
final_response = ""
|
141 |
+
for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
|
142 |
+
if isinstance(chunk, str): final_response += chunk
|
143 |
+
elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
|
|
|
|
|
|
|
145 |
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
|
146 |
+
return updated_history, None
|
|
|
|
|
147 |
except Exception as e:
|
148 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}], None
|
|
|
|
|
|
|
|
|
149 |
|
150 |
+
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
|
151 |
+
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
|
152 |
return demo
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
+
ui = create_ui()
|
|
|
156 |
ui.queue(api_open=False).launch(
|
157 |
server_name="0.0.0.0",
|
158 |
server_port=7860,
|
159 |
show_error=True,
|
160 |
allowed_paths=["/data/reports"],
|
161 |
share=False
|
162 |
+
)
|