Ali2206 commited on
Commit
3fa2049
·
verified ·
1 Parent(s): 5eb9bf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -99
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import json
5
  import gradio as gr
6
  from typing import List
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
  import hashlib
9
  import shutil
10
  import re
@@ -27,7 +27,7 @@ except ImportError:
27
  HAS_PYPDFIUM2 = False
28
  import pdfplumber
29
 
30
- # Configure logging and suppress warnings
31
  logging.basicConfig(level=logging.INFO)
32
  logging.getLogger("pdfminer").setLevel(logging.ERROR)
33
  logger = logging.getLogger(__name__)
@@ -67,76 +67,58 @@ def file_hash(path: str) -> str:
67
  with open(path, "rb") as f:
68
  return hashlib.md5(f.read()).hexdigest()
69
 
70
- async def extract_all_pages_async(file_path: str, progress_callback=None, use_ocr=False) -> str:
71
  try:
 
 
 
 
72
  if HAS_PYPDFIUM2:
73
  pdf = pdfium.PdfDocument(file_path)
74
  total_pages = len(pdf)
75
  if total_pages == 0:
76
  return ""
77
 
78
- batch_size = 5
79
- batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
80
- text_chunks = [""] * total_pages
81
- processed_pages = 0
82
-
83
- def extract_batch(start: int, end: int) -> List[tuple]:
84
- results = []
85
- for i in range(start, end):
86
- page = pdf[i]
87
- text = page.get_textpage().get_text_range() or ""
88
- if not text.strip() and use_ocr and 'pytesseract' in sys.modules:
89
- bitmap = page.render(scale=2).to_pil()
90
- text = pytesseract.image_to_string(bitmap, lang="eng")
91
- results.append((i, f"=== Page {i + 1} ===\n{text.strip()}"))
92
- return results
93
-
94
- loop = asyncio.get_event_loop()
95
  with ThreadPoolExecutor(max_workers=4) as executor:
96
- futures = [loop.run_in_executor(executor, extract_batch, start, end) for start, end in batches]
97
- for future in await asyncio.gather(*futures):
98
- for page_num, text in future:
99
- text_chunks[page_num] = text
100
- logger.debug("Page %d extracted: %s...", page_num + 1, text[:50])
101
- processed_pages += batch_size
102
  if progress_callback:
103
- progress_callback(min(processed_pages, total_pages), total_pages)
104
 
 
 
105
  pdf.close()
106
  else:
107
- # Fallback to pdfplumber
108
  with pdfplumber.open(file_path) as pdf:
109
  total_pages = len(pdf.pages)
110
  if total_pages == 0:
111
  return ""
112
 
113
- batch_size = 5
114
- batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
115
- text_chunks = [""] * total_pages
116
- processed_pages = 0
117
-
118
- def extract_batch(start: int, end: int) -> List[tuple]:
119
- results = []
120
- with pdfplumber.open(file_path) as pdf:
121
- for i in range(start, end):
122
- page = pdf.pages[i]
123
- text = page.extract_text() or ""
124
- results.append((i, f"=== Page {i + 1} ===\n{text.strip()}"))
125
- return results
126
-
127
- loop = asyncio.get_event_loop()
128
- with ThreadPoolExecutor(max_workers=4) as executor:
129
- futures = [loop.run_in_executor(executor, extract_batch, start, end) for start, end in batches]
130
- for future in await asyncio.gather(*futures):
131
- for page_num, text in future:
132
- text_chunks[page_num] = text
133
- logger.debug("Page %d extracted: %s...", page_num + 1, text[:50])
134
- processed_pages += batch_size
135
- if progress_callback:
136
- progress_callback(min(processed_pages, total_pages), total_pages)
137
-
138
- extracted_text = "\n\n".join(filter(None, text_chunks))
139
  logger.info("Extracted %d pages, total length: %d chars", total_pages, len(extracted_text))
 
 
 
140
  return extracted_text
141
  except Exception as e:
