Ali2206 commited on
Commit
13df505
·
verified ·
1 Parent(s): 5707e8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -44
app.py CHANGED
@@ -10,7 +10,7 @@ import re
10
  import psutil
11
  import subprocess
12
  from collections import defaultdict
13
- from vllm import LLM, SamplingParams # MODIFIED: Direct vLLM for batching
14
 
15
  # Persistent directory
16
  persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
@@ -35,7 +35,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
35
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
36
  sys.path.insert(0, src_path)
37
 
38
- from txagent.txagent import TxAgent
39
 
40
  def sanitize_utf8(text: str) -> str:
41
  return text.encode("utf-8", "ignore").decode("utf-8")
@@ -88,31 +88,6 @@ def log_system_usage(tag=""):
88
  except Exception as e:
89
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
90
 
91
- def clean_response(text: str) -> str:
92
- text = sanitize_utf8(text)
93
- text = re.sub(r"\[TOOL_CALLS\].*?\n|\[.*?\].*?\n|(?:get_|tool\s|retrieve\s|use\s|rag\s).*?\n", "", text, flags=re.DOTALL | re.IGNORECASE)
94
- text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
95
- text = re.sub(
96
- r"(?i)(to\s|analyze|will\s|since\s|no\s|none|previous|attempt|involve|check\s|explore|manually|"
97
- r"start|look|use|focus|retrieve|tool|based\s|overall|indicate|mention|consider|ensure|need\s|"
98
- r"provide|review|assess|identify|potential|records|patient|history|symptoms|medication|"
99
- r"conflict|assessment|follow-up|issue|reasoning|step|prompt|address|rag|thought|try|john\sdoe|nkma).*?\n",
100
- "", text, flags=re.DOTALL
101
- )
102
- text = re.sub(r"\n{2,}", "\n", text).strip()
103
- lines = []
104
- valid_heading = False
105
- for line in text.split("\n"):
106
- line = line.strip()
107
- if line.lower() in ["missed diagnoses:", "medication conflicts:", "incomplete assessments:", "urgent follow-up:"]:
108
- valid_heading = True
109
- lines.append(f"**{line[:-1]}**:")
110
- elif valid_heading and line.startswith("-"):
111
- lines.append(line)
112
- else:
113
- valid_heading = False
114
- return "\n".join(lines).strip()
115
-
116
  def normalize_text(text: str) -> str:
117
  return re.sub(r"\s+", " ", text.lower().strip())
118
 
@@ -146,10 +121,11 @@ def init_agent():
146
  log_system_usage("Before Load")
147
  model = LLM(
148
  model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
149
- max_model_len=4096, # MODIFIED: Reduce KV cache
150
  enforce_eager=True,
151
  enable_chunked_prefill=True,
152
  max_num_batched_tokens=8192,
 
153
  )
154
  log_system_usage("After Load")
155
  print("✅ Model Ready")
@@ -178,44 +154,66 @@ def create_ui(model):
178
  extracted = "\n".join([json.loads(r).get("content", "") for r in results if "content" in json.loads(r)])
179
  file_hash_value = file_hash(files[0].name) if files else ""
180
 
181
- chunk_size = 800
182
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
183
  chunk_responses = []
184
- batch_size = 8
185
  total_chunks = len(chunks)
186
 
187
  prompt_template = """
188
- Output only oversights under these headings, one point each. No tools, reasoning, or extra text.
 
 
 
 
 
189
 
190
- **Missed Diagnoses**:
191
- **Medication Conflicts**:
192
- **Incomplete Assessments**:
193
- **Urgent Follow-up**:
194
 
195
- Records:
196
- {chunk}
197
- """
198
  sampling_params = SamplingParams(
199
- temperature=0.1,
200
- max_tokens=32, # MODIFIED: Reduce for speed
201
  seed=100,
202
  )
203
 
204
  try:
 
205
  for i in range(0, len(chunks), batch_size):
206
  batch = chunks[i:i + batch_size]
207
  prompts = [prompt_template.format(chunk=chunk) for chunk in batch]
208
  log_system_usage(f"Batch {i//batch_size + 1}")
209
- outputs = model.generate(prompts, sampling_params) # MODIFIED: Batch inference
210
  batch_responses = []
211
- with ThreadPoolExecutor(max_workers=8) as executor: # MODIFIED: Parallel cleanup
212
  futures = [executor.submit(clean_response, output.outputs[0].text) for output in outputs]
213
  batch_responses.extend(f.result() for f in as_completed(futures))
