|
import sys |
|
import os |
|
import pandas as pd |
|
import json |
|
import gradio as gr |
|
from typing import List, Tuple, Dict, Any |
|
import hashlib |
|
import shutil |
|
import re |
|
from datetime import datetime |
|
import time |
|
import markdown |
|
from collections import defaultdict |
|
|
|
|
|
persistent_dir = "/data/hf_cache" |
|
os.makedirs(persistent_dir, exist_ok=True) |
|
|
|
model_cache_dir = os.path.join(persistent_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") |
|
file_cache_dir = os.path.join(persistent_dir, "cache") |
|
report_dir = os.path.join(persistent_dir, "reports") |
|
|
|
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: |
|
os.makedirs(directory, exist_ok=True) |
|
|
|
os.environ["HF_HOME"] = model_cache_dir |
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
def file_hash(path: str) -> str: |
|
"""Generate MD5 hash of file contents""" |
|
with open(path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def clean_response(text: str) -> str: |
|
"""Clean and normalize text output""" |
|
try: |
|
text = text.encode('utf-8', 'surrogatepass').decode('utf-8') |
|
except UnicodeError: |
|
text = text.encode('utf-8', 'replace').decode('utf-8') |
|
|
|
|
|
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) |
|
text = re.sub(r"\n{3,}", "\n\n", text) |
|
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) |
|
return text.strip() |
|
|
|
def extract_medical_data(df: pd.DataFrame) -> Dict[str, Any]: |
|
"""Extract and organize medical data from DataFrame""" |
|
medical_data = defaultdict(list) |
|
|
|
for _, row in df.iterrows(): |
|
record = { |
|
'form_name': row.get('Form Name', ''), |
|
'form_item': row.get('Form Item', ''), |
|
'response': row.get('Item Response', ''), |
|
'date': row.get('Interview Date', ''), |
|
'interviewer': row.get('Interviewer', ''), |
|
'description': row.get('Description', '') |
|
} |
|
medical_data[row['Booking Number']].append(record) |
|
|
|
return medical_data |
|
|
|
def identify_red_flags(records: List[Dict[str, Any]]) -> Dict[str, List[str]]: |
|
"""Identify potential red flags in medical records""" |
|
red_flags = { |
|
'symptoms': defaultdict(list), |
|
'medications': defaultdict(list), |
|
'diagnoses': defaultdict(list), |
|
'vitals': defaultdict(list), |
|
'labs': defaultdict(list) |
|
} |
|
|
|
for record in records: |
|
form_name = record['form_name'].lower() |
|
item = record['form_item'].lower() |
|
response = record['response'].lower() |
|
|
|
|
|
if 'pain' in item or 'symptom' in form_name: |
|
if 'severe' in response or 'chronic' in response: |
|
red_flags['symptoms'][item].append(response) |
|
|
|
|
|
elif 'medication' in form_name or 'drug' in form_name: |
|
if 'interaction' in response or 'allergy' in response: |
|
red_flags['medications'][item].append(response) |
|
|
|
|
|
elif 'diagnosis' in form_name: |
|
if 'rule out' in response or 'possible' in response: |
|
red_flags['diagnoses'][item].append(response) |
|
|
|
|
|
elif 'vital' in form_name: |
|
try: |
|
value = float(re.search(r'\d+\.?\d*', response).group()) |
|
if ('blood pressure' in item and value > 140) or \ |
|
('heart rate' in item and (value < 50 or value > 100)) or \ |
|
('temperature' in item and value > 38): |
|
red_flags['vitals'][item].append(response) |
|
except: |
|
pass |
|
|
|
|
|
elif 'lab' in form_name or 'test' in form_name: |
|
if 'abnormal' in response or 'high' in response or 'low' in response: |
|
red_flags['labs'][item].append(response) |
|
|
|
return red_flags |
|
|
|
def generate_analysis_prompt(booking: str, records: List[Dict[str, Any]], red_flags: Dict[str, Any]]) -> str: |
|
"""Generate structured prompt for analysis""" |
|
records_text = "\n".join( |
|
f"- {r['form_name']}: {r['form_item']} = {r['response']} ({r['date']} by {r['interviewer']})\n {r['description']}" |
|
for r in records |
|
) |
|
|
|
red_flags_text = "\n".join( |
|
f"### {category.capitalize()} Red Flags\n" + "\n".join( |
|
f"- {item}: {', '.join(responses)}" |
|
for item, responses in items.items() |
|
) |
|
for category, items in red_flags.items() if items |
|
) |
|
|
|
prompt = f""" |
|
**Patient Booking Number**: {booking} |
|
|
|
**Medical Records Summary**: |
|
{records_text} |
|
|
|
**Identified Red Flags**: |
|
{red_flags_text if red_flags_text else "No obvious red flags detected"} |
|
|
|
**Comprehensive Analysis Instructions**: |
|
1. Review all medical data and red flags above |
|
2. Identify any potential missed diagnoses based on symptoms, labs, and clinical findings |
|
3. Check for medication conflicts or inappropriate prescriptions |
|
4. Note any incomplete assessments or missing diagnostic workups |
|
5. Flag any urgent follow-up needs or critical findings |
|
6. Provide recommendations in clear, actionable terms |
|
|
|
**Required Output Format**: |
|
### Missed Diagnoses |
|
- [List any conditions that may have been overlooked based on the data] |
|
|
|
### Medication Issues |
|
- [List any medication conflicts, inappropriate prescriptions, or missing medications] |
|
|
|
### Assessment Gaps |
|
- [List any incomplete assessments or missing diagnostic tests] |
|
|
|
### Urgent Follow-up |
|
- [List any findings requiring immediate attention] |
|
|
|
### Clinical Recommendations |
|
- [Provide specific recommendations for next steps] |
|
""" |
|
return prompt |
|
|
|
def parse_excel_to_prompts(file_path: str) -> List[Tuple[str, str]]: |
|
"""Parse Excel file into analysis prompts with red flag detection""" |
|
try: |
|
xl = pd.ExcelFile(file_path) |
|
df = xl.parse(xl.sheet_names[0], header=0).fillna("") |
|
medical_data = extract_medical_data(df) |
|
|
|
prompts = [] |
|
for booking, records in medical_data.items(): |
|
red_flags = identify_red_flags(records) |
|
prompt = generate_analysis_prompt(booking, records, red_flags) |
|
prompts.append((booking, prompt)) |
|
|
|
return prompts |
|
except Exception as e: |
|
raise ValueError(f"Error parsing Excel file: {str(e)}") |
|
|
|
def init_agent(): |
|
"""Initialize the TxAgent with appropriate settings""" |
|
default_tool_path = os.path.abspath("data/new_tool.json") |
|
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
|
|
if not os.path.exists(target_tool_path): |
|
shutil.copy(default_tool_path, target_tool_path) |
|
|
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={"new_tool": target_tool_path}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=100, |
|
additional_default_tools=[], |
|
) |
|
agent.init_model() |
|
return agent |
|
|
|
def format_markdown(text: str) -> str: |
|
"""Convert markdown text to HTML for better display""" |
|
return markdown.markdown(text, extensions=['fenced_code', 'tables']) |
|
|
|
def create_ui(agent): |
|
"""Create Gradio UI interface""" |
|
with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo: |
|
gr.Markdown("# 🏥 Clinical Oversight Assistant (Missed Diagnosis Detection)") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Analysis"): |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
file_upload = gr.File( |
|
label="Upload Excel File", |
|
file_types=[".xlsx"], |
|
file_count="single", |
|
interactive=True |
|
) |
|
msg_input = gr.Textbox( |
|
label="Additional Instructions", |
|
placeholder="Add any specific analysis requests...", |
|
lines=3 |
|
) |
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
send_btn = gr.Button("Analyze", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="Analysis Results", |
|
height=600, |
|
bubble_full_width=False, |
|
show_copy_button=True, |
|
render_markdown=True |
|
) |
|
download_output = gr.File( |
|
label="Download Full Report", |
|
interactive=False |
|
) |
|
|
|
with gr.TabItem("Instructions"): |
|
gr.Markdown(""" |
|
## How to Use This Tool |
|
|
|
1. **Upload Excel File**: Select your patient records Excel file |
|
2. **Add Instructions** (Optional): Provide any specific analysis requests |
|
3. **Click Analyze**: The system will process each patient record |
|
4. **Review Results**: Analysis appears in the chat window |
|
5. **Download Report**: Get a full text report of all findings |
|
|
|
### Excel File Requirements |
|
Your Excel file must contain these columns: |
|
- Booking Number (patient identifier) |
|
- Form Name (type of medical form) |
|
- Form Item (specific field name) |
|
- Item Response (patient response or value) |
|
- Interview Date (date of recording) |
|
- Interviewer (who recorded the data) |
|
- Description (additional notes) |
|
|
|
### Analysis Includes |
|
- **Missed diagnoses**: Potential conditions not identified |
|
- **Medication issues**: Conflicts, side effects, inappropriate prescriptions |
|
- **Assessment gaps**: Missing tests or incomplete evaluations |
|
- **Urgent follow-up**: Critical findings needing immediate attention |
|
- **Clinical recommendations**: Actionable next steps |
|
""") |
|
|
|
def format_message(role: str, content: str) -> Tuple[str, str]: |
|
"""Format messages for the chatbot in (user, bot) format""" |
|
if role == "user": |
|
return (content, None) |
|
else: |
|
return (None, content) |
|
|
|
def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]: |
|
"""Main analysis function""" |
|
if not file: |
|
raise gr.Error("Please upload an Excel file first") |
|
|
|
try: |
|
|
|
new_history = chat_history + [format_message("user", message)] |
|
new_history.append(format_message("assistant", "⏳ Processing Excel data...")) |
|
yield new_history, None |
|
|
|
prompts = parse_excel_to_prompts(file.name) |
|
full_output = "" |
|
|
|
for idx, (booking, prompt) in enumerate(prompts, 1): |
|
chunk_output = "" |
|
try: |
|
for result in agent.run_gradio_chat( |
|
message=prompt, |
|
history=[], |
|
temperature=0.2, |
|
max_new_tokens=1024, |
|
max_token=4096, |
|
call_agent=False, |
|
conversation=[], |
|
): |
|
if isinstance(result, list): |
|
for r in result: |
|
if hasattr(r, 'content') and r.content: |
|
cleaned = clean_response(r.content) |
|
chunk_output += cleaned + "\n" |
|
elif isinstance(result, str): |
|
cleaned = clean_response(result) |
|
chunk_output += cleaned + "\n" |
|
|
|
if chunk_output: |
|
output = f"## Patient Booking: {booking}\n{chunk_output.strip()}\n" |
|
new_history[-1] = format_message("assistant", output) |
|
yield new_history, None |
|
|
|
except Exception as e: |
|
error_msg = f"⚠️ Error processing booking {booking}: {str(e)}" |
|
new_history.append(format_message("assistant", error_msg)) |
|
yield new_history, None |
|
continue |
|
|
|
if chunk_output: |
|
output = f"## Patient Booking: {booking}\n{chunk_output.strip()}\n" |
|
new_history.append(format_message("assistant", output)) |
|
full_output += output + "\n" |
|
yield new_history, None |
|
|
|
|
|
file_hash_value = file_hash(file.name) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
report_path = os.path.join(report_dir, f"{file_hash_value}_{timestamp}_report.md") |
|
|
|
with open(report_path, "w", encoding="utf-8") as f: |
|
f.write("# Clinical Oversight Analysis Report\n\n") |
|
f.write(f"**Generated on**: {timestamp}\n\n") |
|
f.write(f"**Source file**: {file.name}\n\n") |
|
f.write(full_output) |
|
|
|
yield new_history, report_path if os.path.exists(report_path) else None |
|
|
|
except Exception as e: |
|
new_history.append(format_message("assistant", f"❌ Error: {str(e)}")) |
|
yield new_history, None |
|
raise gr.Error(f"Analysis failed: {str(e)}") |
|
|
|
def clear_chat(): |
|
"""Clear chat history and outputs""" |
|
return [], None |
|
|
|
|
|
send_btn.click( |
|
analyze, |
|
inputs=[msg_input, chatbot, file_upload], |
|
outputs=[chatbot, download_output], |
|
api_name="analyze" |
|
) |
|
|
|
msg_input.submit( |
|
analyze, |
|
inputs=[msg_input, chatbot, file_upload], |
|
outputs=[chatbot, download_output] |
|
) |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
inputs=[], |
|
outputs=[chatbot, download_output] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
agent = init_agent() |
|
demo = create_ui(agent) |
|
|
|
demo.queue( |
|
api_open=False, |
|
max_size=20 |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
allowed_paths=[report_dir], |
|
share=False |
|
) |
|
except Exception as e: |
|
print(f"Failed to launch application: {str(e)}") |
|
sys.exit(1) |