| import sys | |
| import os | |
| import pandas as pd | |
| import json | |
| import gradio as gr | |
| from typing import List, Tuple | |
| import hashlib | |
| import shutil | |
| import re | |
| from datetime import datetime | |
| import time | |
| persistent_dir = "/data/hf_cache" | |
| os.makedirs(persistent_dir, exist_ok=True) | |
| model_cache_dir = os.path.join(persistent_dir, "txagent_models") | |
| tool_cache_dir = os.path.join(persistent_dir, "tool_cache") | |
| file_cache_dir = os.path.join(persistent_dir, "cache") | |
| report_dir = os.path.join(persistent_dir, "reports") | |
| for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: | |
| os.makedirs(directory, exist_ok=True) | |
| os.environ["HF_HOME"] = model_cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_path = os.path.abspath(os.path.join(current_dir, "src")) | |
| sys.path.insert(0, src_path) | |
| from txagent.txagent import TxAgent | |
| def file_hash(path: str) -> str: | |
| with open(path, "rb") as f: | |
| return hashlib.md5(f.read()).hexdigest() | |
| def clean_response(text: str) -> str: | |
| try: | |
| text = text.encode('utf-8', 'surrogatepass').decode('utf-8') | |
| except UnicodeError: | |
| text = text.encode('utf-8', 'replace').decode('utf-8') | |
| text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) | |
| return text.strip() | |
| def parse_excel_as_whole_prompt(file_path: str) -> str: | |
| xl = pd.ExcelFile(file_path) | |
| df = xl.parse(xl.sheet_names[0], header=0).fillna("") | |
| records = [] | |
| for _, row in df.iterrows(): | |
| record = f"- {row['Form Name']}: {row['Form Item']} = {row['Item Response']} ({row['Interview Date']} by {row['Interviewer']})\n{row['Description']}" | |
| records.append(clean_response(record)) | |
| record_text = "\n".join(records) | |
| prompt = f""" | |
| Patient Complete History: | |
| Instructions: | |
| Based on the complete patient record below, identify any potential missed diagnoses, medication conflicts, incomplete assessments, and urgent follow-up needs. Provide a clinical summary under the markdown headings. | |
| Patient History: | |
| {record_text} | |
| ### Missed Diagnoses | |
| - ... | |
| ### Medication Conflicts | |
| - ... | |
| ### Incomplete Assessments | |
| - ... | |
| ### Urgent Follow-up | |
| - ... | |
| """ | |
| return prompt | |
| def init_agent(): | |
| default_tool_path = os.path.abspath("data/new_tool.json") | |
| target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") | |
| if not os.path.exists(target_tool_path): | |
| shutil.copy(default_tool_path, target_tool_path) | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| tool_files_dict={"new_tool": target_tool_path}, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=100, | |
| additional_default_tools=[], | |
| ) | |
| agent.init_model() | |
| return agent | |
| def create_ui(agent): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>\ud83c\udfe5 Full Medical History Analyzer</h1>") | |
| chatbot = gr.Chatbot(label="Summary Output", height=600) | |
| file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"], file_count="single") | |
| msg_input = gr.Textbox(label="Optional Message", placeholder="Add context or instructions...", lines=2) | |
| send_btn = gr.Button("Analyze") | |
| download_output = gr.File(label="Download Report") | |
| def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]: | |
| if not file: | |
| raise gr.Error("Please upload an Excel file.") | |
| new_history = chat_history + [(message, None)] | |
| new_history.append((None, "⏳ Analyzing full patient history...")) | |
| yield new_history, None | |
| try: | |
| prompt = parse_excel_as_whole_prompt(file.name) | |
| full_output = "" | |
| for result in agent.run_gradio_chat( | |
| message=prompt, | |
| history=[], | |
| temperature=0.2, | |
| max_new_tokens=2048, | |
| max_token=4096, | |
| call_agent=False, | |
| conversation=[], | |
| ): | |
| if isinstance(result, list): | |
| for r in result: | |
| if hasattr(r, 'content') and r.content: | |
| full_output += clean_response(r.content) + "\n" | |
| elif isinstance(result, str): | |
| full_output += clean_response(result) + "\n" | |
| new_history[-1] = (None, full_output.strip()) | |
| report_path = os.path.join(report_dir, f"{file_hash(file.name)}_final_report.txt") | |
| with open(report_path, "w", encoding="utf-8") as f: | |
| f.write(full_output.strip()) | |
| yield new_history, report_path | |
| except Exception as e: | |
| new_history.append((None, f"❌ Error during analysis: {str(e)}")) | |
| yield new_history, None | |
| send_btn.click(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output]) | |
| msg_input.submit(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output]) | |
| return demo | |
| if __name__ == "__main__": | |
| agent = init_agent() | |
| demo = create_ui(agent) | |
| demo.queue(api_open=False).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| allowed_paths=[report_dir], | |
| share=False | |
| ) | |