File size: 7,321 Bytes
1bb8be7
dae38a2
 
 
 
1bb8be7
9c0d5a4
dae38a2
 
 
 
9c0d5a4
 
dae38a2
9c0d5a4
 
dae38a2
d2cced3
9b25f67
dae38a2
 
 
 
 
 
9c0d5a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae38a2
 
 
 
 
9c0d5a4
 
dae38a2
9c0d5a4
dae38a2
 
 
 
 
9b25f67
dae38a2
 
 
 
 
 
 
 
 
9c0d5a4
dae38a2
 
 
 
 
 
 
 
 
 
9c0d5a4
dae38a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c0d5a4
 
 
 
 
 
 
dae38a2
 
9c0d5a4
dae38a2
 
 
 
 
 
 
9c0d5a4
 
dae38a2
 
 
9c0d5a4
dae38a2
 
9c0d5a4
dae38a2
 
 
9c0d5a4
 
 
 
 
 
 
 
 
 
 
 
dae38a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb8be7
dae38a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import sys
import os
import pandas as pd
import pdfplumber
import json
import gradio as gr
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil

# ✅ Fix: Add src to Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))

# ✅ Persist model cache to Hugging Face Space's /data directory
model_cache_dir = "/data/txagent_models"
os.makedirs(model_cache_dir, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
os.environ["HF_HOME"] = model_cache_dir

from txagent.txagent import TxAgent

def sanitize_utf8(text: str) -> str:
    return text.encode("utf-8", "ignore").decode("utf-8")

def clean_final_response(text: str) -> str:
    cleaned = text.replace("[TOOL_CALLS]", "").strip()
    responses = cleaned.split("[Final Analysis]")

    if len(responses) <= 1:
        return f"<div style='padding:1em;border:1px solid #ccc;border-radius:12px;color:#fff;background:#353F54;'><p>{cleaned}</p></div>"

    panels = []
    for i, section in enumerate(responses[1:], 1):
        final = section.strip()
        panels.append(
            f"<div style='background:#2B2B2B;color:#E0E0E0;border-radius:12px;margin-bottom:1em;border:1px solid #888;'>"
            f"<div style='font-size:1.1em;font-weight:bold;padding:0.75em;background:#3A3A3A;color:#fff;border-radius:12px 12px 0 0;'>🧠 Final Analysis #{i}</div>"
            f"<div style='padding:1em;line-height:1.6;'>{final.replace(chr(10), '<br>')}</div>"
            f"</div>"
        )
    return "".join(panels)

def file_hash(path):
    with open(path, "rb") as f:
        return hashlib.md5(f.read()).hexdigest()

def convert_file_to_json(file_path: str, file_type: str) -> str:
    try:
        cache_dir = "/data/cache"
        os.makedirs(cache_dir, exist_ok=True)
        h = file_hash(file_path)
        cache_path = os.path.join(cache_dir, f"{h}.json")

        if os.path.exists(cache_path):
            return open(cache_path, "r", encoding="utf-8").read()

        if file_type == "csv":
            df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
        elif file_type in ["xls", "xlsx"]:
            try:
                df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
            except:
                df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
        elif file_type == "pdf":
            with pdfplumber.open(file_path) as pdf:
                text = "\n".join([page.extract_text() or "" for page in pdf.pages])
            result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
            open(cache_path, "w", encoding="utf-8").write(result)
            return result
        else:
            return json.dumps({"error": f"Unsupported file type: {file_type}"})

        if df is None or df.empty:
            return json.dumps({"warning": f"No data extracted from: {file_path}"})

        df = df.fillna("")
        content = df.astype(str).values.tolist()
        result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
        open(cache_path, "w", encoding="utf-8").write(result)
        return result
    except Exception as e:
        return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})

def create_ui(agent: TxAgent):
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")

        chatbot = gr.Chatbot(label="CPS Assistant", height=600, type="messages")
        file_upload = gr.File(
            label="Upload Medical File",
            file_types=[".pdf", ".txt", ".docx", ".jpg", ".png", ".csv", ".xls", ".xlsx"],
            file_count="multiple"
        )
        message_input = gr.Textbox(placeholder="Ask a biomedical question or just upload the files...", show_label=False)
        send_button = gr.Button("Send", variant="primary")
        conversation_state = gr.State([])

        def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
            try:
                history.append({"role": "user", "content": message})
                history.append({"role": "assistant", "content": "⏳ Processing your request..."})
                yield history

                extracted_text = ""
                if uploaded_files and isinstance(uploaded_files, list):
                    for file in uploaded_files:
                        if not hasattr(file, 'name'):
                            continue
                        path = file.name
                        ext = path.split(".")[-1].lower()
                        json_text = convert_file_to_json(path, ext)
                        extracted_text += sanitize_utf8(json_text) + "\n"

                context = (
                    "You are an expert clinical AI assistant. Review this patient's history, medications, and notes, and ONLY provide a final answer summarizing what the doctor might have missed."
                )
                chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"

                generator = agent.run_gradio_chat(
                    message=chunked_prompt,
                    history=[],
                    temperature=0.3,
                    max_new_tokens=1024,
                    max_token=8192,
                    call_agent=False,
                    conversation=conversation,
                    uploaded_files=uploaded_files,
                    max_round=30
                )

                final_response = ""
                for update in generator:
                    if not update:
                        continue
                    if isinstance(update, list):
                        for msg in update:
                            if hasattr(msg, "content"):
                                final_response += msg.content
                    elif isinstance(update, str):
                        final_response += update

                    history[-1] = {"role": "assistant", "content": final_response.strip()}
                    yield history

                cleaned = final_response.strip().replace("[TOOL_CALLS]", "").strip()
                history[-1] = {"role": "assistant", "content": cleaned or "❌ No response."}
                yield history

            except Exception as chat_error:
                print(f"Chat handling error: {chat_error}")
                history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
                yield history

        inputs = [message_input, chatbot, conversation_state, file_upload]
        send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
        message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)

        gr.Examples([
            ["Upload your medical form and ask what the doctor might've missed."],
            ["This patient was treated with antibiotics for UTI. What else should we check?"],
            ["Is there anything abnormal in the attached blood work report?"]
        ], inputs=message_input)

    return demo