File size: 6,926 Bytes
4b0f1a8
511fb62
5ffaf72
9aeb1dd
0151c98
9aeb1dd
9438945
511fb62
 
410d25f
66e2fa0
79fb3cd
 
 
 
 
 
66e2fa0
79fb3cd
511fb62
5ffaf72
12efdad
dc06321
59ced24
 
a87f861
59ced24
9438945
 
 
 
79fb3cd
a87f861
12efdad
70839bb
dc06321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79fb3cd
511fb62
dc06321
66e2fa0
dc06321
 
 
66e2fa0
 
 
 
 
 
 
 
 
 
 
70839bb
 
a52dfd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import json
import torch
import logging
import numpy
import gradio as gr
from importlib.resources import files
from txagent import TxAgent
from tooluniverse import ToolUniverse

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Environment setup
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
current_dir = os.path.dirname(os.path.abspath(__file__))

# Configuration
CONFIG = {
    "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
    "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
    "tool_files": {
        "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
        "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
        "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
        "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')),
        "new_tool": os.path.join(current_dir, 'data', 'new_tool.json')
    }
}

class TxAgentApp:
    def __init__(self):
        self.agent = None
        self.initialize_agent()

    def initialize_agent(self):
        """Initialize the TxAgent with proper error handling"""
        try:
            self.prepare_tool_files()
            logger.info("Initializing TxAgent...")
            
            self.agent = TxAgent(
                model_name=CONFIG["model_name"],
                rag_model_name=CONFIG["rag_model_name"],
                tool_files_dict=CONFIG["tool_files"],
                force_finish=True,
                enable_checker=True,
                step_rag_num=10,
                seed=42,
                additional_default_tools=["DirectResponse", "RequireClarification"]
            )
            
            logger.info("Initializing model...")
            self.agent.init_model()
            logger.info("Agent initialization complete")
            
        except Exception as e:
            logger.error(f"Failed to initialize agent: {e}")
            raise

    def prepare_tool_files(self):
        """Prepare the tool files directory"""
        try:
            os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
            if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
                logger.info("Creating new_tool.json...")
                tu = ToolUniverse()
                tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
                with open(CONFIG["tool_files"]["new_tool"], "w") as f:
                    json.dump(tools, f, indent=2)
        except Exception as e:
            logger.error(f"Failed to prepare tool files: {e}")
            raise

    def respond(self, msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
        """Handle user message and generate response"""
        try:
            if not isinstance(msg, str) or len(msg.strip()) <= 10:
                return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}]

            message = msg.strip()
            chat_history.append({"role": "user", "content": message})
            formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]

            logger.info(f"Processing message: {message[:100]}...")
            
            response_generator = self.agent.run_gradio_chat(
                message=message,
                history=formatted_history,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                max_token=max_tokens,
                call_agent=multi_agent,
                conversation=conversation,
                max_round=max_round,
                seed=42
            )

            collected = ""
            for chunk in response_generator:
                if isinstance(chunk, dict) and "content" in chunk:
                    collected += chunk["content"]
                elif isinstance(chunk, str):
                    collected += chunk
                elif chunk is not None:
                    collected += str(chunk)

            chat_history.append({"role": "assistant", "content": collected or "No response generated."})
            return chat_history

        except Exception as e:
            logger.error(f"Error in respond function: {e}")
            chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
            return chat_history

    def create_demo(self):
        """Create and return the Gradio interface"""
        with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo:
            gr.Markdown("# TxAgent - Biomedical AI Assistant")
            
            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(
                        label="Conversation",
                        height=600
                    )
                    msg = gr.Textbox(
                        label="Your question",
                        placeholder="Ask a biomedical question...",
                        lines=3
                    )
                    submit = gr.Button("Ask", variant="primary")
                    
                with gr.Column(scale=1):
                    temp = gr.Slider(0, 1, value=0.3, label="Temperature")
                    max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens")
                    max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
                    max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
                    multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False)
                    clear_btn = gr.Button("Clear Chat")

            submit.click(
                self.respond,
                inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
                outputs=[chatbot]
            )
            clear_btn.click(lambda: [], None, chatbot, queue=False)
            
            # Add a dummy event to ensure the app stays alive
            demo.load(lambda: None, None, None)
            
        return demo

def main():
    """Main entry point for the application"""
    try:
        logger.info("Starting TxAgent application...")
        app = TxAgentApp()
        demo = app.create_demo()
        
        logger.info("Launching Gradio interface...")
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            share=True,
            show_error=True
        )
    except Exception as e:
        logger.error(f"Application failed to start: {e}")
        raise

if __name__ == "__main__":
    main()