File size: 5,430 Bytes
9aeb1dd
dec1312
0b9e159
 
410d25f
66e2fa0
81ad366
79fb3cd
 
dc06321
 
81ad366
dc06321
81ad366
dec1312
dc06321
dec1312
81ad366
 
 
0b9e159
 
 
 
dec1312
0b9e159
 
 
 
 
dec1312
0b9e159
 
dec1312
dc06321
81ad366
 
dc06321
81ad366
dc06321
 
90e4214
0b9e159
dc06321
0b9e159
 
 
 
dc06321
0b9e159
90e4214
0b9e159
 
 
 
 
 
 
dc06321
0b9e159
 
dc06321
0b9e159
 
dc06321
0b9e159
 
 
 
 
 
 
dc06321
0b9e159
81ad366
 
0b9e159
 
 
81ad366
0b9e159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ad366
0b9e159
 
 
 
 
 
 
90e4214
0b9e159
 
 
 
 
 
 
81ad366
70839bb
0b9e159
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import os
import logging
from txagent import TxAgent

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TxAgentApp:
    def __init__(self):
        self.agent = self._initialize_agent()

    def _initialize_agent(self):
        """Initialize the TxAgent with proper parameters"""
        try:
            logger.info("Initializing TxAgent...")
            agent = TxAgent(
                model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
                rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
                enable_finish=True,
                enable_rag=True,
                enable_summary=False,
                init_rag_num=0,
                step_rag_num=10,
                summary_mode='step',
                summary_skip_last_k=0,
                summary_context_length=None,
                force_finish=True,
                avoid_repeat=True,
                seed=42,
                enable_checker=True,
                enable_chat=False,
                additional_default_tools=["DirectResponse", "RequireClarification"]
            )
            logger.info("Model loading complete")
            return agent
        except Exception as e:
            logger.error(f"Initialization failed: {str(e)}")
            raise

    def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
        """Handle streaming responses with Gradio"""
        try:
            if not isinstance(message, str) or len(message.strip()) <= 10:
                yield "Please provide a valid message longer than 10 characters."
                return

            response_generator = self.agent.run_gradio_chat(
                message=message.strip(),
                history=chat_history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation_state,
                max_round=max_round,
                seed=42
            )

            full_response = ""
            for chunk in response_generator:
                if isinstance(chunk, dict) and "content" in chunk:
                    content = chunk["content"]
                elif isinstance(chunk, str):
                    content = chunk
                else:
                    content = str(chunk)
                
                full_response += content
                yield full_response

        except Exception as e:
            logger.error(f"Error in respond function: {str(e)}")
            yield f"⚠️ Error: {str(e)}"

def create_demo():
    """Create and configure the Gradio interface"""
    app = TxAgentApp()
    
    with gr.Blocks(
        title="TxAgent Medical AI",
        theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
    ) as demo:
        gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
        
        with gr.Row(equal_height=False):
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(
                    height=650,
                    bubble_full_width=False,
                    render_markdown=True
                )
                msg = gr.Textbox(
                    label="Your medical query",
                    placeholder="Enter your biomedical question...",
                    lines=5,
                    max_lines=10
                )
                
            with gr.Column(scale=1):
                with gr.Accordion("⚙️ Parameters", open=False):
                    temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
                    max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
                    max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
                    max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
                    multi_agent = gr.Checkbox(value=False, label="Multi-Agent Mode")
                
                submit_btn = gr.Button("Submit", variant="primary")
                clear_btn = gr.Button("Clear History")
        
        conversation_state = gr.State([])
        
        # Chat interface
        msg.submit(
            app.respond,
            [msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
            chatbot
        ).then(lambda: "", None, msg)
        
        submit_btn.click(
            app.respond,
            [msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
            chatbot
        ).then(lambda: "", None, msg)
        
        clear_btn.click(
            lambda: ([], []), 
            None, 
            [chatbot, conversation_state]
        )
    
    return demo

def main():
    """Main entry point for the application"""
    try:
        logger.info("Starting TxAgent application...")
        demo = create_demo()
        
        logger.info("Launching Gradio interface...")
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=False
        )
    except Exception as e:
        logger.error(f"Application failed to start: {str(e)}")
        raise

if __name__ == "__main__":
    main()