Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
import gradio as gr
|
@@ -29,8 +31,7 @@ os.makedirs("/data/vllm_cache", exist_ok=True)
|
|
29 |
# Lazy loading of heavy dependencies
|
30 |
def lazy_load_agent():
|
31 |
from txagent.txagent import TxAgent
|
32 |
-
|
33 |
-
# Initialize agent with optimized settings
|
34 |
agent = TxAgent(
|
35 |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
36 |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
@@ -71,11 +72,11 @@ def process_file(file_path: str, file_type: str) -> str:
|
|
71 |
try:
|
72 |
h = file_hash(file_path)
|
73 |
cache_path = f"/data/file_cache/{h}.json"
|
74 |
-
|
75 |
if os.path.exists(cache_path):
|
76 |
with open(cache_path, "r", encoding="utf-8") as f:
|
77 |
return f.read()
|
78 |
-
|
79 |
if file_type == "pdf":
|
80 |
content = extract_priority_pages(file_path)
|
81 |
result = json.dumps({"filename": os.path.basename(file_path), "content": content})
|
@@ -91,7 +92,7 @@ def process_file(file_path: str, file_type: str) -> str:
|
|
91 |
with open(cache_path, "w", encoding="utf-8") as f:
|
92 |
f.write(result)
|
93 |
return result
|
94 |
-
|
95 |
except Exception as e:
|
96 |
return json.dumps({"error": str(e)})
|
97 |
|
@@ -100,7 +101,7 @@ def format_response(response: str) -> str:
|
|
100 |
if "Based on the medical records provided" in response:
|
101 |
parts = response.split("Based on the medical records provided")
|
102 |
response = "Based on the medical records provided" + parts[-1]
|
103 |
-
|
104 |
replacements = {
|
105 |
"1. **Missed Diagnoses**:": "### π Missed Diagnoses",
|
106 |
"2. **Medication Conflicts**:": "\n### π Medication Conflicts",
|
@@ -108,30 +109,27 @@ def format_response(response: str) -> str:
|
|
108 |
"4. **Abnormal Results Needing Follow-up**:": "\n### β οΈ Abnormal Results Needing Follow-up",
|
109 |
"Overall, the patient's medical records": "\n### π Overall Assessment"
|
110 |
}
|
111 |
-
|
112 |
for old, new in replacements.items():
|
113 |
response = response.replace(old, new)
|
114 |
-
|
115 |
return response
|
116 |
|
117 |
def analyze_files(message: str, history: List, files: List):
|
118 |
try:
|
119 |
-
# Wait for agent to load if not ready
|
120 |
while agent is None:
|
121 |
time.sleep(0.1)
|
122 |
-
|
123 |
-
# Append user message to history in correct format
|
124 |
history.append([message, None])
|
125 |
yield history, None
|
126 |
-
|
127 |
-
# Process files in parallel
|
128 |
extracted_data = ""
|
129 |
if files:
|
130 |
with ThreadPoolExecutor(max_workers=4) as executor:
|
131 |
futures = [executor.submit(process_file, f.name, f.name.split(".")[-1].lower())
|
132 |
-
|
133 |
extracted_data = "\n".join(f.result() for f in as_completed(futures))
|
134 |
-
|
135 |
prompt = f"""Review these medical records:
|
136 |
{extracted_data[:10000]}
|
137 |
|
@@ -142,7 +140,7 @@ Identify:
|
|
142 |
4. Abnormal results needing follow-up
|
143 |
|
144 |
Analysis:"""
|
145 |
-
|
146 |
response = ""
|
147 |
for chunk in agent.run_gradio_chat(
|
148 |
message=prompt,
|
@@ -155,88 +153,45 @@ Analysis:"""
|
|
155 |
response += chunk
|
156 |
elif isinstance(chunk, list):
|
157 |
response += "".join(getattr(c, 'content', '') for c in chunk)
|
158 |
-
|
159 |
formatted = format_response(response)
|
160 |
if formatted.strip():
|
161 |
history[-1][1] = formatted
|
162 |
yield history, None
|
163 |
-
|
164 |
final_output = format_response(response) or "No clear oversights identified."
|
165 |
history[-1][1] = final_output
|
166 |
yield history, None
|
167 |
-
|
168 |
except Exception as e:
|
169 |
history[-1][1] = f"β Error: {str(e)}"
|
170 |
yield history, None
|
171 |
|
172 |
-
#
|
173 |
-
with gr.Blocks(title="Clinical Oversight Assistant"
|
174 |
-
.gradio-container {
|
175 |
-
max-width: 1200px !important;
|
176 |
-
margin: auto;
|
177 |
-
}
|
178 |
-
.container {
|
179 |
-
max-width: 1200px !important;
|
180 |
-
}
|
181 |
-
.chatbot {
|
182 |
-
min-height: 500px;
|
183 |
-
}
|
184 |
-
""") as demo:
|
185 |
gr.Markdown("""
|
186 |
-
<div style='text-align: center;
|
187 |
-
<h1
|
188 |
<p>Upload medical records to analyze for potential oversights in patient care</p>
|
189 |
</div>
|
190 |
""")
|
191 |
-
|
192 |
with gr.Row():
|
193 |
-
with gr.Column(scale=1
|
194 |
-
file_upload = gr.File(
|
195 |
-
|
196 |
-
file_types=[".pdf", ".csv", ".xls", ".xlsx"],
|
197 |
-
file_count="multiple",
|
198 |
-
height=100
|
199 |
-
)
|
200 |
-
query = gr.Textbox(
|
201 |
-
label="Your Query",
|
202 |
-
placeholder="Ask about potential oversights...",
|
203 |
-
lines=3
|
204 |
-
)
|
205 |
submit = gr.Button("Analyze", variant="primary")
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
with gr.Column(scale=2, min_width=600):
|
218 |
-
chatbot = gr.Chatbot(
|
219 |
-
label="Analysis Results",
|
220 |
-
height=600,
|
221 |
-
bubble_full_width=False,
|
222 |
-
show_copy_button=True
|
223 |
-
)
|
224 |
-
|
225 |
-
submit.click(
|
226 |
-
analyze_files,
|
227 |
-
inputs=[query, chatbot, file_upload],
|
228 |
-
outputs=[chatbot, gr.File(visible=False)]
|
229 |
-
)
|
230 |
-
|
231 |
-
query.submit(
|
232 |
-
analyze_files,
|
233 |
-
inputs=[query, chatbot, file_upload],
|
234 |
-
outputs=[chatbot, gr.File(visible=False)]
|
235 |
-
)
|
236 |
|
237 |
if __name__ == "__main__":
|
238 |
-
demo.queue(
|
239 |
-
server_name="0.0.0.0",
|
240 |
-
server_port=7860,
|
241 |
-
show_error=True
|
242 |
-
)
|
|
|
1 |
+
# Optimized app.py with lazy loading and preloading thread, fixed chatbot format and startup error handling
|
2 |
+
|
3 |
import sys
|
4 |
import os
|
5 |
import gradio as gr
|
|
|
31 |
# Lazy loading of heavy dependencies
|
32 |
def lazy_load_agent():
|
33 |
from txagent.txagent import TxAgent
|
34 |
+
|
|
|
35 |
agent = TxAgent(
|
36 |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
37 |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
|
|
72 |
try:
|
73 |
h = file_hash(file_path)
|
74 |
cache_path = f"/data/file_cache/{h}.json"
|
75 |
+
|
76 |
if os.path.exists(cache_path):
|
77 |
with open(cache_path, "r", encoding="utf-8") as f:
|
78 |
return f.read()
|
79 |
+
|
80 |
if file_type == "pdf":
|
81 |
content = extract_priority_pages(file_path)
|
82 |
result = json.dumps({"filename": os.path.basename(file_path), "content": content})
|
|
|
92 |
with open(cache_path, "w", encoding="utf-8") as f:
|
93 |
f.write(result)
|
94 |
return result
|
95 |
+
|
96 |
except Exception as e:
|
97 |
return json.dumps({"error": str(e)})
|
98 |
|
|
|
101 |
if "Based on the medical records provided" in response:
|
102 |
parts = response.split("Based on the medical records provided")
|
103 |
response = "Based on the medical records provided" + parts[-1]
|
104 |
+
|
105 |
replacements = {
|
106 |
"1. **Missed Diagnoses**:": "### π Missed Diagnoses",
|
107 |
"2. **Medication Conflicts**:": "\n### π Medication Conflicts",
|
|
|
109 |
"4. **Abnormal Results Needing Follow-up**:": "\n### β οΈ Abnormal Results Needing Follow-up",
|
110 |
"Overall, the patient's medical records": "\n### π Overall Assessment"
|
111 |
}
|
112 |
+
|
113 |
for old, new in replacements.items():
|
114 |
response = response.replace(old, new)
|
115 |
+
|
116 |
return response
|
117 |
|
118 |
def analyze_files(message: str, history: List, files: List):
|
119 |
try:
|
|
|
120 |
while agent is None:
|
121 |
time.sleep(0.1)
|
122 |
+
|
|
|
123 |
history.append([message, None])
|
124 |
yield history, None
|
125 |
+
|
|
|
126 |
extracted_data = ""
|
127 |
if files:
|
128 |
with ThreadPoolExecutor(max_workers=4) as executor:
|
129 |
futures = [executor.submit(process_file, f.name, f.name.split(".")[-1].lower())
|
130 |
+
for f in files if hasattr(f, 'name')]
|
131 |
extracted_data = "\n".join(f.result() for f in as_completed(futures))
|
132 |
+
|
133 |
prompt = f"""Review these medical records:
|
134 |
{extracted_data[:10000]}
|
135 |
|
|
|
140 |
4. Abnormal results needing follow-up
|
141 |
|
142 |
Analysis:"""
|
143 |
+
|
144 |
response = ""
|
145 |
for chunk in agent.run_gradio_chat(
|
146 |
message=prompt,
|
|
|
153 |
response += chunk
|
154 |
elif isinstance(chunk, list):
|
155 |
response += "".join(getattr(c, 'content', '') for c in chunk)
|
156 |
+
|
157 |
formatted = format_response(response)
|
158 |
if formatted.strip():
|
159 |
history[-1][1] = formatted
|
160 |
yield history, None
|
161 |
+
|
162 |
final_output = format_response(response) or "No clear oversights identified."
|
163 |
history[-1][1] = final_output
|
164 |
yield history, None
|
165 |
+
|
166 |
except Exception as e:
|
167 |
history[-1][1] = f"β Error: {str(e)}"
|
168 |
yield history, None
|
169 |
|
170 |
+
# UI definition
|
171 |
+
with gr.Blocks(title="Clinical Oversight Assistant") as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
gr.Markdown("""
|
173 |
+
<div style='text-align: center;'>
|
174 |
+
<h1>π©Ί Clinical Oversight Assistant</h1>
|
175 |
<p>Upload medical records to analyze for potential oversights in patient care</p>
|
176 |
</div>
|
177 |
""")
|
178 |
+
|
179 |
with gr.Row():
|
180 |
+
with gr.Column(scale=1):
|
181 |
+
file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
182 |
+
query = gr.Textbox(label="Your Query", placeholder="Ask about potential oversights...", lines=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
submit = gr.Button("Analyze", variant="primary")
|
184 |
+
gr.Examples([
|
185 |
+
["What potential diagnoses might have been missed?"],
|
186 |
+
["Are there any medication conflicts I should be aware of?"],
|
187 |
+
["What assessments appear incomplete in these records?"]
|
188 |
+
], inputs=query)
|
189 |
+
|
190 |
+
with gr.Column(scale=2):
|
191 |
+
chatbot = gr.Chatbot(label="Analysis Results", height=600, type="messages")
|
192 |
+
|
193 |
+
submit.click(analyze_files, inputs=[query, chatbot, file_upload], outputs=[chatbot, gr.File(visible=False)])
|
194 |
+
query.submit(analyze_files, inputs=[query, chatbot, file_upload], outputs=[chatbot, gr.File(visible=False)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
|
|
|
|
|
|
|