142
  logger.error("PDF processing error: %s", e)
@@ -151,10 +133,7 @@ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None)
151
  return cache[cache_key]
152
 
153
  if file_type == "pdf":
154
- text = asyncio.run(extract_all_pages_async(file_path, progress_callback, use_ocr=False))
155
- if not text.strip() or "PDF processing error" in text:
156
- logger.info("Retrying extraction with OCR for %s", file_path)
157
- text = asyncio.run(extract_all_pages_async(file_path, progress_callback, use_ocr=True))
158
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
159
  elif file_type == "csv":
160
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
@@ -199,10 +178,12 @@ def clean_response(text: str) -> str:
199
  text = text.replace("\n\n\n", "\n\n")
200
  sections = {}
201
  current_section = None
 
202
  for line in text.splitlines():
203
  line = line.strip()
204
- if not line:
205
  continue
 
206
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
207
  if section_match:
208
  current_section = section_match.group(1)
@@ -213,13 +194,21 @@ def clean_response(text: str) -> str:
213
  cleaned = [f"### {heading}\n" + "\n".join(findings) for heading, findings in sections.items() if findings]
214
  result = "\n\n".join(cleaned).strip()
215
  logger.debug("Cleaned response length: %d chars", len(result))
216
- return result or "No issues identified"
217
-
218
- def summarize_findings(combined_response: str) -> str:
219
- if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
220
- return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records."
221
- sections = {}
 
 
 
 
 
 
 
222
  current_section = None
 
223
  for line in combined_response.splitlines():
224
  line = line.strip()
225
  if not line:
@@ -227,15 +216,19 @@ def summarize_findings(combined_response: str) -> str:
227
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
228
  if section_match:
229
  current_section = section_match.group(1)
230
- sections.setdefault(current_section, [])
231
  continue
232
- if current_section and line.startswith("- "):
233
- sections[current_section].append(line[2:])
234
- summary_lines = [
235
- f"- **{heading}**: {'; '.join(findings[:1])}. Risks: potential adverse outcomes. Recommend: urgent review."
236
- for heading, findings in sections.items() if findings
237
- ]
238
- result = "### Summary of Clinical Oversights\n" + "\n".join(summary_lines) if summary_lines else "### Summary of Clinical Oversights\nNo critical oversights identified."
 
 
 
 
 
239
  logger.debug("Summary length: %d chars", len(result))
240
  return result
241
 
@@ -267,8 +260,8 @@ def init_agent():
267
  def create_ui(agent):
268
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
269
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
270
- chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
271
- final_summary = gr.Markdown(label="Summary of Clinical Oversights")
272
  file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
273
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
274
  send_btn = gr.Button("Analyze", variant="primary")
@@ -276,9 +269,9 @@ def create_ui(agent):
276
  progress_bar = gr.Progress()
277
 
278
  prompt_template = """
279
- Analyze the patient record excerpt for clinical oversights. Provide a concise, evidence-based summary in markdown with findings grouped under headings (e.g., 'Missed Diagnoses'). For each finding, include clinical context, risks, and recommendations. Output only markdown bullet points under headings. If no issues, state "No issues identified".
280
 
281
- Patient Record Excerpt (Chunk {0} of {1}):
282
  {chunk}
283
  """
284
 
@@ -302,69 +295,64 @@ Patient Record Excerpt (Chunk {0} of {1}):
302
  yield history, None, ""
303
  logger.info("Extracted text length: %d chars", len(extracted))
304
 
305
- chunk_size = 3000 # Adjusted for balance
306
- chunks = [extracted[i:i + chunk_size] for i in range(0, max(len(extracted), 1), chunk_size)]
307
- if not chunks:
308
- chunks = [""] # Ensure at least one chunk
309
  logger.info("Created %d chunks", len(chunks))
310
- combined_response = ""
 
 
311
  batch_size = 2
312
 
313
  try:
314
  for batch_idx in range(0, len(chunks), batch_size):
315
  batch_chunks = chunks[batch_idx:batch_idx + batch_size]
