CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
94b553f verified
raw
history blame
26.3 kB
import sys
import os
import pandas as pd
import pdfplumber
import json
import gradio as gr
from typing import List, Dict, Generator, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import psutil
import subprocess
import logging
import torch
import gc
from diskcache import Cache
from transformers import AutoTokenizer
# ==================== CONFIGURATION ====================
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Setup directories
PERSISTENT_DIR = "/data/hf_cache"
DIRECTORIES = {
"models": os.path.join(PERSISTENT_DIR, "txagent_models"),
"tools": os.path.join(PERSISTENT_DIR, "tool_cache"),
"cache": os.path.join(PERSISTENT_DIR, "cache"),
"reports": os.path.join(PERSISTENT_DIR, "reports"),
"vllm": os.path.join(PERSISTENT_DIR, "vllm_cache")
}
# Create directories
for dir_path in DIRECTORIES.values():
os.makedirs(dir_path, exist_ok=True)
# Environment variables
os.environ.update({
"HF_HOME": DIRECTORIES["models"],
"TRANSFORMERS_CACHE": DIRECTORIES["models"],
"VLLM_CACHE_DIR": DIRECTORIES["vllm"],
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
# Add src path for txagent
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
# ==================== UTILITY FUNCTIONS ====================
def sanitize_text(text: str) -> str:
"""Clean and sanitize text input"""
return text.encode("utf-8", "ignore").decode("utf-8")
def get_file_hash(file_path: str) -> str:
"""Generate MD5 hash of file content"""
with open(file_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def log_system_resources(tag: str = "") -> None:
"""Log system resource usage"""
try:
cpu = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used//(1024**2)}MB/{mem.total//(1024**2)}MB")
gpu_info = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu",
"--format=csv,nounits,noheader"],
capture_output=True, text=True
)
if gpu_info.returncode == 0:
used, total, util = gpu_info.stdout.strip().split(", ")
logger.info(f"[{tag}] GPU: {used}MB/{total}MB | Util: {util}%")
except Exception as e:
logger.error(f"[{tag}] Resource monitoring failed: {e}")
# ==================== FILE PROCESSING ====================
class FileProcessor:
@staticmethod
def extract_pdf_text(file_path: str, cache: Cache) -> str:
"""Extract text from PDF with caching"""
cache_key = f"pdf_{get_file_hash(file_path)}"
if cache_key in cache:
return cache[cache_key]
try:
with pdfplumber.open(file_path) as pdf:
total_pages = len(pdf.pages)
if not total_pages:
return ""
def process_page_range(start: int, end: int) -> List[tuple]:
results = []
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages[start:end]:
page_num = start + pdf.pages.index(page)
text = page.extract_text() or ""
results.append((page_num, f"=== Page {page_num + 1} ===\n{text.strip()}"))
return results
batch_size = 10
batches = [(i, min(i+batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
text_chunks = [""] * total_pages
with ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(process_page_range, start, end) for start, end in batches]
for future in as_completed(futures):
for page_num, text in future.result():
text_chunks[page_num] = text
result = "\n\n".join(filter(None, text_chunks))
cache[cache_key] = result
return result
except Exception as e:
logger.error(f"PDF processing error: {e}")
return f"PDF processing error: {str(e)}"
@staticmethod
def excel_to_data(file_path: str, cache: Cache) -> List[Dict]:
"""Convert Excel file to structured data with caching"""
cache_key = f"excel_{get_file_hash(file_path)}"
if cache_key in cache:
return cache[cache_key]
try:
df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str)
content = df.where(pd.notnull(df), "").astype(str).values.tolist()
result = [{"filename": os.path.basename(file_path), "rows": content, "type": "excel"}]
cache[cache_key] = result
return result
except Exception as e:
logger.error(f"Excel processing error: {e}")
return [{"error": f"Excel processing error: {str(e)}"}]
@staticmethod
def csv_to_data(file_path: str, cache: Cache) -> List[Dict]:
"""Convert CSV file to structured data with caching"""
cache_key = f"csv_{get_file_hash(file_path)}"
if cache_key in cache:
return cache[cache_key]
try:
chunks = []
for chunk in pd.read_csv(
file_path, header=None, dtype=str,
encoding_errors='replace', on_bad_lines='skip', chunksize=10000
):
chunks.append(chunk)
df = pd.concat(chunks) if chunks else pd.DataFrame()
content = df.where(pd.notnull(df), "").astype(str).values.tolist()
result = [{"filename": os.path.basename(file_path), "rows": content, "type": "csv"}]
cache[cache_key] = result
return result
except Exception as e:
logger.error(f"CSV processing error: {e}")
return [{"error": f"CSV processing error: {str(e)}"}]
@classmethod
def process_file(cls, file_path: str, file_type: str, cache: Cache) -> List[Dict]:
"""Route file processing based on type"""
processors = {
"pdf": cls.extract_pdf_text,
"xls": cls.excel_to_data,
"xlsx": cls.excel_to_data,
"csv": cls.csv_to_data
}
if file_type not in processors:
return [{"error": f"Unsupported file type: {file_type}"}]
try:
result = processors[file_type](file_path, cache)
if file_type == "pdf":
return [{
"filename": os.path.basename(file_path),
"content": result,
"status": "initial",
"type": "pdf"
}]
return result
except Exception as e:
logger.error(f"Error processing {file_type} file: {e}")
return [{"error": f"Error processing file: {str(e)}"}]
# ==================== TEXT PROCESSING ====================
class TextProcessor:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
self.cache = Cache(DIRECTORIES["cache"], size_limit=10*1024**3)
def chunk_text(self, text: str, max_tokens: int = 1200) -> List[str]:
"""Split text into token-limited chunks"""
tokens = self.tokenizer.encode(text)
return [
self.tokenizer.decode(tokens[i:i+max_tokens])
for i in range(0, len(tokens), max_tokens)
]
def clean_response(self, text: str) -> str:
"""Clean and format model response"""
text = sanitize_text(text)
text = re.sub(r"\[.*?\]|\bNone\b", "", text)
diagnoses = []
in_diagnoses = False
for line in text.splitlines():
line = line.strip()
if not line:
continue
if re.match(r"###\s*Missed Diagnoses", line):
in_diagnoses = True
continue
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
in_diagnoses = False
continue
if in_diagnoses and re.match(r"-\s*.+", line):
diagnosis = re.sub(r"^\-\s*", "", line).strip()
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
diagnoses.append(diagnosis)
return " ".join(diagnoses) if diagnoses else ""
def summarize_results(self, analysis: str) -> str:
"""Generate concise summary from full analysis"""
chunks = analysis.split("--- Analysis for Chunk")
diagnoses = []
for chunk in chunks:
chunk = chunk.strip()
if not chunk or "No oversights identified" in chunk:
continue
in_diagnoses = False
for line in chunk.splitlines():
line = line.strip()
if not line:
continue
if re.match(r"###\s*Missed Diagnoses", line):
in_diagnoses = True
continue
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
in_diagnoses = False
continue
if in_diagnoses and re.match(r"-\s*.+", line):
diagnosis = re.sub(r"^\-\s*", "", line).strip()
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
diagnoses.append(diagnosis)
unique_diagnoses = list(dict.fromkeys(diagnoses))
if not unique_diagnoses:
return "No missed diagnoses were identified in the provided records."
if len(unique_diagnoses) > 1:
summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
summary += f", and {unique_diagnoses[-1]}"
else:
summary = "Missed diagnoses include " + unique_diagnoses[0]
return summary + ", all requiring urgent clinical review."
# ==================== CORE APPLICATION ====================
class ClinicalOversightApp:
def __init__(self):
self.agent = self._initialize_agent()
self.text_processor = TextProcessor()
self.file_processor = FileProcessor()
def _initialize_agent(self):
"""Initialize the TxAgent with proper configuration"""
logger.info("Initializing AI model...")
log_system_resources("Before Load")
tool_path = os.path.join(DIRECTORIES["tools"], "new_tool.json")
if not os.path.exists(tool_path):
default_tools = os.path.abspath("data/new_tool.json")
shutil.copy(default_tools, tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": tool_path},
force_finish=True,
enable_checker=False,
step_rag_num=4,
seed=100,
additional_default_tools=[],
)
agent.init_model()
log_system_resources("After Load")
logger.info("AI Agent Ready")
return agent
def process_response_stream(self, prompt: str, history: List[dict]) -> Generator[dict, None, None]:
"""Stream the agent's response with proper formatting"""
full_response = ""
for chunk in self.agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []):
if not chunk:
continue
if isinstance(chunk, list):
for message in chunk:
if hasattr(message, 'content') and message.content:
cleaned = self.text_processor.clean_response(message.content)
if cleaned:
full_response += cleaned + " "
yield {"role": "assistant", "content": full_response}
elif isinstance(chunk, str) and chunk.strip():
cleaned = self.text_processor.clean_response(chunk)
if cleaned:
full_response += cleaned + " "
yield {"role": "assistant", "content": full_response}
def analyze(self, message: str, history: List[dict], files: List) -> Generator[tuple, None, None]:
"""Main analysis pipeline with proper output formatting"""
chatbot_output = history.copy()
download_output = None
final_summary = ""
progress_text = {"value": "Starting analysis...", "visible": True}
try:
# Add user message to history
chatbot_output.append({"role": "user", "content": message})
yield (chatbot_output, download_output, final_summary, progress_text)
# Process uploaded files
.ArrayIndexOutOfBoundsException: Index -1 out of bounds for length 0
extracted = []
file_hash_value = ""
if files:
with ThreadPoolExecutor(max_workers=2) as executor:
futures = []
for f in files:
file_type = f.name.split(".")[-1].lower()
futures.append(executor.submit(self.file_processor.process_file, f.name, file_type, self.text_processor.cache))
for i, future in enumerate(as_completed(futures), 1):
try:
extracted.extend(future.result())
progress_text = self._update_progress(i, len(files), "Processing files")
yield (chatbot_output, download_output, final_summary, progress_text)
except Exception as e:
logger.error(f"File processing error: {e}")
extracted.append({"error": f"Error processing file: {str(e)}"})
file_hash_value = get_file_hash(files[0].name) if files else ""
chatbot_output.append({"role": "assistant", "content": "✅ File processing complete"})
progress_text = self._update_progress(len(files), len(files), "Files processed")
yield (chatbot_output, download_output, final_summary, progress_text)
# Analyze content
text_content = "\n".join(json.dumps(item) for item in extracted)
chunks = self.text_processor.chunk_text(text_content)
combined_response = ""
for chunk_idx, chunk in enumerate(chunks, 1):
prompt = f"""
Analyze this patient record for missed diagnoses. Provide a concise, evidence-based summary
as a single paragraph without headings or bullet points. Include specific clinical findings
with their potential implications and urgent review recommendations. If no missed diagnoses
are found, state 'No missed diagnoses identified'.
Patient Record (Chunk {chunk_idx}/{len(chunks)}):
{chunk[:1200]}
"""
chatbot_output.append({"role": "assistant", "content": ""})
progress_text = self._update_progress(chunk_idx, len(chunks), "Analyzing")
yield (chatbot_output, download_output, final_summary, progress_text)
# Stream response
chunk_response = ""
for update in self.process_response_stream(prompt, chatbot_output):
chatbot_output[-1] = update
chunk_response = update["content"]
progress_text = self._update_progress(chunk_idx, len(chunks), "Analyzing")
yield (chatbot_output, download_output, final_summary, progress_text)
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
torch.cuda.empty_cache()
gc.collect()
# Generate final outputs
final_summary = self.text_processor.summarize_results(combined_response)
report_path = os.path.join(DIRECTORIES["reports"], f"{file_hash_value}_report.txt") if file_hash_value else None
if report_path:
with open(report_path, "w", encoding="utf-8") as f:
f.write(combined_response + "\n\n" + final_summary)
download_output = report_path if report_path and os.path.exists(report_path) else None
progress_text = {"visible": False}
yield (chatbot_output, download_output, final_summary, progress_text)
except Exception as e:
logger.error(f"Analysis error: {e}")
chatbot_output.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
final_summary = f"Error occurred: {str(e)}"
progress_text = {"visible": False}
yield (chatbot_output, download_output, final_summary, progress_text)
def _update_progress(self, current: int, total: int, stage: str = "") -> Dict[str, Any]:
"""Format progress update for UI"""
progress = f"{stage} - {current}/{total}" if stage else f"{current}/{total}"
return {"value": progress, "visible": True}
def create_interface(self):
"""Create Gradio interface with ChatGPT-like design"""
css = """
body, .gradio-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: var(--background);
color: var(--text-color);
}
.gradio-container {
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
.chat-container {
background: var(--chat-bg);
border-radius: 12px;
padding: 20px;
height: 80vh;
overflow-y: auto;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.message {
margin: 10px 0;
padding: 12px 16px;
border-radius: 12px;
max-width: 80%;
transition: all 0.2s ease;
}
.message.user {
background: #007bff;
color: white;
margin-left: auto;
}
.message.assistant {
background: var(--message-bg);
color: var(--text-color);
}
.input-container {
display: flex;
align-items: center;
margin-top: 20px;
background: var(--chat-bg);
padding: 10px 20px;
border-radius: 25px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.input-textbox {
flex-grow: 1;
border: none;
background: transparent;
color: var(--text-color);
outline: none;
}
.send-btn {
background: #007bff;
color: white;
border: none;
border-radius: 20px;
padding: 8px 16px;
margin-left: 10px;
}
.send-btn:hover {
background: #0056b3;
}
.sidebar {
background: var(--sidebar-bg);
padding: 20px;
border-radius: 12px;
margin-top: 20px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.sidebar-hidden {
display: none;
}
.header {
text-align: center;
margin-bottom: 20px;
}
.theme-toggle {
position: absolute;
top: 20px;
right: 20px;
background: #007bff;
color: white;
border: none;
border-radius: 20px;
padding: 8px 16px;
}
:root {
--background: #ffffff;
--text-color: #333333;
--chat-bg: #f7f7f8;
--message-bg: #e5e5ea;
--sidebar-bg: #f1f1f1;
}
@media (prefers-color-scheme: dark) {
:root {
--background: #1e2a44;
--text-color: #ffffff;
--chat-bg: #2d3b55;
--message-bg: #3e4c6a;
--sidebar-bg: #2a3650;
}
}
@media (max-width: 600px) {
.gradio-container {
padding: 10px;
}
.chat-container {
height: 70vh;
}
.input-container {
flex-direction: column;
gap: 10px;
}
.send-btn {
width: 100%;
margin-left: 0;
}
}
"""
js = """
function toggleTheme() {
const root = document.documentElement;
const isDark = root.style.getPropertyValue('--background') === '#1e2a44';
root.style.setProperty('--background', isDark ? '#ffffff' : '#1e2a44');
root.style.setProperty('--text-color', isDark ? '#333333' : '#ffffff');
root.style.setProperty('--chat-bg', isDark ? '#f7f7f8' : '#2d3b55');
root.style.setProperty('--message-bg', isDark ? '#e5e5ea' : '#3e4c6a');
root.style.setProperty('--sidebar-bg', isDark ? '#f1f1f1' : '#2a3650');
localStorage.setItem('theme', isDark ? 'light' : 'dark');
}
function toggleSidebar() {
const sidebar = document.querySelector('.sidebar');
sidebar.classList.toggle('sidebar-hidden');
}
document.addEventListener('DOMContentLoaded', () => {
const savedTheme = localStorage.getItem('theme');
if (savedTheme === 'dark') toggleTheme();
document.querySelector('.sidebar').classList.add('sidebar-hidden');
});
"""
with gr.Blocks(theme=gr.themes.Default(), css=css, js=js, title="Clinical Oversight Assistant") as app:
gr.HTML("""
<div class='header'>
<h1 style='color: var(--text-color);'>🩺 Clinical Oversight Assistant</h1>
<p style='color: var(--text-color); opacity: 0.7;'>
AI-powered analysis of patient records for missed diagnoses
</p>
</div>
""")
gr.Button("Toggle Light/Dark Mode", elem_classes="theme-toggle").click(
None, None, None, _js="toggleTheme"
)
with gr.Column(elem_classes="chat-container"):
chatbot = gr.Chatbot(
label="Clinical Analysis",
height="100%",
show_copy_button=True,
type="messages",
elem_classes="chatbot"
)
with gr.Row():
gr.Button("Show/Hide Tools", variant="secondary").click(
None, None, None, _js="toggleSidebar"
)
with gr.Column(elem_classes="sidebar"):
file_upload = gr.File(
file_types=[".pdf", ".csv", ".xls", ".xlsx"],
file_count="multiple",
label="Upload Patient Records"
)
gr.Markdown("### 📝 Summary of Findings")
final_summary = gr.Markdown(
"Analysis results will appear here..."
)
gr.Markdown("### 📂 Report Download")
download_output = gr.File(
label="Full Report",
visible=False,
interactive=False
)
with gr.Row(elem_classes="input-container"):
msg_input = gr.Textbox(
placeholder="Ask about potential oversights or upload files...",
show_label=False,
container=False,
elem_classes="input-textbox",
autofocus=True
)
send_btn = gr.Button(
"Analyze",
variant="primary",
elem_classes="send-btn"
)
progress_text = gr.Textbox(
label="Progress Status",
visible=False,
interactive=False
)
send_btn.click(
self.analyze,
inputs=[msg_input, chatbot, file_upload],
outputs=[chatbot, download_output, final_summary, progress_text],
show_progress="hidden"
)
msg_input.submit(
self.analyze,
inputs=[msg_input, chatbot, file_upload],
outputs=[chatbot, download_output, final_summary, progress_text],
show_progress="hidden"
)
app.load(
lambda: [
[], None, "", "", None, {"visible": False}
],
outputs=[chatbot, download_output, final_summary, msg_input, file_upload, progress_text],
queue=False
)
return app
# ==================== APPLICATION ENTRY POINT ====================
if __name__ == "__main__":
try:
logger.info("Starting Clinical Oversight Assistant...")
app = ClinicalOversightApp()
interface = app.create_interface()
interface.queue(
api_open=False,
max_size=20
).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=[DIRECTORIES["reports"]],
share=False
)
except Exception as e:
logger.error(f"Application failed to start: {e}")
raise
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()