Ali2206 commited on
Commit
6741b3e
·
verified ·
1 Parent(s): 9277e15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -87
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import sys
2
  import os
3
  import pandas as pd
4
- import pdfplumber
5
  import json
6
  import gradio as gr
7
  from typing import List
@@ -16,9 +15,15 @@ import torch
16
  import gc
17
  from diskcache import Cache
18
  import time
 
 
 
 
 
19
 
20
- # Configure logging
21
  logging.basicConfig(level=logging.INFO)
 
22
  logger = logging.getLogger(__name__)
23
 
24
  # Persistent directory
@@ -56,37 +61,45 @@ def file_hash(path: str) -> str:
56
  with open(path, "rb") as f:
57
  return hashlib.md5(f.read()).hexdigest()
58
 
59
- def extract_all_pages(file_path: str, progress_callback=None) -> str:
60
  try:
61
- with pdfplumber.open(file_path) as pdf:
62
- total_pages = len(pdf.pages)
63
- if total_pages == 0:
64
- return ""
65
 
66
- batch_size = 10
67
  batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
68
  text_chunks = [""] * total_pages
69
  processed_pages = 0
70
 
71
  def extract_batch(start: int, end: int) -> List[tuple]:
72
  results = []
73
- with pdfplumber.open(file_path) as pdf:
74
- for page in pdf.pages[start:end]:
75
- page_num = start + pdf.pages.index(page)
76
- page_text = page.extract_text() or ""
77
- results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
 
 
 
78
  return results
79
 
80
- with ThreadPoolExecutor(max_workers=6) as executor:
81
- futures = [executor.submit(extract_batch, start, end) for start, end in batches]
82
- for future in as_completed(futures):
83
- for page_num, text in future.result():
 
84
  text_chunks[page_num] = text
 
85
  processed_pages += batch_size
86
  if progress_callback:
87
  progress_callback(min(processed_pages, total_pages), total_pages)
88
 
89
- return "\n\n".join(filter(None, text_chunks))
 
 
 
90
  except Exception as e:
91
  logger.error("PDF processing error: %s", e)
92
  return f"PDF processing error: {str(e)}"
@@ -96,10 +109,15 @@ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None)
96
  file_h = file_hash(file_path)
97
  cache_key = f"{file_h}_{file_type}"
98
  if cache_key in cache:
 
99
  return cache[cache_key]
100
 
101
  if file_type == "pdf":
102
- text = extract_all_pages(file_path, progress_callback)
 
 
 
 
103
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
104
  elif file_type == "csv":
105
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
@@ -117,6 +135,7 @@ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None)
117
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
118
 
119
  cache[cache_key] = result
 
120
  return result
121
  except Exception as e:
122
  logger.error("Error processing %s: %s", os.path.basename(file_path), e)
@@ -139,66 +158,49 @@ def log_system_usage(tag=""):
139
 
140
  def clean_response(text: str) -> str:
141
  text = sanitize_utf8(text)
142
- text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
143
- text = re.sub(r"\n{3,}", "\n\n", text)
144
- text = re.sub(r"[^\n#\-\*\w\s\.\,\:\(\)]+", "", text)
145
-
146
  sections = {}
147
  current_section = None
148
- lines = text.splitlines()
149
- for line in lines:
150
  line = line.strip()
151
  if not line:
152
  continue
153
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
154
  if section_match:
155
  current_section = section_match.group(1)
156
- if current_section not in sections:
157
- sections[current_section] = []
158
  continue
159
- finding_match = re.match(r"-\s*.+", line)
160
- if finding_match and current_section and not re.match(r"-\s*No issues identified", line):
161
  sections[current_section].append(line)
162
-
163
- cleaned = []
164
- for heading, findings in sections.items():
165
- if findings:
166
- cleaned.append(f"### {heading}\n" + "\n".join(findings))
167
-
168
- text = "\n\n".join(cleaned).strip()
169
- return text if text else ""
170
 
171
  def summarize_findings(combined_response: str) -> str:
172
  if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
173
  return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records."
174
-
175
  sections = {}
176
- lines = combined_response.splitlines()
177
  current_section = None
178
- for line in lines:
179
  line = line.strip()
