Ali2206 commited on
Commit
34a564f
·
verified ·
1 Parent(s): f13efd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -75
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import sys
2
  import os
3
  import pandas as pd
@@ -14,18 +16,13 @@ import subprocess
14
  import traceback
15
  import torch
16
 
17
- # Set VLLM logging level to DEBUG for detailed output
18
  os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
19
-
20
- # If no GPU is available, force CPU usage by hiding CUDA devices
21
  if not torch.cuda.is_available():
22
  print("No GPU detected. Forcing CPU mode by setting CUDA_VISIBLE_DEVICES to an empty string.")
23
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
24
 
25
- # Persistent directory setup
26
  persistent_dir = "/data/hf_cache"
27
  os.makedirs(persistent_dir, exist_ok=True)
28
-
29
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
30
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
31
  file_cache_dir = os.path.join(persistent_dir, "cache")
@@ -35,7 +32,6 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
35
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
36
  os.makedirs(directory, exist_ok=True)
37
 
38
- # Update environment variables
39
  os.environ["HF_HOME"] = model_cache_dir
40
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
41
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
@@ -48,7 +44,6 @@ sys.path.insert(0, src_path)
48
 
49
  from txagent.txagent import TxAgent
50
 
51
- # Medical keywords for processing PDF files
52
  MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
53
  'allergies', 'summary', 'impression', 'findings', 'recommendations'}
54
 
@@ -68,7 +63,7 @@ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
68
  text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
69
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
70
  page_text = page.extract_text() or ""
71
- if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
72
  text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
73
  return "\n\n".join(text_chunks)
74
  except Exception as e:
@@ -129,7 +124,7 @@ def log_system_usage(tag=""):
129
 
130
  def init_agent():
131
  try:
132
- print("🔁 Initializing model...")
133
  log_system_usage("Before Load")
134
  default_tool_path = os.path.abspath("data/new_tool.json")
135
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
@@ -166,33 +161,26 @@ def create_ui(agent):
166
 
167
  def analyze(message: str, history: list, files: list):
168
  try:
169
- # Initialize response with loading message
170
  history.append({"role": "user", "content": message})
171
  history.append({"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."})
172
  yield history, None
173
 
174
- # Process files in parallel
175
  extracted = ""
176
  file_hash_value = ""
177
  if files:
178
  with ThreadPoolExecutor(max_workers=4) as executor:
179
- futures = [
180
- executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
181
- for f in files
182
- ]
183
  results = []
184
  for future in as_completed(futures):
185
  try:
186
- res = future.result()
187
- results.append(sanitize_utf8(res))
188
  except Exception as e:
189
  print("❌ Error in file processing:", str(e))
190
  traceback.print_exc()
191
  extracted = "\n".join(results)
192
  file_hash_value = file_hash(files[0].name)
193
 
194
- # Truncate extracted content to avoid token limit issues
195
- max_content_length = 8000 # Reduced from 12000 to prevent token overflow
196
  prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
197
  1. List potential missed diagnoses
198
  2. Flag any medication conflicts
@@ -205,78 +193,56 @@ Medical Records:
205
  ### Potential Oversights:
