Ali2206 commited on
Commit
d2dfc7e
·
verified ·
1 Parent(s): 7a596d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -55
app.py CHANGED
@@ -10,9 +10,10 @@ import re
10
  import psutil
11
  import subprocess
12
  from collections import defaultdict
 
13
 
14
- # Persistent directory
15
- persistent_dir = "/data/hf_cache"
16
  os.makedirs(persistent_dir, exist_ok=True)
17
 
18
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
@@ -89,47 +90,55 @@ def log_system_usage(tag=""):
89
 
90
  def clean_response(text: str) -> str:
91
  text = sanitize_utf8(text)
92
- # Remove all tool-related and reasoning text
93
- text = re.sub(r"\[TOOL_CALLS\].*|(?:get_|tool\s|retrieve\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
94
  text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
95
- text = re.sub(r"(?i)(to address|analyze the|will (start|look|use|focus)|since the|no (drug|clinical|information)|none|previous|attempt|involve|check for|explore|manually).*?\n", "", text, flags=re.DOTALL)
 
 
 
 
 
 
96
  text = re.sub(r"\n{3,}", "\n\n", text).strip()
97
- # Only keep text under specific headings
98
- if not re.search(r"^(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", text, re.MULTILINE | re.IGNORECASE):
99
- return ""
100
- return text
 
 
 
 
 
 
 
 
 
101
 
102
  def consolidate_findings(responses: List[str]) -> str:
103
- # Aggregate findings under each heading, removing duplicates
104
  findings = defaultdict(set)
105
  headings = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"]
106
 
107
  for response in responses:
108
  if not response:
109
  continue
110
- # Split response into sections by heading
111
  current_heading = None
112
- current_points = []
113
  for line in response.split("\n"):
114
  line = line.strip()
115
  if not line:
116
  continue
117
- if any(line.lower().startswith(h.lower()) for h in headings):
118
- if current_heading and current_points:
119
- findings[current_heading].update(current_points)
120
- current_heading = next(h for h in headings if line.lower().startswith(h.lower()))
121
- current_points = []
122
  elif current_heading and line.startswith("-"):
123
- current_points.append(line)
124
- if current_heading and current_points:
125
- findings[current_heading].update(current_points)
126
 
127
- # Format consolidated output
128
  output = []
129
  for heading in headings:
130
  if findings[heading]:
131
  output.append(f"**{heading}**:")
132
- output.extend(sorted(findings[heading]))
133
  return "\n".join(output).strip() if output else "No oversights identified."
134
 
135
  def init_agent():
@@ -143,7 +152,8 @@ def init_agent():
143
  step_rag_num=1,
144
  seed=100,
145
  )
146
- agent.init_model()
 
147
  log_system_usage("After Load")
148
  print("✅ Agent Ready")
149
  return agent
@@ -171,13 +181,14 @@ def create_ui(agent):
171
  extracted = "\n".join(results)
172
  file_hash_value = file_hash(files[0].name) if files else ""
173
 
174
- # Split into small chunks of 1,500 characters
175
- chunk_size = 1500
176
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
177
  chunk_responses = []
 
178
 
179
  prompt_template = """
180
- List doctor oversights under these headings only, with one brief point each. No tools or reasoning steps.
181
 
182
  **Missed Diagnoses**:
183
  **Medication Conflicts**:
@@ -189,35 +200,39 @@ Records:
189
  """
190
 
191
  try:
192
- # Process all chunks, collecting responses
193
- for chunk in chunks:
194
- prompt = prompt_template.format(chunk=chunk)
195
- chunk_response = ""
196
- for output in agent.run_gradio_chat(
197
- message=prompt,
198
- history=[],
199
- temperature=0.1,
200
- max_new_tokens=256,
201
- max_token=4096,
202
- call_agent=False,
203
- conversation=[],
204
- ):
205
- if output is None:
206
- continue
207
- if isinstance(output, list):
208
- for m in output:
209
- if hasattr(m, 'content') and m.content:
210
- cleaned = clean_response(m.content)
211
- if cleaned:
212
- chunk_response += cleaned + "\n"
213
- elif isinstance(output, str) and output.strip():
214
- cleaned = clean_response(output)
215
- if cleaned:
216
- chunk_response += cleaned + "\n"
217
- if chunk_response:
218
- chunk_responses.append(chunk_response)
219
-
220
- # Consolidate all responses into one final output
 
 
 
 
221
  final_response = consolidate_findings(chunk_responses)
222
  history[-1]["content"] = final_response
223
  yield history, None
 
10
  import psutil
