File size: 9,479 Bytes
1777737
3a20a5b
728def5
 
3a20a5b
1f0c81e
 
dfe34bb
446fbec
841c3cb
 
0e7a2f6
dfe34bb
8505d49
1f0c81e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13fb959
28560cd
1f0c81e
dfe34bb
28560cd
 
 
446fbec
 
 
3a20a5b
41945fe
3a20a5b
41945fe
3a20a5b
 
ff7a915
446fbec
 
 
 
 
1f0c81e
ff7a915
dfe34bb
5ff2c92
dfe34bb
28560cd
1f0c81e
dfe34bb
28560cd
446fbec
28560cd
446fbec
dfe34bb
446fbec
 
28560cd
446fbec
 
 
1f0c81e
 
446fbec
 
1f0c81e
446fbec
dfe34bb
5ff2c92
dfe34bb
 
1f0c81e
 
 
 
 
 
3a20a5b
 
 
 
 
1f0c81e
 
 
 
3a20a5b
774fd26
edb2500
28560cd
dfe34bb
4e4aafc
 
 
13fb959
4e4aafc
 
dfe34bb
4a6ed35
7c14cc2
1f0c81e
9086c95
 
13fb959
dfe34bb
28560cd
 
 
 
 
 
 
 
 
 
 
 
5ff2c92
28560cd
15df552
28560cd
c87fc4e
57d92c0
edb2500
9086c95
c87fc4e
9086c95
 
 
c87fc4e
9086c95
5ff2c92
9086c95
15df552
5ff2c92
 
 
 
 
 
 
446fbec
9086c95
1f0c81e
9086c95
5ff2c92
adec3a7
9086c95
 
1f0c81e
9086c95
 
 
 
 
7c14cc2
1f0c81e
9086c95
1f0c81e
 
9086c95
 
15df552
57d92c0
9086c95
 
 
 
 
 
 
88317c7
3a20a5b
57d92c0
 
88317c7
3a20a5b
28560cd
3ae42d2
 
3a20a5b
3492c23
13fb959
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import sys
import os
import pandas as pd
import pdfplumber
import gradio as gr
import re
from typing import List, Dict, Optional

# βœ… Fix: Add src to Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))

from txagent.txagent import TxAgent

def sanitize_utf8(text: str) -> str:
    """Clean text of problematic Unicode characters"""
    return text.encode('utf-8', 'ignore').decode('utf-8')

def clean_final_response(response: str) -> str:
    """Remove tool calls and other artifacts from final response"""
    # Split on TOOL_CALLS if present
    if '[TOOL_CALLS]' in response:
        response = response.split('[TOOL_CALLS]')[0]
    # Remove any remaining special tokens
    response = re.sub(r'\[[A-Z_]+\]', '', response)
    return response.strip()

def chunk_text(text: str, max_tokens: int = 8000) -> List[str]:
    """Split text into chunks based on token count estimate"""
    words = text.split()
    chunks = []
    current_chunk = []
    current_tokens = 0
    
    for word in words:
        # Estimate tokens (roughly 1 token per 4 characters)
        word_tokens = len(word) // 4 + 1
        if current_tokens + word_tokens > max_tokens and current_chunk:
            chunks.append(' '.join(current_chunk))
            current_chunk = [word]
            current_tokens = word_tokens
        else:
            current_chunk.append(word)
            current_tokens += word_tokens
    
    if current_chunk:
        chunks.append(' '.join(current_chunk))
    
    return chunks

def extract_all_text_from_csv_or_excel(file_path: str, progress=None, index=0, total=1) -> str:
    """Extract text from spreadsheet files with error handling"""
    try:
        if not os.path.exists(file_path):
            return f"File not found: {file_path}"

        if progress:
            progress((index + 1) / total, desc=f"Reading spreadsheet: {os.path.basename(file_path)}")

        if file_path.endswith(".csv"):
            df = pd.read_csv(file_path, encoding="utf-8", errors="replace", low_memory=False)
        elif file_path.endswith((".xls", ".xlsx")):
            df = pd.read_excel(file_path, engine="openpyxl")
        else:
            return f"Unsupported spreadsheet format: {file_path}"

        lines = []
        for _, row in df.iterrows():
            line = " | ".join(str(cell) for cell in row if pd.notna(cell))
            if line:
                lines.append(line)
        return f"πŸ“„ {os.path.basename(file_path)}\n\n" + "\n".join(lines)

    except Exception as e:
        return f"[Error reading {os.path.basename(file_path)}]: {str(e)}"