206
  """
207
 
208
- print("🔎 Generated prompt:")
209
- print(prompt)
210
-
211
- # Initialize response tracking
212
  full_response = ""
213
  response_chunks = []
214
 
215
- # Process streaming response with error handling
216
- try:
217
- for chunk in agent.run_gradio_chat(
218
- message=prompt,
219
- history=[],
220
- temperature=0.2,
221
- max_new_tokens=2048,
222
- max_token=4096,
223
- call_agent=False,
224
- conversation=[]
225
- ):
226
- try:
227
- if chunk is None:
228
- continue
229
-
230
- # Handle different chunk types
231
- if isinstance(chunk, str):
232
- chunk_content = chunk
233
- elif hasattr(chunk, 'content'):
234
- chunk_content = chunk.content
235
- elif isinstance(chunk, list):
236
- chunk_content = "".join([c.content for c in chunk if hasattr(c, "content") and c.content])
237
- else:
238
- print("DEBUG: Received unknown type chunk", type(chunk))
239
- continue
240
-
241
- if not chunk_content:
242
- continue
243
 
244
- response_chunks.append(chunk_content)
245
- full_response = "".join(response_chunks)
246
 
247
- # Clean the response for display
248
- display_response = full_response.split('[TOOL_CALLS]')[0].strip()
249
- display_response = display_response.replace('[TxAgent]', '').strip()
250
 
251
- # Update the chat history with the latest response
 
 
252
  if len(history) > 0 and history[-1]["role"] == "assistant":
253
  history[-1]["content"] = display_response
254
  else:
255
  history.append({"role": "assistant", "content": display_response})
256
 
257
- yield history, None
258
-
259
- except Exception as e:
260
- print("❌ Error processing chunk:", str(e))
261
- traceback.print_exc()
262
- continue
263
-
264
- except Exception as e:
265
- print("❌ Error in model streaming:", str(e))
266
- traceback.print_exc()
267
- history.append({"role": "assistant", "content": f"Error in model response: {str(e)}"})
268
- yield history, None
269
- return
270
 
271
- # Final response handling
272
  if not full_response:
273
  full_response = "⚠️ No clear oversights identified or model output was invalid."
274
  else:
275
- # Clean up the final response
276
- full_response = full_response.split('[TOOL_CALLS]')[0].strip()
277
  full_response = full_response.replace('[TxAgent]', '').strip()
278
 
279
- # Save report if we have files
280
  report_path = None
281
  if file_hash_value:
282
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
@@ -287,7 +253,6 @@ Medical Records:
287
  print("❌ Error saving report:", str(e))
288
  traceback.print_exc()
289
 
290
- # Ensure the final response is in the history
291
  if len(history) > 0 and history[-1]["role"] == "assistant":
292
  history[-1]["content"] = full_response
293
  else:
 
1
+ # (Full Updated Code Snippet with Proper Final Response Handling)
2
+
3
  import sys
4
  import os
5
  import pandas as pd
 
16
  import traceback
17
  import torch
18
 
 
19
  os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
 
 
20
  if not torch.cuda.is_available():
21
  print("No GPU detected. Forcing CPU mode by setting CUDA_VISIBLE_DEVICES to an empty string.")
22
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
23
 
 
24
  persistent_dir = "/data/hf_cache"
25
  os.makedirs(persistent_dir, exist_ok=True)
 
26
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
27
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
28
  file_cache_dir = os.path.join(persistent_dir, "cache")
 
32
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
33
  os.makedirs(directory, exist_ok=True)
34
 
 
35
  os.environ["HF_HOME"] = model_cache_dir
36
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
37
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
 
44
 
45
  from txagent.txagent import TxAgent
46
 
 
47
  MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
48
  'allergies', 'summary', 'impression', 'findings', 'recommendations'}
49
 
 
63
  text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
64
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
65
  page_text = page.extract_text() or ""
66
+ if any(re.search(rf'\\b{kw}\\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
67
  text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
68
  return "\n\n".join(text_chunks)
69
  except Exception as e:
 
124
 
125
  def init_agent():
126
  try:
127
+ print("\U0001F501 Initializing model...")
128
  log_system_usage("Before Load")
129
  default_tool_path = os.path.abspath("data/new_tool.json")
130
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
 
161
 
162
  def analyze(message: str, history: list, files: list):
163
  try:
 
164
  history.append({"role": "user", "content": message})
165
  history.append({"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."})
166
  yield history, None
167
 
 
168
  extracted = ""
169
  file_hash_value = ""
170
  if files:
171
  with ThreadPoolExecutor(max_workers=4) as executor:
172
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
 
 
 
173
  results = []
174
  for future in as_completed(futures):
175
  try:
176
+ results.append(sanitize_utf8(future.result()))
 
177
  except Exception as e:
178
  print("❌ Error in file processing:", str(e))
179
  traceback.print_exc()
180
  extracted = "\n".join(results)
181
  file_hash_value = file_hash(files[0].name)
182
 
183
+ max_content_length = 8000
 
184
  prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
185
  1. List potential missed diagnoses
186
  2. Flag any medication conflicts
 
193
  ### Potential Oversights:
194
  """
195
 
 
 
 
 
196
  full_response = ""
197
  response_chunks = []
198
 
199
+ for chunk in agent.run_gradio_chat(
200
+ message=prompt,
201
+ history=[],
202
+ temperature=0.2,
203
+ max_new_tokens=2048,
204
+ max_token=4096,
205
+ call_agent=False,
206
+ conversation=[]
207
+ ):
208
+ try:
209
+ chunk_content = ""
210
+ if isinstance(chunk, str):
211
+ chunk_content = chunk
212
+ elif hasattr(chunk, 'content'):
213
+ chunk_content = chunk.content
214
+ elif isinstance(chunk, list):
215
+ chunk_content = "".join([c.content for c in chunk if hasattr(c, "content") and c.content])
216
+
217
+ if not chunk_content:
218
+ continue
 
 
 
 
 
 
 
 
219
 
220
+ response_chunks.append(chunk_content)
221
+ full_response = "".join(response_chunks)
222
 
223
+ display_response = re.split(r"\\[TOOL_CALLS\\].*?$", full_response, flags=re.DOTALL)[0].strip()
224
+ display_response = display_response.replace('[TxAgent]', '').strip()
 
225
 
226
+ if len(history) > 1 and history[-2]["role"] == "assistant" and history[-2]["content"] == display_response:
227
+ pass
228
+ else:
229
  if len(history) > 0 and history[-1]["role"] == "assistant":
230
  history[-1]["content"] = display_response
231
  else:
232
  history.append({"role": "assistant", "content": display_response})
233
 
234
+ yield history, None
235
+ except Exception as e:
236
+ print("❌ Error processing chunk:", str(e))
237
+ traceback.print_exc()
238
+ continue
 
 
 
 
 
 
 
 
239
 
 
240
  if not full_response:
241
  full_response = "⚠️ No clear oversights identified or model output was invalid."
242
  else:
243
+ full_response = re.split(r"\\[TOOL_CALLS\\].*?$", full_response, flags=re.DOTALL)[0].strip()
 
244
  full_response = full_response.replace('[TxAgent]', '').strip()
245
 
 
246
  report_path = None
247
  if file_hash_value:
248
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
 
253
  print("❌ Error saving report:", str(e))
254
  traceback.print_exc()
255
 
 
256
  if len(history) > 0 and history[-1]["role"] == "assistant":
257
  history[-1]["content"] = full_response
258
  else: