|
import sys |
|
import os |
|
import polars as pl |
|
import json |
|
import gradio as gr |
|
from typing import List, Tuple |
|
import hashlib |
|
import shutil |
|
import re |
|
from datetime import datetime |
|
import time |
|
import asyncio |
|
import aiofiles |
|
import cachetools |
|
import logging |
|
import markdown |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
persistent_dir = "/data/hf_cache" |
|
os.makedirs(persistent_dir, exist_ok=True) |
|
|
|
model_cache_dir = os.path.join(persistent_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") |
|
file_cache_dir = os.path.join(persistent_dir, "cache") |
|
report_dir = os.path.join(persistent_dir, "reports") |
|
|
|
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: |
|
os.makedirs(directory, exist_ok=True) |
|
|
|
os.environ["HF_HOME"] = model_cache_dir |
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
|
|
cache = cachetools.LRUCache(maxsize=100) |
|
|
|
def file_hash(path: str) -> str: |
|
"""Generate MD5 hash of a file.""" |
|
with open(path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def clean_response(text: str) -> str: |
|
"""Clean text by removing unwanted characters and normalizing.""" |
|
try: |
|
text = text.encode('utf-8', 'surrogatepass').decode('utf-8') |
|
except UnicodeError: |
|
text = text.encode('utf-8', 'replace').decode('utf-8') |
|
|
|
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) |
|
text = re.sub(r"\n{3,}", "\n\n", text) |
|
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) |
|
return text.strip() |
|
|
|
async def load_and_clean_data(file_path: str) -> pl.DataFrame: |
|
"""Load and clean Excel data using polars.""" |
|
try: |
|
logger.info(f"Loading Excel file: {file_path}") |
|
df = pl.read_excel(file_path).with_columns([ |
|
pl.col(col).str.strip_chars().fill_null("").alias(col) for col in [ |
|
"Booking Number", "Form Name", "Form Item", "Item Response", |
|
"Interviewer", "Interview Date", "Description" |
|
] |
|
]).filter(pl.col("Booking Number").str.starts_with("BKG")) |
|
logger.info(f"Loaded {len(df)} records") |
|
return df |
|
except Exception as e: |
|
logger.error(f"Error loading data: {str(e)}") |
|
raise |
|
|
|
def generate_summary(df: pl.DataFrame) -> tuple[str, dict]: |
|
"""Generate summary statistics and interesting fact.""" |
|
symptom_counts = {} |
|
for desc in df["Description"]: |
|
desc = desc.lower() |
|
if "chest discomfort" in desc: |
|
symptom_counts["Chest Discomfort"] = symptom_counts.get("Chest Discomfort", 0) + 1 |
|
if "headaches" in desc: |
|
symptom_counts["Headaches"] = symptom_counts.get("Headaches", 0) + 1 |
|
if "weight loss" in desc: |
|
symptom_counts["Weight Loss"] = symptom_counts.get("Weight Loss", 0) + 1 |
|
if "back pain" in desc: |
|
symptom_counts["Chronic Back Pain"] = symptom_counts.get("Chronic Back Pain", 0) + 1 |
|
if "cough" in desc: |
|
symptom_counts["Persistent Cough"] = symptom_counts.get("Persistent Cough", 0) + 1 |
|
|
|
total_records = len(df) |
|
unique_bookings = df["Booking Number"].n_unique() |
|
interesting_fact = ( |
|
f"Chest discomfort was reported in {symptom_counts.get('Chest Discomfort', 0)} records, " |
|
"frequently leading to ECG/lab referrals. Inconsistent follow-up documentation raises " |
|
"concerns about potential missed cardiovascular diagnoses." |
|
) |
|
|
|
summary = ( |
|
f"## Summary\n\n" |
|
f"Analyzed {total_records:,} patient records from {unique_bookings:,} unique bookings in 2023. " |
|
f"Key findings include a high prevalence of chest discomfort ({symptom_counts.get('Chest Discomfort', 0)} instances), " |
|
f"suggesting possible underdiagnosis of cardiovascular issues.\n\n" |
|
f"### Interesting Fact\n{interesting_fact}\n" |
|
) |
|
return summary, symptom_counts |
|
|
|
def prepare_aggregate_prompt(df: pl.DataFrame) -> str: |
|
"""Prepare a single prompt for all patient data.""" |
|
groups = df.group_by("Booking Number").agg([ |
|
pl.col("Form Name"), pl.col("Form Item"), |
|
pl.col("Item Response"), pl.col("Interviewer"), |
|
pl.col("Interview Date"), pl.col("Description") |
|
]) |
|
|
|
records = [] |
|
for booking in groups.iter_rows(named=True): |
|
booking_id = booking["Booking Number"] |
|
for i in range(len(booking["Form Name"])): |
|
record = ( |
|
f"- {booking['Form Name'][i]}: {booking['Form Item'][i]} = {booking['Item Response'][i]} " |
|
f"({booking['Interview Date'][i]} by {booking['Interviewer'][i]})\n{booking['Description'][i]}" |
|
) |
|
records.append(clean_response(record)) |
|
|
|
record_text = "\n".join(records) |
|
prompt = f""" |
|
Patient Medical History Analysis |
|
|
|
Instructions: |
|
Analyze the following aggregated patient data from all bookings to identify potential missed diagnoses, medication conflicts, incomplete assessments, and urgent follow-up needs across the entire dataset. Provide a comprehensive summary under the specified markdown headings. Focus on patterns and recurring issues across multiple patients. |
|
|
|
Data: |
|
{record_text} |
|
|
|
### Missed Diagnoses |
|
- ... |
|
|
|
### Medication Conflicts |
|
- ... |
|
|
|
### Incomplete Assessments |
|
- ... |
|
|
|
### Urgent Follow-up |
|
- ... |
|
""" |
|
return prompt |
|
|
|
def init_agent(): |
|
"""Initialize TxAgent with tool configuration.""" |
|
default_tool_path = os.path.abspath("data/new_tool.json") |
|
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
|
|
if not os.path.exists(target_tool_path): |
|
shutil.copy(default_tool_path, target_tool_path) |
|
|
|
try: |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={"new_tool": target_tool_path}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=100, |
|
additional_default_tools=[], |
|
) |
|
agent.init_model() |
|
return agent |
|
except Exception as e: |
|
logger.error(f"Failed to initialize TxAgent: {str(e)}") |
|
raise |
|
|
|
async def generate_report(agent, df: pl.DataFrame, file_hash_value: str) -> tuple[str, str]: |
|
"""Generate a comprehensive markdown report.""" |
|
logger.info("Generating comprehensive report...") |
|
report_path = os.path.join(report_dir, f"{file_hash_value}_report.md") |
|
|
|
|
|
summary, symptom_counts = generate_summary(df) |
|
|
|
|
|
prompt = prepare_aggregate_prompt(df) |
|
full_output = "" |
|
|
|
try: |
|
chunk_output = "" |
|
for result in agent.run_gradio_chat( |
|
message=prompt, |
|
history=[], |
|
temperature=0.2, |
|
max_new_tokens=2048, |
|
max_token=8192, |
|
call_agent=False, |
|
conversation=[], |
|
): |
|
if isinstance(result, list): |
|
for r in result: |
|
if hasattr(r, 'content') and r.content: |
|
cleaned = clean_response(r.content) |
|
chunk_output += cleaned + "\n" |
|
elif isinstance(result, str): |
|
cleaned = clean_response(result) |
|
chunk_output += cleaned + "\n" |
|
full_output = chunk_output.strip() |
|
yield full_output, None |
|
|
|
|
|
sections = ["Missed Diagnoses", "Medication Conflicts", "Incomplete Assessments", "Urgent Follow-up"] |
|
filtered_output = [] |
|
current_section = None |
|
for line in full_output.split("\n"): |
|
if any(line.startswith(f"### {section}") for section in sections): |
|
current_section = line |
|
filtered_output.append(line) |
|
elif current_section and line.strip().startswith("-") and line.strip() != "- ...": |
|
filtered_output.append(line) |
|
|
|
|
|
final_output = summary + "## Clinical Findings\n\n" |
|
if filtered_output: |
|
final_output += "\n".join(filtered_output) + "\n\n" |
|
else: |
|
final_output += "No significant clinical findings identified.\n\n" |
|
|
|
final_output += ( |
|
"## Conclusion\n\n" |
|
"The analysis reveals significant gaps in patient care, including potential missed cardiovascular diagnoses " |
|
"due to inconsistent follow-up on chest discomfort and elevated vitals. Low medication adherence is a recurring " |
|
"issue, likely impacting treatment efficacy. Incomplete assessments, particularly missing vital signs, hinder " |
|
"comprehensive care. Urgent follow-up is recommended for patients with chest discomfort and elevated vitals to " |
|
"prevent adverse outcomes." |
|
) |
|
|
|
|
|
async with aiofiles.open(report_path, "w") as f: |
|
await f.write(final_output) |
|
|
|
logger.info(f"Report saved to {report_path}") |
|
yield final_output, report_path |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating report: {str(e)}") |
|
yield f"Error: {str(e)}", None |
|
|
|
def create_ui(agent): |
|
"""Create Gradio interface for clinical oversight analysis.""" |
|
with gr.Blocks( |
|
theme=gr.themes.Soft(), |
|
title="Clinical Oversight Assistant", |
|
css=""" |
|
.gradio-container {max-width: 1000px; margin: auto; font-family: Arial, sans-serif;} |
|
#chatbot {border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; background: #f9fafb;} |
|
.markdown {white-space: pre-wrap;} |
|
""" |
|
) as demo: |
|
gr.Markdown("# 🏥 Clinical Oversight Assistant (Excel Optimized)") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Analysis"): |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
file_upload = gr.File( |
|
label="Upload Excel File", |
|
file_types=[".xlsx"], |
|
file_count="single", |
|
interactive=True |
|
) |
|
msg_input = gr.Textbox( |
|
label="Additional Instructions", |
|
placeholder="Add any specific analysis requests...", |
|
lines=3 |
|
) |
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
send_btn = gr.Button("Analyze", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="Analysis Results", |
|
height=600, |
|
bubble_full_width=False, |
|
show_copy_button=True, |
|
elem_id="chatbot" |
|
) |
|
download_output = gr.File( |
|
label="Download Full Report", |
|
interactive=False |
|
) |
|
|
|
with gr.TabItem("Instructions"): |
|
gr.Markdown(""" |
|
## How to Use This Tool |
|
|
|
1. **Upload Excel File**: Select your patient records Excel file |
|
2. **Add Instructions** (Optional): Provide any specific analysis requests |
|
3. **Click Analyze**: The system will process all patient records and generate a comprehensive report |
|
4. **Review Results**: Analysis appears in the chat window |
|
5. **Download Report**: Get a full markdown report of all findings |
|
|
|
### Excel File Requirements |
|
Your Excel file must contain these columns: |
|
- Booking Number |
|
- Form Name |
|
- Form Item |
|
- Item Response |
|
- Interview Date |
|
- Interviewer |
|
- Description |
|
|
|
### Analysis Includes |
|
- Missed diagnoses |
|
- Medication conflicts |
|
- Incomplete assessments |
|
- Urgent follow-up needs |
|
""") |
|
|
|
def format_message(role: str, content: str) -> Tuple[str, str]: |
|
"""Format messages for the chatbot in (user, bot) format.""" |
|
if role == "user": |
|
return (content, None) |
|
else: |
|
return (None, content) |
|
|
|
async def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]: |
|
"""Analyze uploaded file and generate comprehensive report.""" |
|
if not file: |
|
raise gr.Error("Please upload an Excel file first") |
|
|
|
try: |
|
|
|
new_history = chat_history + [format_message("user", message)] |
|
new_history.append(format_message("assistant", "⏳ Processing Excel data...")) |
|
yield new_history, None |
|
|
|
|
|
df = await load_and_clean_data(file.name) |
|
file_hash_value = file_hash(file.name) |
|
|
|
|
|
async for output, report_path in generate_report(agent, df, file_hash_value): |
|
if output: |
|
new_history[-1] = format_message("assistant", output) |
|
yield new_history, report_path |
|
else: |
|
yield new_history, report_path |
|
|
|
except Exception as e: |
|
logger.error(f"Analysis failed: {str(e)}") |
|
new_history.append(format_message("assistant", f"❌ Error: {str(e)}")) |
|
yield new_history, None |
|
raise gr.Error(f"Analysis failed: {str(e)}") |
|
|
|
def clear_chat(): |
|
"""Clear chat history and download output.""" |
|
return [], None |
|
|
|
|
|
send_btn.click( |
|
analyze, |
|
inputs=[msg_input, chatbot, file_upload], |
|
outputs=[chatbot, download_output], |
|
api_name="analyze", |
|
queue=True |
|
) |
|
|
|
msg_input.submit( |
|
analyze, |
|
inputs=[msg_input, chatbot, file_upload], |
|
outputs=[chatbot, download_output], |
|
queue=True |
|
) |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
inputs=[], |
|
outputs=[chatbot, download_output] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
agent = init_agent() |
|
demo = create_ui(agent) |
|
|
|
demo.queue( |
|
api_open=False, |
|
max_size=20 |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
allowed_paths=[report_dir], |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to launch application: {str(e)}") |
|
print(f"Failed to launch application: {str(e)}") |
|
sys.exit(1) |