import sys import os import pandas as pd import pdfplumber import json import gradio as gr from typing import List from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import psutil import subprocess import logging import torch import gc from diskcache import Cache import time # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Persistent directory persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache") for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]: os.makedirs(directory, exist_ok=True) os.environ["HF_HOME"] = model_cache_dir os.environ["TRANSFORMERS_CACHE"] = model_cache_dir os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # Initialize cache with 10GB limit cache = Cache(file_cache_dir, size_limit=10 * 1024**3) def sanitize_utf8(text: str) -> str: return text.encode("utf-8", "ignore").decode("utf-8") def file_hash(path: str) -> str: with open(path, "rb") as f: return hashlib.md5(f.read()).hexdigest() def extract_all_pages(file_path: str, progress_callback=None) -> str: try: with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if total_pages == 0: logger.error("No pages found in PDF") return "" batch_size = 10 batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)] text_chunks = [""] * total_pages processed_pages = 0 def extract_batch(start: int, end: int) -> List[tuple]: results = [] with pdfplumber.open(file_path) as pdf: for idx, page in enumerate(pdf.pages[start:end], start=start): page_text = page.extract_text() or "" results.append((idx, f"=== Page {idx + 1} ===\n{page_text.strip()}")) logger.debug("Extracted page %d, text length: %d chars", idx + 1, len(page_text)) return results with ThreadPoolExecutor(max_workers=6) as executor: futures = [executor.submit(extract_batch, start, end) for start, end in batches] for future in as_completed(futures): for page_num, text in future.result(): if page_num < len(text_chunks): text_chunks[page_num] = text else: logger.warning("Page number %d out of range for text_chunks (size %d)", page_num, len(text_chunks)) processed_pages += batch_size if progress_callback: progress_callback(min(processed_pages, total_pages), total_pages) logger.info("Processed %d/%d pages", min(processed_pages, total_pages), total_pages) extracted_text = "\n\n".join(filter(None, text_chunks)) logger.info("Extracted %d pages, total length: %d chars", total_pages, len(extracted_text)) return extracted_text except Exception as e: logger.error("PDF processing error: %s", e) return f"PDF processing error: {str(e)}" def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str: try: file_h = file_hash(file_path) cache_key = f"{file_h}_{file_type}" if cache_key in cache: logger.info("Using cached extraction for %s", file_path) return cache[cache_key] if file_type == "pdf": text = extract_all_pages(file_path, progress_callback) result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"}) elif file_type == "csv": df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip") content = df.fillna("").astype(str).values.tolist() result = json.dumps({"filename": os.path.basename(file_path), "rows": content}) elif file_type in ["xls", "xlsx"]: try: df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str) except Exception: df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str) content = df.fillna("").astype(str).values.tolist() result = json.dumps({"filename": os.path.basename(file_path), "rows": content}) else: result = json.dumps({"error": f"Unsupported file type: {file_type}"}) cache[cache_key] = result logger.info("Cached extraction for %s, size: %d bytes", file_path, len(result)) return result except Exception as e: logger.error("Error processing %s: %s", os.path.basename(file_path), e) return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}) def log_system_usage(tag=""): try: cpu = psutil.cpu_percent(interval=1) mem = psutil.virtual_memory() logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2)) result = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], capture_output=True, text=True ) if result.returncode == 0: used, total, util = result.stdout.strip().split(", ") logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util) except Exception as e: logger.error("[%s] GPU/CPU monitor failed: %s", tag, e) def clean_response(text: str) -> str: text = sanitize_utf8(text) text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL) text = re.sub(r"\n{3,}", "\n\n", text) text = re.sub(r"[^\n#\-\*\w\s\.\,\:\(\)]+", "", text) sections = {} current_section = None lines = text.splitlines() for line in lines: line = line.strip() if not line: continue section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line) if section_match: current_section = section_match.group(1) if current_section not in sections: sections[current_section] = [] continue finding_match = re.match(r"-\s*.+", line) if finding_match and current_section and not re.match(r"-\s*No issues identified", line): sections[current_section].append(line) cleaned = [] for heading, findings in sections.items(): if findings: cleaned.append(f"### {heading}\n" + "\n".join(findings)) text = "\n\n".join(cleaned).strip() logger.debug("Cleaned response length: %d chars", len(text)) return text if text else "" def summarize_findings(combined_response: str) -> str: if not combined_response or all("No oversights identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")): logger.info("No clinical oversights identified in analysis") return "### Summary of Clinical Oversights\nNo critical oversights identified in the provided records." sections = {} lines = combined_response.splitlines() current_section = None for line in lines: line = line.strip() if not line: continue section_match = re.match(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line) if section_match: current_section = section_match.group(1) if current_section not in sections: sections[current_section] = [] continue finding_match = re.match(r"-\s*(.+)", line) if finding_match and current_section: sections[current_section].append(finding_match.group(1)) summary_lines = [] for heading, findings in sections.items(): if findings: summary = f"- **{heading}**: {'; '.join(findings[:2])}. Risks: {heading.lower()} may lead to adverse outcomes. Recommend: urgent review and specialist referral." summary_lines.append(summary) if not summary_lines: logger.info("No clinical oversights identified after summarization") return "### Summary of Clinical Oversights\nNo critical oversights identified." summary = "### Summary of Clinical Oversights\n" + "\n".join(summary_lines) logger.info("Summarized findings: %s", summary[:100]) return summary def init_agent(): logger.info("Initializing model...") log_system_usage("Before Load") default_tool_path = os.path.abspath("data/new_tool.json") target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") if not os.path.exists(target_tool_path): shutil.copy(default_tool_path, target_tool_path) agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": target_tool_path}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, additional_default_tools=[], ) agent.init_model() log_system_usage("After Load") logger.info("Agent Ready") return agent def create_ui(agent): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("