316
- batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:2000]) for i, chunk in enumerate(batch_chunks)]
317
  batch_responses = []
318
 
319
  progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
320
 
321
  async def process_chunk(prompt):
322
  chunk_response = ""
 
323
  for chunk_output in agent.run_gradio_chat(
324
- message=prompt, history=[], temperature=0.2, max_new_tokens=256, max_token=1024, call_agent=False, conversation=[]
325
  ):
326
  if chunk_output is None:
327
  continue
328
  if isinstance(chunk_output, list):
329
  for m in chunk_output:
330
  if hasattr(m, 'content') and m.content:
 
331
  cleaned = clean_response(m.content)
332
  chunk_response += cleaned + "\n\n"
333
  elif isinstance(chunk_output, str) and chunk_output.strip():
 
334
  cleaned = clean_response(chunk_output)
335
  chunk_response += cleaned + "\n\n"
 
336
  logger.debug("Chunk response length: %d chars", len(chunk_response))
337
  return chunk_response
338
 
339
  futures = [process_chunk(prompt) for prompt in batch_prompts]
340
  batch_responses = await asyncio.gather(*futures)
 
341
  torch.cuda.empty_cache()
342
  gc.collect()
343
 
344
- for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
345
- if chunk_response.strip():
346
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
347
- else:
348
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
349
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
350
- yield history, None, ""
351
-
352
- if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
353
- history[-1]["content"] = combined_response.strip()
354
- else:
355
- history.append({"role": "assistant", "content": "No oversights identified in the provided records."})
356
 
357
- summary = summarize_findings(combined_response)
358
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
359
  if report_path:
360
  with open(report_path, "w", encoding="utf-8") as f:
361
- f.write(combined_response + "\n\n" + summary)
362
  yield history, report_path if report_path and os.path.exists(report_path) else None, summary
363
 
364
  except Exception as e:
365
  logger.error("Analysis error: %s", e)
366
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
367
- yield history, None, f"### Summary of Clinical Oversights\nError occurred during analysis: {str(e)}"
368
 
369
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
370
  msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
 
4
  import json
5
  import gradio as gr
6
  from typing import List
7
+ from concurrent.futures import ThreadPoolExecutor
8
  import hashlib
9
  import shutil
10
  import re
 
27
  HAS_PYPDFIUM2 = False
28
  import pdfplumber
29
 
30
+ # Configure logging
31
  logging.basicConfig(level=logging.INFO)
32
  logging.getLogger("pdfminer").setLevel(logging.ERROR)
33
  logger = logging.getLogger(__name__)
 
67
  with open(path, "rb") as f:
68
  return hashlib.md5(f.read()).hexdigest()
69
 
70
+ async def extract_all_pages_async(file_path: str, progress_callback=None, force_ocr=False) -> str:
71
  try:
72
+ extracted_text = ""
73
+ total_pages = 0
74
+ text_chunks = []
75
+
76
  if HAS_PYPDFIUM2:
77
  pdf = pdfium.PdfDocument(file_path)
78
  total_pages = len(pdf)
79
  if total_pages == 0:
80
  return ""
81
 
82
+ def extract_page(i):
83
+ page = pdf[i]
84
+ text = page.get_textpage().get_text_range() or ""
85
+ if (not text.strip() or len(text) < 100) and force_ocr and 'pytesseract' in sys.modules:
86
+ logger.info("Falling back to OCR for page %d", i + 1)
87
+ bitmap = page.render(scale=2).to_pil()
88
+ text = pytesseract.image_to_string(bitmap, lang="eng")
89
+ return (i, f"=== Page {i + 1} ===\n{text.strip()}")
90
+
 
 
 
 
 
 
 
 
91
  with ThreadPoolExecutor(max_workers=4) as executor:
92
+ futures = [executor.submit(extract_page, i) for i in range(total_pages)]
93
+ for future in as_completed(futures):
94
+ page_num, text = future.result()
95
+ text_chunks.append((page_num, text))
96
+ logger.debug("Page %d extracted: %s...", page_num + 1, text[:50])
 
