Ali2206 commited on
Commit
499e72e
·
verified ·
1 Parent(s): 0ea3469

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -254
app.py CHANGED
@@ -2,34 +2,26 @@ import sys
2
  import os
3
  import pandas as pd
4
  import pdfplumber
 
5
  import gradio as gr
6
- from typing import List, Dict
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
  import hashlib
9
  import shutil
10
  import re
 
 
11
  import logging
12
  import torch
13
  import gc
14
  from diskcache import Cache
15
- from transformers import AutoTokenizer
16
- from functools import lru_cache
17
- from difflib import SequenceMatcher
18
 
19
  # Configure logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
- # Constants
24
- MAX_TOKENS = 1800
25
- BATCH_SIZE = 1
26
- MAX_WORKERS = 2
27
- CHUNK_SIZE = 5
28
- MODEL_MAX_TOKENS = 131072
29
- MAX_TEXT_LENGTH = 500000
30
- MAX_ROWS_TO_PROCESS = 1000 # Limit for Excel/CSV rows
31
-
32
- # Persistent directory setup
33
  persistent_dir = "/data/hf_cache"
34
  os.makedirs(persistent_dir, exist_ok=True)
35
 
@@ -37,12 +29,16 @@ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
37
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
38
  file_cache_dir = os.path.join(persistent_dir, "cache")
39
  report_dir = os.path.join(persistent_dir, "reports")
40
- os.makedirs(report_dir, exist_ok=True)
 
 
 
41
 
42
- os.environ.update({
43
- "HF_HOME": model_cache_dir,
44
- "TOKENIZERS_PARALLELISM": "false",
45
- })
 
46
 
47
  current_dir = os.path.dirname(os.path.abspath(__file__))
48
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -50,291 +46,295 @@ sys.path.insert(0, src_path)
50
 
51
  from txagent.txagent import TxAgent
52
 
53
- # Initialize cache
54
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
55
 
56
- @lru_cache(maxsize=1)
57
- def get_tokenizer():
58
- return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
59
-
60
  def sanitize_utf8(text: str) -> str:
61
  return text.encode("utf-8", "ignore").decode("utf-8")
62
 
63
  def file_hash(path: str) -> str:
64
- hash_md5 = hashlib.md5()
65
  with open(path, "rb") as f:
