CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
2943a5c verified
raw
history blame
14 kB
import sys
import os
import pandas as pd
import pdfplumber
import json
import gradio as gr
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import psutil
import subprocess
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('clinical_oversight.log')
]
)
logger = logging.getLogger(__name__)
# Persistent directory
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(directory, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
'allergies', 'summary', 'impression', 'findings', 'recommendations'}
def sanitize_utf8(text: str) -> str:
return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path: str) -> str:
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
try:
text_chunks = []
with pdfplumber.open(file_path) as pdf:
for i, page in enumerate(pdf.pages[:3]):
text = page.extract_text() or ""
text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
page_text = page.extract_text() or ""
if any(re.search(rf'\\b{kw}\\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
return "\n\n".join(text_chunks)
except Exception as e:
logger.error(f"Error extracting pages from PDF: {str(e)}")
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path: str, file_type: str) -> str:
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
return f.read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
skip_blank_lines=False, on_bad_lines="skip")
content = df.fillna("").astype(str).values.tolist()
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
elif file_type in ["xls", "xlsx"]:
try:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
except Exception:
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
content = df.fillna("").astype(str).values.tolist()
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
else:
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f:
f.write(result)
return result
except Exception as e:
logger.error(f"Error converting {file_type} file to JSON: {str(e)}")
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
def log_system_usage(tag=""):
try:
cpu = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
logger.info(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
capture_output=True, text=True
)
if result.returncode == 0:
used, total, util = result.stdout.strip().split(", ")
logger.info(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
except Exception as e:
logger.error(f"[{tag}] GPU/CPU monitor failed: {e}")
def init_agent():
logger.info("🔁 Initializing model...")
log_system_usage("Before Load")
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100,
additional_default_tools=[],
)
agent.init_model()
log_system_usage("After Load")
logger.info("✅ Agent Ready")
return agent
def format_response_for_ui(response: str) -> str:
"""Formats the raw response for clean display in the UI"""
# Remove any tool call metadata
cleaned = response.split("[TOOL_CALLS]")[0].strip()
# If we have a structured response, format it nicely
if "Potential missed diagnoses" in cleaned or "Flagged medication conflicts" in cleaned:
# Add markdown formatting for better readability
formatted = []
for line in cleaned.split("\n"):
if line.startswith("Potential missed diagnoses"):
formatted.append(f"### 🔍 Potential Missed Diagnoses")
elif line.startswith("Flagged medication conflicts"):
formatted.append(f"\n### ⚠️ Flagged Medication Conflicts")
elif line.startswith("Incomplete assessments"):
formatted.append(f"\n### 📋 Incomplete Assessments")
elif line.startswith("Highlighted abnormal results"):
formatted.append(f"\n### ❗ Abnormal Results Needing Follow-up")
else:
formatted.append(line)
return "\n".join(formatted)
return cleaned
def analyze(message: str, history: List[Tuple[str, str]], files: list):
start_time = datetime.now()
logger.info(f"Starting analysis for message: {message[:100]}...")
if files:
logger.info(f"Processing {len(files)} uploaded files")
# Initialize chat history in the correct format if empty
if history is None:
history = []
# Add user message to history
history.append([message, None])
yield history, None
extracted = ""
file_hash_value = ""
if files:
try:
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
extracted = "\n".join(results)
file_hash_value = file_hash(files[0].name)
logger.info(f"Processed {len(files)} files, extracted {len(extracted)} characters")
except Exception as e:
logger.error(f"Error processing files: {str(e)}")
history[-1][1] = f"❌ Error processing files: {str(e)}"
yield history, None
return
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
1. List potential missed diagnoses
2. Flag any medication conflicts
3. Note incomplete assessments
4. Highlight abnormal results needing follow-up
Medical Records:
{extracted[:12000]}
### Potential Oversights:
"""
logger.info(f"Generated prompt with {len(prompt)} characters")
response_chunks = []
try:
logger.info("Starting model inference...")
for chunk in agent.run_gradio_chat(
message=prompt,
history=[],
temperature=0.2,
max_new_tokens=1024,
max_token=4096,
call_agent=False,
conversation=[]
):
if not chunk:
continue
if isinstance(chunk, str):
response_chunks.append(chunk)
elif isinstance(chunk, list):
response_chunks.extend([c.content for c in chunk if hasattr(c, 'content')])
partial_response = "".join(response_chunks)
formatted_partial = format_response_for_ui(partial_response)
if formatted_partial:
history[-1][1] = formatted_partial
yield history, None
full_response = "".join(response_chunks)
logger.info(f"Full model response received: {full_response[:500]}...")
final_output = format_response_for_ui(full_response)
if not final_output or len(final_output) < 20: # Very short response
final_output = "No clear oversights identified. Recommend comprehensive review."
logger.info("No significant findings detected in analysis")
history[-1][1] = final_output
# Save report
report_path = None
if file_hash_value:
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
try:
with open(report_path, "w", encoding="utf-8") as f:
f.write(final_output)
logger.info(f"Saved report to {report_path}")
except Exception as e:
logger.error(f"Error saving report: {str(e)}")
elapsed = (datetime.now() - start_time).total_seconds()
logger.info(f"Analysis completed in {elapsed:.2f} seconds")
yield history, report_path if report_path and os.path.exists(report_path) else None
except Exception as e:
logger.error(f"Error during analysis: {str(e)}", exc_info=True)
history[-1][1] = f"❌ Error during analysis: {str(e)}"
yield history, None
def create_ui(agent):
with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo:
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
gr.Markdown("""
<div style='text-align: center; margin-bottom: 20px;'>
Upload medical records and receive analysis of potential oversights, including:<br>
- Missed diagnoses - Medication conflicts - Incomplete assessments - Abnormal results needing follow-up
</div>
""")
with gr.Row():
with gr.Column(scale=2):
file_upload = gr.File(
label="Upload Medical Records",
file_types=[".pdf", ".csv", ".xls", ".xlsx"],
file_count="multiple",
interactive=True
)
msg_input = gr.Textbox(
placeholder="Ask about potential oversights...",
show_label=False,
lines=3,
max_lines=5
)
send_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Analysis Results",
height=600,
bubble_full_width=False,
show_copy_button=True
)
download_output = gr.File(
label="Download Full Report",
interactive=False
)
# Examples for quick testing
examples = gr.Examples(
examples=[
["Are there any potential missed diagnoses in these records?"],
["What medication conflicts should I be aware of?"],
["Are there any incomplete assessments in this case?"]
],
inputs=[msg_input],
label="Example Questions"
)
send_btn.click(
analyze,
inputs=[msg_input, gr.State([]), file_upload],
outputs=[chatbot, download_output]
)
msg_input.submit(
analyze,
inputs=[msg_input, gr.State([]), file_upload],
outputs=[chatbot, download_output]
)
# Add some footer text
gr.Markdown("""
<div style='text-align: center; margin-top: 20px; color: #666; font-size: 0.9em;'>
Note: This tool provides preliminary analysis only. Always verify findings with complete clinical evaluation.
</div>
""")
return demo
if __name__ == "__main__":
logger.info("🚀 Launching Clinical Oversight Assistant...")
try:
agent = init_agent()
demo = create_ui(agent)
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=[report_dir],
share=False
)
except Exception as e:
logger.error(f"Failed to launch application: {str(e)}", exc_info=True)
raise