Update app.py
Browse files
app.py
CHANGED
@@ -10,9 +10,10 @@ import re
|
|
10 |
import psutil
|
11 |
import subprocess
|
12 |
from collections import defaultdict
|
|
|
13 |
|
14 |
-
# Persistent directory
|
15 |
-
persistent_dir = "/data/hf_cache"
|
16 |
os.makedirs(persistent_dir, exist_ok=True)
|
17 |
|
18 |
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
|
@@ -89,47 +90,55 @@ def log_system_usage(tag=""):
|
|
89 |
|
90 |
def clean_response(text: str) -> str:
|
91 |
text = sanitize_utf8(text)
|
92 |
-
#
|
93 |
-
text = re.sub(r"\[TOOL_CALLS\].*|(?:get_|tool\s|retrieve\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
|
94 |
text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
|
95 |
-
text = re.sub(
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
97 |
-
# Only keep
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
def consolidate_findings(responses: List[str]) -> str:
|
103 |
-
#
|
104 |
findings = defaultdict(set)
|
105 |
headings = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"]
|
106 |
|
107 |
for response in responses:
|
108 |
if not response:
|
109 |
continue
|
110 |
-
# Split response into sections by heading
|
111 |
current_heading = None
|
112 |
-
current_points = []
|
113 |
for line in response.split("\n"):
|
114 |
line = line.strip()
|
115 |
if not line:
|
116 |
continue
|
117 |
-
if
|
118 |
-
if
|
119 |
-
findings[current_heading].update(current_points)
|
120 |
-
current_heading = next(h for h in headings if line.lower().startswith(h.lower()))
|
121 |
-
current_points = []
|
122 |
elif current_heading and line.startswith("-"):
|
123 |
-
|
124 |
-
if current_heading and current_points:
|
125 |
-
findings[current_heading].update(current_points)
|
126 |
|
127 |
-
# Format
|
128 |
output = []
|
129 |
for heading in headings:
|
130 |
if findings[heading]:
|
131 |
output.append(f"**{heading}**:")
|
132 |
-
output.extend(sorted(findings[heading]))
|
133 |
return "\n".join(output).strip() if output else "No oversights identified."
|
134 |
|
135 |
def init_agent():
|
@@ -143,7 +152,8 @@ def init_agent():
|
|
143 |
step_rag_num=1,
|
144 |
seed=100,
|
145 |
)
|
146 |
-
|
|
|
147 |
log_system_usage("After Load")
|
148 |
print("✅ Agent Ready")
|
149 |
return agent
|
@@ -171,13 +181,14 @@ def create_ui(agent):
|
|
171 |
extracted = "\n".join(results)
|
172 |
file_hash_value = file_hash(files[0].name) if files else ""
|
173 |
|
174 |
-
# Split into
|
175 |
-
chunk_size =
|
176 |
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
177 |
chunk_responses = []
|
|
|
178 |
|
179 |
prompt_template = """
|
180 |
-
|
181 |
|
182 |
**Missed Diagnoses**:
|
183 |
**Medication Conflicts**:
|
@@ -189,35 +200,39 @@ Records:
|
|
189 |
"""
|
190 |
|
191 |
try:
|
192 |
-
# Process
|
193 |
-
for
|
194 |
-
|
195 |
-
|
196 |
-
for
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
final_response = consolidate_findings(chunk_responses)
|
222 |
history[-1]["content"] = final_response
|
223 |
yield history, None
|
|
|
10 |
import psutil
|
11 |
import subprocess
|
12 |
from collections import defaultdict
|
13 |
+
import torch
|
14 |
|
15 |
+
# Persistent directory for Hugging Face Space
|
16 |
+
persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
|
17 |
os.makedirs(persistent_dir, exist_ok=True)
|
18 |
|
19 |
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
|
|
|
90 |
|
91 |
def clean_response(text: str) -> str:
|
92 |
text = sanitize_utf8(text)
|
93 |
+
# Exhaustively remove all unwanted text
|
94 |
+
text = re.sub(r"\[TOOL_CALLS\].*|(?:get_|tool\s|retrieve\s|use\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
|
95 |
text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
|
96 |
+
text = re.sub(
|
97 |
+
r"(?i)(to address|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
|
98 |
+
r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
|
99 |
+
r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
|
100 |
+
r"conflict|assessment|follow-up|issue|reasoning|step).*?\n",
|
101 |
+
"", text, flags=re.DOTALL
|
102 |
+
)
|
103 |
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
104 |
+
# Only keep lines under headings or bullet points
|
105 |
+
lines = []
|
106 |
+
valid_heading = False
|
107 |
+
for line in text.split("\n"):
|
108 |
+
line = line.strip()
|
109 |
+
if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
|
110 |
+
valid_heading = True
|
111 |
+
lines.append(f"**{line[:-1]}**:")
|
112 |
+
elif valid_heading and line.startswith("-"):
|
113 |
+
lines.append(line)
|
114 |
+
else:
|
115 |
+
valid_heading = False
|
116 |
+
return "\n".join(lines).strip()
|
117 |
|
118 |
def consolidate_findings(responses: List[str]) -> str:
|
119 |
+
# Merge findings, keeping only unique points
|
120 |
findings = defaultdict(set)
|
121 |
headings = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"]
|
122 |
|
123 |
for response in responses:
|
124 |
if not response:
|
125 |
continue
|
|
|
126 |
current_heading = None
|
|
|
127 |
for line in response.split("\n"):
|
128 |
line = line.strip()
|
129 |
if not line:
|
130 |
continue
|
131 |
+
if line.lower().startswith(tuple(h.lower() + ":" for h in headings)):
|
132 |
+
current_heading = next(h for h in headings if line.lower().startswith(h.lower() + ":"))
|
|
|
|
|
|
|
133 |
elif current_heading and line.startswith("-"):
|
134 |
+
findings[current_heading].add(line)
|
|
|
|
|
135 |
|
136 |
+
# Format final output
|
137 |
output = []
|
138 |
for heading in headings:
|
139 |
if findings[heading]:
|
140 |
output.append(f"**{heading}**:")
|
141 |
+
output.extend(sorted(findings[heading], key=lambda x: x.lower()))
|
142 |
return "\n".join(output).strip() if output else "No oversights identified."
|
143 |
|
144 |
def init_agent():
|
|
|
152 |
step_rag_num=1,
|
153 |
seed=100,
|
154 |
)
|
155 |
+
# Enable FP16 for A100
|
156 |
+
agent.init_model(dtype=torch.float16)
|
157 |
log_system_usage("After Load")
|
158 |
print("✅ Agent Ready")
|
159 |
return agent
|
|
|
181 |
extracted = "\n".join(results)
|
182 |
file_hash_value = file_hash(files[0].name) if files else ""
|
183 |
|
184 |
+
# Split into tiny chunks of 1,000 characters
|
185 |
+
chunk_size = 1000
|
186 |
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
187 |
chunk_responses = []
|
188 |
+
batch_size = 4 # Process 4 chunks at a time on A100
|
189 |
|
190 |
prompt_template = """
|
191 |
+
Output only oversights under these headings, one brief point each. No tools, reasoning, or extra text.
|
192 |
|
193 |
**Missed Diagnoses**:
|
194 |
**Medication Conflicts**:
|
|
|
200 |
"""
|
201 |
|
202 |
try:
|
203 |
+
# Process chunks in batches
|
204 |
+
for i in range(0, len(chunks), batch_size):
|
205 |
+
batch = chunks[i:i + batch_size]
|
206 |
+
batch_responses = []
|
207 |
+
for chunk in batch:
|
208 |
+
prompt = prompt_template.format(chunk=chunk)
|
209 |
+
chunk_response = ""
|
210 |
+
for output in agent.run_gradio_chat(
|
211 |
+
message=prompt,
|
212 |
+
history=[],
|
213 |
+
temperature=0.1,
|
214 |
+
max_new_tokens=128,
|
215 |
+
max_token=8192,
|
216 |
+
call_agent=False,
|
217 |
+
conversation=[],
|
218 |
+
):
|
219 |
+
if output is None:
|
220 |
+
continue
|
221 |
+
if isinstance(output, list):
|
222 |
+
for m in output:
|
223 |
+
if hasattr(m, 'content') and m.content:
|
224 |
+
cleaned = clean_response(m.content)
|
225 |
+
if cleaned:
|
226 |
+
chunk_response += cleaned + "\n"
|
227 |
+
elif isinstance(output, str) and output.strip():
|
228 |
+
cleaned = clean_response(output)
|
229 |
+
if cleaned:
|
230 |
+
chunk_response += cleaned + "\n"
|
231 |
+
if chunk_response:
|
232 |
+
batch_responses.append(chunk_response)
|
233 |
+
chunk_responses.extend(batch_responses)
|
234 |
+
|
235 |
+
# Consolidate into one final result
|
236 |
final_response = consolidate_findings(chunk_responses)
|
237 |
history[-1]["content"] = final_response
|
238 |
yield history, None
|