File size: 5,094 Bytes
1777737
3a20a5b
728def5
 
3a20a5b
dfe34bb
3a20a5b
dfe34bb
0e7a2f6
dfe34bb
728def5
3a20a5b
dfe34bb
3a20a5b
 
 
 
 
 
 
 
dfe34bb
 
3a20a5b
dfe34bb
728def5
3a20a5b
dfe34bb
 
 
3a20a5b
 
dfe34bb
 
 
 
 
3a20a5b
 
dfe34bb
 
 
 
728def5
dfe34bb
3492c23
3ae42d2
 
3a20a5b
 
 
 
 
 
 
 
774fd26
3492c23
3a20a5b
dfe34bb
4e4aafc
 
 
 
 
 
dfe34bb
4a6ed35
3a20a5b
dfe34bb
3a20a5b
 
 
dfe34bb
3a20a5b
 
dfe34bb
3a20a5b
0e7a2f6
 
3a20a5b
 
0e7a2f6
3ae42d2
4a6ed35
3492c23
 
 
 
 
 
 
 
3a20a5b
3492c23
 
 
88317c7
 
3a20a5b
 
 
88317c7
3a20a5b
3ae42d2
 
 
3a20a5b
3492c23
0e7a2f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sys
import os
import pandas as pd
import pdfplumber
import gradio as gr

# ✅ Add src to Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from txagent.txagent import TxAgent


def extract_all_text_from_csv_or_excel(file_path, progress=None, index=0, total=1):
    try:
        if file_path.endswith(".csv"):
            df = pd.read_csv(file_path, low_memory=False)
        elif file_path.endswith((".xls", ".xlsx")):
            df = pd.read_excel(file_path)
        else:
            return f"Unsupported spreadsheet format: {file_path}"
        if progress:
            progress((index + 1) / total, desc=f"Processed table: {os.path.basename(file_path)}")
        return df.to_string(index=False)
    except Exception as e:
        return f"Error parsing file: {e}"


def extract_all_text_from_pdf(file_path, progress=None, index=0, total=1):
    extracted = []
    try:
        with pdfplumber.open(file_path) as pdf:
            num_pages = len(pdf.pages)
            for i, page in enumerate(pdf.pages):
                tables = page.extract_tables()
                for table in tables:
                    for row in table:
                        if any(row):
                            extracted.append("\t".join([cell or "" for cell in row]))
                if progress:
                    progress((index + i / num_pages) / total, desc=f"Parsing PDF: {os.path.basename(file_path)} ({i+1}/{num_pages})")
        return "\n".join(extracted)
    except Exception as e:
        return f"Error parsing PDF: {e}"


def create_ui(agent: TxAgent):
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
        chatbot = gr.Chatbot(label="CPS Assistant", height=600, type="messages")

        file_upload = gr.File(
            label="Upload Medical File",
            file_types=[".pdf", ".txt", ".docx", ".jpg", ".png", ".csv", ".xls", ".xlsx"],
            file_count="multiple"
        )
        message_input = gr.Textbox(placeholder="Ask a biomedical question or just upload the files...", show_label=False)
        send_button = gr.Button("Send", variant="primary")
        conversation_state = gr.State([])

        def handle_chat(message, history, conversation, uploaded_files, progress=gr.Progress()):
            context = (
                "You are an expert clinical AI assistant reviewing medical form or interview data. "
                "Your job is to analyze this data and reason about any information or red flags that a human doctor might have overlooked. "
                "Provide a **detailed and structured response**, including examples, supporting evidence from the form, and clinical rationale for why these items matter. "
                "Ensure the output is informative and helpful for improving patient care. "
                "Do not hallucinate. Base the response only on the provided form content. "
                "End with a section labeled '🧠 Final Analysis' where you summarize key findings the doctor may have missed."
            )

            if uploaded_files:
                extracted_text = ""
                total_files = len(uploaded_files)

                for index, file in enumerate(uploaded_files):
                    path = file.name
                    if path.endswith((".csv", ".xls", ".xlsx")):
                        extracted_text += extract_all_text_from_csv_or_excel(path, progress, index, total_files) + "\n"
                    elif path.endswith(".pdf"):
                        extracted_text += extract_all_text_from_pdf(path, progress, index, total_files) + "\n"
                    else:
                        extracted_text += f"(Uploaded file: {os.path.basename(path)})\n"
                        if progress:
                            progress((index + 1) / total_files, desc=f"Skipping unsupported file: {os.path.basename(path)}")

                message = f"{context}\n\n---\n{extracted_text.strip()}\n---\n\nBegin your reasoning."

            generator = agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=0.3,
                max_new_tokens=1024,
                max_token=8192,
                call_agent=False,
                conversation=conversation,
                uploaded_files=uploaded_files,
                max_round=30
            )
            for update in generator:
                yield update

        inputs = [message_input, chatbot, conversation_state, file_upload]
        send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
        message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)

        gr.Examples([
            ["Upload your medical form and ask what the doctor might’ve missed."],
            ["This patient was treated with antibiotics for UTI. What else should we check?"],
            ["Is there anything abnormal in the attached blood work report?"]
        ], inputs=message_input)

    return demo