180
  if not line:
181
  continue
182
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
183
  if section_match:
184
  current_section = section_match.group(1)
185
- if current_section not in sections:
186
- sections[current_section] = []
187
  continue
188
- finding_match = re.match(r"-\s*(.+)", line)
189
- if finding_match and current_section:
190
- sections[current_section].append(finding_match.group(1))
191
-
192
- summary_lines = []
193
- for heading, findings in sections.items():
194
- if findings:
195
- summary = f"- **{heading}**: {'; '.join(findings[:2])}. Risks: {heading.lower()} may lead to adverse outcomes. Recommend: urgent review and specialist referral."
196
- summary_lines.append(summary)
197
-
198
- if not summary_lines:
199
- return "### Summary of Clinical Oversights\nNo critical oversights identified."
200
-
201
- return "### Summary of Clinical Oversights\n" + "\n".join(summary_lines)
202
 
203
  def init_agent():
204
  logger.info("Initializing model...")
@@ -214,7 +216,9 @@ def init_agent():
214
  tool_files_dict={"new_tool": target_tool_path},
215
  force_finish=True,
216
  enable_checker=False,
217
- step_rag_num=4,
 
 
218
  seed=100,
219
  additional_default_tools=[],
220
  )
@@ -241,7 +245,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
241
  {chunk}