214
- chunk_responses.extend([r for r in batch_responses if r])
215
  processed = min(i + len(batch), total_chunks)
216
- history[-1]["content"] = f"🔄 Analyzing... ({processed}/{total_chunks} chunks)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  yield history, None
218
 
 
219
  final_response = consolidate_findings(chunk_responses)
220
  history[-1]["content"] = final_response
221
  yield history, None
 
10
  import psutil
11
  import subprocess
12
  from collections import defaultdict
13
+ from vllm import LLM, SamplingParams
14
 
15
  # Persistent directory
16
  persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
 
35
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
36
  sys.path.insert(0, src_path)
37
 
38
+ from txagent.txagent import TxAgent, clean_response # MODIFIED: Import clean_response
39
 
40
  def sanitize_utf8(text: str) -> str:
41
  return text.encode("utf-8", "ignore").decode("utf-8")
 
88
  except Exception as e:
89
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def normalize_text(text: str) -> str:
92
  return re.sub(r"\s+", " ", text.lower().strip())
93
 
 
121
  log_system_usage("Before Load")
122
  model = LLM(
123
  model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
124
+ max_model_len=4096, # MODIFIED: Enforce low VRAM
125
  enforce_eager=True,
126
  enable_chunked_prefill=True,
127
  max_num_batched_tokens=8192,
128
+ gpu_memory_utilization=0.5, # MODIFIED: Limit VRAM
129
  )
130
  log_system_usage("After Load")
131
  print("✅ Model Ready")
 
154
  extracted = "\n".join([json.loads(r).get("content", "") for r in results if "content" in json.loads(r)])
155
  file_hash_value = file_hash(files[0].name) if files else ""
156
 
157
+ chunk_size = 800 # MODIFIED: Enforce correct size
158
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
159
  chunk_responses = []
160
+ batch_size = 4 # MODIFIED: Lower for VRAM
161
  total_chunks = len(chunks)
162
 
163
  prompt_template = """
164
+ Strictly output oversights under these exact headings, one point per line, starting with "-". No other text, reasoning, or tools.
165
+
166
+ **Missed Diagnoses**:
167
+ **Medication Conflicts**:
168
+ **Incomplete Assessments**:
169
+ **Urgent Follow-up**:
170
 
171
+ Records:
172
+ {chunk}
173
+ """ # MODIFIED: Stronger instructions
 
174
 
 
 
 
175
  sampling_params = SamplingParams(
176
+ temperature=0.3, # MODIFIED: Improve output quality
177
+ max_tokens=64, # MODIFIED: Allow full responses
178
  seed=100,
179
  )
180
 
181
  try:
182
+ findings = defaultdict(list) # MODIFIED: Track per batch
183
  for i in range(0, len(chunks), batch_size):
184
  batch = chunks[i:i + batch_size]
185
  prompts = [prompt_template.format(chunk=chunk) for chunk in batch]
186
  log_system_usage(f"Batch {i//batch_size + 1}")
187
+ outputs = model.generate(prompts, sampling_params, use_tqdm=True) # MODIFIED: Stream progress
188
  batch_responses = []
189
+ with ThreadPoolExecutor(max_workers=4) as executor:
190
  futures = [executor.submit(clean_response, output.outputs[0].text) for output in outputs]
191
  batch_responses.extend(f.result() for f in as_completed(futures))
192
+
193
  processed = min(i + len(batch), total_chunks)
194
+ batch_output = []
195
+ for response in batch_responses:
196
+ if response:
197
+ chunk_responses.append(response)
198
+ current_heading = None
199
+ for line in response.split("\n"):
200
+ line = line.strip()
201
+ if line.lower().startswith(tuple(h.lower() + ":" for h in ["missed diagnoses", "medication conflicts", "incomplete assessments", "urgent follow-up"])):
202
+ current_heading = line[:-1]
203
+ if current_heading not in batch_output:
204
+ batch_output.append(current_heading + ":")
205
+ elif current_heading and line.startswith("-"):
206
+ findings[current_heading].append(line)
207
+ batch_output.append(line)
208
+
209
+ # MODIFIED: Stream partial results
210
+ if batch_output:
211
+ history[-1]["content"] = "\n".join(batch_output) + f"\n\n🔄 Processing chunk {processed}/{total_chunks}..."
212
+ else:
213
+ history[-1]["content"] = f"🔄 Processing chunk {processed}/{total_chunks}..."
214
  yield history, None
215
 
216
+ # MODIFIED: Final consolidation
217
  final_response = consolidate_findings(chunk_responses)
218
  history[-1]["content"] = final_response
219
  yield history, None