Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ import re
|
|
10 |
import psutil
|
11 |
import subprocess
|
12 |
from collections import defaultdict
|
13 |
-
from vllm import LLM, SamplingParams
|
14 |
|
15 |
# Persistent directory
|
16 |
persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
|
@@ -35,7 +35,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
35 |
src_path = os.path.abspath(os.path.join(current_dir, "src"))
|
36 |
sys.path.insert(0, src_path)
|
37 |
|
38 |
-
from txagent.txagent import TxAgent
|
39 |
|
40 |
def sanitize_utf8(text: str) -> str:
|
41 |
return text.encode("utf-8", "ignore").decode("utf-8")
|
@@ -88,31 +88,6 @@ def log_system_usage(tag=""):
|
|
88 |
except Exception as e:
|
89 |
print(f"[{tag}] GPU/CPU monitor failed: {e}")
|
90 |
|
91 |
-
def clean_response(text: str) -> str:
|
92 |
-
text = sanitize_utf8(text)
|
93 |
-
text = re.sub(r"\[TOOL_CALLS\].*?\n|\[.*?\].*?\n|(?:get_|tool\s|retrieve\s|use\s|rag\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
|
94 |
-
text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
|
95 |
-
text = re.sub(
|
96 |
-
r"(?i)(to\s|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
|
97 |
-
r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
|
98 |
-
r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
|
99 |
-
r"conflict|assessment|follow-up|issue|reasoning|step|prompt|address|rag|thought|try|john\sdoe|nkma).*?\n",
|
100 |
-
"", text, flags=re.DOTALL
|
101 |
-
)
|
102 |
-
text = re.sub(r"\n{2,}", "\n", text).strip()
|
103 |
-
lines = []
|
104 |
-
valid_heading = False
|
105 |
-
for line in text.split("\n"):
|
106 |
-
line = line.strip()
|
107 |
-
if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
|
108 |
-
valid_heading = True
|
109 |
-
lines.append(f"**{line[:-1]}**:")
|
110 |
-
elif valid_heading and line.startswith("-"):
|
111 |
-
lines.append(line)
|
112 |
-
else:
|
113 |
-
valid_heading = False
|
114 |
-
return "\n".join(lines).strip()
|
115 |
-
|
116 |
def normalize_text(text: str) -> str:
|
117 |
return re.sub(r"\s+", " ", text.lower().strip())
|
118 |
|
@@ -146,10 +121,11 @@ def init_agent():
|
|
146 |
log_system_usage("Before Load")
|
147 |
model = LLM(
|
148 |
model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
149 |
-
max_model_len=4096, # MODIFIED:
|
150 |
enforce_eager=True,
|
151 |
enable_chunked_prefill=True,
|
152 |
max_num_batched_tokens=8192,
|
|
|
153 |
)
|
154 |
log_system_usage("After Load")
|
155 |
print("✅ Model Ready")
|
@@ -178,44 +154,66 @@ def create_ui(model):
|
|
178 |
extracted = "\n".join([json.loads(r).get("content", "") for r in results if "content" in json.loads(r)])
|
179 |
file_hash_value = file_hash(files[0].name) if files else ""
|
180 |
|
181 |
-
chunk_size = 800
|
182 |
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
183 |
chunk_responses = []
|
184 |
-
batch_size =
|
185 |
total_chunks = len(chunks)
|
186 |
|
187 |
prompt_template = """
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
**Urgent Follow-up**:
|
194 |
|
195 |
-
Records:
|
196 |
-
{chunk}
|
197 |
-
"""
|
198 |
sampling_params = SamplingParams(
|
199 |
-
temperature=0.
|
200 |
-
max_tokens=
|
201 |
seed=100,
|
202 |
)
|
203 |
|
204 |
try:
|
|
|
205 |
for i in range(0, len(chunks), batch_size):
|
206 |
batch = chunks[i:i + batch_size]
|
207 |
prompts = [prompt_template.format(chunk=chunk) for chunk in batch]
|
208 |
log_system_usage(f"Batch {i//batch_size + 1}")
|
209 |
-
outputs = model.generate(prompts, sampling_params) # MODIFIED:
|
210 |
batch_responses = []
|
211 |
-
with ThreadPoolExecutor(max_workers=
|
212 |
futures = [executor.submit(clean_response, output.outputs[0].text) for output in outputs]
|
213 |
batch_responses.extend(f.result() for f in as_completed(futures))
|
214 |
-
|
215 |
processed = min(i + len(batch), total_chunks)
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
yield history, None
|
218 |
|
|
|
219 |
final_response = consolidate_findings(chunk_responses)
|
220 |
history[-1]["content"] = final_response
|
221 |
yield history, None
|
|
|
10 |
import psutil
|
11 |
import subprocess
|
12 |
from collections import defaultdict
|
13 |
+
from vllm import LLM, SamplingParams
|
14 |
|
15 |
# Persistent directory
|
16 |
persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
|
|
|
35 |
src_path = os.path.abspath(os.path.join(current_dir, "src"))
|
36 |
sys.path.insert(0, src_path)
|
37 |
|
38 |
+
from txagent.txagent import TxAgent, clean_response # MODIFIED: Import clean_response
|
39 |
|
40 |
def sanitize_utf8(text: str) -> str:
|
41 |
return text.encode("utf-8", "ignore").decode("utf-8")
|
|
|
88 |
except Exception as e:
|
89 |
print(f"[{tag}] GPU/CPU monitor failed: {e}")
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
def normalize_text(text: str) -> str:
|
92 |
return re.sub(r"\s+", " ", text.lower().strip())
|
93 |
|
|
|
121 |
log_system_usage("Before Load")
|
122 |
model = LLM(
|
123 |
model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
124 |
+
max_model_len=4096, # MODIFIED: Enforce low VRAM
|
125 |
enforce_eager=True,
|
126 |
enable_chunked_prefill=True,
|
127 |
max_num_batched_tokens=8192,
|
128 |
+
gpu_memory_utilization=0.5, # MODIFIED: Limit VRAM
|
129 |
)
|
130 |
log_system_usage("After Load")
|
131 |
print("✅ Model Ready")
|
|
|
154 |
extracted = "\n".join([json.loads(r).get("content", "") for r in results if "content" in json.loads(r)])
|
155 |
file_hash_value = file_hash(files[0].name) if files else ""
|
156 |
|
157 |
+
chunk_size = 800 # MODIFIED: Enforce correct size
|
158 |
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
159 |
chunk_responses = []
|
160 |
+
batch_size = 4 # MODIFIED: Lower for VRAM
|
161 |
total_chunks = len(chunks)
|
162 |
|
163 |
prompt_template = """
|
164 |
+
Strictly output oversights under these exact headings, one point per line, starting with "-". No other text, reasoning, or tools.
|
165 |
+
|
166 |
+
**Missed Diagnoses**:
|
167 |
+
**Medication Conflicts**:
|
168 |
+
**Incomplete Assessments**:
|
169 |
+
**Urgent Follow-up**:
|
170 |
|
171 |
+
Records:
|
172 |
+
{chunk}
|
173 |
+
""" # MODIFIED: Stronger instructions
|
|
|
174 |
|
|
|
|
|
|
|
175 |
sampling_params = SamplingParams(
|
176 |
+
temperature=0.3, # MODIFIED: Improve output quality
|
177 |
+
max_tokens=64, # MODIFIED: Allow full responses
|
178 |
seed=100,
|
179 |
)
|
180 |
|
181 |
try:
|
182 |
+
findings = defaultdict(list) # MODIFIED: Track per batch
|
183 |
for i in range(0, len(chunks), batch_size):
|
184 |
batch = chunks[i:i + batch_size]
|
185 |
prompts = [prompt_template.format(chunk=chunk) for chunk in batch]
|
186 |
log_system_usage(f"Batch {i//batch_size + 1}")
|
187 |
+
outputs = model.generate(prompts, sampling_params, use_tqdm=True) # MODIFIED: Stream progress
|
188 |
batch_responses = []
|
189 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
190 |
futures = [executor.submit(clean_response, output.outputs[0].text) for output in outputs]
|
191 |
batch_responses.extend(f.result() for f in as_completed(futures))
|
192 |
+
|
193 |
processed = min(i + len(batch), total_chunks)
|
194 |
+
batch_output = []
|
195 |
+
for response in batch_responses:
|
196 |
+
if response:
|
197 |
+
chunk_responses.append(response)
|
198 |
+
current_heading = None
|
199 |
+
for line in response.split("\n"):
|
200 |
+
line = line.strip()
|
201 |
+
if line.lower().startswith(tuple(h.lower() + ":" for h in ["missed diagnoses", "medication conflicts", "incomplete assessments", "urgent follow-up"])):
|
202 |
+
current_heading = line[:-1]
|
203 |
+
if current_heading not in batch_output:
|
204 |
+
batch_output.append(current_heading + ":")
|
205 |
+
elif current_heading and line.startswith("-"):
|
206 |
+
findings[current_heading].append(line)
|
207 |
+
batch_output.append(line)
|
208 |
+
|
209 |
+
# MODIFIED: Stream partial results
|
210 |
+
if batch_output:
|
211 |
+
history[-1]["content"] = "\n".join(batch_output) + f"\n\n🔄 Processing chunk {processed}/{total_chunks}..."
|
212 |
+
else:
|
213 |
+
history[-1]["content"] = f"🔄 Processing chunk {processed}/{total_chunks}..."
|
214 |
yield history, None
|
215 |
|
216 |
+
# MODIFIED: Final consolidation
|
217 |
final_response = consolidate_findings(chunk_responses)
|
218 |
history[-1]["content"] = final_response
|
219 |
yield history, None
|