242
  """
243
 
244
- def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
245
  history.append({"role": "user", "content": message})
246
  yield history, None, ""
247
 
@@ -252,56 +256,61 @@ Patient Record Excerpt (Chunk {0} of {1}):
252
  progress(current / total, desc=f"Extracting text... Page {current}/{total}")
253
  return history, None, ""
254
 
255
- with ThreadPoolExecutor(max_workers=6) as executor:
256
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
257
- results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
258
- extracted = "\n".join(results)
259
- file_hash_value = file_hash(files[0].name) if files else ""
260
 
261
  history.append({"role": "assistant", "content": "✅ Text extraction complete."})
262
  yield history, None, ""
 
263
 
264
- chunk_size = 6000
265
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
 
266
  combined_response = ""
267
  batch_size = 2
268
 
269
  try:
270
  for batch_idx in range(0, len(chunks), batch_size):
271
  batch_chunks = chunks[batch_idx:batch_idx + batch_size]
272
- batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:4000]) for i, chunk in enumerate(batch_chunks)]
273
  batch_responses = []
274
 
275
  progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
276
 
277
- with ThreadPoolExecutor(max_workers=len(batch_chunks)) as executor:
278
- futures = [executor.submit(agent.run_gradio_chat, prompt, [], 0.2, 512, 2048, False, []) for prompt in batch_prompts]
279
- for future in as_completed(futures):
280
- chunk_response = ""
281
- for chunk_output in future.result():
282
- if chunk_output is None:
283
- continue
284
- if isinstance(chunk_output, list):
285
- for m in chunk_output:
286
- if hasattr(m, 'content') and m.content:
287
- cleaned = clean_response(m.content)
288
- if cleaned and re.search(r"###\s*\w+", cleaned):
289
- chunk_response += cleaned + "\n\n"
290
- elif isinstance(chunk_output, str) and chunk_output.strip():
291
- cleaned = clean_response(m.content)
292
- if cleaned and re.search(r"###\s*\w+", cleaned):
293
- chunk_response += cleaned + "\n\n"
294
- batch_responses.append(chunk_response)
295
- torch.cuda.empty_cache()
296
- gc.collect()
 
 
 
 
297
 
298
  for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
299
  if chunk_response:
300
  combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
301
  else:
302
  combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
303
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
304
- yield history, None, ""
305
 
306
  if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
307
  history[-1]["content"] = combined_response.strip()
 
1
  import sys
2
  import os
3
  import pandas as pd
 
4
  import json
5
  import gradio as gr
6
  from typing import List
 
15
  import gc
16
  from diskcache import Cache
17
  import time
18
+ import asyncio
19
+ import pypdfium2 as pdfium
20
+ import pytesseract
21
+ from PIL import Image
22
+ import io
23
 
24
+ # Configure logging and suppress warnings
25
  logging.basicConfig(level=logging.INFO)
26
+ logging.getLogger("pdfminer").setLevel(logging.ERROR)
27
  logger = logging.getLogger(__name__)
28
 
29
  # Persistent directory
 
61
  with open(path, "rb") as f:
62
  return hashlib.md5(f.read()).hexdigest()
63
 
64
+ async def extract_all_pages_async(file_path: str, progress_callback=None, use_ocr=False) -> str:
65
  try:
66
+ pdf = pdfium.PdfDocument(file_path)
67
+ total_pages = len(pdf)
68
+ if total_pages == 0:
69
+ return ""
70
 
71
+ batch_size = 5
72
  batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
73
  text_chunks = [""] * total_pages
74
  processed_pages = 0
75
 
76
  def extract_batch(start: int, end: int) -> List[tuple]:
77
  results = []
78
+ for i in range(start, end):
79
+ page = pdf[i]
80
+ text = page.get_textpage().get_text_range() or ""
81
+ if not text.strip() and use_ocr:
82
+ # Fallback to OCR
83
+ bitmap = page.render(scale=2).to_pil()
84
+ text = pytesseract.image_to_string(bitmap, lang="eng")
85
+ results.append((i, f"=== Page {i + 1} ===\n{text.strip()}"))
86
  return results
87
 
88
+ loop = asyncio.get_event_loop()
89
+ with ThreadPoolExecutor(max_workers=4) as executor:
90
+ futures = [loop.run_in_executor(executor, extract_batch, start, end) for start, end in batches]
91
+ for future in await asyncio.gather(*futures):
92
+ for page_num, text in future:
93
  text_chunks[page_num] = text
94
+ logger.debug("Page %d extracted: %s...", page_num + 1, text[:50])
95
  processed_pages += batch_size
96
  if progress_callback:
97
  progress_callback(min(processed_pages, total_pages), total_pages)
98
 
99
+ pdf.close()
100
+ extracted_text = "\n\n".join(filter(None, text_chunks))
101
+ logger.info("Extracted %d pages, total length: %d chars", total_pages, len(extracted_text))
102
+ return extracted_text
103
  except Exception as e:
104
  logger.error("PDF processing error: %s", e)
105
  return f"PDF processing error: {str(e)}"
 
109
  file_h = file_hash(file_path)
110
  cache_key = f"{file_h}_{file_type}"
111
  if cache_key in cache:
112
+ logger.info("Using cached extraction for %s", file_path)
113
  return cache[cache_key]
114
 
115
  if file_type == "pdf":
116
+ # Try without OCR first, fallback to OCR if empty
117
+ text = asyncio.run(extract_all_pages_async(file_path, progress_callback, use_ocr=False))
118
+ if not text.strip() or "PDF processing error" in text:
119
+ logger.info("Retrying extraction with OCR for %s", file_path)
120
+ text = asyncio.run(extract_all_pages_async(file_path, progress_callback, use_ocr=True))
121
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
122
  elif file_type == "csv":
123
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
 
135
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
136
 
137
  cache[cache_key] = result
138
+ logger.info("Cached extraction for %s, size: %d bytes", file_path, len(result))
139
  return result
140
  except Exception as e:
141
  logger.error("Error processing %s: %s", os.path.basename(file_path), e)
 
158
 
159
  def clean_response(text: str) -> str:
160
  text = sanitize_utf8(text)
161
+ text = text.replace("[", "").replace("]", "").replace("None", "") # Faster string ops
162
+ text = text.replace("\n\n\n", "\n\n")
 
 
163
  sections = {}
164
  current_section = None
165
+ for line in text.splitlines():
 
166
  line = line.strip()
167
  if not line:
168
  continue
169
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
170
  if section_match:
171
  current_section = section_match.group(1)
172
+ sections.setdefault(current_section, [])
 
173
  continue
174
+ if current_section and line.startswith("- ") and "No issues identified" not in line:
 
175
  sections[current_section].append(line)
176
+ cleaned = [f"### {heading}\n" + "\n".join(findings) for heading, findings in sections.items() if findings]
177
+ result = "\n\n".join(cleaned).strip()
178
+ logger.debug("Cleaned response length: %d chars", len(result))
179
+ return result or ""
 
 
 
 
180
 
181
  def summarize_findings(combined_response: str) -> str:
182
  if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
183
  return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records."
 
184
  sections = {}
 
185
  current_section = None
186
+ for line in combined_response.splitlines():
187
  line = line.strip()
188
  if not line:
189
  continue
190
  section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
191
  if section_match:
192
  current_section = section_match.group(1)
193
+ sections.setdefault(current_section, [])
 
194
  continue
195
+ if current_section and line.startswith("- "):
196
+ sections[current_section].append(line[2:])
197
+ summary_lines = [
198
+ f"- **{heading}**: {'; '.join(findings[:1])}. Risks: potential adverse outcomes. Recommend: urgent review."
199
+ for heading, findings in sections.items() if findings
200
+ ]
201
+ result = "### Summary of Clinical Oversights\n" + "\n".join(summary_lines) if summary_lines else "### Summary of Clinical Oversights\nNo critical oversights identified."
202
+ logger.debug("Summary length: %d chars", len(result))
203
+ return result
 
 
 
 
 
204
 
205
  def init_agent():
206
  logger.info("Initializing model...")
 
216
  tool_files_dict={"new_tool": target_tool_path},
217
  force_finish=True,
218
  enable_checker=False,
219
+ enable_rag=False,
220
+ init_rag_num=0,
221
+ step_rag_num=0,
222
  seed=100,
223
  additional_default_tools=[],
224
  )
 
245
  {chunk}
246
  """
