File size: 5,590 Bytes
70839bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import logging
import torch
import gradio as gr
from txagent import TxAgent

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Configuration
MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
RAG_MODEL_NAME = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
TOOL_FILE = "data/new_tool.json"

# Environment setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

class TxAgentSystem:
    def __init__(self):
        self.agent = None
        self.is_initialized = False
        self.examples = [
            ["A 68-year-old with CKD prescribed metformin. Safe for renal clearance?"],
            ["30-year-old on Prozac diagnosed with WHIM. Safe to take Xolremdi?"]
        ]
        
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is not available - GPU required")

        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

        self._initialize_system()

    def _initialize_system(self):
        try:
            os.makedirs("data", exist_ok=True)
            if not os.path.exists(TOOL_FILE):
                with open(TOOL_FILE, "w") as f:
                    f.write("[]")

            logger.info("Initializing TxAgent...")
            
            # Initialize with RAG disabled first
            try:
                self.agent = TxAgent(
                    model_name=MODEL_NAME,
                    rag_model_name=RAG_MODEL_NAME,
                    tool_files_dict={"new_tool": TOOL_FILE},
                    force_finish=True,
                    enable_checker=True,
                    step_rag_num=10,
                    seed=100,
                    enable_rag=True
                )
            except Exception as e:
                logger.warning(f"Failed to initialize with RAG: {str(e)}")
                logger.info("Retrying without RAG...")
                self.agent = TxAgent(
                    model_name=MODEL_NAME,
                    rag_model_name=None,
                    tool_files_dict={"new_tool": TOOL_FILE},
                    force_finish=True,
                    enable_checker=True,
                    step_rag_num=0,
                    seed=100,
                    enable_rag=False
                )

            logger.info("Loading main model...")
            self.agent.init_model()

            self.is_initialized = True
            logger.info("System initialization completed successfully")

        except Exception as e:
            logger.error(f"System initialization failed: {str(e)}")
            self.is_initialized = False
            raise

    def chat_fn(self, message, history, temperature, max_tokens, rag_depth):
        if not self.is_initialized:
            return "", history + [(message, "System initialization failed. Please check logs.")]

        try:
            response = self.agent.run_gradio_chat(
                message=message,
                history=history,
                temperature=temperature,
                max_new_tokens=max_tokens,
                max_total_tokens=16384,
                enable_multi_agent=False,
                conv_history=history,
                max_steps=rag_depth,
                seed=100
            )
            new_history = history + [(message, response)]
            return "", new_history

        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            return "", history + [(message, "⚠️ GPU memory overflow. Please try a shorter query.")]

        except Exception as e:
            logger.error(f"Chat error: {str(e)}")
            return "", history + [(message, f"🚨 Error: {str(e)}")]

    def launch_ui(self):
        with gr.Blocks(theme=gr.themes.Soft(), title="TxAgent Medical AI") as demo:
            gr.Markdown("## 🧠 TxAgent (A100/H100 Optimized)")

            status = gr.Textbox(
                value="✅ System ready" if self.is_initialized else "❌ Initialization failed",
                label="System Status",
                interactive=False
            )

            with gr.Row():
                with gr.Column(scale=3):
                    chatbot = gr.Chatbot(height=600, label="Conversation History")
                    msg = gr.Textbox(label="Enter Medical Query", placeholder="Type your question here...")
                with gr.Column(scale=1):
                    temp = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
                    max_tokens = gr.Slider(128, 8192, value=2048, label="Max Response Tokens")
                    rag_depth = gr.Slider(1, 20, value=10, label="RAG Depth")
                    clear_btn = gr.Button("Clear History")

            gr.Examples(
                examples=self.examples,
                inputs=msg,
                label="Example Queries"
            )

            msg.submit(
                self.chat_fn,
                inputs=[msg, chatbot, temp, max_tokens, rag_depth],
                outputs=[msg, chatbot]
            )
            clear_btn.click(lambda: None, None, chatbot, queue=False)

        demo.launch(
            server_name="0.0.0.0",
            server_port=7860
        )

if __name__ == "__main__":
    try:
        system = TxAgentSystem()
        system.launch_ui()
    except Exception as e:
        logger.critical(f"Fatal error: {str(e)}")
        raise