CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
02ebb35 verified
raw
history blame
22.3 kB
import sys
import os
import pandas as pd
import pdfplumber
import json
import gradio as gr
from typing import List, Dict, Generator, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import psutil
import subprocess
import logging
import torch
import gc
from diskcache import Cache
from transformers import AutoTokenizer
# ==================== CONFIGURATION ====================
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Setup directories
PERSISTENT_DIR = "/data/hf_cache"
DIRECTORIES = {
"models": os.path.join(PERSISTENT_DIR, "txagent_models"),
"tools": os.path.join(PERSISTENT_DIR, "tool_cache"),
"cache": os.path.join(PERSISTENT_DIR, "cache"),
"reports": os.path.join(PERSISTENT_DIR, "reports"),
"vllm": os.path.join(PERSISTENT_DIR, "vllm_cache")
}
# Create directories
for dir_path in DIRECTORIES.values():
os.makedirs(dir_path, exist_ok=True)
# Environment variables
os.environ.update({
"HF_HOME": DIRECTORIES["models"],
"TRANSFORMERS_CACHE": DIRECTORIES["models"],
"VLLM_CACHE_DIR": DIRECTORIES["vllm"],
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
from txagent.txagent import TxAgent
# ==================== UTILITY FUNCTIONS ====================
def sanitize_text(text: str) -> str:
"""Clean and sanitize text input"""
return text.encode("utf-8", "ignore").decode("utf-8")
def get_file_hash(file_path: str) -> str:
"""Generate MD5 hash of file content"""
with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def log_system_resources(tag: str = "") -> None:
"""Log system resource usage"""
try:
cpu = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used//(1024**2)}MB/{mem.total//(1024**2)}MB")
gpu_info = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu",
"--format=csv,nounits,noheader"],
capture_output=True, text=True
)
if gpu_info.returncode == 0:
used, total, util = gpu_info.stdout.strip().split(", ")
logger.info(f"[{tag}] GPU: {used}MB/{total}MB | Util: {util}%")
except Exception as e:
logger.error(f"[{tag}] Resource monitoring failed: {e}")
# ==================== FILE PROCESSING ====================
class FileProcessor:
@staticmethod
def extract_pdf_text(file_path: str) -> str:
"""Extract text from PDF with parallel processing"""
try:
with pdfplumber.open(file_path) as pdf:
total_pages = len(pdf.pages)
if not total_pages:
return ""
def process_page_range(start: int, end: int) -> List[tuple]:
results = []
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages[start:end]:
page_num = start + pdf.pages.index(page)
text = page.extract_text() or ""
results.append((page_num, f"=== Page {page_num + 1} ===\n{text.strip()}"))
return results
batch_size = 10
batches = [(i, min(i+batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
text_chunks = [""] * total_pages
with ThreadPoolExecutor(max_workers=6) as executor:
futures = [executor.submit(process_page_range, start, end) for start, end in batches]
for future in as_completed(futures):
for page_num, text in future.result():
text_chunks[page_num] = text
return "\n\n".join(filter(None, text_chunks))
except Exception as e:
logger.error(f"PDF processing error: {e}")
return f"PDF processing error: {str(e)}"
@staticmethod
def excel_to_data(file_path: str) -> List[Dict]:
"""Convert Excel file to structured data"""
try:
df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str)
content = df.where(pd.notnull(df), "").astype(str).values.tolist()
return [{"filename": os.path.basename(file_path), "rows": content, "type": "excel"}]
except Exception as e:
logger.error(f"Excel processing error: {e}")
return [{"error": f"Excel processing error: {str(e)}"}]
@staticmethod
def csv_to_data(file_path: str) -> List[Dict]:
"""Convert CSV file to structured data"""
try:
chunks = []
for chunk in pd.read_csv(
file_path, header=None, dtype=str,
encoding_errors='replace', on_bad_lines='skip', chunksize=10000
):
chunks.append(chunk)
df = pd.concat(chunks) if chunks else pd.DataFrame()
content = df.where(pd.notnull(df), "").astype(str).values.tolist()
return [{"filename": os.path.basename(file_path), "rows": content, "type": "csv"}]
except Exception as e:
logger.error(f"CSV processing error: {e}")
return [{"error": f"CSV processing error: {str(e)}"}]
@classmethod
def process_file(cls, file_path: str, file_type: str) -> List[Dict]:
"""Route file processing based on type"""
processors = {
"pdf": cls.extract_pdf_text,
"xls": cls.excel_to_data,
"xlsx": cls.excel_to_data,
"csv": cls.csv_to_data
}
if file_type not in processors:
return [{"error": f"Unsupported file type: {file_type}"}]
try:
result = processors[file_type](file_path)
if file_type == "pdf":
return [{
"filename": os.path.basename(file_path),
"content": result,
"status": "initial",
"type": "pdf"
}]
return result
except Exception as e:
logger.error(f"Error processing {file_type} file: {e}")
return [{"error": f"Error processing file: {str(e)}"}]
# ==================== TEXT PROCESSING ====================
class TextProcessor:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
self.cache = Cache(DIRECTORIES["cache"], size_limit=10*1024**3)
def chunk_text(self, text: str, max_tokens: int = 1800) -> List[str]:
"""Split text into token-limited chunks"""
tokens = self.tokenizer.encode(text)
return [
self.tokenizer.decode(tokens[i:i+max_tokens])
for i in range(0, len(tokens), max_tokens)
]
def clean_response(self, text: str) -> str:
"""Clean and format model response"""
text = sanitize_text(text)
text = re.sub(
r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\."
r"|Since the previous attempts.*?\.|I need to.*?medications\."
r"|Retrieving tools.*?\.", "", text, flags=re.DOTALL
)
diagnoses = []
in_diagnoses = False
for line in text.splitlines():
line = line.strip()
if not line:
continue
if re.match(r"###\s*Missed Diagnoses", line):
in_diagnoses = True
continue
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
in_diagnoses = False
continue
if in_diagnoses and re.match(r"-\s*.+", line):
diagnosis = re.sub(r"^\-\s*", "", line).strip()
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
diagnoses.append(diagnosis)
return " ".join(diagnoses) if diagnoses else ""
def summarize_results(self, analysis: str) -> str:
"""Generate concise summary from full analysis"""
chunks = analysis.split("--- Analysis for Chunk")
diagnoses = []
for chunk in chunks:
chunk = chunk.strip()
if not chunk or "No oversights identified" in chunk:
continue
in_diagnoses = False
for line in chunk.splitlines():
line = line.strip()
if not line:
continue
if re.match(r"###\s*Missed Diagnoses", line):
in_diagnoses = True
continue
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
in_diagnoses = False
continue
if in_diagnoses and re.match(r"-\s*.+", line):
diagnosis = re.sub(r"^\-\s*", "", line).strip()
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
diagnoses.append(diagnosis)
unique_diagnoses = list(dict.fromkeys(diagnoses)) # Remove duplicates
if not unique_diagnoses:
return "No missed diagnoses were identified in the provided records."
if len(unique_diagnoses) > 1:
summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
summary += f", and {unique_diagnoses[-1]}"
else:
summary = "Missed diagnoses include " + unique_diagnoses[0]
return summary + ", all requiring urgent clinical review."
# ==================== CORE APPLICATION ====================
class ClinicalOversightApp:
def __init__(self):
self.agent = self._initialize_agent()
self.text_processor = TextProcessor()
self.file_processor = FileProcessor()
def _initialize_agent(self):
"""Initialize the TxAgent with proper configuration"""
logger.info("Initializing AI model...")
log_system_resources("Before Load")
tool_path = os.path.join(DIRECTORIES["tools"], "new_tool.json")
if not os.path.exists(tool_path):
default_tools = os.path.abspath("data/new_tool.json")
shutil.copy(default_tools, tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": tool_path},
force_finish=True,
enable_checker=False,
step_rag_num=4,
seed=100,
additional_default_tools=[],
)
agent.init_model()
log_system_resources("After Load")
logger.info("AI Agent Ready")
return agent
def process_response_stream(self, prompt: str, history: List[dict]) -> Generator[dict, None, None]:
"""Stream the agent's response with proper formatting"""
full_response = ""
for chunk in self.agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []):
if not chunk:
continue
if isinstance(chunk, list):
for message in chunk:
if hasattr(message, 'content') and message.content:
cleaned = self.text_processor.clean_response(message.content)
if cleaned:
full_response += cleaned + " "
yield {"role": "assistant", "content": full_response}
elif isinstance(chunk, str) and chunk.strip():
cleaned = self.text_processor.clean_response(chunk)
if cleaned:
full_response += cleaned + " "
yield {"role": "assistant", "content": full_response}
def analyze(self, message: str, history: List[dict], files: List) -> Generator[Dict[str, Any], None, None]:
"""Main analysis pipeline with proper output formatting"""
# Initialize all output components
outputs = {
"chatbot": history.copy(),
"download_output": None,
"final_summary": "",
"progress_text": {"value": "Starting analysis...", "visible": True}
}
yield outputs
try:
# Add user message to history
history.append({"role": "user", "content": message})
outputs["chatbot"] = history
yield outputs
# Process uploaded files
extracted = []
file_hash_value = ""
if files:
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for f in files:
file_type = f.name.split(".")[-1].lower()
futures.append(executor.submit(self.file_processor.process_file, f.name, file_type))
for i, future in enumerate(as_completed(futures), 1):
try:
extracted.extend(future.result())
outputs["progress_text"] = self._update_progress(i, len(files), "Processing files")
yield outputs
except Exception as e:
logger.error(f"File processing error: {e}")
extracted.append({"error": f"Error processing file: {str(e)}"})
file_hash_value = get_file_hash(files[0].name) if files else ""
history.append({"role": "assistant", "content": "✅ File processing complete"})
outputs.update({
"chatbot": history,
"progress_text": self._update_progress(len(files), len(files), "Files processed")
})
yield outputs
# Analyze content
text_content = "\n".join(json.dumps(item) for item in extracted)
chunks = self.text_processor.chunk_text(text_content)
combined_response = ""
for chunk_idx, chunk in enumerate(chunks, 1):
prompt = f"""
Analyze this patient record for missed diagnoses. Provide a concise, evidence-based summary
as a single paragraph without headings or bullet points. Include specific clinical findings
with their potential implications and urgent review recommendations. If no missed diagnoses
are found, state 'No missed diagnoses identified'.
Patient Record (Chunk {chunk_idx}/{len(chunks)}):
{chunk[:1800]}
"""
history.append({"role": "assistant", "content": ""})
outputs.update({
"chatbot": history,
"progress_text": self._update_progress(chunk_idx, len(chunks), "Analyzing")
})
yield outputs
# Stream response
chunk_response = ""
for update in self.process_response_stream(prompt, history):
history[-1] = update
chunk_response = update["content"]
outputs.update({
"chatbot": history,
"progress_text": self._update_progress(chunk_idx, len(chunks), "Analyzing")
})
yield outputs
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
torch.cuda.empty_cache()
gc.collect()
# Generate final outputs
summary = self.text_processor.summarize_results(combined_response)
report_path = os.path.join(DIRECTORIES["reports"], f"{file_hash_value}_report.txt") if file_hash_value else None
if report_path:
with open(report_path, "w", encoding="utf-8") as f:
f.write(combined_response + "\n\n" + summary)
outputs.update({
"download_output": report_path if report_path else None,
"final_summary": summary,
"progress_text": {"visible": False}
})
yield outputs
except Exception as e:
logger.error(f"Analysis error: {e}")
history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
outputs.update({
"chatbot": history,
"final_summary": f"Error occurred: {str(e)}",
"progress_text": {"visible": False}
})
yield outputs
def _update_progress(self, current: int, total: int, stage: str = "") -> Dict[str, Any]:
"""Format progress update for UI"""
progress = f"{stage} - {current}/{total}" if stage else f"{current}/{total}"
return {"value": progress, "visible": True, "label": f"Progress: {progress}"}
def create_interface(self):
"""Create Gradio interface with improved layout"""
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate"
),
title="Clinical Oversight Assistant",
css="""
.diagnosis-summary {
border-left: 4px solid #4f46e5;
padding: 12px;
background: #f8fafc;
border-radius: 4px;
}
.file-upload {
border: 2px dashed #cbd5e1;
border-radius: 8px;
padding: 20px;
}
"""
) as app:
# Header Section
gr.Markdown("""
<div style='text-align: center; margin-bottom: 20px;'>
<h1 style='color: #4f46e5;'>🩺 Clinical Oversight Assistant</h1>
<p style='color: #64748b;'>
AI-powered analysis of patient records for potential missed diagnoses
</p>
</div>
""")
with gr.Row(equal_height=False):
# Main Chat Column
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Clinical Analysis",
height=600,
show_copy_button=True,
avatar_images=(
"assets/user.png",
"assets/assistant.png"
) if os.path.exists("assets/user.png") else None,
bubble_full_width=False,
type="messages",
elem_classes=["chat-container"]
)
# Results Column
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### 📝 Summary of Findings")
final_summary = gr.Markdown(
"Analysis results will appear here...",
elem_classes=["diagnosis-summary"]
)
with gr.Group():
gr.Markdown("### 📂 Report Download")
download_output = gr.File(
label="Full Report",
visible=False,
interactive=False
)
# Input Section
with gr.Row():
file_upload = gr.File(
file_types=[".pdf", ".csv", ".xls", ".xlsx"],
file_count="multiple",
label="Upload Patient Records",
elem_classes=["file-upload"]
)
# Interaction Section
with gr.Row():
msg_input = gr.Textbox(
placeholder="Ask about potential oversights or upload files...",
show_label=False,
container=False,
scale=7,
autofocus=True
)
send_btn = gr.Button(
"Analyze",
variant="primary",
scale=1,
min_width=100
)
# Progress Indicator
progress_text = gr.Textbox(
label="Progress Status",
visible=False,
interactive=False
)
# Event Handlers
send_btn.click(
self.analyze,
inputs=[msg_input, chatbot, file_upload],
outputs=[chatbot, download_output, final_summary, progress_text],
show_progress="hidden"
)
msg_input.submit(
self.analyze,
inputs=[msg_input, chatbot, file_upload],
outputs=[chatbot, download_output, final_summary, progress_text],
show_progress="hidden"
)
app.load(
lambda: [
[], None, "", "", None, {"visible": False}
],
outputs=[chatbot, download_output, final_summary, msg_input, file_upload, progress_text],
queue=False
)
return app
# ==================== APPLICATION ENTRY POINT ====================
if __name__ == "__main__":
try:
logger.info("Starting Clinical Oversight Assistant...")
app = ClinicalOversightApp()
interface = app.create_interface()
interface.queue(
api_open=False,
max_size=20
).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=[DIRECTORIES["reports"]],
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {e}")
raise
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()