|
import sys |
|
import os |
|
import pandas as pd |
|
import pdfplumber |
|
import json |
|
import gradio as gr |
|
from typing import List, Dict, Generator, Any, Optional |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import hashlib |
|
import shutil |
|
import re |
|
import psutil |
|
import subprocess |
|
import logging |
|
import torch |
|
import gc |
|
from diskcache import Cache |
|
from transformers import AutoTokenizer |
|
from pathlib import Path |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
BASE_DIR = Path("/data/hf_cache") |
|
DIRECTORIES = { |
|
"models": BASE_DIR / "txagent_models", |
|
"tools": BASE_DIR / "tool_cache", |
|
"cache": BASE_DIR / "cache", |
|
"reports": BASE_DIR / "reports", |
|
"vllm": BASE_DIR / "vllm_cache" |
|
} |
|
|
|
for dir_path in DIRECTORIES.values(): |
|
dir_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
os.environ.update({ |
|
"HF_HOME": str(DIRECTORIES["models"]), |
|
"TRANSFORMERS_CACHE": str(DIRECTORIES["models"]), |
|
"VLLM_CACHE_DIR": str(DIRECTORIES["vllm"]), |
|
"TOKENIZERS_PARALLELISM": "false", |
|
"CUDA_LAUNCH_BLOCKING": "1" |
|
}) |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
|
|
class FileProcessor: |
|
"""Handles all file processing operations""" |
|
|
|
@staticmethod |
|
def extract_pdf_content(file_path: str) -> str: |
|
"""Extract text from PDF with parallel processing""" |
|
try: |
|
with pdfplumber.open(file_path) as pdf: |
|
total_pages = len(pdf.pages) |
|
if not total_pages: |
|
return "" |
|
|
|
def process_batch(start: int, end: int) -> List[tuple]: |
|
results = [] |
|
with pdfplumber.open(file_path) as pdf: |
|
for page in pdf.pages[start:end]: |
|
page_num = start + pdf.pages.index(page) |
|
text = page.extract_text() or "" |
|
results.append((page_num, f"=== Page {page_num + 1} ===\n{text.strip()}")) |
|
return results |
|
|
|
batch_size = min(10, total_pages) |
|
batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)] |
|
text_chunks = [""] * total_pages |
|
|
|
with ThreadPoolExecutor(max_workers=min(6, os.cpu_count() or 4)) as executor: |
|
futures = [executor.submit(process_batch, start, end) for start, end in batches] |
|
for future in as_completed(futures): |
|
for page_num, text in future.result(): |
|
text_chunks[page_num] = text |
|
|
|
return "\n\n".join(filter(None, text_chunks)) |
|
except Exception as e: |
|
logger.error(f"PDF extraction failed: {e}") |
|
return f"PDF processing error: {str(e)}" |
|
|
|
@staticmethod |
|
def process_tabular_data(file_path: str, file_type: str) -> List[Dict]: |
|
"""Process Excel or CSV files""" |
|
try: |
|
if file_type == "csv": |
|
chunks = pd.read_csv( |
|
file_path, |
|
header=None, |
|
dtype=str, |
|
encoding_errors='replace', |
|
on_bad_lines='skip', |
|
chunksize=10000 |
|
) |
|
df = pd.concat(chunks) if chunks else pd.DataFrame() |
|
else: |
|
try: |
|
df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str) |
|
except: |
|
df = pd.read_excel(file_path, engine='xlrd', header=None, dtype=str) |
|
|
|
return [{ |
|
"filename": os.path.basename(file_path), |
|
"rows": df.where(pd.notnull(df), "").astype(str).values.tolist(), |
|
"type": file_type |
|
}] |
|
except Exception as e: |
|
logger.error(f"{file_type.upper()} processing failed: {e}") |
|
return [{"error": f"{file_type.upper()} processing error: {str(e)}"}] |
|
|
|
@classmethod |
|
def handle_upload(cls, file_path: str, file_type: str) -> List[Dict]: |
|
"""Route file processing based on type""" |
|
processor_map = { |
|
"pdf": cls.extract_pdf_content, |
|
"xls": lambda x: cls.process_tabular_data(x, "excel"), |
|
"xlsx": lambda x: cls.process_tabular_data(x, "excel"), |
|
"csv": lambda x: cls.process_tabular_data(x, "csv") |
|
} |
|
|
|
if file_type not in processor_map: |
|
return [{"error": f"Unsupported file type: {file_type}"}] |
|
|
|
try: |
|
result = processor_map[file_type](file_path) |
|
if file_type == "pdf": |
|
return [{ |
|
"filename": os.path.basename(file_path), |
|
"content": result, |
|
"type": "pdf" |
|
}] |
|
return result |
|
except Exception as e: |
|
logger.error(f"File processing failed: {e}") |
|
return [{"error": f"File processing error: {str(e)}"}] |
|
|
|
class TextAnalyzer: |
|
"""Handles text processing and analysis""" |
|
|
|
def __init__(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") |
|
self.cache = Cache(DIRECTORIES["cache"], size_limit=10*1024**3) |
|
|
|
def chunk_content(self, text: str, max_tokens: int = 1800) -> List[str]: |
|
"""Split text into token-limited chunks""" |
|
tokens = self.tokenizer.encode(text) |
|
return [ |
|
self.tokenizer.decode(tokens[i:i+max_tokens]) |
|
for i in range(0, len(tokens), max_tokens) |
|
] |
|
|
|
def clean_output(self, text: str) -> str: |
|
"""Clean and format model response""" |
|
text = text.encode("utf-8", "ignore").decode("utf-8") |
|
text = re.sub( |
|
r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\." |
|
r"|Since the previous attempts.*?\.|I need to.*?medications\." |
|
r"|Retrieving tools.*?\.", "", text, flags=re.DOTALL |
|
) |
|
|
|
diagnoses = [] |
|
in_section = False |
|
|
|
for line in text.splitlines(): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
if re.match(r"###\s*Missed Diagnoses", line): |
|
in_section = True |
|
continue |
|
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): |
|
in_section = False |
|
continue |
|
if in_section and re.match(r"-\s*.+", line): |
|
diagnosis = re.sub(r"^\-\s*", "", line).strip() |
|
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): |
|
diagnoses.append(diagnosis) |
|
|
|
return " ".join(diagnoses) if diagnoses else "" |
|
|
|
def generate_summary(self, analysis: str) -> str: |
|
"""Create concise clinical summary""" |
|
findings = [] |
|
for chunk in analysis.split("--- Analysis for Chunk"): |
|
chunk = chunk.strip() |
|
if not chunk or "No oversights identified" in chunk: |
|
continue |
|
|
|
in_section = False |
|
for line in chunk.splitlines(): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
if re.match(r"###\s*Missed Diagnoses", line): |
|
in_section = True |
|
continue |
|
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): |
|
in_section = False |
|
continue |
|
if in_section and re.match(r"-\s*.+", line): |
|
finding = re.sub(r"^\-\s*", "", line).strip() |
|
if finding and not re.match(r"No issues identified", finding, re.IGNORECASE): |
|
findings.append(finding) |
|
|
|
unique_findings = list(dict.fromkeys(findings)) |
|
|
|
if not unique_findings: |
|
return "No clinical concerns identified in the provided records." |
|
|
|
if len(unique_findings) > 1: |
|
summary = "Potential concerns include: " + ", ".join(unique_findings[:-1]) |
|
summary += f", and {unique_findings[-1]}" |
|
else: |
|
summary = "Potential concern identified: " + unique_findings[0] |
|
|
|
return summary + ". Recommend urgent clinical review." |
|
|
|
class ClinicalAgent: |
|
"""Main application controller""" |
|
|
|
def __init__(self): |
|
self.agent = self._init_agent() |
|
self.file_processor = FileProcessor() |
|
self.text_analyzer = TextAnalyzer() |
|
|
|
def _init_agent(self) -> Any: |
|
"""Initialize the AI agent""" |
|
logger.info("Initializing clinical agent...") |
|
self._log_system_status("pre-init") |
|
|
|
tool_path = DIRECTORIES["tools"] / "new_tool.json" |
|
if not tool_path.exists(): |
|
default_tools = Path("data/new_tool.json") |
|
if default_tools.exists(): |
|
shutil.copy(default_tools, tool_path) |
|
|
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={"new_tool": str(tool_path)}, |
|
force_finish=True, |
|
enable_checker=False, |
|
step_rag_num=4, |
|
seed=100, |
|
additional_default_tools=[], |
|
) |
|
agent.init_model() |
|
|
|
self._log_system_status("post-init") |
|
logger.info("Clinical agent ready") |
|
return agent |
|
|
|
def _log_system_status(self, phase: str) -> None: |
|
"""Log system resource utilization""" |
|
try: |
|
cpu = psutil.cpu_percent(interval=1) |
|
mem = psutil.virtual_memory() |
|
logger.info(f"[{phase}] CPU: {cpu:.1f}% | RAM: {mem.used//(1024**2)}MB/{mem.total//(1024**2)}MB") |
|
|
|
gpu_info = subprocess.run( |
|
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", |
|
"--format=csv,nounits,noheader"], |
|
capture_output=True, text=True |
|
) |
|
if gpu_info.returncode == 0: |
|
used, total, util = gpu_info.stdout.strip().split(", ") |
|
logger.info(f"[{phase}] GPU: {used}MB/{total}MB | Util: {util}%") |
|
except Exception as e: |
|
logger.error(f"Resource monitoring failed: {e}") |
|
|
|
def process_stream(self, prompt: str, history: List[Dict]) -> Generator[Dict, None, None]: |
|
"""Stream the agent's responses""" |
|
full_response = "" |
|
for chunk in self.agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []): |
|
if not chunk: |
|
continue |
|
|
|
if isinstance(chunk, list): |
|
for msg in chunk: |
|
if hasattr(msg, 'content') and msg.content: |
|
cleaned = self.text_analyzer.clean_output(msg.content) |
|
if cleaned: |
|
full_response += cleaned + " " |
|
yield {"role": "assistant", "content": full_response} |
|
elif isinstance(chunk, str) and chunk.strip(): |
|
cleaned = self.text_analyzer.clean_output(chunk) |
|
if cleaned: |
|
full_response += cleaned + " " |
|
yield {"role": "assistant", "content": full_response} |
|
|
|
def analyze_records(self, message: str, history: List[Dict], files: List) -> Generator[tuple, None, None]: |
|
"""Main analysis workflow""" |
|
outputs = { |
|
"chatbot": history.copy(), |
|
"download_output": None, |
|
"final_summary": "", |
|
"progress": {"value": "Initializing...", "visible": True} |
|
} |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
try: |
|
|
|
history.append({"role": "user", "content": message}) |
|
outputs["chatbot"] = history |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
|
|
extracted = [] |
|
file_hash = "" |
|
|
|
if files: |
|
with ThreadPoolExecutor(max_workers=4) as executor: |
|
futures = [] |
|
for f in files: |
|
file_type = Path(f.name).suffix[1:].lower() |
|
futures.append(executor.submit( |
|
self.file_processor.handle_upload, |
|
f.name, |
|
file_type |
|
)) |
|
|
|
for i, future in enumerate(as_completed(futures), 1): |
|
try: |
|
extracted.extend(future.result()) |
|
outputs["progress"] = self._format_progress(i, len(files), "Processing files") |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
except Exception as e: |
|
logger.error(f"File processing failed: {e}") |
|
extracted.append({"error": str(e)}) |
|
|
|
if files and os.path.exists(files[0].name): |
|
file_hash = hashlib.md5(open(files[0].name, "rb").read()).hexdigest() |
|
|
|
history.append({"role": "assistant", "content": "✅ Files processed successfully"}) |
|
outputs.update({ |
|
"chatbot": history, |
|
"progress": self._format_progress(len(files), len(files), "Files processed") |
|
}) |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
|
|
text_content = "\n".join(json.dumps(item) for item in extracted) |
|
chunks = self.text_analyzer.chunk_content(text_content) |
|
full_analysis = "" |
|
|
|
for idx, chunk in enumerate(chunks, 1): |
|
prompt = f""" |
|
Analyze this clinical documentation for potential missed diagnoses. Provide: |
|
1. Specific clinical findings with references (e.g., "Elevated BP (160/95) on page 3") |
|
2. Their clinical significance |
|
3. Urgency of review |
|
Use concise, continuous prose without bullet points. If no concerns, state "No missed diagnoses identified." |
|
|
|
Document Excerpt (Part {idx}/{len(chunks)}): |
|
{chunk[:1750]} |
|
""" |
|
history.append({"role": "assistant", "content": ""}) |
|
outputs.update({ |
|
"chatbot": history, |
|
"progress": self._format_progress(idx, len(chunks), "Analyzing") |
|
}) |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
|
|
chunk_response = "" |
|
for update in self.process_stream(prompt, history): |
|
history[-1] = update |
|
chunk_response = update["content"] |
|
outputs.update({ |
|
"chatbot": history, |
|
"progress": self._format_progress(idx, len(chunks), "Analyzing") |
|
}) |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
full_analysis += f"--- Analysis Part {idx} ---\n{chunk_response}\n" |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
summary = self.text_analyzer.generate_summary(full_analysis) |
|
report_path = DIRECTORIES["reports"] / f"{file_hash}_report.txt" if file_hash else None |
|
|
|
if report_path: |
|
with open(report_path, "w", encoding="utf-8") as f: |
|
f.write(full_analysis + "\n\nSUMMARY:\n" + summary) |
|
|
|
outputs.update({ |
|
"download_output": str(report_path) if report_path and report_path.exists() else None, |
|
"final_summary": summary, |
|
"progress": {"visible": False} |
|
}) |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
except Exception as e: |
|
logger.error(f"Analysis failed: {e}") |
|
history.append({"role": "assistant", "content": f"❌ Analysis error: {str(e)}"}) |
|
outputs.update({ |
|
"chatbot": history, |
|
"final_summary": f"Error: {str(e)}", |
|
"progress": {"visible": False} |
|
}) |
|
yield (outputs["chatbot"], outputs["download_output"], outputs["final_summary"], outputs["progress"]) |
|
|
|
def _format_progress(self, current: int, total: int, stage: str = "") -> Dict[str, Any]: |
|
"""Format progress update for UI""" |
|
status = f"{stage} - {current}/{total}" if stage else f"{current}/{total}" |
|
return {"value": status, "visible": True, "label": f"Progress: {status}"} |
|
|
|
def create_interface(self) -> gr.Blocks: |
|
"""Build the Gradio interface""" |
|
css = """ |
|
/* ==================== BASE STYLES ==================== */ |
|
:root { |
|
--primary-color: #4f46e5; |
|
--primary-dark: #4338ca; |
|
--border-radius: 8px; |
|
--transition: all 0.3s ease; |
|
--shadow: 0 4px 12px rgba(0,0,0,0.1); |
|
--font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; |
|
--background: #ffffff; |
|
--text-color: #1e293b; |
|
--chat-bg: #f8fafc; |
|
--message-bg: #e2e8f0; |
|
--panel-bg: rgba(248, 250, 252, 0.9); |
|
--panel-dark-bg: rgba(30, 41, 59, 0.9); |
|
} |
|
|
|
[data-theme="dark"] { |
|
--background: #1e2a44; |
|
--text-color: #f1f5f9; |
|
--chat-bg: #2d3b55; |
|
--message-bg: #475569; |
|
--panel-bg: var(--panel-dark-bg); |
|
} |
|
|
|
body, .gradio-container { |
|
font-family: var(--font-family); |
|
background: var(--background); |
|
color: var(--text-color); |
|
margin: 0; |
|
padding: 0; |
|
transition: var(--transition); |
|
} |
|
|
|
/* ==================== LAYOUT ==================== */ |
|
.gradio-container { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
padding: 1.5rem; |
|
display: flex; |
|
flex-direction: column; |
|
gap: 1.5rem; |
|
} |
|
|
|
.chat-container { |
|
background: var(--chat-bg); |
|
border-radius: var(--border-radius); |
|
border: 1px solid #e2e8f0; |
|
padding: 1.5rem; |
|
min-height: 50vh; |
|
max-height: 80vh; |
|
overflow-y: auto; |
|
box-shadow: var(--shadow); |
|
margin-bottom: 4rem; |
|
} |
|
|
|
.summary-panel { |
|
background: var(--panel-bg); |
|
border-left: 4px solid var(--primary-color); |
|
padding: 1rem; |
|
border-radius: var(--border-radius); |
|
margin-bottom: 1rem; |
|
box-shadow: var(--shadow); |
|
backdrop-filter: blur(8px); |
|
} |
|
|
|
.upload-area { |
|
border: 2px dashed #cbd5e1; |
|
border-radius: var(--border-radius); |
|
padding: 1.5rem; |
|
margin: 0.75rem 0; |
|
transition: var(--transition); |
|
} |
|
|
|
.upload-area:hover { |
|
border-color: var(--primary-color); |
|
background: rgba(79, 70, 229, 0.05); |
|
} |
|
|
|
/* ==================== COMPONENTS ==================== */ |
|
.chat__message { |
|
margin: 0.75rem 0; |
|
padding: 0.75rem 1rem; |
|
border-radius: var(--border-radius); |
|
max-width: 85%; |
|
transition: var(--transition); |
|
background: var(--message-bg); |
|
border: 1px solid rgba(0,0,0,0.05); |
|
animation: messageFade 0.3s ease; |
|
} |
|
|
|
.chat__message:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
|
} |
|
|
|
.chat__message.user { |
|
background: linear-gradient(135deg, var(--primary-color), var(--primary-dark)); |
|
color: white; |
|
margin-left: auto; |
|
} |
|
|
|
.chat__message.assistant { |
|
background: var(--message-bg); |
|
color: var(--text-color); |
|
} |
|
|
|
.input-container { |
|
display: flex; |
|
align-items: center; |
|
gap: 0.75rem; |
|
background: var(--chat-bg); |
|
padding: 0.75rem 1rem; |
|
border-radius: 1.5rem; |
|
box-shadow: var(--shadow); |
|
position: sticky; |
|
bottom: 1rem; |
|
z-index: 10; |
|
} |
|
|
|
.input__textbox { |
|
flex-grow: 1; |
|
border: none; |
|
background: transparent; |
|
color: var(--text-color); |
|
outline: none; |
|
font-size: 1rem; |
|
} |
|
|
|
.input__textbox:focus { |
|
border-bottom: 2px solid var(--primary-color); |
|
} |
|
|
|
.submit-btn { |
|
background: linear-gradient(135deg, var(--primary-color), var(--primary-dark)); |
|
color: white; |
|
border: none; |
|
border-radius: 1rem; |
|
padding: 0.5rem 1.25rem; |
|
font-size: 0.9rem; |
|
transition: var(--transition); |
|
} |
|
|
|
.submit-btn:hover { |
|
transform: scale(1.05); |
|
} |
|
|
|
.submit-btn:active { |
|
animation: glow 0.3s ease; |
|
} |
|
|
|
.tooltip { |
|
position: relative; |
|
} |
|
|
|
.tooltip:hover::after { |
|
content: attr(data-tip); |
|
position: absolute; |
|
top: -2.5rem; |
|
left: 50%; |
|
transform: translateX(-50%); |
|
background: #1e293b; |
|
color: white; |
|
padding: 0.4rem 0.8rem; |
|
border-radius: 0.4rem; |
|
font-size: 0.85rem; |
|
max-width: 200px; |
|
white-space: normal; |
|
text-align: center; |
|
z-index: 1000; |
|
animation: fadeIn 0.3s ease; |
|
} |
|
|
|
.progress-tracker { |
|
position: relative; |
|
padding: 0.5rem; |
|
background: var(--message-bg); |
|
border-radius: var(--border-radius); |
|
margin-top: 0.75rem; |
|
overflow: hidden; |
|
} |
|
|
|
.progress-tracker::before { |
|
content: ''; |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
height: 100%; |
|
width: 0; |
|
background: linear-gradient(to right, var(--primary-color), var(--primary-dark)); |
|
opacity: 0.3; |
|
animation: progress 2s ease-in-out infinite; |
|
} |
|
|
|
/* ==================== ANIMATIONS ==================== */ |
|
@keyframes glow { |
|
0%, 100% { transform: scale(1); opacity: 1; } |
|
50% { transform: scale(1.1); opacity: 0.8; } |
|
} |
|
|
|
@keyframes fadeIn { |
|
from { opacity: 0; } |
|
to { opacity: 1; } |
|
} |
|
|
|
@keyframes messageFade { |
|
from { opacity: 0; transform: translateY(10px) scale(0.95); } |
|
to { opacity: 1; transform: translateY(0) scale(1); } |
|
} |
|
|
|
@keyframes progress { |
|
0% { width: 0; } |
|
50% { width: 60%; } |
|
100% { width: 0; } |
|
} |
|
|
|
/* ==================== THEMES ==================== */ |
|
[data-theme="dark"] .chat-container { |
|
border-color: #475569; |
|
} |
|
|
|
[data-theme="dark"] .upload-area { |
|
border-color: #64748b; |
|
} |
|
|
|
[data-theme="dark"] .upload-area:hover { |
|
background: rgba(79, 70, 229, 0.1); |
|
} |
|
|
|
[data-theme="dark"] .summary-panel { |
|
border-left-color: #818cf8; |
|
} |
|
|
|
/* ==================== MEDIA QUERIES ==================== */ |
|
@media (max-width: 768px) { |
|
.gradio-container { |
|
padding: 1rem; |
|
} |
|
|
|
.chat-container { |
|
min-height: 40vh; |
|
max-height: 70vh; |
|
margin-bottom: 3.5rem; |
|
} |
|
|
|
.summary-panel { |
|
padding: 0.75rem; |
|
} |
|
|
|
.upload-area { |
|
padding: 1rem; |
|
} |
|
|
|
.input-container { |
|
gap: 0.5rem; |
|
padding: 0.5rem; |
|
} |
|
|
|
.submit-btn { |
|
padding: 0.4rem 1rem; |
|
} |
|
} |
|
|
|
@media (max-width: 480px) { |
|
.chat-container { |
|
padding: 1rem; |
|
margin-bottom: 3rem; |
|
} |
|
|
|
.input-container { |
|
flex-direction: column; |
|
padding: 0.5rem; |
|
} |
|
|
|
.input__textbox { |
|
font-size: 0.9rem; |
|
} |
|
|
|
.submit-btn { |
|
width: 100%; |
|
padding: 0.5rem; |
|
font-size: 0.85rem; |
|
} |
|
|
|
.chat__message { |
|
max-width: 90%; |
|
padding: 0.5rem 0.75rem; |
|
} |
|
|
|
.tooltip:hover::after { |
|
top: auto; |
|
bottom: -2.5rem; |
|
max-width: 80vw; |
|
} |
|
} |
|
""" |
|
|
|
js = """ |
|
function applyTheme(theme) { |
|
document.documentElement.setAttribute('data-theme', theme); |
|
localStorage.setItem('theme', theme); |
|
} |
|
|
|
document.addEventListener('DOMContentLoaded', () => { |
|
const savedTheme = localStorage.getItem('theme') || 'light'; |
|
applyTheme(savedTheme); |
|
}); |
|
""" |
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate" |
|
), |
|
title="Clinical Oversight Assistant", |
|
css=css, |
|
js=js |
|
) as app: |
|
|
|
gr.Markdown(""" |
|
<div style='text-align: center; margin-bottom: 24px;'> |
|
<h1 style='color: var(--primary-color); margin-bottom: 8px;'>🩺 Clinical Oversight Assistant</h1> |
|
<p style='color: #64748b;'> |
|
AI-powered analysis for identifying potential missed diagnoses in patient records |
|
</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(equal_height=False): |
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown( |
|
"<div class='tooltip' data-tip='View conversation history'>**Clinical Analysis Conversation**</div>" |
|
) |
|
chatbot = gr.Chatbot( |
|
label="", |
|
height=650, |
|
show_copy_button=True, |
|
avatar_images=( |
|
"assets/user.png", |
|
"assets/assistant.png" |
|
) if Path("assets/user.png").exists() else None, |
|
bubble_full_width=False, |
|
type="messages", |
|
elem_classes=["chat-container"] |
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown( |
|
"<div class='tooltip' data-tip='Summary of findings'>**Clinical Summary**</div>" |
|
) |
|
final_summary = gr.Markdown( |
|
"<div class='tooltip' data-tip='Analysis results'>Analysis results will appear here...</div>", |
|
elem_classes=["summary-panel"] |
|
) |
|
|
|
with gr.Group(): |
|
gr.Markdown( |
|
"<div class='tooltip' data-tip='Download report'>**Report Export**</div>" |
|
) |
|
download_output = gr.File( |
|
label="Download Full Analysis", |
|
visible=False, |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Row(): |
|
file_upload = gr.File( |
|
file_types=[".pdf", ".csv", ".xls", ".xlsx"], |
|
file_count="multiple", |
|
label="Upload Patient Records", |
|
elem_classes=["upload-area"], |
|
elem_id="file-upload" |
|
) |
|
|
|
with gr.Row(elem_classes=["input-container"]): |
|
user_input = gr.Textbox( |
|
placeholder="Enter your clinical query or analysis request...", |
|
show_label=False, |
|
container=False, |
|
scale=7, |
|
autofocus=True, |
|
elem_classes=["input__textbox"], |
|
elem_id="user-input" |
|
) |
|
submit_btn = gr.Button( |
|
"Analyze", |
|
variant="primary", |
|
scale=1, |
|
min_width=120, |
|
elem_classes=["submit-btn"], |
|
elem_id="submit-btn" |
|
) |
|
|
|
|
|
progress_tracker = gr.Textbox( |
|
label="Analysis Progress", |
|
visible=False, |
|
interactive=False, |
|
elem_classes=["progress-tracker"], |
|
elem_id="progress-tracker" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
self.analyze_records, |
|
inputs=[user_input, chatbot, file_upload], |
|
outputs=[chatbot, download_output, final_summary, progress_tracker], |
|
show_progress="hidden" |
|
) |
|
|
|
user_input.submit( |
|
self.analyze_records, |
|
inputs=[user_input, chatbot, file_upload], |
|
outputs=[chatbot, download_output, final_summary, progress_tracker], |
|
show_progress="hidden" |
|
) |
|
|
|
app.load( |
|
lambda: [[], None, "<div class='tooltip' data-tip='Analysis results'>Analysis results will appear here...</div>", "", None, {"visible": False}], |
|
outputs=[chatbot, download_output, final_summary, user_input, file_upload, progress_tracker], |
|
queue=False |
|
) |
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Launching Clinical Oversight Assistant...") |
|
clinical_app = ClinicalAgent() |
|
interface = clinical_app.create_interface() |
|
|
|
interface.queue( |
|
api_open=False, |
|
max_size=20 |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
allowed_paths=[str(DIRECTORIES["reports"])], |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Application failed to start: {e}") |
|
raise |
|
finally: |
|
if torch.distributed.is_initialized(): |
|
torch.distributed.destroy_process_group() |