|
import os |
|
import logging |
|
import torch |
|
import gradio as gr |
|
from txagent import TxAgent |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B" |
|
RAG_MODEL_NAME = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B" |
|
TOOL_FILE = "data/new_tool.json" |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["CUDA_MODULE_LOADING"] = "LAZY" |
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" |
|
|
|
class TxAgentSystem: |
|
def __init__(self): |
|
self.agent = None |
|
self.is_initialized = False |
|
self.examples = [ |
|
["A 68-year-old with CKD prescribed metformin. Safe for renal clearance?"], |
|
["30-year-old on Prozac diagnosed with WHIM. Safe to take Xolremdi?"] |
|
] |
|
|
|
if not torch.cuda.is_available(): |
|
raise RuntimeError("CUDA is not available - GPU required") |
|
|
|
logger.info(f"GPU: {torch.cuda.get_device_name(0)}") |
|
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") |
|
|
|
self._initialize_system() |
|
|
|
def _initialize_system(self): |
|
try: |
|
os.makedirs("data", exist_ok=True) |
|
if not os.path.exists(TOOL_FILE): |
|
with open(TOOL_FILE, "w") as f: |
|
f.write("[]") |
|
|
|
logger.info("Initializing TxAgent...") |
|
|
|
|
|
try: |
|
self.agent = TxAgent( |
|
model_name=MODEL_NAME, |
|
rag_model_name=RAG_MODEL_NAME, |
|
tool_files_dict={"new_tool": TOOL_FILE}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=10, |
|
seed=100, |
|
enable_rag=True |
|
) |
|
except Exception as e: |
|
logger.warning(f"Failed to initialize with RAG: {str(e)}") |
|
logger.info("Retrying without RAG...") |
|
self.agent = TxAgent( |
|
model_name=MODEL_NAME, |
|
rag_model_name=None, |
|
tool_files_dict={"new_tool": TOOL_FILE}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=0, |
|
seed=100, |
|
enable_rag=False |
|
) |
|
|
|
logger.info("Loading main model...") |
|
self.agent.init_model() |
|
|
|
self.is_initialized = True |
|
logger.info("System initialization completed successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"System initialization failed: {str(e)}") |
|
self.is_initialized = False |
|
raise |
|
|
|
def chat_fn(self, message, history, temperature, max_tokens, rag_depth): |
|
if not self.is_initialized: |
|
return "", history + [(message, "System initialization failed. Please check logs.")] |
|
|
|
try: |
|
response = self.agent.run_gradio_chat( |
|
message=message, |
|
history=history, |
|
temperature=temperature, |
|
max_new_tokens=max_tokens, |
|
max_total_tokens=16384, |
|
enable_multi_agent=False, |
|
conv_history=history, |
|
max_steps=rag_depth, |
|
seed=100 |
|
) |
|
new_history = history + [(message, response)] |
|
return "", new_history |
|
|
|
except torch.cuda.OutOfMemoryError: |
|
torch.cuda.empty_cache() |
|
return "", history + [(message, "β οΈ GPU memory overflow. Please try a shorter query.")] |
|
|
|
except Exception as e: |
|
logger.error(f"Chat error: {str(e)}") |
|
return "", history + [(message, f"π¨ Error: {str(e)}")] |
|
|
|
def launch_ui(self): |
|
with gr.Blocks(theme=gr.themes.Soft(), title="TxAgent Medical AI") as demo: |
|
gr.Markdown("## π§ TxAgent (A100/H100 Optimized)") |
|
|
|
status = gr.Textbox( |
|
value="β
System ready" if self.is_initialized else "β Initialization failed", |
|
label="System Status", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot(height=600, label="Conversation History") |
|
msg = gr.Textbox(label="Enter Medical Query", placeholder="Type your question here...") |
|
with gr.Column(scale=1): |
|
temp = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") |
|
max_tokens = gr.Slider(128, 8192, value=2048, label="Max Response Tokens") |
|
rag_depth = gr.Slider(1, 20, value=10, label="RAG Depth") |
|
clear_btn = gr.Button("Clear History") |
|
|
|
gr.Examples( |
|
examples=self.examples, |
|
inputs=msg, |
|
label="Example Queries" |
|
) |
|
|
|
msg.submit( |
|
self.chat_fn, |
|
inputs=[msg, chatbot, temp, max_tokens, rag_depth], |
|
outputs=[msg, chatbot] |
|
) |
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |
|
|
|
if __name__ == "__main__": |
|
try: |
|
system = TxAgentSystem() |
|
system.launch_ui() |
|
except Exception as e: |
|
logger.critical(f"Fatal error: {str(e)}") |
|
raise |