import sys import os import pandas as pd import json import gradio as gr from typing import List, Tuple, Union, Generator, BinaryIO, Dict, Any import re from datetime import datetime import atexit import torch.distributed as dist import logging # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Cleanup for PyTorch distributed def cleanup(): if dist.is_initialized(): logger.info("Cleaning up PyTorch distributed process group") dist.destroy_process_group() atexit.register(cleanup) # Setup directories persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: os.makedirs(d, exist_ok=True) os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))) from txagent.txagent import TxAgent MAX_MODEL_TOKENS = 32768 MAX_CHUNK_TOKENS = 8192 MAX_NEW_TOKENS = 2048 PROMPT_OVERHEAD = 500 def clean_response(text: str) -> str: text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) text = re.sub(r"\n{3,}", "\n\n", text) text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) return text.strip() def estimate_tokens(text: str) -> int: return len(text) // 3.5 + 1 def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str: """Handle Gradio file upload object which is a dictionary with 'name' and other keys""" all_text = [] try: if isinstance(file_obj, dict) and 'name' in file_obj: file_path = file_obj['name'] elif isinstance(file_obj, str): file_path = file_obj else: raise ValueError("Unsupported file input type") if not os.path.exists(file_path): raise FileNotFoundError(f"Temporary upload file not found at: {file_path}") xls = pd.ExcelFile(file_path) for sheet_name in xls.sheet_names: try: df = xls.parse(sheet_name).astype(str).fillna("") rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1) sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()] all_text.extend(sheet_text) except Exception as e: logger.warning(f"Could not parse sheet {sheet_name}: {e}") continue return "\n".join(all_text) except Exception as e: raise ValueError(f"āŒ Error processing Excel file: {str(e)}") def split_text_into_chunks(text: str) -> List[str]: effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0 for line in lines: t = estimate_tokens(line) if curr_tokens + t > effective_max: if curr_chunk: chunks.append("\n".join(curr_chunk)) curr_chunk, curr_tokens = [line], t else: curr_chunk.append(line) curr_tokens += t if curr_chunk: chunks.append("\n".join(curr_chunk)) return chunks def build_prompt_from_text(chunk: str) -> str: return f""" ### Clinical Records Analysis Please analyze these clinical notes and provide: - Key diagnostic indicators - Current medications and potential issues - Recommended follow-up actions - Any inconsistencies or concerns --- {chunk} --- Provide a structured response with clear medical reasoning. """ def validate_tool_file(tool_name: str, tool_path: str) -> None: """Validate the structure of a tool JSON file.""" try: if not os.path.exists(tool_path): raise FileNotFoundError(f"Tool file not found: {tool_path}") with open(tool_path, 'r') as f: tool_data = json.load(f) logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}") if isinstance(tool_data, list): for item in tool_data: if not isinstance(item, dict) or 'name' not in item: raise ValueError(f"Invalid tool format in {tool_name}: each item must be a dict with a 'name' key, got {item}") elif isinstance(tool_data, dict): if 'tools' in tool_data: if not isinstance(tool_data['tools'], list): raise ValueError(f"'tools' field in {tool_name} must be a list, got {type(tool_data['tools'])}") for item in tool_data['tools']: if not isinstance(item, dict) or 'name' not in item: raise ValueError(f"Invalid tool format in {tool_name}: each tool must be a dict with a 'name' key, got {item}") else: if 'name' not in tool_data: raise ValueError(f"Invalid tool format in {tool_name}: dict must have a 'name' key or 'tools' field, got {tool_data}") else: raise ValueError(f"Invalid tool file {tool_name}: must be a list or dict, got {type(tool_data)}") except Exception as e: logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}") raise def init_agent() -> TxAgent: tool_path = os.path.join(tool_cache_dir, "new_tool.json") logger.info(f"Checking for tool file at: {tool_path}") # Create default tool file if it doesn't exist if not os.path.exists(tool_path): default_tool = { "name": "new_tool", "description": "Default tool configuration", "version": "1.0", "tools": [ {"name": "dummy_tool", "description": "Dummy tool for testing", "version": "1.0"} ] } logger.info(f"Creating default tool file at: {tool_path}") with open(tool_path, 'w') as f: json.dump(default_tool, f) # Define tool files tool_files_dict = { 'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json', 'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json', 'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json', 'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json', 'new_tool': tool_path } # Validate all tool files for tool_name, tool_path in tool_files_dict.items(): validate_tool_file(tool_name, tool_path) # Initialize TxAgent try: logger.info(f"Initializing TxAgent with tool_files_dict: {tool_files_dict}") agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict=tool_files_dict, force_finish=True, enable_checker=True, step_rag_num=4, seed=100 ) logger.info("TxAgent initialized, calling init_model") agent.init_model() logger.info("TxAgent model initialized successfully") return agent except Exception as e: logger.error(f"Error initializing TxAgent: {str(e)}", exc_info=True) raise def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]: accumulated_text = "" try: if input_file is None: yield "āŒ Please upload a valid Excel file.", None, "" return try: text = extract_text_from_excel(input_file) chunks = split_text_into_chunks(text) except Exception as e: yield f"āŒ {str(e)}", None, "" return for i, chunk in enumerate(chunks): prompt = build_prompt_from_text(chunk) partial = "" for res in agent.run_gradio_chat( message=prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[] ): partial += res if isinstance(res, str) else res.content cleaned = clean_response(partial) accumulated_text += f"\n\nšŸ“„ Analysis Part {i+1}:\n{cleaned}" yield accumulated_text, None, "" summary_prompt = f"Please summarize this analysis:\n\n{accumulated_text}" final_report = "" for res in agent.run_gradio_chat( message=summary_prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[] ): final_report += res if isinstance(res, str) else res.content cleaned = clean_response(final_report) report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md") with open(report_path, 'w') as f: f.write(f"# Clinical Analysis Report\n\n{cleaned}") yield f"{accumulated_text}\n\nšŸ“Š Final Summary:\n{cleaned}", report_path, cleaned except Exception as e: logger.error(f"Processing error in stream_report: {str(e)}", exc_info=True) yield f"āŒ Processing error: {str(e)}", None, "" def create_ui(agent: TxAgent) -> gr.Blocks: with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px !important}") as demo: gr.Markdown("""# Clinical Records Analyzer""") with gr.Row(): file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"]) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Row(): with gr.Column(scale=2): report_output = gr.Markdown() with gr.Column(scale=1): report_file = gr.File(label="Download Report", visible=False) full_output = gr.State() analyze_btn.click( fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output] ) return demo if __name__ == "__main__": try: agent = init_agent() demo = create_ui(agent) logger.info("Launching Gradio UI") demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Application error: {str(e)}", exc_info=True) print(f"Application error: {str(e)}", file=sys.stderr) sys.exit(1)