File size: 5,590 Bytes
70839bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import logging
import torch
import gradio as gr
from txagent import TxAgent
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Configuration
MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
RAG_MODEL_NAME = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
TOOL_FILE = "data/new_tool.json"
# Environment setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
class TxAgentSystem:
def __init__(self):
self.agent = None
self.is_initialized = False
self.examples = [
["A 68-year-old with CKD prescribed metformin. Safe for renal clearance?"],
["30-year-old on Prozac diagnosed with WHIM. Safe to take Xolremdi?"]
]
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available - GPU required")
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
self._initialize_system()
def _initialize_system(self):
try:
os.makedirs("data", exist_ok=True)
if not os.path.exists(TOOL_FILE):
with open(TOOL_FILE, "w") as f:
f.write("[]")
logger.info("Initializing TxAgent...")
# Initialize with RAG disabled first
try:
self.agent = TxAgent(
model_name=MODEL_NAME,
rag_model_name=RAG_MODEL_NAME,
tool_files_dict={"new_tool": TOOL_FILE},
force_finish=True,
enable_checker=True,
step_rag_num=10,
seed=100,
enable_rag=True
)
except Exception as e:
logger.warning(f"Failed to initialize with RAG: {str(e)}")
logger.info("Retrying without RAG...")
self.agent = TxAgent(
model_name=MODEL_NAME,
rag_model_name=None,
tool_files_dict={"new_tool": TOOL_FILE},
force_finish=True,
enable_checker=True,
step_rag_num=0,
seed=100,
enable_rag=False
)
logger.info("Loading main model...")
self.agent.init_model()
self.is_initialized = True
logger.info("System initialization completed successfully")
except Exception as e:
logger.error(f"System initialization failed: {str(e)}")
self.is_initialized = False
raise
def chat_fn(self, message, history, temperature, max_tokens, rag_depth):
if not self.is_initialized:
return "", history + [(message, "System initialization failed. Please check logs.")]
try:
response = self.agent.run_gradio_chat(
message=message,
history=history,
temperature=temperature,
max_new_tokens=max_tokens,
max_total_tokens=16384,
enable_multi_agent=False,
conv_history=history,
max_steps=rag_depth,
seed=100
)
new_history = history + [(message, response)]
return "", new_history
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
return "", history + [(message, "⚠️ GPU memory overflow. Please try a shorter query.")]
except Exception as e:
logger.error(f"Chat error: {str(e)}")
return "", history + [(message, f"🚨 Error: {str(e)}")]
def launch_ui(self):
with gr.Blocks(theme=gr.themes.Soft(), title="TxAgent Medical AI") as demo:
gr.Markdown("## 🧠 TxAgent (A100/H100 Optimized)")
status = gr.Textbox(
value="✅ System ready" if self.is_initialized else "❌ Initialization failed",
label="System Status",
interactive=False
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=600, label="Conversation History")
msg = gr.Textbox(label="Enter Medical Query", placeholder="Type your question here...")
with gr.Column(scale=1):
temp = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
max_tokens = gr.Slider(128, 8192, value=2048, label="Max Response Tokens")
rag_depth = gr.Slider(1, 20, value=10, label="RAG Depth")
clear_btn = gr.Button("Clear History")
gr.Examples(
examples=self.examples,
inputs=msg,
label="Example Queries"
)
msg.submit(
self.chat_fn,
inputs=[msg, chatbot, temp, max_tokens, rag_depth],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
demo.launch(
server_name="0.0.0.0",
server_port=7860
)
if __name__ == "__main__":
try:
system = TxAgentSystem()
system.launch_ui()
except Exception as e:
logger.critical(f"Fatal error: {str(e)}")
raise |