def extract_all_text_from_pdf(file_path: str, progress=None, index=0, total=1) -> str:
    """Extract text from PDF files with error handling"""
    try:
        if not os.path.exists(file_path):
            return f"PDF not found: {file_path}"

        extracted = []
        with pdfplumber.open(file_path) as pdf:
            num_pages = len(pdf.pages)
            for i, page in enumerate(pdf.pages):
                try:
                    text = page.extract_text() or ""
                    extracted.append(text.strip())
                    if progress:
                        progress((index + (i / num_pages)) / total, 
                               desc=f"Reading PDF: {os.path.basename(file_path)} ({i+1}/{num_pages})")
                except Exception as e:
                    extracted.append(f"[Error reading page {i+1}]: {str(e)}")
        return f"πŸ“„ {os.path.basename(file_path)}\n\n" + "\n\n".join(extracted)

    except Exception as e:
        return f"[Error reading PDF {os.path.basename(file_path)}]: {str(e)}"

def create_ui(agent: TxAgent):
    with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Patient Support System") as demo:
        gr.Markdown("<h1 style='text-align: center;'>πŸ“‹ CPS: Clinical Patient Support System</h1>")
        
        # Fix: Changed type to 'messages' to match Gradio requirements
        chatbot = gr.Chatbot(label="CPS Assistant", height=600, type="messages")
        
        file_upload = gr.File(
            label="Upload Medical File",
            file_types=[".pdf", ".txt", ".docx", ".jpg", ".png", ".csv", ".xls", ".xlsx"],
            file_count="multiple"
        )
        message_input = gr.Textbox(
            placeholder="Ask a biomedical question or just upload the files...", 
            show_label=False
        )
        send_button = gr.Button("Send", variant="primary")
        conversation_state = gr.State([])

        def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
            context = (
                "You are an expert clinical AI assistant reviewing medical form or interview data. "
                "Your job is to analyze this data and reason about any information or red flags that a human doctor might have overlooked. "
                "Provide a **detailed and structured response**, including examples, supporting evidence from the form, and clinical rationale for why these items matter. "
                "Ensure the output is informative and helpful for improving patient care. "
                "Do not hallucinate. Base the response only on the provided form content. "
                "End with a section labeled '🧠 Final Analysis' where you summarize key findings the doctor may have missed."
            )

            try:
                # Show processing message immediately
                history.append((message, "⏳ Processing your request..."))
                yield history

                extracted_text = ""
                if uploaded_files and isinstance(uploaded_files, list):
                    total_files = len(uploaded_files)
                    for index, file in enumerate(uploaded_files):
                        if not hasattr(file, 'name'):
                            continue
                        path = file.name
                        try:
                            if path.endswith((".csv", ".xls", ".xlsx")):
                                extracted_text += extract_all_text_from_csv_or_excel(path, progress, index, total_files) + "\n"
                            elif path.endswith(".pdf"):
                                extracted_text += extract_all_text_from_pdf(path, progress, index, total_files) + "\n"
                            else:
                                extracted_text += f"(Uploaded file: {os.path.basename(path)})\n"
                        except Exception as file_error:
                            extracted_text += f"[Error processing {os.path.basename(path)}]: {str(file_error)}\n"

                sanitized = sanitize_utf8(extracted_text.strip())
                chunks = chunk_text(sanitized)

                full_response = ""
                for i, chunk in enumerate(chunks):
                    chunked_prompt = (
                        f"{context}\n\n--- Uploaded File Content (Chunk {i+1}/{len(chunks)}) ---\n\n{chunk}\n\n"
                        f"--- End of Chunk ---\n\nNow begin your analysis:"
                    )

                    generator = agent.run_gradio_chat(
                        message=chunked_prompt,
                        history=[],
                        temperature=0.3,
                        max_new_tokens=1024,
                        max_token=8192,
                        call_agent=False,
                        conversation=conversation,
                        uploaded_files=uploaded_files,
                        max_round=30
                    )

                    # Collect all updates from the generator
                    chunk_response = ""
                    for update in generator:
                        if isinstance(update, str):
                            chunk_response += update
                        elif isinstance(update, list):
                            # Handle list of messages
                            for msg in update:
                                if hasattr(msg, 'content'):
                                    chunk_response += msg.content

                    full_response += chunk_response + "\n\n"

                # Clean up the final response
                full_response = clean_final_response(full_response.strip())
                
                # Remove the processing message and add the final response
                history[-1] = (message, full_response)
                yield history

            except Exception as chat_error:
                print(f"Chat handling error: {chat_error}")
                error_msg = "An error occurred while processing your request. Please try again."
                if len(history) > 0 and history[-1][1].startswith("⏳"):
                    history[-1] = (history[-1][0], error_msg)
                else:
                    history.append((message, error_msg))
                yield history

        inputs = [message_input, chatbot, conversation_state, file_upload]
        send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
        message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)

        gr.Examples([
            ["Upload your medical form and ask what the doctor might've missed."],
            ["This patient was treated with antibiotics for UTI. What else should we check?"],
            ["Is there anything abnormal in the attached blood work report?"]
        ], inputs=message_input)

    return demo