import sys import os import pandas as pd import pdfplumber import gradio as gr from typing import List, Dict from concurrent.futures import ThreadPoolExecutor, as_completed import hashlib import shutil import re import logging import torch import gc from diskcache import Cache from transformers import AutoTokenizer from functools import lru_cache from difflib import SequenceMatcher # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants MAX_TOKENS = 1800 BATCH_SIZE = 1 MAX_WORKERS = 2 CHUNK_SIZE = 5 MODEL_MAX_TOKENS = 131072 MAX_TEXT_LENGTH = 500000 MAX_ROWS_TO_PROCESS = 1000 # Limit for Excel/CSV rows # Persistent directory setup persistent_dir = "/data/hf_cache" os.makedirs(persistent_dir, exist_ok=True) model_cache_dir = os.path.join(persistent_dir, "txagent_models") tool_cache_dir = os.path.join(persistent_dir, "tool_cache") file_cache_dir = os.path.join(persistent_dir, "cache") report_dir = os.path.join(persistent_dir, "reports") os.makedirs(report_dir, exist_ok=True) os.environ.update({ "HF_HOME": model_cache_dir, "TOKENIZERS_PARALLELISM": "false", }) current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) from txagent.txagent import TxAgent # Initialize cache cache = Cache(file_cache_dir, size_limit=10 * 1024**3) @lru_cache(maxsize=1) def get_tokenizer(): return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") def sanitize_utf8(text: str) -> str: return text.encode("utf-8", "ignore").decode("utf-8") def file_hash(path: str) -> str: hash_md5 = hashlib.md5() with open(path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def extract_pdf_page(page, tokenizer, max_tokens=MAX_TOKENS) -> List[str]: try: text = page.extract_text() or "" text = sanitize_utf8(text) if len(text) > MAX_TEXT_LENGTH // 10: text = text[:MAX_TEXT_LENGTH // 10] tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) > max_tokens: chunks = [] current_chunk = [] current_length = 0 for token in tokens: if current_length + 1 > max_tokens: chunks.append(tokenizer.decode(current_chunk)) current_chunk = [token] current_length = 1 else: current_chunk.append(token) current_length += 1 if current_chunk: chunks.append(tokenizer.decode(current_chunk)) return chunks return [text] except Exception as e: logger.warning(f"Error extracting page {page.page_number}: {str(e)}") return [] def extract_all_pages(file_path: str) -> List[str]: try: tokenizer = get_tokenizer() with pdfplumber.open(file_path) as pdf: total_pages = len(pdf.pages) if total_pages == 0: return ["PDF appears to be empty"] results = [] for i in range(0, min(total_pages, 50)): # Limit to first 50 pages try: page = pdf.pages[i] chunks = extract_pdf_page(page, tokenizer) for chunk in chunks: results.append(f"=== Page {i+1} ===\n{chunk}") except Exception as e: logger.warning(f"Error processing page {i+1}: {str(e)}") continue return results if results else ["Could not extract text from PDF"] except Exception as e: logger.error(f"PDF processing error: {e}") return [f"PDF processing error: {str(e)}"] def excel_to_json(file_path: str) -> List[Dict]: engines = ['openpyxl', 'xlrd'] for engine in engines: try: with pd.ExcelFile(file_path, engine=engine) as excel_file: sheets = excel_file.sheet_names if not sheets: return [{"error": "No sheets found"}] results = [] for sheet_name in sheets[:3]: # Limit to first 3 sheets try: df = pd.read_excel( excel_file, sheet_name=sheet_name, header=None, dtype=str, na_filter=False, nrows=MAX_ROWS_TO_PROCESS # Limit rows ) if not df.empty: rows = df.head(MAX_ROWS_TO_PROCESS).values.tolist() results.append({ "filename": os.path.basename(file_path), "sheet": sheet_name, "rows": rows, "type": "excel" }) except Exception as e: logger.warning(f"Error processing sheet {sheet_name}: {str(e)}") continue return results if results else [{"error": "No readable data found"}] except Exception as e: logger.warning(f"Excel engine {engine} failed: {str(e)}") continue return [{"error": "Could not process Excel file with any engine"}] def csv_to_json(file_path: str) -> List[Dict]: try: df = pd.read_csv( file_path, header=None, dtype=str, encoding_errors='replace', on_bad_lines='skip', nrows=MAX_ROWS_TO_PROCESS # Limit rows ) if df.empty: return [{"error": "CSV file is empty"}] return [{ "filename": os.path.basename(file_path), "rows": df.values.tolist(), "type": "csv" }] except Exception as e: logger.error(f"CSV processing error: {e}") return [{"error": f"CSV processing error: {str(e)}"}] def process_file_cached(file_path: str, file_type: str) -> List[Dict]: try: logger.info(f"Processing {file_type} file: {os.path.basename(file_path)}") if file_type == "pdf": chunks = extract_all_pages(file_path) return [{ "filename": os.path.basename(file_path), "content": chunk, "type": "pdf" } for chunk in chunks] elif file_type in ["xls", "xlsx"]: return excel_to_json(file_path) elif file_type == "csv": return csv_to_json(file_path) return [{"error": f"Unsupported file type: {file_type}"}] except Exception as e: logger.error(f"Error processing file: {e}") return [{"error": f"Error processing file: {str(e)}"}] def clean_response(text: str) -> str: if not text: return "" patterns = [ (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""), (re.compile(r"\s+"), " "), ] for pattern, repl in patterns: text = pattern.sub(repl, text) return text.strip() @lru_cache(maxsize=1) def init_agent(): logger.info("Initializing model...") agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={"new_tool": os.path.join(tool_cache_dir, "new_tool.json")}, force_finish=True, enable_checker=False, step_rag_num=4, seed=100, ) agent.init_model() logger.info("Agent Ready") return agent def create_ui(agent): PROMPT_TEMPLATE = """ Analyze this patient record excerpt for missed diagnoses (limit response to 500 tokens): {chunk} """ with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("

