File size: 3,719 Bytes
9aeb1dd
511fb62
 
81ad366
 
410d25f
66e2fa0
81ad366
79fb3cd
 
dc06321
 
81ad366
dc06321
81ad366
 
dc06321
81ad366
 
 
 
 
 
 
 
dc06321
81ad366
 
dc06321
81ad366
dc06321
 
81ad366
 
dc06321
 
 
81ad366
 
 
 
dc06321
81ad366
dc06321
81ad366
 
dc06321
81ad366
dc06321
81ad366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc06321
81ad366
 
 
 
 
dc06321
81ad366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70839bb
81ad366
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from txagent import TxAgent
from tooluniverse import ToolUniverse
import os
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TxAgentApp:
    def __init__(self):
        self.agent = self._initialize_agent()

    def _initialize_agent(self):
        """Initialize the TxAgent with A100 optimizations"""
        try:
            logger.info("Initializing TxAgent with A100 optimizations...")
            agent = TxAgent(
                model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
                rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
                device_map="auto",
                torch_dtype="auto",
                enable_xformers=True,
                max_model_len=8192  # Optimized for A100 80GB
            )
            logger.info("Model loading complete")
            return agent
        except Exception as e:
            logger.error(f"Initialization failed: {str(e)}")
            raise

    def respond(self, message, history):
        """Handle streaming responses with Gradio 5.23"""
        try:
            response_generator = self.agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=0.3,
                max_new_tokens=2048,
                stream=True
            )
            
            for chunk in response_generator:
                if isinstance(chunk, dict):
                    yield chunk.get("content", "")
                elif isinstance(chunk, str):
                    yield chunk
        except Exception as e:
            logger.error(f"Generation error: {str(e)}")
            yield f"⚠️ Error: {str(e)}"

# Initialize the app
app = TxAgentApp()

# Gradio 5.23 interface
with gr.Blocks(
    title="TxAgent Medical AI",
    theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
) as demo:
    gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
    
    with gr.Row(equal_height=False):
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(
                height=650,
                bubble_full_width=False,
                avatar_images=(
                    "https://example.com/user.png",  # Replace with actual avatars
                    "https://example.com/bot.png"
                )
            )
        with gr.Column(scale=1):
            with gr.Accordion("⚙️ Parameters", open=False):
                temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
                max_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
                rag_toggle = gr.Checkbox(value=True, label="Enable RAG")
            
            msg = gr.Textbox(
                label="Your medical query",
                placeholder="Enter your biomedical question...",
                lines=5,
                max_lines=10
            )
            submit_btn = gr.Button("Submit", variant="primary")
            clear_btn = gr.Button("Clear History")

    # Chat interface
    msg.submit(
        app.respond,
        [msg, chatbot],
        chatbot,
        api_name="chat"
    ).then(
        lambda: "", None, msg
    )

    submit_btn.click(
        app.respond,
        [msg, chatbot],
        chatbot,
        api_name="chat"
    ).then(
        lambda: "", None, msg
    )

    clear_btn.click(
        lambda: [], None, chatbot
    )

# Launch configuration
if __name__ == "__main__":
    demo.queue(
        concurrency_count=5,
        max_size=20
    ).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        favicon_path="icon.png"  # Add favicon
    )