Ali2206 commited on
Commit
d313543
Β·
verified Β·
1 Parent(s): 455d1f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -48
app.py CHANGED
@@ -10,6 +10,7 @@ import re
10
  import psutil
11
  import subprocess
12
  from collections import defaultdict
 
13
 
14
  # Persistent directory
15
  persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
@@ -143,23 +144,18 @@ def consolidate_findings(responses: List[str]) -> str:
143
  def init_agent():
144
  print("πŸ” Initializing model...")
145
  log_system_usage("Before Load")
146
- agent = TxAgent(
147
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
148
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
149
- force_finish=True,
150
- enable_checker=False,
151
- enable_rag=False,
152
- enable_finish=False, # MODIFIED: Disable Finish tool
153
- tool_files_dict=None,
154
- step_rag_num=0,
155
- seed=100,
156
  )
157
- agent.init_model()
158
  log_system_usage("After Load")
159
- print("βœ… Agent Ready")
160
- return agent
161
 
162
- def create_ui(agent):
163
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
164
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
165
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
@@ -185,7 +181,7 @@ def create_ui(agent):
185
  chunk_size = 800
186
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
187
  chunk_responses = []
188
- batch_size = 8 # MODIFIED: Increase for parallelism
189
  total_chunks = len(chunks)
190
 
191
  prompt_template = """
@@ -199,42 +195,26 @@ Output only oversights under these headings, one point each. No tools, reasoning
199
  Records:
200
  {chunk}
201
  """
 
 
 
 
 
202
 
203
  try:
204
  for i in range(0, len(chunks), batch_size):
205
  batch = chunks[i:i + batch_size]
 
 
 
206
  batch_responses = []
207
- log_system_usage(f"Batch {i//batch_size + 1}") # MODIFIED: Log VRAM
208
- for j, chunk in enumerate(batch):
209
- prompt = prompt_template.format(chunk=chunk)
210
- chunk_response = ""
211
- for output in agent.run_gradio_chat(
212
- message=prompt,
213
- history=[],
214
- temperature=0.1,
215
- max_new_tokens=64, # MODIFIED: Reduce for speed
216
- max_token=4096,
217
- call_agent=False,
218
- conversation=[],
219
- ):
220
- if output is None:
221
- continue
222
- if isinstance(output, list):
223
- for m in output:
224
- if hasattr(m, 'content') and m.content:
225
- cleaned = clean_response(m.content)
226
- if cleaned:
227
- chunk_response += cleaned + "\n"
228
- elif isinstance(output, str) and output.strip():
229
- cleaned = clean_response(output)
230
- if cleaned:
231
- chunk_response += cleaned + "\n"
232
- if chunk_response:
233
- batch_responses.append(chunk_response)
234
- processed = min(i + j + 1, total_chunks)
235
- history[-1]["content"] = f"πŸ”„ Analyzing... ({processed}/{total_chunks} chunks)"
236
- yield history, None
237
- chunk_responses.extend(batch_responses)
238
 
239
  final_response = consolidate_findings(chunk_responses)
240
  history[-1]["content"] = final_response
@@ -257,8 +237,8 @@ Records:
257
 
258
  if __name__ == "__main__":
259
  print("πŸš€ Launching app...")
260
- agent = init_agent()
261
- demo = create_ui(agent)
262
  demo.queue(api_open=False).launch(
263
  server_name="0.0.0.0",
264
  server_port=7860,
 
10
  import psutil
11
  import subprocess
12
  from collections import defaultdict
13
+ from vllm import LLM, SamplingParams # MODIFIED: Direct vLLM for batching
14
 
15
  # Persistent directory
16
  persistent_dir = os.getenv("HF_HOME", "/data/hf_cache")
 
144
  def init_agent():
145
  print("πŸ” Initializing model...")
146
  log_system_usage("Before Load")
147
+ model = LLM(
148
+ model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
149
+ max_model_len=4096, # MODIFIED: Reduce KV cache
150
+ enforce_eager=True,
151
+ enable_chunked_prefill=True,
152
+ max_num_batched_tokens=8192,
 
 
 
 
153
  )
 
154
  log_system_usage("After Load")
155
+ print("βœ… Model Ready")
156
+ return model
157
 
158
+ def create_ui(model):
159
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
161
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
 
181
  chunk_size = 800
182
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
183
  chunk_responses = []
184
+ batch_size = 8
185
  total_chunks = len(chunks)
186
 
187
  prompt_template = """
 
195
  Records:
196
  {chunk}
197
  """
198
+ sampling_params = SamplingParams(
199
+ temperature=0.1,
200
+ max_tokens=32, # MODIFIED: Reduce for speed
201
+ seed=100,
202
+ )
203
 
204
  try:
205
  for i in range(0, len(chunks), batch_size):
206
  batch = chunks[i:i + batch_size]
207
+ prompts = [prompt_template.format(chunk=chunk) for chunk in batch]
208
+ log_system_usage(f"Batch {i//batch_size + 1}")
209
+ outputs = model.generate(prompts, sampling_params) # MODIFIED: Batch inference
210
  batch_responses = []
211
+ with ThreadPoolExecutor(max_workers=8) as executor: # MODIFIED: Parallel cleanup
212
+ futures = [executor.submit(clean_response, output.outputs[0].text) for output in outputs]
213
+ batch_responses.extend(f.result() for f in as_completed(futures))
214
+ chunk_responses.extend([r for r in batch_responses if r])
215
+ processed = min(i + len(batch), total_chunks)
216
+ history[-1]["content"] = f"πŸ”„ Analyzing... ({processed}/{total_chunks} chunks)"
217
+ yield history, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  final_response = consolidate_findings(chunk_responses)
220
  history[-1]["content"] = final_response
 
237
 
238
  if __name__ == "__main__":
239
  print("πŸš€ Launching app...")
240
+ model = init_agent()
241
+ demo = create_ui(model)
242
  demo.queue(api_open=False).launch(
243
  server_name="0.0.0.0",
244
  server_port=7860,