Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import re
|
|
12 |
import psutil
|
13 |
import subprocess
|
14 |
from datetime import datetime
|
|
|
15 |
|
16 |
# Persistent directory setup
|
17 |
persistent_dir = "/data/hf_cache"
|
@@ -44,8 +45,10 @@ MEDICAL_KEYWORDS = {
|
|
44 |
'allergies', 'summary', 'impression', 'findings', 'recommendations',
|
45 |
'conclusion', 'history', 'examination', 'progress', 'discharge'
|
46 |
}
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
|
50 |
def sanitize_utf8(text: str) -> str:
|
51 |
"""Ensure text is UTF-8 clean."""
|
@@ -56,14 +59,21 @@ def file_hash(path: str) -> str:
|
|
56 |
with open(path, "rb") as f:
|
57 |
return hashlib.md5(f.read()).hexdigest()
|
58 |
|
59 |
-
def
|
|
|
|
|
|
|
|
|
|
|
60 |
"""
|
61 |
-
Extract all pages from PDF with
|
62 |
-
Returns (extracted_text, total_pages)
|
63 |
"""
|
64 |
try:
|
65 |
text_chunks = []
|
66 |
total_pages = 0
|
|
|
|
|
67 |
with pdfplumber.open(file_path) as pdf:
|
68 |
total_pages = len(pdf.pages)
|
69 |
|
@@ -71,18 +81,22 @@ def extract_all_pages(file_path: str) -> Tuple[str, int]:
|
|
71 |
page_text = page.extract_text() or ""
|
72 |
lower_text = page_text.lower()
|
73 |
|
74 |
-
#
|
75 |
if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
|
76 |
-
|
|
|
|
|
77 |
else:
|
78 |
-
text_chunks.append(f"=== Page {i+1} ===\n{page_text.strip()}")
|
|
|
|
|
79 |
|
80 |
-
return "\n
|
81 |
except Exception as e:
|
82 |
-
return f"PDF processing error: {str(e)}", 0
|
83 |
|
84 |
def convert_file_to_json(file_path: str, file_type: str) -> str:
|
85 |
-
"""Convert file to JSON format with caching
|
86 |
try:
|
87 |
h = file_hash(file_path)
|
88 |
cache_path = os.path.join(file_cache_dir, f"{h}.json")
|
@@ -92,11 +106,12 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
|
|
92 |
return f.read()
|
93 |
|
94 |
if file_type == "pdf":
|
95 |
-
text, total_pages =
|
96 |
result = json.dumps({
|
97 |
"filename": os.path.basename(file_path),
|
98 |
"content": text,
|
99 |
"total_pages": total_pages,
|
|
|
100 |
"status": "complete"
|
101 |
})
|
102 |
elif file_type == "csv":
|
@@ -106,15 +121,22 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
|
|
106 |
skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
|
107 |
chunks.append(chunk.fillna("").astype(str).values.tolist())
|
108 |
content = [item for sublist in chunks for item in sublist]
|
109 |
-
result = json.dumps({
|
|
|
|
|
|
|
|
|
110 |
elif file_type in ["xls", "xlsx"]:
|
111 |
try:
|
112 |
-
# Read Excel in chunks if possible
|
113 |
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
|
114 |
except Exception:
|
115 |
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
116 |
content = df.fillna("").astype(str).values.tolist()
|
117 |
-
result = json.dumps({
|
|
|
|
|
|
|
|
|
118 |
else:
|
119 |
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
120 |
|
@@ -204,6 +226,40 @@ def format_final_report(analysis_results: List[str], filename: str) -> str:
|
|
204 |
|
205 |
return "\n".join(report)
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
def init_agent():
|
208 |
"""Initialize the TxAgent with proper configuration."""
|
209 |
print("🔁 Initializing model...")
|
@@ -229,72 +285,74 @@ def init_agent():
|
|
229 |
print("✅ Agent Ready")
|
230 |
return agent
|
231 |
|
232 |
-
def
|
233 |
-
"""Analyze
|
234 |
-
|
235 |
-
sections = re.split(r"(=== MEDICAL SECTION|=== Page \d+ ===)", content)
|
236 |
-
sections = [s.strip() for s in sections if s.strip()]
|
237 |
-
|
238 |
analysis_results = []
|
239 |
-
current_chunk = ""
|
240 |
-
|
241 |
-
for section in sections:
|
242 |
-
# If adding this section would exceed chunk size, analyze current chunk
|
243 |
-
if len(current_chunk) + len(section) > CHUNK_SIZE and current_chunk:
|
244 |
-
analysis_results.append(process_chunk(current_chunk, filename, agent))
|
245 |
-
current_chunk = section
|
246 |
-
else:
|
247 |
-
current_chunk += "\n\n" + section
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
Analyze this section of medical records for clinical oversights. Focus on:
|
259 |
-
1. Critical findings needing immediate attention
|
260 |
-
2. Potential missed diagnoses
|
261 |
-
3. Medication conflicts
|
262 |
-
4. Assessment gaps
|
263 |
-
5. Follow-up recommendations
|
264 |
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
268 |
|
269 |
-
|
270 |
-
Focus on factual evidence from the content.
|
271 |
"""
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
continue
|
285 |
-
|
286 |
-
if isinstance(output, list):
|
287 |
-
for m in output:
|
288 |
-
if hasattr(m, 'content') and m.content:
|
289 |
-
cleaned = clean_response(m.content)
|
290 |
-
if cleaned:
|
291 |
-
full_response += cleaned + "\n"
|
292 |
-
elif isinstance(output, str) and output.strip():
|
293 |
-
cleaned = clean_response(output)
|
294 |
-
if cleaned:
|
295 |
-
full_response += cleaned + "\n"
|
296 |
|
297 |
-
return
|
298 |
|
299 |
def create_ui(agent):
|
300 |
"""Create the Gradio interface."""
|
@@ -316,7 +374,7 @@ def create_ui(agent):
|
|
316 |
label="Analysis Focus"
|
317 |
)
|
318 |
with gr.Row():
|
319 |
-
send_btn = gr.Button("Analyze
|
320 |
clear_btn = gr.Button("Clear")
|
321 |
status = gr.Textbox(label="Status", interactive=False)
|
322 |
|
@@ -338,11 +396,12 @@ def create_ui(agent):
|
|
338 |
yield "", None, "⚠️ Please upload at least one file to analyze."
|
339 |
return
|
340 |
|
341 |
-
yield "", None, "⏳ Processing documents..."
|
342 |
|
343 |
# Process all files completely
|
344 |
file_contents = []
|
345 |
filenames = []
|
|
|
346 |
|
347 |
with ThreadPoolExecutor(max_workers=4) as executor:
|
348 |
futures = []
|
@@ -356,7 +415,14 @@ def create_ui(agent):
|
|
356 |
|
357 |
results = []
|
358 |
for future in as_completed(futures):
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
file_contents = results
|
362 |
|
@@ -367,11 +433,11 @@ def create_ui(agent):
|
|
367 |
for fc in file_contents
|
368 |
])
|
369 |
|
370 |
-
yield "", None, "🔍 Analyzing content..."
|
371 |
|
372 |
try:
|
373 |
# Process the complete document
|
374 |
-
full_report =
|
375 |
combined_content,
|
376 |
combined_filename,
|
377 |
agent
|
@@ -408,6 +474,13 @@ def create_ui(agent):
|
|
408 |
|
409 |
if __name__ == "__main__":
|
410 |
print("🚀 Launching app...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
agent = init_agent()
|
412 |
demo = create_ui(agent)
|
413 |
demo.queue(
|
|
|
12 |
import psutil
|
13 |
import subprocess
|
14 |
from datetime import datetime
|
15 |
+
import tiktoken
|
16 |
|
17 |
# Persistent directory setup
|
18 |
persistent_dir = "/data/hf_cache"
|
|
|
45 |
'allergies', 'summary', 'impression', 'findings', 'recommendations',
|
46 |
'conclusion', 'history', 'examination', 'progress', 'discharge'
|
47 |
}
|
48 |
+
TOKENIZER = "cl100k_base" # Matches Llama 3's tokenizer
|
49 |
+
MAX_MODEL_LEN = 8000 # Conservative estimate for model context
|
50 |
+
CHUNK_TOKEN_SIZE = MAX_MODEL_LEN // 2 # Target chunk size
|
51 |
+
MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
|
52 |
|
53 |
def sanitize_utf8(text: str) -> str:
|
54 |
"""Ensure text is UTF-8 clean."""
|
|
|
59 |
with open(path, "rb") as f:
|
60 |
return hashlib.md5(f.read()).hexdigest()
|
61 |
|
62 |
+
def count_tokens(text: str) -> int:
|
63 |
+
"""Count tokens using the same method as the model"""
|
64 |
+
encoding = tiktoken.get_encoding(TOKENIZER)
|
65 |
+
return len(encoding.encode(text))
|
66 |
+
|
67 |
+
def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
|
68 |
"""
|
69 |
+
Extract all pages from PDF with token counting.
|
70 |
+
Returns (extracted_text, total_pages, total_tokens)
|
71 |
"""
|
72 |
try:
|
73 |
text_chunks = []
|
74 |
total_pages = 0
|
75 |
+
total_tokens = 0
|
76 |
+
|
77 |
with pdfplumber.open(file_path) as pdf:
|
78 |
total_pages = len(pdf.pages)
|
79 |
|
|
|
81 |
page_text = page.extract_text() or ""
|
82 |
lower_text = page_text.lower()
|
83 |
|
84 |
+
# Mark medical sections
|
85 |
if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
|
86 |
+
section_header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n"
|
87 |
+
text_chunks.append(section_header + page_text.strip())
|
88 |
+
total_tokens += count_tokens(section_header)
|
89 |
else:
|
90 |
+
text_chunks.append(f"\n=== Page {i+1} ===\n{page_text.strip()}")
|
91 |
+
|
92 |
+
total_tokens += count_tokens(page_text)
|
93 |
|
94 |
+
return "\n".join(text_chunks), total_pages, total_tokens
|
95 |
except Exception as e:
|
96 |
+
return f"PDF processing error: {str(e)}", 0, 0
|
97 |
|
98 |
def convert_file_to_json(file_path: str, file_type: str) -> str:
|
99 |
+
"""Convert file to JSON format with caching and token counting."""
|
100 |
try:
|
101 |
h = file_hash(file_path)
|
102 |
cache_path = os.path.join(file_cache_dir, f"{h}.json")
|
|
|
106 |
return f.read()
|
107 |
|
108 |
if file_type == "pdf":
|
109 |
+
text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
|
110 |
result = json.dumps({
|
111 |
"filename": os.path.basename(file_path),
|
112 |
"content": text,
|
113 |
"total_pages": total_pages,
|
114 |
+
"total_tokens": total_tokens,
|
115 |
"status": "complete"
|
116 |
})
|
117 |
elif file_type == "csv":
|
|
|
121 |
skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
|
122 |
chunks.append(chunk.fillna("").astype(str).values.tolist())
|
123 |
content = [item for sublist in chunks for item in sublist]
|
124 |
+
result = json.dumps({
|
125 |
+
"filename": os.path.basename(file_path),
|
126 |
+
"rows": content,
|
127 |
+
"total_tokens": count_tokens(str(content))
|
128 |
+
})
|
129 |
elif file_type in ["xls", "xlsx"]:
|
130 |
try:
|
|
|
131 |
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
|
132 |
except Exception:
|
133 |
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
134 |
content = df.fillna("").astype(str).values.tolist()
|
135 |
+
result = json.dumps({
|
136 |
+
"filename": os.path.basename(file_path),
|
137 |
+
"rows": content,
|
138 |
+
"total_tokens": count_tokens(str(content))
|
139 |
+
})
|
140 |
else:
|
141 |
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
142 |
|
|
|
226 |
|
227 |
return "\n".join(report)
|
228 |
|
229 |
+
def split_content_by_tokens(content: str, max_tokens: int = CHUNK_TOKEN_SIZE) -> List[str]:
|
230 |
+
"""Split content into chunks that fit within token limits"""
|
231 |
+
paragraphs = re.split(r"\n\s*\n", content)
|
232 |
+
chunks = []
|
233 |
+
current_chunk = []
|
234 |
+
current_tokens = 0
|
235 |
+
|
236 |
+
for para in paragraphs:
|
237 |
+
para_tokens = count_tokens(para)
|
238 |
+
if para_tokens > max_tokens:
|
239 |
+
# Handle very long paragraphs by splitting sentences
|
240 |
+
sentences = re.split(r'(?<=[.!?])\s+', para)
|
241 |
+
for sent in sentences:
|
242 |
+
sent_tokens = count_tokens(sent)
|
243 |
+
if current_tokens + sent_tokens > max_tokens:
|
244 |
+
chunks.append("\n\n".join(current_chunk))
|
245 |
+
current_chunk = [sent]
|
246 |
+
current_tokens = sent_tokens
|
247 |
+
else:
|
248 |
+
current_chunk.append(sent)
|
249 |
+
current_tokens += sent_tokens
|
250 |
+
elif current_tokens + para_tokens > max_tokens:
|
251 |
+
chunks.append("\n\n".join(current_chunk))
|
252 |
+
current_chunk = [para]
|
253 |
+
current_tokens = para_tokens
|
254 |
+
else:
|
255 |
+
current_chunk.append(para)
|
256 |
+
current_tokens += para_tokens
|
257 |
+
|
258 |
+
if current_chunk:
|
259 |
+
chunks.append("\n\n".join(current_chunk))
|
260 |
+
|
261 |
+
return chunks
|
262 |
+
|
263 |
def init_agent():
|
264 |
"""Initialize the TxAgent with proper configuration."""
|
265 |
print("🔁 Initializing model...")
|
|
|
285 |
print("✅ Agent Ready")
|
286 |
return agent
|
287 |
|
288 |
+
def analyze_complete_document(content: str, filename: str, agent: TxAgent) -> str:
|
289 |
+
"""Analyze complete document with proper chunking and token management"""
|
290 |
+
chunks = split_content_by_tokens(content)
|
|
|
|
|
|
|
291 |
analysis_results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
+
for i, chunk in enumerate(chunks):
|
294 |
+
try:
|
295 |
+
# Create context-aware prompt
|
296 |
+
prompt = f"""
|
297 |
+
Analyze this section ({i+1}/{len(chunks)}) of medical records for clinical oversights.
|
298 |
+
Focus on factual evidence from the content only.
|
299 |
|
300 |
+
**File:** {filename}
|
301 |
+
**Content:**
|
302 |
+
{chunk}
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
+
Provide concise findings under these headings:
|
305 |
+
1. CRITICAL FINDINGS (urgent issues)
|
306 |
+
2. MISSED DIAGNOSES (with supporting evidence)
|
307 |
+
3. MEDICATION ISSUES (specific conflicts)
|
308 |
+
4. ASSESSMENT GAPS (missing evaluations)
|
309 |
+
5. FOLLOW-UP RECOMMENDATIONS (specific actions)
|
310 |
|
311 |
+
Be concise and evidence-based:
|
|
|
312 |
"""
|
313 |
+
# Ensure prompt + chunk doesn't exceed model limits
|
314 |
+
prompt_tokens = count_tokens(prompt)
|
315 |
+
chunk_tokens = count_tokens(chunk)
|
316 |
+
|
317 |
+
if prompt_tokens + chunk_tokens > MAX_MODEL_LEN - 1024: # Leave room for response
|
318 |
+
# Dynamically adjust chunk size
|
319 |
+
max_chunk_tokens = MAX_MODEL_LEN - prompt_tokens - 1024
|
320 |
+
adjusted_chunk = ""
|
321 |
+
tokens_used = 0
|
322 |
+
for para in re.split(r"\n\s*\n", chunk):
|
323 |
+
para_tokens = count_tokens(para)
|
324 |
+
if tokens_used + para_tokens <= max_chunk_tokens:
|
325 |
+
adjusted_chunk += "\n\n" + para
|
326 |
+
tokens_used += para_tokens
|
327 |
+
else:
|
328 |
+
break
|
329 |
+
chunk = adjusted_chunk.strip()
|
330 |
+
|
331 |
+
response = ""
|
332 |
+
for output in agent.run_gradio_chat(
|
333 |
+
message=prompt,
|
334 |
+
history=[],
|
335 |
+
temperature=0.1,
|
336 |
+
max_new_tokens=1024,
|
337 |
+
max_token=MAX_MODEL_LEN,
|
338 |
+
call_agent=False,
|
339 |
+
conversation=[],
|
340 |
+
):
|
341 |
+
if output:
|
342 |
+
if isinstance(output, list):
|
343 |
+
for m in output:
|
344 |
+
if hasattr(m, 'content'):
|
345 |
+
response += clean_response(m.content)
|
346 |
+
elif isinstance(output, str):
|
347 |
+
response += clean_response(output)
|
348 |
+
|
349 |
+
if response:
|
350 |
+
analysis_results.append(response)
|
351 |
+
except Exception as e:
|
352 |
+
print(f"Error processing chunk {i}: {str(e)}")
|
353 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
+
return format_final_report(analysis_results, filename)
|
356 |
|
357 |
def create_ui(agent):
|
358 |
"""Create the Gradio interface."""
|
|
|
374 |
label="Analysis Focus"
|
375 |
)
|
376 |
with gr.Row():
|
377 |
+
send_btn = gr.Button("Analyze Complete Documents", variant="primary")
|
378 |
clear_btn = gr.Button("Clear")
|
379 |
status = gr.Textbox(label="Status", interactive=False)
|
380 |
|
|
|
396 |
yield "", None, "⚠️ Please upload at least one file to analyze."
|
397 |
return
|
398 |
|
399 |
+
yield "", None, "⏳ Processing documents (this may take several minutes for large files)..."
|
400 |
|
401 |
# Process all files completely
|
402 |
file_contents = []
|
403 |
filenames = []
|
404 |
+
total_tokens = 0
|
405 |
|
406 |
with ThreadPoolExecutor(max_workers=4) as executor:
|
407 |
futures = []
|
|
|
415 |
|
416 |
results = []
|
417 |
for future in as_completed(futures):
|
418 |
+
result = sanitize_utf8(future.result())
|
419 |
+
results.append(result)
|
420 |
+
try:
|
421 |
+
data = json.loads(result)
|
422 |
+
if "total_tokens" in data:
|
423 |
+
total_tokens += data["total_tokens"]
|
424 |
+
except:
|
425 |
+
pass
|
426 |
|
427 |
file_contents = results
|
428 |
|
|
|
433 |
for fc in file_contents
|
434 |
])
|
435 |
|
436 |
+
yield "", None, f"🔍 Analyzing content ({total_tokens//1000}k tokens)..."
|
437 |
|
438 |
try:
|
439 |
# Process the complete document
|
440 |
+
full_report = analyze_complete_document(
|
441 |
combined_content,
|
442 |
combined_filename,
|
443 |
agent
|
|
|
474 |
|
475 |
if __name__ == "__main__":
|
476 |
print("🚀 Launching app...")
|
477 |
+
# Install tiktoken if not available
|
478 |
+
try:
|
479 |
+
import tiktoken
|
480 |
+
except ImportError:
|
481 |
+
print("Installing tiktoken...")
|
482 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
|
483 |
+
|
484 |
agent = init_agent()
|
485 |
demo = create_ui(agent)
|
486 |
demo.queue(
|