# āœ… Fully updated app.py for TxAgent with strict tool validation to prevent runtime errors import sys import os import pandas as pd import json import gradio as gr from typing import List, Tuple, Union, Generator, Dict, Any import re from datetime import datetime import atexit import torch.distributed as dist import logging # Logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger("app") # Cleanup def cleanup(): if dist.is_initialized(): logger.info("Cleaning up PyTorch distributed process group") dist.destroy_process_group() atexit.register(cleanup) # Directories persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: os.makedirs(d, exist_ok=True) os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir # Import TxAgent sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))) from txagent.txagent import TxAgent MAX_MODEL_TOKENS = 32768 MAX_CHUNK_TOKENS = 8192 MAX_NEW_TOKENS = 2048 PROMPT_OVERHEAD = 500 def clean_response(text: str) -> str: text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) text = re.sub(r"\n{3,}", "\n\n", text) text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) return text.strip() def estimate_tokens(text: str) -> int: return len(text) // 3.5 + 1 def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str: if isinstance(file_obj, dict) and 'name' in file_obj: file_path = file_obj['name'] elif isinstance(file_obj, str): file_path = file_obj else: raise ValueError("Unsupported file input type") if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") xls = pd.ExcelFile(file_path) all_text = [] for sheet in xls.sheet_names: try: df = xls.parse(sheet).astype(str).fillna("") rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1) sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()] all_text.extend(sheet_text) except Exception as e: logger.warning(f"Failed to parse {sheet}: {e}") return "\n".join(all_text) def split_text_into_chunks(text: str) -> List[str]: lines = text.split("\n") chunks, current, current_tokens = [], [], 0 max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD for line in lines: t = estimate_tokens(line) if current_tokens + t > max_tokens: chunks.append("\n".join(current)) current, current_tokens = [line], t else: current.append(line) current_tokens += t if current: chunks.append("\n".join(current)) return chunks def build_prompt_from_text(chunk: str) -> str: return f""" ### Clinical Records Analysis Please analyze these clinical notes and provide: - Key diagnostic indicators - Current medications and potential issues - Recommended follow-up actions - Any inconsistencies or concerns --- {chunk} --- Provide a structured response with clear medical reasoning. """ def clean_and_rewrite_tool_file(original_path: str, cleaned_path: str) -> bool: try: with open(original_path, "r") as f: data = json.load(f) if isinstance(data, dict) and "tools" in data: tools = data["tools"] elif isinstance(data, list): tools = data elif isinstance(data, dict) and "name" in data: tools = [data] else: return False if not all(isinstance(t, dict) and "name" in t for t in tools): return False with open(cleaned_path, "w") as out: json.dump(tools, out) return True except Exception as e: logger.error(f"Failed to clean tool {original_path}: {e}") return False def init_agent() -> TxAgent: new_tool_path = os.path.join(tool_cache_dir, "new_tool.json") if not os.path.exists(new_tool_path): with open(new_tool_path, 'w') as f: json.dump([{"name": "dummy_tool", "description": "test", "version": "1.0"}], f) raw_tool_files = { 'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json', 'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json', 'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json', 'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json', 'new_tool': new_tool_path } validated_paths = {} for name, original_path in raw_tool_files.items(): cleaned_path = os.path.join(tool_cache_dir, f"{name}_cleaned.json") if clean_and_rewrite_tool_file(original_path, cleaned_path): validated_paths[name] = cleaned_path if not validated_paths: raise ValueError("No valid tools found after sanitizing.") agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict=validated_paths, force_finish=True, enable_checker=True, step_rag_num=4, seed=100 ) agent.init_model() return agent def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]: accumulated = "" try: if input_file is None: yield "āŒ Upload a valid Excel file.", None, "" return text = extract_text_from_excel(input_file) chunks = split_text_into_chunks(text) for i, chunk in enumerate(chunks): prompt = build_prompt_from_text(chunk) result = "" for out in agent.run_gradio_chat( message=prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[]): result += out if isinstance(out, str) else out.content cleaned = clean_response(result) accumulated += f"\n\nšŸ“„ Part {i+1}:\n{cleaned}" yield accumulated, None, "" summary_prompt = f"Summarize this analysis:\n\n{accumulated}" summary = "" for out in agent.run_gradio_chat( message=summary_prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[]): summary += out if isinstance(out, str) else out.content final = clean_response(summary) report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md") with open(report_path, 'w') as f: f.write(f"# Clinical Report\n\n{final}") yield f"{accumulated}\n\nšŸ“Š Final Summary:\n{final}", report_path, final except Exception as e: logger.error(f"Stream error: {e}", exc_info=True) yield f"āŒ Error: {str(e)}", None, "" def create_ui(agent: TxAgent) -> gr.Blocks: with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# šŸ„ Clinical Records Analyzer") with gr.Row(): file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"]) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Row(): with gr.Column(scale=2): report_output = gr.Markdown() with gr.Column(scale=1): report_file = gr.File(label="Download", visible=False) full_output = gr.State() analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output]) return demo if __name__ == "__main__": try: agent = init_agent() demo = create_ui(agent) demo.launch(server_name="0.0.0.0", server_port=7860, share=False) except Exception as e: logger.error(f"App error: {e}", exc_info=True) print(f"āŒ Application error: {e}", file=sys.stderr) sys.exit(1)