11
  import subprocess
12
  from collections import defaultdict
13
+ import torch
14
 
15
+ # Persistent directory for Hugging Face Space
16
+ persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
17
  os.makedirs(persistent_dir, exist_ok=True)
18
 
19
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
 
90
 
91
  def clean_response(text: str) -> str:
92
  text = sanitize_utf8(text)
93
+ # Exhaustively remove all unwanted text
94
+ text = re.sub(r"\[TOOL_CALLS\].*|(?:get_|tool\s|retrieve\s|use\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
95
  text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
96
+ text = re.sub(
97
+ r"(?i)(to address|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
98
+ r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
99
+ r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
100
+ r"conflict|assessment|follow-up|issue|reasoning|step).*?\n",
101
+ "", text, flags=re.DOTALL
102
+ )
103
  text = re.sub(r"\n{3,}", "\n\n", text).strip()
104
+ # Only keep lines under headings or bullet points
105
+ lines = []
106
+ valid_heading = False
107
+ for line in text.split("\n"):
108
+ line = line.strip()
109
+ if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
110
+ valid_heading = True
111
+ lines.append(f"**{line[:-1]}**:")
112
+ elif valid_heading and line.startswith("-"):
113
+ lines.append(line)
114
+ else:
115
+ valid_heading = False
116
+ return "\n".join(lines).strip()
117
 
118
  def consolidate_findings(responses: List[str]) -> str:
119
+ # Merge findings, keeping only unique points
120
  findings = defaultdict(set)
121
  headings = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"]
122
 
123
  for response in responses:
124
  if not response:
125
  continue
 
126
  current_heading = None
 
127
  for line in response.split("\n"):
128
  line = line.strip()
129
  if not line:
130
  continue
131
+ if line.lower().startswith(tuple(h.lower() + ":" for h in headings)):
132
+ current_heading = next(h for h in headings if line.lower().startswith(h.lower() + ":"))
 
 
 
133
  elif current_heading and line.startswith("-"):
134
+ findings[current_heading].add(line)
 
 
135
 
136
+ # Format final output
137
  output = []
138
  for heading in headings:
139
  if findings[heading]:
140
  output.append(f"**{heading}**:")
141
+ output.extend(sorted(findings[heading], key=lambda x: x.lower()))
142
  return "\n".join(output).strip() if output else "No oversights identified."
143
 
144
  def init_agent():
 
152
  step_rag_num=1,
153
  seed=100,
154
  )
155
+ # Enable FP16 for A100
156
+ agent.init_model(dtype=torch.float16)
157
  log_system_usage("After Load")
158
  print("✅ Agent Ready")
159
  return agent
 
181
  extracted = "\n".join(results)
182
  file_hash_value = file_hash(files[0].name) if files else ""
183
 
184
+ # Split into tiny chunks of 1,000 characters
185
+ chunk_size = 1000
186
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
187
  chunk_responses = []
188
+ batch_size = 4 # Process 4 chunks at a time on A100
189
 
190
  prompt_template = """
191
+ Output only oversights under these headings, one brief point each. No tools, reasoning, or extra text.
192
 
193
  **Missed Diagnoses**:
194
  **Medication Conflicts**:
 
200
  """
201
 
202
  try:
203
+ # Process chunks in batches
204
+ for i in range(0, len(chunks), batch_size):
205
+ batch = chunks[i:i + batch_size]
206
+ batch_responses = []
207
+ for chunk in batch:
208
+ prompt = prompt_template.format(chunk=chunk)
209
+ chunk_response = ""
210
+ for output in agent.run_gradio_chat(
211
+ message=prompt,
212
+ history=[],
213
+ temperature=0.1,
214
+ max_new_tokens=128,
215
+ max_token=8192,
216
+ call_agent=False,
217
+ conversation=[],
218
+ ):
219
+ if output is None:
220
+ continue
221
+ if isinstance(output, list):
222
+ for m in output:
223
+ if hasattr(m, 'content') and m.content:
224
+ cleaned = clean_response(m.content)
225
+ if cleaned:
226
+ chunk_response += cleaned + "\n"
227
+ elif isinstance(output, str) and output.strip():
228
+ cleaned = clean_response(output)
229
+ if cleaned:
230
+ chunk_response += cleaned + "\n"
231
+ if chunk_response:
232
+ batch_responses.append(chunk_response)
233
+ chunk_responses.extend(batch_responses)
234
+
235
+ # Consolidate into one final result
236
  final_response = consolidate_findings(chunk_responses)
237
  history[-1]["content"] = final_response
238
  yield history, None