Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import sys
|
2 |
import os
|
3 |
import pandas as pd
|
4 |
-
import pdfplumber
|
5 |
import json
|
6 |
import gradio as gr
|
7 |
from typing import List
|
@@ -16,9 +15,15 @@ import torch
|
|
16 |
import gc
|
17 |
from diskcache import Cache
|
18 |
import time
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
# Configure logging
|
21 |
logging.basicConfig(level=logging.INFO)
|
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
# Persistent directory
|
@@ -56,37 +61,45 @@ def file_hash(path: str) -> str:
|
|
56 |
with open(path, "rb") as f:
|
57 |
return hashlib.md5(f.read()).hexdigest()
|
58 |
|
59 |
-
def
|
60 |
try:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
batch_size =
|
67 |
batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
|
68 |
text_chunks = [""] * total_pages
|
69 |
processed_pages = 0
|
70 |
|
71 |
def extract_batch(start: int, end: int) -> List[tuple]:
|
72 |
results = []
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
78 |
return results
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
for
|
83 |
-
|
|
|
84 |
text_chunks[page_num] = text
|
|
|
85 |
processed_pages += batch_size
|
86 |
if progress_callback:
|
87 |
progress_callback(min(processed_pages, total_pages), total_pages)
|
88 |
|
89 |
-
|
|
|
|
|
|
|
90 |
except Exception as e:
|
91 |
logger.error("PDF processing error: %s", e)
|
92 |
return f"PDF processing error: {str(e)}"
|
@@ -96,10 +109,15 @@ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None)
|
|
96 |
file_h = file_hash(file_path)
|
97 |
cache_key = f"{file_h}_{file_type}"
|
98 |
if cache_key in cache:
|
|
|
99 |
return cache[cache_key]
|
100 |
|
101 |
if file_type == "pdf":
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
104 |
elif file_type == "csv":
|
105 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
|
@@ -117,6 +135,7 @@ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None)
|
|
117 |
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
118 |
|
119 |
cache[cache_key] = result
|
|
|
120 |
return result
|
121 |
except Exception as e:
|
122 |
logger.error("Error processing %s: %s", os.path.basename(file_path), e)
|
@@ -139,66 +158,49 @@ def log_system_usage(tag=""):
|
|
139 |
|
140 |
def clean_response(text: str) -> str:
|
141 |
text = sanitize_utf8(text)
|
142 |
-
text =
|
143 |
-
text =
|
144 |
-
text = re.sub(r"[^\n#\-\*\w\s\.\,\:\(\)]+", "", text)
|
145 |
-
|
146 |
sections = {}
|
147 |
current_section = None
|
148 |
-
|
149 |
-
for line in lines:
|
150 |
line = line.strip()
|
151 |
if not line:
|
152 |
continue
|
153 |
section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
|
154 |
if section_match:
|
155 |
current_section = section_match.group(1)
|
156 |
-
|
157 |
-
sections[current_section] = []
|
158 |
continue
|
159 |
-
|
160 |
-
if finding_match and current_section and not re.match(r"-\s*No issues identified", line):
|
161 |
sections[current_section].append(line)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
cleaned.append(f"### {heading}\n" + "\n".join(findings))
|
167 |
-
|
168 |
-
text = "\n\n".join(cleaned).strip()
|
169 |
-
return text if text else ""
|
170 |
|
171 |
def summarize_findings(combined_response: str) -> str:
|
172 |
if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
|
173 |
return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records."
|
174 |
-
|
175 |
sections = {}
|
176 |
-
lines = combined_response.splitlines()
|
177 |
current_section = None
|
178 |
-
for line in
|
179 |
line = line.strip()
|
180 |
if not line:
|
181 |
continue
|
182 |
section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
|
183 |
if section_match:
|
184 |
current_section = section_match.group(1)
|
185 |
-
|
186 |
-
sections[current_section] = []
|
187 |
continue
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
if not summary_lines:
|
199 |
-
return "### Summary of Clinical Oversights\nNo critical oversights identified."
|
200 |
-
|
201 |
-
return "### Summary of Clinical Oversights\n" + "\n".join(summary_lines)
|
202 |
|
203 |
def init_agent():
|
204 |
logger.info("Initializing model...")
|
@@ -214,7 +216,9 @@ def init_agent():
|
|
214 |
tool_files_dict={"new_tool": target_tool_path},
|
215 |
force_finish=True,
|
216 |
enable_checker=False,
|
217 |
-
|
|
|
|
|
218 |
seed=100,
|
219 |
additional_default_tools=[],
|
220 |
)
|
@@ -241,7 +245,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
|
|
241 |
{chunk}
|
242 |
"""
|
243 |
|
244 |
-
def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
|
245 |
history.append({"role": "user", "content": message})
|
246 |
yield history, None, ""
|
247 |
|
@@ -252,56 +256,61 @@ Patient Record Excerpt (Chunk {0} of {1}):
|
|
252 |
progress(current / total, desc=f"Extracting text... Page {current}/{total}")
|
253 |
return history, None, ""
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
file_hash_value = file_hash(files[0].name) if files else ""
|
260 |
|
261 |
history.append({"role": "assistant", "content": "✅ Text extraction complete."})
|
262 |
yield history, None, ""
|
|
|
263 |
|
264 |
-
chunk_size =
|
265 |
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
|
|
266 |
combined_response = ""
|
267 |
batch_size = 2
|
268 |
|
269 |
try:
|
270 |
for batch_idx in range(0, len(chunks), batch_size):
|
271 |
batch_chunks = chunks[batch_idx:batch_idx + batch_size]
|
272 |
-
batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:
|
273 |
batch_responses = []
|
274 |
|
275 |
progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
|
276 |
|
277 |
-
|
278 |
-
|
279 |
-
for
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
297 |
|
298 |
for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
|
299 |
if chunk_response:
|
300 |
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
|
301 |
else:
|
302 |
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
|
303 |
-
|
304 |
-
|
305 |
|
306 |
if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
|
307 |
history[-1]["content"] = combined_response.strip()
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
import pandas as pd
|
|
|
4 |
import json
|
5 |
import gradio as gr
|
6 |
from typing import List
|
|
|
15 |
import gc
|
16 |
from diskcache import Cache
|
17 |
import time
|
18 |
+
import asyncio
|
19 |
+
import pypdfium2 as pdfium
|
20 |
+
import pytesseract
|
21 |
+
from PIL import Image
|
22 |
+
import io
|
23 |
|
24 |
+
# Configure logging and suppress warnings
|
25 |
logging.basicConfig(level=logging.INFO)
|
26 |
+
logging.getLogger("pdfminer").setLevel(logging.ERROR)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
29 |
# Persistent directory
|
|
|
61 |
with open(path, "rb") as f:
|
62 |
return hashlib.md5(f.read()).hexdigest()
|
63 |
|
64 |
+
async def extract_all_pages_async(file_path: str, progress_callback=None, use_ocr=False) -> str:
|
65 |
try:
|
66 |
+
pdf = pdfium.PdfDocument(file_path)
|
67 |
+
total_pages = len(pdf)
|
68 |
+
if total_pages == 0:
|
69 |
+
return ""
|
70 |
|
71 |
+
batch_size = 5
|
72 |
batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
|
73 |
text_chunks = [""] * total_pages
|
74 |
processed_pages = 0
|
75 |
|
76 |
def extract_batch(start: int, end: int) -> List[tuple]:
|
77 |
results = []
|
78 |
+
for i in range(start, end):
|
79 |
+
page = pdf[i]
|
80 |
+
text = page.get_textpage().get_text_range() or ""
|
81 |
+
if not text.strip() and use_ocr:
|
82 |
+
# Fallback to OCR
|
83 |
+
bitmap = page.render(scale=2).to_pil()
|
84 |
+
text = pytesseract.image_to_string(bitmap, lang="eng")
|
85 |
+
results.append((i, f"=== Page {i + 1} ===\n{text.strip()}"))
|
86 |
return results
|
87 |
|
88 |
+
loop = asyncio.get_event_loop()
|
89 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
90 |
+
futures = [loop.run_in_executor(executor, extract_batch, start, end) for start, end in batches]
|
91 |
+
for future in await asyncio.gather(*futures):
|
92 |
+
for page_num, text in future:
|
93 |
text_chunks[page_num] = text
|
94 |
+
logger.debug("Page %d extracted: %s...", page_num + 1, text[:50])
|
95 |
processed_pages += batch_size
|
96 |
if progress_callback:
|
97 |
progress_callback(min(processed_pages, total_pages), total_pages)
|
98 |
|
99 |
+
pdf.close()
|
100 |
+
extracted_text = "\n\n".join(filter(None, text_chunks))
|
101 |
+
logger.info("Extracted %d pages, total length: %d chars", total_pages, len(extracted_text))
|
102 |
+
return extracted_text
|
103 |
except Exception as e:
|
104 |
logger.error("PDF processing error: %s", e)
|
105 |
return f"PDF processing error: {str(e)}"
|
|
|
109 |
file_h = file_hash(file_path)
|
110 |
cache_key = f"{file_h}_{file_type}"
|
111 |
if cache_key in cache:
|
112 |
+
logger.info("Using cached extraction for %s", file_path)
|
113 |
return cache[cache_key]
|
114 |
|
115 |
if file_type == "pdf":
|
116 |
+
# Try without OCR first, fallback to OCR if empty
|
117 |
+
text = asyncio.run(extract_all_pages_async(file_path, progress_callback, use_ocr=False))
|
118 |
+
if not text.strip() or "PDF processing error" in text:
|
119 |
+
logger.info("Retrying extraction with OCR for %s", file_path)
|
120 |
+
text = asyncio.run(extract_all_pages_async(file_path, progress_callback, use_ocr=True))
|
121 |
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
122 |
elif file_type == "csv":
|
123 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
|
|
|
135 |
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
136 |
|
137 |
cache[cache_key] = result
|
138 |
+
logger.info("Cached extraction for %s, size: %d bytes", file_path, len(result))
|
139 |
return result
|
140 |
except Exception as e:
|
141 |
logger.error("Error processing %s: %s", os.path.basename(file_path), e)
|
|
|
158 |
|
159 |
def clean_response(text: str) -> str:
|
160 |
text = sanitize_utf8(text)
|
161 |
+
text = text.replace("[", "").replace("]", "").replace("None", "") # Faster string ops
|
162 |
+
text = text.replace("\n\n\n", "\n\n")
|
|
|
|
|
163 |
sections = {}
|
164 |
current_section = None
|
165 |
+
for line in text.splitlines():
|
|
|
166 |
line = line.strip()
|
167 |
if not line:
|
168 |
continue
|
169 |
section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
|
170 |
if section_match:
|
171 |
current_section = section_match.group(1)
|
172 |
+
sections.setdefault(current_section, [])
|
|
|
173 |
continue
|
174 |
+
if current_section and line.startswith("- ") and "No issues identified" not in line:
|
|
|
175 |
sections[current_section].append(line)
|
176 |
+
cleaned = [f"### {heading}\n" + "\n".join(findings) for heading, findings in sections.items() if findings]
|
177 |
+
result = "\n\n".join(cleaned).strip()
|
178 |
+
logger.debug("Cleaned response length: %d chars", len(result))
|
179 |
+
return result or ""
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def summarize_findings(combined_response: str) -> str:
|
182 |
if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
|
183 |
return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records."
|
|
|
184 |
sections = {}
|
|
|
185 |
current_section = None
|
186 |
+
for line in combined_response.splitlines():
|
187 |
line = line.strip()
|
188 |
if not line:
|
189 |
continue
|
190 |
section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line)
|
191 |
if section_match:
|
192 |
current_section = section_match.group(1)
|
193 |
+
sections.setdefault(current_section, [])
|
|
|
194 |
continue
|
195 |
+
if current_section and line.startswith("- "):
|
196 |
+
sections[current_section].append(line[2:])
|
197 |
+
summary_lines = [
|
198 |
+
f"- **{heading}**: {'; '.join(findings[:1])}. Risks: potential adverse outcomes. Recommend: urgent review."
|
199 |
+
for heading, findings in sections.items() if findings
|
200 |
+
]
|
201 |
+
result = "### Summary of Clinical Oversights\n" + "\n".join(summary_lines) if summary_lines else "### Summary of Clinical Oversights\nNo critical oversights identified."
|
202 |
+
logger.debug("Summary length: %d chars", len(result))
|
203 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
def init_agent():
|
206 |
logger.info("Initializing model...")
|
|
|
216 |
tool_files_dict={"new_tool": target_tool_path},
|
217 |
force_finish=True,
|
218 |
enable_checker=False,
|
219 |
+
enable_rag=False,
|
220 |
+
init_rag_num=0,
|
221 |
+
step_rag_num=0,
|
222 |
seed=100,
|
223 |
additional_default_tools=[],
|
224 |
)
|
|
|
245 |
{chunk}
|
246 |
"""
|
247 |
|
248 |
+
async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
|
249 |
history.append({"role": "user", "content": message})
|
250 |
yield history, None, ""
|
251 |
|
|
|
256 |
progress(current / total, desc=f"Extracting text... Page {current}/{total}")
|
257 |
return history, None, ""
|
258 |
|
259 |
+
futures = [convert_file_to_json(f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
|
260 |
+
results = [sanitize_utf8(future) for future in futures]
|
261 |
+
extracted = "\n".join(results)
|
262 |
+
file_hash_value = file_hash(files[0].name) if files else ""
|
|
|
263 |
|
264 |
history.append({"role": "assistant", "content": "✅ Text extraction complete."})
|
265 |
yield history, None, ""
|
266 |
+
logger.info("Extracted text length: %d chars", len(extracted))
|
267 |
|
268 |
+
chunk_size = 4000 # Increased slightly
|
269 |
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
270 |
+
logger.info("Created %d chunks", len(chunks))
|
271 |
combined_response = ""
|
272 |
batch_size = 2
|
273 |
|
274 |
try:
|
275 |
for batch_idx in range(0, len(chunks), batch_size):
|
276 |
batch_chunks = chunks[batch_idx:batch_idx + batch_size]
|
277 |
+
batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:2000]) for i, chunk in enumerate(batch_chunks)]
|
278 |
batch_responses = []
|
279 |
|
280 |
progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
|
281 |
|
282 |
+
async def process_chunk(prompt):
|
283 |
+
chunk_response = ""
|
284 |
+
for chunk_output in agent.run_gradio_chat(
|
285 |
+
message=prompt, history=[], temperature=0.2, max_new_tokens=128, max_token=768, call_agent=False, conversation=[]
|
286 |
+
):
|
287 |
+
if chunk_output is None:
|
288 |
+
continue
|
289 |
+
if isinstance(chunk_output, list):
|
290 |
+
for m in chunk_output:
|
291 |
+
if hasattr(m, 'content') and m.content:
|
292 |
+
cleaned = clean_response(m.content)
|
293 |
+
if cleaned and re.search(r"###\s*\w+", cleaned):
|
294 |
+
chunk_response += cleaned + "\n\n"
|
295 |
+
elif isinstance(chunk_output, str) and chunk_output.strip():
|
296 |
+
cleaned = clean_response(chunk_output)
|
297 |
+
if cleaned and re.search(r"###\s*\w+", cleaned):
|
298 |
+
chunk_response += cleaned + "\n\n"
|
299 |
+
logger.debug("Chunk response length: %d chars", len(chunk_response))
|
300 |
+
return chunk_response
|
301 |
+
|
302 |
+
futures = [process_chunk(prompt) for prompt in batch_prompts]
|
303 |
+
batch_responses = await asyncio.gather(*futures)
|
304 |
+
torch.cuda.empty_cache()
|
305 |
+
gc.collect()
|
306 |
|
307 |
for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
|
308 |
if chunk_response:
|
309 |
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
|
310 |
else:
|
311 |
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo oversights identified for this chunk.\n\n"
|
312 |
+
history[-1] = {"role": "assistant", "content": combined_response.strip()}
|
313 |
+
yield history, None, ""
|
314 |
|
315 |
if combined_response.strip() and not all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
|
316 |
history[-1]["content"] = combined_response.strip()
|