247
 
248
+ async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
249
  history.append({"role": "user", "content": message})
250
  yield history, None, ""
251
 
 
256
  progress(current / total, desc=f"Extracting text... Page {current}/{total}")
257
  return history, None, ""
258
 
259
+ futures = [convert_file_to_json(f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
260
+ results = [sanitize_utf8(future) for future in futures]
261
+ extracted = "\n".join(results)
262
+ file_hash_value = file_hash(files[0].name) if files else ""
 
263
 
264
  history.append({"role": "assistant", "content": "✅ Text extraction complete."})
265
  yield history, None, ""
266
+ logger.info("Extracted text length: %d chars", len(extracted))
267
 
268
+ chunk_size = 4000 # Increased slightly
269
  chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
270
+ logger.info("Created %d chunks", len(chunks))
271
  combined_response = ""
272
  batch_size = 2
273
 
274
  try:
275
  for batch_idx in range(0, len(chunks), batch_size):
276
  batch_chunks = chunks[batch_idx:batch_idx + batch_size]
277
+ batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:2000]) for i, chunk in enumerate(batch_chunks)]
278
  batch_responses = []
279
 
280
  progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
281
 
282
+ async def process_chunk(prompt):
283
+ chunk_response = ""
284
+ for chunk_output in agent.run_gradio_chat(
285
+ message=prompt, history=[], temperature=0.2, max_new_tokens=128, max_token=768, call_agent=False, conversation=[]
286
+ ):
287
+ if chunk_output is None:
288
+ continue
289
+ if isinstance(chunk_output, list):
290
+ for m in chunk_output:
291
+ if hasattr(m, 'content') and m.content:
292
+ cleaned = clean_response(m.content)
293
+ if cleaned and re.search(r"###\s*\w+", cleaned):
294
+ chunk_response += cleaned + "\n\n"
295
+ elif isinstance(chunk_output, str) and chunk_output.strip():
296
+ cleaned = clean_response(chunk_output)
297
+ if cleaned and re.search(r"###\s*\w+", cleaned):
298
+ chunk_response += cleaned + "\n\n"
299
+ logger.debug("Chunk response length: %d chars", len(chunk_response))
300
+ return chunk_response
301
+
302
+ futures = [process_chunk(prompt) for prompt in batch_prompts]
303
+ batch_responses = await asyncio.gather(*futures)
304
+ torch.cuda.empty_cache()
305
+ gc.collect()
306
 
307
  for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
308
  if chunk_response:
309
  combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
310
  else:
311
  combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
312
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
313
+ yield history, None, ""
314
 
315
  if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
316
  history[-1]["content"] = combined_response.strip()