66
- for chunk in iter(lambda: f.read(4096), b""):
67
- hash_md5.update(chunk)
68
- return hash_md5.hexdigest()
69
-
70
- def extract_pdf_page(page, tokenizer, max_tokens=MAX_TOKENS) -> List[str]:
71
- try:
72
- text = page.extract_text() or ""
73
- text = sanitize_utf8(text)
74
- if len(text) > MAX_TEXT_LENGTH // 10:
75
- text = text[:MAX_TEXT_LENGTH // 10]
76
-
77
- tokens = tokenizer.encode(text, add_special_tokens=False)
78
- if len(tokens) > max_tokens:
79
- chunks = []
80
- current_chunk = []
81
- current_length = 0
82
- for token in tokens:
83
- if current_length + 1 > max_tokens:
84
- chunks.append(tokenizer.decode(current_chunk))
85
- current_chunk = [token]
86
- current_length = 1
87
- else:
88
- current_chunk.append(token)
89
- current_length += 1
90
- if current_chunk:
91
- chunks.append(tokenizer.decode(current_chunk))
92
- return chunks
93
- return [text]
94
- except Exception as e:
95
- logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
96
- return []
97
 
98
- def extract_all_pages(file_path: str) -> List[str]:
99
  try:
100
- tokenizer = get_tokenizer()
101
  with pdfplumber.open(file_path) as pdf:
102
  total_pages = len(pdf.pages)
103
  if total_pages == 0:
104
- return ["PDF appears to be empty"]
105
-
106
- results = []
107
- for i in range(0, min(total_pages, 50)): # Limit to first 50 pages
108
- try:
109
- page = pdf.pages[i]
110
- chunks = extract_pdf_page(page, tokenizer)
111
- for chunk in chunks:
112
- results.append(f"=== Page {i+1} ===\n{chunk}")
113
- except Exception as e:
114
- logger.warning(f"Error processing page {i+1}: {str(e)}")
115
- continue
116
-
117
- return results if results else ["Could not extract text from PDF"]
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
- logger.error(f"PDF processing error: {e}")
120
- return [f"PDF processing error: {str(e)}"]
121
-
122
- def excel_to_json(file_path: str) -> List[Dict]:
123
- engines = ['openpyxl', 'xlrd']
124
- for engine in engines:
125
- try:
126
- with pd.ExcelFile(file_path, engine=engine) as excel_file:
127
- sheets = excel_file.sheet_names
128
- if not sheets:
129
- return [{"error": "No sheets found"}]
130
-
131
- results = []
132
- for sheet_name in sheets[:3]: # Limit to first 3 sheets
133
- try:
134
- df = pd.read_excel(
135
- excel_file,
136
- sheet_name=sheet_name,
137
- header=None,
138
- dtype=str,
139
- na_filter=False,
140
- nrows=MAX_ROWS_TO_PROCESS # Limit rows
141
- )
142
- if not df.empty:
143
- rows = df.head(MAX_ROWS_TO_PROCESS).values.tolist()
144
- results.append({
145
- "filename": os.path.basename(file_path),
146
- "sheet": sheet_name,
147
- "rows": rows,
148
- "type": "excel"
149
- })
150
- except Exception as e:
151
- logger.warning(f"Error processing sheet {sheet_name}: {str(e)}")
152
- continue
153
-
154
- return results if results else [{"error": "No readable data found"}]
155
- except Exception as e:
156
- logger.warning(f"Excel engine {engine} failed: {str(e)}")
157
- continue
158
-
159
- return [{"error": "Could not process Excel file with any engine"}]
160
 
161
- def csv_to_json(file_path: str) -> List[Dict]:
162
  try:
163
- df = pd.read_csv(
164
- file_path,
165
- header=None,
166
- dtype=str,
167
- encoding_errors='replace',
168
- on_bad_lines='skip',
169
- nrows=MAX_ROWS_TO_PROCESS # Limit rows
170
- )
171
- if df.empty:
172
- return [{"error": "CSV file is empty"}]
173
-
174
- return [{
175
- "filename": os.path.basename(file_path),
176
- "rows": df.values.tolist(),
177
- "type": "csv"
178
- }]
179
- except Exception as e:
180
- logger.error(f"CSV processing error: {e}")
181
- return [{"error": f"CSV processing error: {str(e)}"}]
182
 
183
- def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
184
- try:
185
- logger.info(f"Processing {file_type} file: {os.path.basename(file_path)}")
186
-
187
  if file_type == "pdf":
188
- chunks = extract_all_pages(file_path)
189
- return [{
190
- "filename": os.path.basename(file_path),
191
- "content": chunk,
192
- "type": "pdf"
193
- } for chunk in chunks]
194
-
195
- elif file_type in ["xls", "xlsx"]:
196
- return excel_to_json(file_path)
197
-
198
  elif file_type == "csv":
199
- return csv_to_json(file_path)
200
-
201
- return [{"error": f"Unsupported file type: {file_type}"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  except Exception as e:
203
- logger.error(f"Error processing file: {e}")
204
- return [{"error": f"Error processing file: {str(e)}"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def clean_response(text: str) -> str:
207
- if not text:
208
- return ""
209
-
210
- patterns = [
211
- (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""),
212
- (re.compile(r"\s+"), " "),
213
- ]
214
-
215
- for pattern, repl in patterns:
216
- text = pattern.sub(repl, text)
217
-
218
- return text.strip()
219
-
220
- @lru_cache(maxsize=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def init_agent():
222
  logger.info("Initializing model...")
223
-
 
 
 
 
 
224
  agent = TxAgent(
225
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
226
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
227
- tool_files_dict={"new_tool": os.path.join(tool_cache_dir, "new_tool.json")},
228
  force_finish=True,
229
  enable_checker=False,
230
  step_rag_num=4,
231
  seed=100,
 
232
  )
233
  agent.init_model()
 
234
  logger.info("Agent Ready")
235
  return agent
236
 
237
  def create_ui(agent):
238
- PROMPT_TEMPLATE = """
239
- Analyze this patient record excerpt for missed diagnoses (limit response to 500 tokens):
 
 
 
 
 
 
 
 
 
 
 
240
  {chunk}
241
  """
242
 
243
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
244
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
245
-
246
- with gr.Row():
247
- with gr.Column(scale=3):
248
- chatbot = gr.Chatbot(label="Analysis", height=500, type="messages")
249
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
250
- send_btn = gr.Button("Analyze", variant="primary")
251
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="single")
252
-
253
- with gr.Column(scale=1):
254
- final_summary = gr.Markdown("## Summary")
255
- status = gr.Textbox(label="Status", interactive=False)
256
-
257
- def analyze(message: str, history: List[Dict], file_obj) -> tuple:
 
 
 
 
 
 
 
 
 
 
258
  try:
259
- if not file_obj:
260
- return history, "Please upload a file first", "No file uploaded"
261
-
262
- file_path = file_obj.name
263
- file_type = os.path.splitext(file_path)[-1].lower().replace(".", "")
264
- history.append({"role": "user", "content": message})
265
-
266
- # Process file
267
- processed = process_file_cached(file_path, file_type)
268
- if "error" in processed[0]:
269
- history.append({"role": "assistant", "content": processed[0]["error"]})
270
- return history, processed[0]["error"], "File processing failed"
271
-
272
- # Prepare chunks
273
- chunks = []
274
- for item in processed:
275
- if "content" in item:
276
- chunks.append(item["content"])
277
- elif "rows" in item:
278
- rows_text = "\n".join([", ".join(map(str, row)) for row in item["rows"][:100]])
279
- chunks.append(f"=== {item.get('sheet', 'Data')} ===\n{rows_text}")
280
-
281
- if not chunks:
282
- history.append({"role": "assistant", "content": "No processable content found."})
283
- return history, "No processable content found", "Content extraction failed"
284
-
285
- # Analyze each chunk
286
- responses = []
287
- for i, chunk in enumerate(chunks[:5]):
288
- try:
289
- prompt = PROMPT_TEMPLATE.format(chunk=chunk[:5000])
290
- response = agent.run_quick_summary(prompt, 0.2, 256, 500)
291
- cleaned = clean_response(response)
292
- if cleaned:
293
- responses.append(f"Analysis {i+1}:\n{cleaned}")
294
- except Exception as e:
295
- logger.warning(f"Error analyzing chunk {i+1}: {str(e)}")
296
- continue
297
-
298
- if not responses:
299
- history.append({"role": "assistant", "content": "No valid analysis generated."})
300
- return history, "No valid analysis generated", "Analysis failed"
301
-
302
- summary = "\n\n".join(responses)
303
- history.append({"role": "assistant", "content": summary})
304
- return history, "Analysis completed", "Success"
305
 
306
- except Exception as e:
307
- logger.error(f"Analysis error: {e}")
308
- history.append({"role": "assistant", "content": f"Error: {str(e)}"})
309
- return history, f"Error: {str(e)}", "Failed"
310
- finally:
311
- torch.cuda.empty_cache()
312
- gc.collect()
313
-
314
- send_btn.click(
315
- analyze,
316
- inputs=[msg_input, chatbot, file_upload],
317
- outputs=[chatbot, final_summary, status]
318
- )
319
 
320
- msg_input.submit(
321
- analyze,
322
- inputs=[msg_input, chatbot, file_upload],
323
- outputs=[chatbot, final_summary, status]
324
- )
325
 
 
 
326
  return demo
327
 
328
-
329
  if __name__ == "__main__":
330
  try:
 
331
  agent = init_agent()
332
  demo = create_ui(agent)
333
- demo.launch(
334
  server_name="0.0.0.0",
335
  server_port=7860,
 
 
336
  share=False
337
  )
338
- except Exception as e:
339
- logger.error(f"Fatal error: {e}")
340
- raise
 
2
  import os
3
  import pandas as pd
4
  import pdfplumber
5
+ import json
6
  import gradio as gr
7
+ from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
12
+ import psutil
13
+ import subprocess
14
  import logging
15
  import torch
16
  import gc
17
  from diskcache import Cache
18
+ import time
 
 
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ # Persistent directory
 
 
 
 
 
 
 
 
 
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
27
 
 
29
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
30
  file_cache_dir = os.path.join(persistent_dir, "cache")
31
  report_dir = os.path.join(persistent_dir, "reports")
32
+ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
33
+
34
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
35
+ os.makedirs(directory, exist_ok=True)
36
 
37
+ os.environ["HF_HOME"] = model_cache_dir
38
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
39
+ os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
42
 
43
  current_dir = os.path.dirname(os.path.abspath(__file__))
44
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
46
 
47
  from txagent.txagent import TxAgent
48
 
49
+ # Initialize cache with 10GB limit
50
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
51
 
 
 
 
 
52
  def sanitize_utf8(text: str) -> str:
53
  return text.encode("utf-8", "ignore").decode("utf-8")
54
 
55
  def file_hash(path: str) -> str:
 
56
  with open(path, "rb") as f:
57
+ return hashlib.md5(f.read()).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def extract_all_pages(file_path: str, progress_callback=None) -> str:
60
  try:
 
61
  with pdfplumber.open(file_path) as pdf:
62
  total_pages = len(pdf.pages)
63
  if total_pages == 0:
64
+ return ""
65
+
66
+ batch_size = 10
67
+ batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
68
+ text_chunks = [""] * total_pages
69
+ processed_pages = 0
70
+
71
+ def extract_batch(start: int, end: int) -> List[tuple]:
72
+ results = []
73
+ with pdfplumber.open(file_path) as pdf:
74
+ for page in pdf.pages[start:end]:
75
+ page_num = start + pdf.pages.index(page)
76
+ page_text = page.extract_text() or ""
77
+ results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
78
+ return results
79
+
80
+ with ThreadPoolExecutor(max_workers=6) as executor:
81
+ futures = [executor.submit(extract_batch, start, end) for start, end in batches]
82
+ for future in as_completed(futures):
83
+ for page_num, text in future.result():
84
+ text_chunks[page_num] = text
85
+ processed_pages += batch_size
86
+ if progress_callback:
87
+ progress_callback(min(processed_pages, total_pages), total_pages)
88
+
89
+ return "\n\n".join(filter(None, text_chunks))
90
  except Exception as e:
91
+ logger.error("PDF processing error: %s", e)
92
+ return f"PDF processing error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
95
  try:
96
+ file_h = file_hash(file_path)
97
+ cache_key = f"{file_h}_{file_type}"
98
+ if cache_key in cache:
99
+ return cache[cache_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
101
  if file_type == "pdf":
102
+ text = extract_all_pages(file_path, progress_callback)
103
+ result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
 
 
 
 
 
 
 
 
104
  elif file_type == "csv":
105
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
106
+ skip_blank_lines=False, on_bad_lines="skip")
107
+ content = df.fillna("").astype(str).values.tolist()
108
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
109
+ elif file_type in ["xls", "xlsx"]:
110
+ try:
111
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
112
+ except Exception:
113
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
114
+ content = df.fillna("").astype(str).values.tolist()
115
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
116
+ else:
117
+ result = json.dumps({"error": f"Unsupported file type: {file_type}"})
118
+
119
+ cache[cache_key] = result
120
+ return result
121
  except Exception as e:
122
+ logger.error("Error processing %s: %s", os.path.basename(file_path), e)
123
+ return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
124
+
125
+ def log_system_usage(tag=""):
126
+ try:
127
+ cpu = psutil.cpu_percent(interval=1)
128
+ mem = psutil.virtual_memory()
129
+ logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
130
+ result = subprocess.run(
131
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
132
+ capture_output=True, text=True
133
+ )
134
+ if result.returncode == 0:
135
+ used, total, util = result.stdout.strip().split(", ")
136
+ logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
137
+ except Exception as e:
138
+ logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
139
 
140
  def clean_response(text: str) -> str:
141
+ text = sanitize_utf8(text)
142
+ text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
143
+ text = re.sub(r"\n{3,}", "\n\n", text)
144
+ text = re.sub(r"[^\n#\-\*\w\s\.\,\:\(\)]+", "", text)
145
+
146
+ sections = {}
147
+ current_section = None
148
+ lines = text.splitlines()
149
+ for line in lines:
150
+ line = line.strip()
151
+ if not line:
152
+ continue
153
+ section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
154
+ if section_match:
155
+ current_section = section_match.group(1)
156
+ if current_section not in sections:
157
+ sections[current_section] = []
158
+ continue
159
+ finding_match = re.match(r"-\s*.+", line)
160
+ if finding_match and current_section and not re.match(r"-\s*No issues identified", line):
161
+ sections[current_section].append(line)
162
+
163
+ cleaned = []
164
+ for heading, findings in sections.items():
165
+ if findings:
166
+ cleaned.append(f"### {heading}\n" + "\n".join(findings))
167
+
168
+ text = "\n\n".join(cleaned).strip()
169
+ return text if text else ""
170
+
171
+ def summarize_findings(combined_response: str) -> str:
172
+ if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
173
+ return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records."
174
+
175
+ sections = {}
176
+ lines = combined_response.splitlines()
177
+ current_section = None
178
+ for line in lines:
179
+ line = line.strip()
180
+ if not line:
181
+ continue
182
+ section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
183
+ if section_match:
184
+ current_section = section_match.group(1)
185
+ if current_section not in sections:
186
+ sections[current_section] = []
187
+ continue
188
+ finding_match = re.match(r"-\s*(.+)", line)
189
+ if finding_match and current_section:
190
+ sections[current_section].append(finding_match.group(1))
191
+
192
+ summary_lines = []
193
+ for heading, findings in sections.items():
194
+ if findings:
195
+ summary = f"- **{heading}**: {'; '.join(findings[:2])}. Risks: {heading.lower()} may lead to adverse outcomes. Recommend: urgent review and specialist referral."
196
+ summary_lines.append(summary)
197
+
198
+ if not summary_lines:
199
+ return "### Summary of Clinical Oversights\nNo critical oversights identified."
200
+
201
+ return "### Summary of Clinical Oversights\n" + "\n".join(summary_lines)
202
+
203
  def init_agent():
204
  logger.info("Initializing model...")
205
+ log_system_usage("Before Load")
206
+ default_tool_path = os.path.abspath("data/new_tool.json")
207
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
208
+ if not os.path.exists(target_tool_path):
209
+ shutil.copy(default_tool_path, target_tool_path)
210
+
211
  agent = TxAgent(
212
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
213
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
214
+ tool_files_dict={"new_tool": target_tool_path},
215
  force_finish=True,
216
  enable_checker=False,
217
  step_rag_num=4,
218
  seed=100,
219
+ additional_default_tools=[],
220
  )
221
  agent.init_model()
222
+ log_system_usage("After Load")
223
  logger.info("Agent Ready")
224
  return agent
225
 
226
  def create_ui(agent):
227
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
228
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
229
+ chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
230
+ final_summary = gr.Markdown(label="Summary of Clinical Oversights")
231
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
232
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
233
+ send_btn = gr.Button("Analyze", variant="primary")
234
+ download_output = gr.File(label="Download Full Report")
235
+ progress_bar = gr.Progress()
236
+
237
+ prompt_template = """
238
+ Analyze the patient record excerpt for clinical oversights. Provide a concise, evidence-based summary in markdown with findings grouped under headings (e.g., 'Missed Diagnoses'). For each finding, include clinical context, risks, and recommendations. Output only markdown bullet points under headings. If no issues, state "No issues identified".
239
+ Patient Record Excerpt (Chunk {0} of {1}):
240
  {chunk}
241
  """
242
 
243
+ def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
244
+ history.append({"role": "user", "content": message})
245
+ yield history, None, ""
246
+
247
+ extracted = ""
248
+ file_hash_value = ""
249
+ if files:
250
+ def update_extraction_progress(current, total):
251
+ progress(current / total, desc=f"Extracting text... Page {current}/{total}")
252
+ return history, None, ""
253
+
254
+ with ThreadPoolExecutor(max_workers=6) as executor:
255
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
256
+ results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
257
+ extracted = "\n".join(results)
258
+ file_hash_value = file_hash(files[0].name) if files else ""
259
+
260
+ history.append({"role": "assistant", "content": "✅ Text extraction complete."})
261
+ yield history, None, ""
262
+
263
+ chunk_size = 6000
264
+ chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
265
+ combined_response = ""
266
+ batch_size = 2
267
+
268
  try:
269
+ for batch_idx in range(0, len(chunks), batch_size):
270
+ batch_chunks = chunks[batch_idx:batch_idx + batch_size]
271
+ batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:4000]) for i, chunk in enumerate(batch_chunks)]
272
+ batch_responses = []
273
+
274
+ progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
275
+
276
+ with ThreadPoolExecutor(max_workers=len(batch_chunks)) as executor:
277
+ futures = [executor.submit(agent.run_gradio_chat, prompt, [], 0.2, 512, 2048, False, []) for prompt in batch_prompts]
278
+ for future in as_completed(futures):
279
+ chunk_response = ""
280
+ for chunk_output in future.result():
281
+ if chunk_output is None:
282
+ continue
283
+ if isinstance(chunk_output, list):
284
+ for m in chunk_output:
285
+ if hasattr(m, 'content') and m.content:
286
+ cleaned = clean_response(m.content)
287
+ if cleaned and re.search(r"###\s*\w+", cleaned):
288
+ chunk_response += cleaned + "\n\n"
289
+ elif isinstance(chunk_output, str) and chunk_output.strip():
290
+ cleaned = clean_response(m.content)
291
+ if cleaned and re.search(r"###\s*\w+", cleaned):
292
+ chunk_response += cleaned + "\n\n"
293
+ batch_responses.append(chunk_response)
294
+ torch.cuda.empty_cache()
295
+ gc.collect()
296
+
297
+ for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
298
+ if chunk_response:
299
+ combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
300
+ else:
301
+ combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
302
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
303
+ yield history, None, ""
304
+
305
+ if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
306
+ history[-1]["content"] = combined_response.strip()
307
+ else:
308
+ history.append({"role": "assistant", "content": "No oversights identified in the provided records."})
 
 
 
 
 
 
309
 
310
+ summary = summarize_findings(combined_response)
311
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
312
+ if report_path:
313
+ with open(report_path, "w", encoding="utf-8") as f:
314
+ f.write(combined_response + "\n\n" + summary)
315
+ yield history, report_path if report_path and os.path.exists(report_path) else None, summary
 
 
 
 
 
 
 
316
 
317
+ except Exception as e:
318
+ logger.error("Analysis error: %s", e)
319
+ history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
320
+ yield history, None, f"### Summary of Clinical Oversights\nError occurred during analysis: {str(e)}"
 
321
 
322
+ send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
323
+ msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
324
  return demo
325
 
 
326
  if __name__ == "__main__":
327
  try:
328
+ logger.info("Launching app...")
329
  agent = init_agent()
330
  demo = create_ui(agent)
331
+ demo.queue(api_open=False).launch(
332
  server_name="0.0.0.0",
333
  server_port=7860,
334
+ show_error=True,
335
+ allowed_paths=[report_dir],
336
  share=False
337
  )
338
+ finally:
339
+ if torch.distributed.is_initialized():
340
+ torch.distributed.destroy_process_group()