97
  if progress_callback:
98
+ progress_callback(page_num + 1, total_pages)
99
 
100
+ text_chunks.sort(key=lambda x: x[0])
101
+ extracted_text = "\n\n".join(chunk[1] for chunk in text_chunks if chunk[1].strip())
102
  pdf.close()
103
  else:
 
104
  with pdfplumber.open(file_path) as pdf:
105
  total_pages = len(pdf.pages)
106
  if total_pages == 0:
107
  return ""
108
 
109
+ for i, page in enumerate(pdf.pages):
110
+ text = page.extract_text() or ""
111
+ text_chunks.append((i, f"=== Page {i + 1} ===\n{text.strip()}"))
112
+ logger.debug("Page %d extracted: %s...", i + 1, text[:50])
113
+ if progress_callback:
114
+ progress_callback(i + 1, total_pages)
115
+
116
+ extracted_text = "\n\n".join(chunk[1] for chunk in text_chunks if chunk[1].strip())
117
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  logger.info("Extracted %d pages, total length: %d chars", total_pages, len(extracted_text))
119
+ if len(extracted_text) < 1000 and not force_ocr and HAS_PYPDFIUM2 and 'pytesseract' in sys.modules:
120
+ logger.info("Text too short, retrying with OCR")
121
+ return await extract_all_pages_async(file_path, progress_callback, force_ocr=True)
122
  return extracted_text
123
  except Exception as e:
124
  logger.error("PDF processing error: %s", e)
 
133
  return cache[cache_key]
134
 
135
  if file_type == "pdf":
136
+ text = asyncio.run(extract_all_pages_async(file_path, progress_callback, force_ocr=False))
 
 
 
137
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
138
  elif file_type == "csv":
139
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
 
178
  text = text.replace("\n\n\n", "\n\n")
179
  sections = {}
180
  current_section = None
181
+ seen_lines = set()
182
  for line in text.splitlines():
183
  line = line.strip()
184
+ if not line or line in seen_lines:
185
  continue
186
+ seen_lines.add(line)
187
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
188
  if section_match:
189
  current_section = section_match.group(1)
 
194
  cleaned = [f"### {heading}\n" + "\n".join(findings) for heading, findings in sections.items() if findings]
195
  result = "\n\n".join(cleaned).strip()
196
  logger.debug("Cleaned response length: %d chars", len(result))
197
+ return result or "No oversights identified"
198
+
199
+ def summarize_findings(all_responses: List[str]) -> str:
200
+ combined_response = "\n\n".join(all_responses)
201
+ if not combined_response or all("No oversights identified" in resp.lower() for resp in all_responses):
202
+ return "### Comprehensive Clinical Oversight Summary\nNo critical oversights were identified across the provided patient records after thorough analysis."
203
+
204
+ sections = {
205
+ "Missed Diagnoses": [],
206
+ "Medication Conflicts": [],
207
+ "Incomplete Assessments": [],
208
+ "Urgent Follow-up": []
209
+ }
210
  current_section = None
211
+ seen_findings = set()
212
  for line in combined_response.splitlines():
213
  line = line.strip()
214
  if not line:
 
216
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
217
  if section_match:
218
  current_section = section_match.group(1)
 
219
  continue
220
+ if current_section and line.startswith("- ") and line not in seen_findings:
221
+ sections[current_section].append(line)
222
+ seen_findings.add(line)
223
+
224
+ summary_lines = []
225
+ for heading, findings in sections.items():
226
+ if findings:
227
+ summary_lines.append(f"### {heading}")
228
+ for finding in findings:
229
+ summary_lines.append(f"{finding}\n - **Risks**: Potential adverse outcomes if not addressed.\n - **Recommendation**: Immediate clinical review and follow-up.")
230
+
231
+ result = "### Comprehensive Clinical Oversight Summary\n" + "\n".join(summary_lines) if summary_lines else "### Comprehensive Clinical Oversight Summary\nNo critical oversights identified."
232
  logger.debug("Summary length: %d chars", len(result))