🩺 Clinical Oversight Assistant

") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(label="Analysis", height=500, type="messages") msg_input = gr.Textbox(placeholder="Ask about potential oversights...") send_btn = gr.Button("Analyze", variant="primary") file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="single") with gr.Column(scale=1): final_summary = gr.Markdown("## Summary") status = gr.Textbox(label="Status", interactive=False) def analyze(message: str, history: List[Dict], file_obj) -> tuple: try: if not file_obj: return history, "Please upload a file first", "No file uploaded" file_path = file_obj.name file_type = os.path.splitext(file_path)[-1].lower().replace(".", "") history.append({"role": "user", "content": message}) # Process file processed = process_file_cached(file_path, file_type) if "error" in processed[0]: history.append({"role": "assistant", "content": processed[0]["error"]}) return history, processed[0]["error"], "File processing failed" # Prepare chunks chunks = [] for item in processed: if "content" in item: chunks.append(item["content"]) elif "rows" in item: rows_text = "\n".join([", ".join(map(str, row)) for row in item["rows"][:100]]) chunks.append(f"=== {item.get('sheet', 'Data')} ===\n{rows_text}") if not chunks: history.append({"role": "assistant", "content": "No processable content found."}) return history, "No processable content found", "Content extraction failed" # Analyze each chunk responses = [] for i, chunk in enumerate(chunks[:5]): try: prompt = PROMPT_TEMPLATE.format(chunk=chunk[:5000]) response = agent.run_quick_summary(prompt, 0.2, 256, 500) cleaned = clean_response(response) if cleaned: responses.append(f"Analysis {i+1}:\n{cleaned}") except Exception as e: logger.warning(f"Error analyzing chunk {i+1}: {str(e)}") continue if not responses: history.append({"role": "assistant", "content": "No valid analysis generated."}) return history, "No valid analysis generated", "Analysis failed" summary = "\n\n".join(responses) history.append({"role": "assistant", "content": summary}) return history, "Analysis completed", "Success" except Exception as e: logger.error(f"Analysis error: {e}") history.append({"role": "assistant", "content": f"Error: {str(e)}"}) return history, f"Error: {str(e)}", "Failed" finally: torch.cuda.empty_cache() gc.collect() send_btn.click( analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, final_summary, status] ) msg_input.submit( analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, final_summary, status] ) return demo if __name__ == "__main__": try: agent = init_agent() demo = create_ui(agent) demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: logger.error(f"Fatal error: {e}") raise