CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
57bee88 verified
raw
history blame
14.1 kB
import pdfplumber
import json
import gradio as gr
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import psutil
import subprocess
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('clinical_oversight.log')
]
)
logger = logging.getLogger(__name__)
# Persistent directory
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(directory, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
'allergies', 'summary', 'impression', 'findings', 'recommendations'}
def sanitize_utf8(text: str) -> str:
return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path: str) -> str:
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
try:
text_chunks = []
with pdfplumber.open(file_path) as pdf:
for i, page in enumerate(pdf.pages[:3]):
text = page.extract_text() or ""
text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
page_text = page.extract_text() or ""
if any(re.search(rf'\\b{kw}\\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
return "\n\n".join(text_chunks)
except Exception as e:
logger.error(f"Error extracting pages from PDF: {str(e)}")
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path: str, file_type: str) -> str:
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
return f.read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
skip_blank_lines=False, on_bad_lines="skip")
content = df.fillna("").astype(str).values.tolist()
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
elif file_type in ["xls", "xlsx"]:
try:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
except Exception:
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
content = df.fillna("").astype(str).values.tolist()
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
else:
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f:
f.write(result)
return result
except Exception as e:
logger.error(f"Error converting {file_type} file to JSON: {str(e)}")
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
def log_system_usage(tag=""):
try:
cpu = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
logger.info(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
capture_output=True, text=True
)
if result.returncode == 0:
used, total, util = result.stdout.strip().split(", ")
logger.info(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
except Exception as e:
logger.error(f"[{tag}] GPU/CPU monitor failed: {e}")
def init_agent():
logger.info("🔁 Initializing model...")
log_system_usage("Before Load")
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100,
additional_default_tools=[],
)
agent.init_model()
log_system_usage("After Load")
logger.info("✅ Agent Ready")
return agent
def format_response_for_ui(response: str) -> str:
"""Formats the raw response for clean display in the UI"""
# Remove any tool call metadata
cleaned = response.split("[TOOL_CALLS]")[0].strip()
# If we have a structured response, format it nicely
if "Potential missed diagnoses" in cleaned or "Flagged medication conflicts" in cleaned:
# Add markdown formatting for better readability
formatted = []
for line in cleaned.split("\n"):
if line.startswith("Potential missed diagnoses"):
formatted.append(f"### 🔍 Potential Missed Diagnoses")
elif line.startswith("Flagged medication conflicts"):
formatted.append(f"\n### ⚠️ Flagged Medication Conflicts")
elif line.startswith("Incomplete assessments"):
formatted.append(f"\n### 📋 Incomplete Assessments")
elif line.startswith("Highlighted abnormal results"):
formatted.append(f"\n### ❗ Abnormal Results Needing Follow-up")
else:
formatted.append(line)
return "\n".join(formatted)
return cleaned
def analyze(message: str, history: list, files: list):
start_time = datetime.now()
logger.info(f"Starting analysis for message: {message[:100]}...")
if files:
logger.info(f"Processing {len(files)} uploaded files")
history = history + [{"role": "user", "content": message},
{"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."}]
yield history, None
extracted = ""
file_hash_value = ""
if files:
try:
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
extracted = "\n".join(results)
file_hash_value = file_hash(files[0].name)
logger.info(f"Processed {len(files)} files, extracted {len(extracted)} characters")
except Exception as e:
logger.error(f"Error processing files: {str(e)}")
history[-1] = {"role": "assistant", "content": f"❌ Error processing files: {str(e)}"}
yield history, None
return
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
1. List potential missed diagnoses
2. Flag any medication conflicts
3. Note incomplete assessments
4. Highlight abnormal results needing follow-up
Medical Records:
{extracted[:12000]}
### Potential Oversights:
"""
logger.info(f"Generated prompt with {len(prompt)} characters")
response_chunks = []
try:
logger.info("Starting model inference...")
for chunk in agent.run_gradio_chat(
message=prompt,
history=[],
temperature=0.2,
max_new_tokens=1024,
max_token=4096,
call_agent=False,
conversation=[]
):
if not chunk:
continue
if isinstance(chunk, str):
response_chunks.append(chunk)
elif isinstance(chunk, list):
response_chunks.extend([c.content for c in chunk if hasattr(c, 'content')])
partial_response = "".join(response_chunks)
formatted_partial = format_response_for_ui(partial_response)
if formatted_partial:
history[-1] = {"role": "assistant", "content": formatted_partial}
yield history, None
full_response = "".join(response_chunks)
logger.info(f"Full model response received: {full_response[:500]}...")
final_output = format_response_for_ui(full_response)
if not final_output or len(final_output) < 20: # Very short response
final_output = "No clear oversights identified. Recommend comprehensive review."
logger.info("No significant findings detected in analysis")
history[-1] = {"role": "assistant", "content": final_output}
# Save report
report_path = None
if file_hash_value:
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
try:
with open(report_path, "w", encoding="utf-8") as f:
f.write(final_output)
logger.info(f"Saved report to {report_path}")
except Exception as e:
logger.error(f"Error saving report: {str(e)}")
elapsed = (datetime.now() - start_time).total_seconds()
logger.info(f"Analysis completed in {elapsed:.2f} seconds")
yield history, report_path if report_path and os.path.exists(report_path) else None
except Exception as e:
logger.error(f"Error during analysis: {str(e)}", exc_info=True)
history[-1] = {"role": "assistant", "content": f"❌ Error during analysis: {str(e)}"}
yield history, None
def create_ui(agent):
with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo:
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
gr.Markdown("""
<div style='text-align: center; margin-bottom: 20px;'>
Upload medical records and receive analysis of potential oversights, including:<br>
- Missed diagnoses - Medication conflicts - Incomplete assessments - Abnormal results needing follow-up
</div>
""")
with gr.Row():
with gr.Column(scale=2):
file_upload = gr.File(
label="Upload Medical Records",
file_types=[".pdf", ".csv", ".xls", ".xlsx"],
file_count="multiple",
interactive=True
)
msg_input = gr.Textbox(
placeholder="Ask about potential oversights...",
show_label=False,
lines=3,
max_lines=5
)
send_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Analysis Results",
height=600,
bubble_full_width=False,
show_copy_button=True
)
download_output = gr.File(
label="Download Full Report",
interactive=False
)
# Examples for quick testing
examples = gr.Examples(
examples=[
["Are there any potential missed diagnoses in these records?"],
["What medication conflicts should I be aware of?"],
["Are there any incomplete assessments in this case?"]
],
inputs=[msg_input],
label="Example Questions"
)
send_btn.click(
analyze,
inputs=[msg_input, gr.State([]), file_upload],
outputs=[chatbot, download_output]
)
msg_input.submit(
analyze,
inputs=[msg_input, gr.State([]), file_upload],
outputs=[chatbot, download_output]
)
# Add some footer text
gr.Markdown("""
<div style='text-align: center; margin-top: 20px; color: #666; font-size: 0.9em;'>
Note: This tool provides preliminary analysis only. Always verify findings with complete clinical evaluation.
</div>
""")
return demo
if __name__ == "__main__":
logger.info("🚀 Launching Clinical Oversight Assistant...")
try:
agent = init_agent()
demo = create_ui(agent)
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=[report_dir],
share=False
)
except Exception as e:
logger.error(f"Failed to launch application: {str(e)}", exc_info=True)
raise