File size: 5,454 Bytes
9aeb1dd
dec1312
0b9e159
 
dcb29df
410d25f
66e2fa0
81ad366
79fb3cd
 
dc06321
 
81ad366
dc06321
81ad366
dec1312
dc06321
dec1312
dcb29df
 
 
 
 
 
 
 
 
81ad366
 
 
dcb29df
0b9e159
 
 
 
dec1312
0b9e159
 
 
 
 
dec1312
0b9e159
 
dec1312
dc06321
dcb29df
 
 
 
81ad366
 
dc06321
81ad366
dc06321
 
90e4214
0b9e159
dc06321
0b9e159
dcb29df
0b9e159
dcb29df
 
 
 
 
 
0b9e159
90e4214
0b9e159
 
 
 
 
 
 
dcb29df
 
 
dc06321
dcb29df
0b9e159
dcb29df
0b9e159
dcb29df
0b9e159
dc06321
0b9e159
dcb29df
81ad366
0b9e159
 
 
81ad366
dcb29df
 
0b9e159
dcb29df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b9e159
 
 
dcb29df
0b9e159
dcb29df
0b9e159
dcb29df
0b9e159
dcb29df
 
 
0b9e159
dcb29df
0b9e159
 
 
 
81ad366
0b9e159
 
 
 
 
 
 
90e4214
0b9e159
 
 
 
 
 
 
81ad366
70839bb
0b9e159
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import os
import logging
from txagent import TxAgent
from tooluniverse import ToolUniverse

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TxAgentApp:
    def __init__(self):
        self.agent = self._initialize_agent()

    def _initialize_agent(self):
        """Initialize the TxAgent with proper parameters"""
        try:
            logger.info("Initializing TxAgent...")
            
            # Initialize default tool files
            tool_files = {
                "opentarget": "opentarget_tools.json",
                "fda_drug_label": "fda_drug_labeling_tools.json",
                "special_tools": "special_tools.json",
                "monarch": "monarch_tools.json"
            }
            
            agent = TxAgent(
                model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
                rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
                tool_files_dict=tool_files,  # This is critical!
                enable_finish=True,
                enable_rag=True,
                enable_summary=False,
                init_rag_num=0,
                step_rag_num=10,
                summary_mode='step',
                summary_skip_last_k=0,
                summary_context_length=None,
                force_finish=True,
                avoid_repeat=True,
                seed=42,
                enable_checker=True,
                enable_chat=False,
                additional_default_tools=["DirectResponse", "RequireClarification"]
            )
            
            # Explicitly initialize the model
            agent.init_model()
            
            logger.info("Model loading complete")
            return agent
        except Exception as e:
            logger.error(f"Initialization failed: {str(e)}")
            raise

    def respond(self, message, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_round):
        """Handle streaming responses with Gradio"""
        try:
            if not isinstance(message, str) or len(message.strip()) <= 10:
                return chat_history + [("", "Please provide a valid message longer than 10 characters.")]

            # Convert chat history to list of tuples if needed
            if chat_history and isinstance(chat_history[0], dict):
                chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]

            response = ""
            for chunk in self.agent.run_gradio_chat(
                message=message.strip(),
                history=chat_history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation_state,
                max_round=max_round,
                seed=42
            ):
                if isinstance(chunk, dict):
                    response += chunk.get("content", "")
                elif isinstance(chunk, str):
                    response += chunk
                else:
                    response += str(chunk)
                
                yield chat_history + [("user", message), ("assistant", response)]

        except Exception as e:
            logger.error(f"Error in respond function: {str(e)}")
            yield chat_history + [("", f"⚠️ Error: {str(e)}")]

def create_demo():
    """Create and configure the Gradio interface"""
    app = TxAgentApp()
    
    with gr.Blocks(title="TxAgent Medical AI") as demo:
        gr.Markdown("# TxAgent Biomedical Assistant")
        
        chatbot = gr.Chatbot(
            label="Conversation",
            height=600,
            bubble_full_width=False
        )
        
        msg = gr.Textbox(
            label="Your medical query",
            placeholder="Enter your biomedical question...",
            lines=3
        )
        
        with gr.Row():
            temp = gr.Slider(0, 1, value=0.3, label="Temperature")
            max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
            max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
            max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
            multi_agent = gr.Checkbox(label="Multi-Agent Mode")
        
        submit = gr.Button("Submit")
        clear = gr.Button("Clear")
        
        conversation_state = gr.State([])
        
        submit.click(
            app.respond,
            [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
            chatbot
        )
        
        clear.click(lambda: [], None, chatbot)
        
        msg.submit(
            app.respond,
            [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
            chatbot
        )
    
    return demo

def main():
    """Main entry point for the application"""
    try:
        logger.info("Starting TxAgent application...")
        demo = create_demo()
        
        logger.info("Launching Gradio interface...")
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=False
        )
    except Exception as e:
        logger.error(f"Application failed to start: {str(e)}")
        raise

if __name__ == "__main__":
    main()