233
  return result
234
 
 
260
  def create_ui(agent):
261
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
262
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
263
+ chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages", visible=False)
264
+ final_summary = gr.Markdown(label="Comprehensive Clinical Oversight Summary")
265
  file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
266
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
267
  send_btn = gr.Button("Analyze", variant="primary")
 
269
  progress_bar = gr.Progress()
270
 
271
  prompt_template = """
272
+ Analyze the patient record excerpt for clinical oversights. Provide a detailed, evidence-based summary in markdown with findings grouped under headings: Missed Diagnoses, Medication Conflicts, Incomplete Assessments, Urgent Follow-up. For each finding, include clinical context, risks, and recommendations. Output only markdown bullet points under headings. If no issues, state "No oversights identified" once.
273
 
274
+ Patient Record Excerpt:
275
  {chunk}
276
  """
277
 
 
295
  yield history, None, ""
296
  logger.info("Extracted text length: %d chars", len(extracted))
297
 
298
+ chunk_size = 3000
299
+ chunks = [extracted[i:i + chunk_size] for i in range(0, max(len(extracted), 1), chunk_size)] or [""]
 
 
300
  logger.info("Created %d chunks", len(chunks))
301
+ for i, chunk in enumerate(chunks):
302
+ logger.debug("Chunk %d content: %s...", i + 1, chunk[:100])
303
+ all_responses = []
304
  batch_size = 2
305
 
306
  try:
307
  for batch_idx in range(0, len(chunks), batch_size):
308
  batch_chunks = chunks[batch_idx:batch_idx + batch_size]
309
+ batch_prompts = [prompt_template.format(chunk=chunk[:2000]) for chunk in batch_chunks]
310
  batch_responses = []
311
 
312
  progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
313
 
314
  async def process_chunk(prompt):
315
  chunk_response = ""
316
+ raw_outputs = []
317
  for chunk_output in agent.run_gradio_chat(
318
+ message=prompt, history=[], temperature=0.2, max_new_tokens=512, max_token=1024, call_agent=False, conversation=[]
319
  ):
320
  if chunk_output is None:
321
  continue
322
  if isinstance(chunk_output, list):
323
  for m in chunk_output:
324
  if hasattr(m, 'content') and m.content:
325
+ raw_outputs.append(m.content)
326
  cleaned = clean_response(m.content)
327
  chunk_response += cleaned + "\n\n"
328
  elif isinstance(chunk_output, str) and chunk_output.strip():
329
+ raw_outputs.append(chunk_output)
330
  cleaned = clean_response(chunk_output)
331
  chunk_response += cleaned + "\n\n"
332
+ logger.debug("Raw outputs: %s", raw_outputs[:100])
333
  logger.debug("Chunk response length: %d chars", len(chunk_response))
334
  return chunk_response
335
 
336
  futures = [process_chunk(prompt) for prompt in batch_prompts]
337
  batch_responses = await asyncio.gather(*futures)
338
+ all_responses.extend([resp.strip() for resp in batch_responses if resp.strip()])
339
  torch.cuda.empty_cache()
340
  gc.collect()
341
 
342
+ summary = summarize_findings(all_responses)
343
+ history.append({"role": "assistant", "content": "Analysis complete. See summary below."})
344
+ yield history, None, summary
 
 
 
 
 
 
 
 
 
345
 
 
346
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
347
  if report_path:
348
  with open(report_path, "w", encoding="utf-8") as f:
349
+ f.write(summary)
350
  yield history, report_path if report_path and os.path.exists(report_path) else None, summary
351
 
352
  except Exception as e:
353
  logger.error("Analysis error: %s", e)
354
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
355
+ yield history, None, f"### Comprehensive Clinical Oversight Summary\nError occurred during analysis: {str(e)}"
356
 
357
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
358
  msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])