|
import sys |
|
import os |
|
import pandas as pd |
|
import pdfplumber |
|
import gradio as gr |
|
from typing import List, Dict |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import hashlib |
|
import shutil |
|
import re |
|
import logging |
|
import torch |
|
import gc |
|
from diskcache import Cache |
|
from transformers import AutoTokenizer |
|
from functools import lru_cache |
|
from difflib import SequenceMatcher |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MAX_TOKENS = 1800 |
|
BATCH_SIZE = 1 |
|
MAX_WORKERS = 2 |
|
CHUNK_SIZE = 5 |
|
MODEL_MAX_TOKENS = 131072 |
|
MAX_TEXT_LENGTH = 500000 |
|
MAX_ROWS_TO_PROCESS = 1000 |
|
|
|
|
|
persistent_dir = "/data/hf_cache" |
|
os.makedirs(persistent_dir, exist_ok=True) |
|
|
|
model_cache_dir = os.path.join(persistent_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") |
|
file_cache_dir = os.path.join(persistent_dir, "cache") |
|
report_dir = os.path.join(persistent_dir, "reports") |
|
os.makedirs(report_dir, exist_ok=True) |
|
|
|
os.environ.update({ |
|
"HF_HOME": model_cache_dir, |
|
"TOKENIZERS_PARALLELISM": "false", |
|
}) |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
|
|
cache = Cache(file_cache_dir, size_limit=10 * 1024**3) |
|
|
|
@lru_cache(maxsize=1) |
|
def get_tokenizer(): |
|
return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") |
|
|
|
def sanitize_utf8(text: str) -> str: |
|
return text.encode("utf-8", "ignore").decode("utf-8") |
|
|
|
def file_hash(path: str) -> str: |
|
hash_md5 = hashlib.md5() |
|
with open(path, "rb") as f: |
|
for chunk in iter(lambda: f.read(4096), b""): |
|
hash_md5.update(chunk) |
|
return hash_md5.hexdigest() |
|
|
|
def extract_pdf_page(page, tokenizer, max_tokens=MAX_TOKENS) -> List[str]: |
|
try: |
|
text = page.extract_text() or "" |
|
text = sanitize_utf8(text) |
|
if len(text) > MAX_TEXT_LENGTH // 10: |
|
text = text[:MAX_TEXT_LENGTH // 10] |
|
|
|
tokens = tokenizer.encode(text, add_special_tokens=False) |
|
if len(tokens) > max_tokens: |
|
chunks = [] |
|
current_chunk = [] |
|
current_length = 0 |
|
for token in tokens: |
|
if current_length + 1 > max_tokens: |
|
chunks.append(tokenizer.decode(current_chunk)) |
|
current_chunk = [token] |
|
current_length = 1 |
|
else: |
|
current_chunk.append(token) |
|
current_length += 1 |
|
if current_chunk: |
|
chunks.append(tokenizer.decode(current_chunk)) |
|
return chunks |
|
return [text] |
|
except Exception as e: |
|
logger.warning(f"Error extracting page {page.page_number}: {str(e)}") |
|
return [] |
|
|
|
def extract_all_pages(file_path: str) -> List[str]: |
|
try: |
|
tokenizer = get_tokenizer() |
|
with pdfplumber.open(file_path) as pdf: |
|
total_pages = len(pdf.pages) |
|
if total_pages == 0: |
|
return ["PDF appears to be empty"] |
|
|
|
results = [] |
|
for i in range(0, min(total_pages, 50)): |
|
try: |
|
page = pdf.pages[i] |
|
chunks = extract_pdf_page(page, tokenizer) |
|
for chunk in chunks: |
|
results.append(f"=== Page {i+1} ===\n{chunk}") |
|
except Exception as e: |
|
logger.warning(f"Error processing page {i+1}: {str(e)}") |
|
continue |
|
|
|
return results if results else ["Could not extract text from PDF"] |
|
except Exception as e: |
|
logger.error(f"PDF processing error: {e}") |
|
return [f"PDF processing error: {str(e)}"] |
|
|
|
def excel_to_json(file_path: str) -> List[Dict]: |
|
engines = ['openpyxl', 'xlrd'] |
|
for engine in engines: |
|
try: |
|
with pd.ExcelFile(file_path, engine=engine) as excel_file: |
|
sheets = excel_file.sheet_names |
|
if not sheets: |
|
return [{"error": "No sheets found"}] |
|
|
|
results = [] |
|
for sheet_name in sheets[:3]: |
|
try: |
|
df = pd.read_excel( |
|
excel_file, |
|
sheet_name=sheet_name, |
|
header=None, |
|
dtype=str, |
|
na_filter=False, |
|
nrows=MAX_ROWS_TO_PROCESS |
|
) |
|
if not df.empty: |
|
rows = df.head(MAX_ROWS_TO_PROCESS).values.tolist() |
|
results.append({ |
|
"filename": os.path.basename(file_path), |
|
"sheet": sheet_name, |
|
"rows": rows, |
|
"type": "excel" |
|
}) |
|
except Exception as e: |
|
logger.warning(f"Error processing sheet {sheet_name}: {str(e)}") |
|
continue |
|
|
|
return results if results else [{"error": "No readable data found"}] |
|
except Exception as e: |
|
logger.warning(f"Excel engine {engine} failed: {str(e)}") |
|
continue |
|
|
|
return [{"error": "Could not process Excel file with any engine"}] |
|
|
|
def csv_to_json(file_path: str) -> List[Dict]: |
|
try: |
|
df = pd.read_csv( |
|
file_path, |
|
header=None, |
|
dtype=str, |
|
encoding_errors='replace', |
|
on_bad_lines='skip', |
|
nrows=MAX_ROWS_TO_PROCESS |
|
) |
|
if df.empty: |
|
return [{"error": "CSV file is empty"}] |
|
|
|
return [{ |
|
"filename": os.path.basename(file_path), |
|
"rows": df.values.tolist(), |
|
"type": "csv" |
|
}] |
|
except Exception as e: |
|
logger.error(f"CSV processing error: {e}") |
|
return [{"error": f"CSV processing error: {str(e)}"}] |
|
|
|
def process_file_cached(file_path: str, file_type: str) -> List[Dict]: |
|
try: |
|
logger.info(f"Processing {file_type} file: {os.path.basename(file_path)}") |
|
|
|
if file_type == "pdf": |
|
chunks = extract_all_pages(file_path) |
|
return [{ |
|
"filename": os.path.basename(file_path), |
|
"content": chunk, |
|
"type": "pdf" |
|
} for chunk in chunks] |
|
|
|
elif file_type in ["xls", "xlsx"]: |
|
return excel_to_json(file_path) |
|
|
|
elif file_type == "csv": |
|
return csv_to_json(file_path) |
|
|
|
return [{"error": f"Unsupported file type: {file_type}"}] |
|
except Exception as e: |
|
logger.error(f"Error processing file: {e}") |
|
return [{"error": f"Error processing file: {str(e)}"}] |
|
|
|
def clean_response(text: str) -> str: |
|
if not text: |
|
return "" |
|
|
|
patterns = [ |
|
(re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""), |
|
(re.compile(r"\s+"), " "), |
|
] |
|
|
|
for pattern, repl in patterns: |
|
text = pattern.sub(repl, text) |
|
|
|
return text.strip() |
|
|
|
@lru_cache(maxsize=1) |
|
def init_agent(): |
|
logger.info("Initializing model...") |
|
|
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={"new_tool": os.path.join(tool_cache_dir, "new_tool.json")}, |
|
force_finish=True, |
|
enable_checker=False, |
|
step_rag_num=4, |
|
seed=100, |
|
) |
|
agent.init_model() |
|
logger.info("Agent Ready") |
|
return agent |
|
|
|
def create_ui(agent): |
|
PROMPT_TEMPLATE = """ |
|
Analyze this patient record excerpt for missed diagnoses (limit response to 500 tokens): |
|
{chunk} |
|
""" |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot(label="Analysis", height=500, type="messages") |
|
msg_input = gr.Textbox(placeholder="Ask about potential oversights...") |
|
send_btn = gr.Button("Analyze", variant="primary") |
|
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="single") |
|
|
|
with gr.Column(scale=1): |
|
final_summary = gr.Markdown("## Summary") |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
def analyze(message: str, history: List[Dict], file_obj) -> tuple: |
|
try: |
|
if not file_obj: |
|
return history, "Please upload a file first", "No file uploaded" |
|
|
|
file_path = file_obj.name |
|
file_type = os.path.splitext(file_path)[-1].lower().replace(".", "") |
|
history.append({"role": "user", "content": message}) |
|
|
|
|
|
processed = process_file_cached(file_path, file_type) |
|
if "error" in processed[0]: |
|
history.append({"role": "assistant", "content": processed[0]["error"]}) |
|
return history, processed[0]["error"], "File processing failed" |
|
|
|
|
|
chunks = [] |
|
for item in processed: |
|
if "content" in item: |
|
chunks.append(item["content"]) |
|
elif "rows" in item: |
|
rows_text = "\n".join([", ".join(map(str, row)) for row in item["rows"][:100]]) |
|
chunks.append(f"=== {item.get('sheet', 'Data')} ===\n{rows_text}") |
|
|
|
if not chunks: |
|
history.append({"role": "assistant", "content": "No processable content found."}) |
|
return history, "No processable content found", "Content extraction failed" |
|
|
|
|
|
responses = [] |
|
for i, chunk in enumerate(chunks[:5]): |
|
try: |
|
prompt = PROMPT_TEMPLATE.format(chunk=chunk[:5000]) |
|
response = agent.run_quick_summary(prompt, 0.2, 256, 500) |
|
cleaned = clean_response(response) |
|
if cleaned: |
|
responses.append(f"Analysis {i+1}:\n{cleaned}") |
|
except Exception as e: |
|
logger.warning(f"Error analyzing chunk {i+1}: {str(e)}") |
|
continue |
|
|
|
if not responses: |
|
history.append({"role": "assistant", "content": "No valid analysis generated."}) |
|
return history, "No valid analysis generated", "Analysis failed" |
|
|
|
summary = "\n\n".join(responses) |
|
history.append({"role": "assistant", "content": summary}) |
|
return history, "Analysis completed", "Success" |
|
|
|
except Exception as e: |
|
logger.error(f"Analysis error: {e}") |
|
history.append({"role": "assistant", "content": f"Error: {str(e)}"}) |
|
return history, f"Error: {str(e)}", "Failed" |
|
finally: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
send_btn.click( |
|
analyze, |
|
inputs=[msg_input, chatbot, file_upload], |
|
outputs=[chatbot, final_summary, status] |
|
) |
|
|
|
msg_input.submit( |
|
analyze, |
|
inputs=[msg_input, chatbot, file_upload], |
|
outputs=[chatbot, final_summary, status] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
agent = init_agent() |
|
demo = create_ui(agent) |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
except Exception as e: |
|
logger.error(f"Fatal error: {e}") |
|
raise |
|
|