|
import sys |
|
import os |
|
import pandas as pd |
|
import pdfplumber |
|
import json |
|
import gradio as gr |
|
from typing import List, Dict, Optional, Generator |
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
import hashlib |
|
import shutil |
|
import re |
|
import psutil |
|
import subprocess |
|
import logging |
|
import torch |
|
import gc |
|
from diskcache import Cache |
|
import time |
|
from transformers import AutoTokenizer |
|
import pyarrow as pa |
|
import pyarrow.csv as pc |
|
import pyarrow.parquet as pq |
|
from vllm import LLM, SamplingParams |
|
import asyncio |
|
import threading |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
response_log_file = os.path.join("/data/hf_cache", "response_log.txt") |
|
response_logger = logging.getLogger("ResponseLogger") |
|
response_handler = logging.FileHandler(response_log_file, mode="a") |
|
response_handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s")) |
|
response_logger.addHandler(response_handler) |
|
response_logger.setLevel(logging.INFO) |
|
|
|
|
|
persistent_dir = "/data/hf_cache" |
|
os.makedirs(persistent_dir, exist_ok=True) |
|
|
|
model_cache_dir = os.path.join(persistent_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") |
|
file_cache_dir = os.path.join(persistent_dir, "cache") |
|
report_dir = os.path.join(persistent_dir, "reports") |
|
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache") |
|
|
|
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]: |
|
os.makedirs(directory, exist_ok=True) |
|
|
|
os.environ["HF_HOME"] = model_cache_dir |
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
|
|
cache = Cache(file_cache_dir, size_limit=10 * 1024**3) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B") |
|
|
|
def sanitize_utf8(text: str) -> str: |
|
return text.encode("utf-8", "ignore").decode("utf-8") |
|
|
|
def file_hash(path: str) -> str: |
|
with open(path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def extract_all_pages(file_path: str, progress_callback=None) -> str: |
|
cache_key = f"pdf_{file_hash(file_path)}" |
|
if cache_key in cache: |
|
return cache[cache_key] |
|
|
|
try: |
|
with pdfplumber.open(file_path) as pdf: |
|
total_pages = len(pdf.pages) |
|
if total_pages == 0: |
|
return "" |
|
|
|
batch_size = 5 |
|
batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)] |
|
text_chunks = [""] * total_pages |
|
processed_pages = 0 |
|
|
|
def extract_batch(start: int, end: int) -> List[tuple]: |
|
results = [] |
|
with pdfplumber.open(file_path) as pdf: |
|
for page in pdf.pages[start:end]: |
|
page_num = start + pdf.pages.index(page) |
|
page_text = page.extract_text_simple() or "" |
|
results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}")) |
|
return results |
|
|
|
with ProcessPoolExecutor(max_workers=4) as executor: |
|
futures = [executor.submit(extract_batch, start, end) for start, end in batches] |
|
for future in as_completed(futures): |
|
for page_num, text in future.result(): |
|
text_chunks[page_num] = text |
|
processed_pages += batch_size |
|
if progress_callback: |
|
progress_callback(min(processed_pages, total_pages), total_pages) |
|
|
|
result = "\n\n".join(filter(None, text_chunks)) |
|
cache[cache_key] = result |
|
return result |
|
except Exception as e: |
|
logger.error("PDF processing error: %s", e) |
|
return f"PDF processing error: {str(e)}" |
|
|
|
def excel_to_json(file_path: str) -> List[Dict]: |
|
cache_key = f"excel_{file_hash(file_path)}" |
|
if cache_key in cache: |
|
return cache[cache_key] |
|
|
|
try: |
|
table = pq.read_table(file_path) |
|
df = table.to_pandas(use_threads=True, split_blocks=True) |
|
content = df.where(pd.notnull(df), "").astype(str).values.tolist() |
|
result = [{ |
|
"filename": os.path.basename(file_path), |
|
"rows": content, |
|
"type": "excel" |
|
}] |
|
cache[cache_key] = result |
|
return result |
|
except Exception as e: |
|
logger.error(f"Error processing Excel file: {e}") |
|
return [{"error": f"Error processing Excel file: {str(e)}"}] |
|
|
|
def csv_to_json(file_path: str) -> List[Dict]: |
|
cache_key = f"csv_{file_hash(file_path)}" |
|
if cache_key in cache: |
|
return cache[cache_key] |
|
|
|
try: |
|
table = pc.read_csv(file_path, parse_options=pc.ParseOptions(invalid_row_handler=lambda x: "skip")) |
|
df = table.to_pandas(use_threads=True, split_blocks=True) |
|
content = df.where(pd.notnull(df), "").astype(str).values.tolist() |
|
result = [{ |
|
"filename": os.path.basename(file_path), |
|
"rows": content, |
|
"type": "csv" |
|
}] |
|
cache[cache_key] = result |
|
return result |
|
except Exception as e: |
|
logger.error(f"Error processing CSV file: {e}") |
|
return [{"error": f"Error processing CSV file: {str(e)}"}] |
|
|
|
def process_file(file_path: str, file_type: str) -> List[Dict]: |
|
try: |
|
if file_type == "pdf": |
|
text = extract_all_pages(file_path) |
|
return [{ |
|
"filename": os.path.basename(file_path), |
|
"content": text, |
|
"status": "initial", |
|
"type": "pdf" |
|
}] |
|
elif file_type in ["xls", "xlsx"]: |
|
return excel_to_json(file_path) |
|
elif file_type == "csv": |
|
return csv_to_json(file_path) |
|
else: |
|
return [{"error": f"Unsupported file type: {file_type}"}] |
|
except Exception as e: |
|
logger.error("Error processing %s: %s", os.path.basename(file_path), e) |
|
return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}] |
|
|
|
def tokenize_and_chunk(text: str, max_tokens: int = 800) -> List[str]: |
|
cache_key = f"tokens_{hashlib.md5(text.encode()).hexdigest()}" |
|
if cache_key in cache: |
|
return cache[cache_key] |
|
|
|
tokens = tokenizer.encode(text, add_special_tokens=False) |
|
chunks = [] |
|
for i in range(0, len(tokens), max_tokens): |
|
chunk_tokens = tokens[i:i + max_tokens] |
|
chunks.append(tokenizer.decode(chunk_tokens, skip_special_tokens=True)) |
|
cache[cache_key] = chunks |
|
return chunks |
|
|
|
def log_system_usage(tag=""): |
|
try: |
|
cpu = psutil.cpu_percent(interval=0.1) |
|
mem = psutil.virtual_memory() |
|
logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2)) |
|
result = subprocess.run( |
|
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"], |
|
capture_output=True, text=True |
|
) |
|
if result.returncode == 0: |
|
used, total, util = result.stdout.strip().split(", ") |
|
logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util) |
|
except Exception as e: |
|
logger.error("[%s] GPU/CPU monitor failed: %s", tag, e) |
|
|
|
def clean_response(text: str) -> str: |
|
text = sanitize_utf8(text) |
|
text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL) |
|
diagnoses = [] |
|
lines = text.splitlines() |
|
in_diagnoses_section = False |
|
for line in lines: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
if re.match(r"###\s*Missed Diagnoses", line): |
|
in_diagnoses_section = True |
|
continue |
|
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): |
|
in_diagnoses_section = False |
|
continue |
|
if in_diagnoses_section and re.match(r"-\s*.+", line): |
|
diagnosis = re.sub(r"^\-\s*", "", line).strip() |
|
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): |
|
diagnoses.append(diagnosis) |
|
text = " ".join(diagnoses) |
|
text = re.sub(r"\s+", " ", text).strip() |
|
text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text) |
|
return text if text else "" |
|
|
|
def summarize_findings(combined_response: str) -> str: |
|
chunks = combined_response.split("--- Analysis for Chunk") |
|
diagnoses = [] |
|
for chunk in chunks: |
|
chunk = chunk.strip() |
|
if not chunk or "No oversights identified" in chunk: |
|
continue |
|
lines = chunk.splitlines() |
|
in_diagnoses_section = False |
|
for line in lines: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
if re.match(r"###\s*Missed Diagnoses", line): |
|
in_diagnoses_section = True |
|
continue |
|
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line): |
|
in_diagnoses_section = False |
|
continue |
|
if in_diagnoses_section and re.match(r"-\s*.+", line): |
|
diagnosis = re.sub(r"^\-\s*", "", line).strip() |
|
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE): |
|
diagnoses.append(diagnosis) |
|
|
|
seen = set() |
|
unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))] |
|
|
|
if not unique_diagnoses: |
|
return "No missed diagnoses were identified in the provided records." |
|
|
|
summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1]) |
|
if len(unique_diagnoses) > 1: |
|
summary += f", and {unique_diagnoses[-1]}" |
|
elif len(unique_diagnoses) == 1: |
|
summary = "Missed diagnoses include " + unique_diagnoses[0] |
|
summary += ", all of which require urgent clinical review to prevent potential adverse outcomes." |
|
|
|
return summary.strip() |
|
|
|
def init_agent(): |
|
logger.info("Initializing model...") |
|
log_system_usage("Before Load") |
|
default_tool_path = os.path.abspath("data/new_tool.json") |
|
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
if not os.path.exists(target_tool_path): |
|
shutil.copy(default_tool_path, target_tool_path) |
|
|
|
llm = LLM( |
|
model="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
gpu_memory_utilization=0.8, |
|
max_model_len=2048, |
|
tensor_parallel_size=1, |
|
) |
|
sampling_params = SamplingParams( |
|
temperature=0.2, |
|
max_tokens=256, |
|
stop=["</s>", "[INST]"], |
|
) |
|
log_system_usage("After Load") |
|
logger.info("Agent Ready") |
|
return llm, sampling_params |
|
|
|
async def create_ui(llm, sampling_params): |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>") |
|
chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages") |
|
final_summary = gr.Markdown(label="Summary of Missed Diagnoses") |
|
file_upload = gr.File(file_types=["pdf", "csv", "xls", "xlsx"], file_count="multiple") |
|
msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False) |
|
send_btn = gr.Button("Analyze", variant="primary") |
|
download_output = gr.File(label="Download Full Report") |
|
progress_bar = gr.Progress() |
|
|
|
prompt_template = """ |
|
Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence. |
|
Patient Record Excerpt (Chunk {0} of {1}): |
|
{chunk} |
|
""" |
|
|
|
def log_response_partial(text: str): |
|
response_logger.info(text) |
|
|
|
async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()): |
|
history.append({"role": "user", "content": message}) |
|
yield history, None, "" |
|
|
|
extracted = [] |
|
file_hash_value = "" |
|
|
|
if files: |
|
with ProcessPoolExecutor(max_workers=4) as executor: |
|
futures = [] |
|
for f in files: |
|
file_type = f.name.split(".")[-1].lower() |
|
futures.append(executor.submit( |
|
process_file, |
|
f.name, |
|
file_type |
|
)) |
|
|
|
for future in as_completed(futures): |
|
try: |
|
extracted.extend(future.result()) |
|
except Exception as e: |
|
logger.error(f"File processing error: {e}") |
|
extracted.append({"error": f"Error processing file: {str(e)}"}) |
|
|
|
file_hash_value = file_hash(files[0].name) if files else "" |
|
history.append({"role": "assistant", "content": "✅ File processing complete"}) |
|
yield history, None, "" |
|
|
|
text_content = "\n".join(json.dumps(item) for item in extracted) |
|
chunks = tokenize_and_chunk(text_content) |
|
combined_response = "" |
|
batch_size = 1 |
|
|
|
try: |
|
for batch_idx in range(0, len(chunks), batch_size): |
|
batch_chunks = chunks[batch_idx:batch_idx + batch_size] |
|
batch_prompts = [ |
|
prompt_template.format( |
|
batch_idx + i + 1, |
|
len(chunks), |
|
chunk=chunk[:800] |
|
) |
|
for i, chunk in enumerate(batch_chunks) |
|
] |
|
|
|
progress((batch_idx) / len(chunks), |
|
desc=f"Analyzing batch {(batch_idx // batch_size) + 1}/{(len(chunks) + batch_size - 1) // batch_size}") |
|
|
|
with torch.no_grad(): |
|
for prompt in batch_prompts: |
|
chunk_response = "" |
|
current_response = "" |
|
stream = llm.generate([prompt], sampling_params, use_tqdm=False) |
|
for output in stream: |
|
for request_output in output: |
|
new_text = request_output.outputs[0].text[len(current_response):] |
|
if new_text: |
|
current_response += new_text |
|
cleaned = clean_response(current_response) |
|
if cleaned and cleaned != chunk_response: |
|
chunk_response = cleaned |
|
history[-1] = {"role": "assistant", "content": chunk_response} |
|
threading.Thread(target=log_response_partial, args=(chunk_response,)).start() |
|
yield history, None, "" |
|
await asyncio.sleep(0.01) |
|
|
|
if chunk_response: |
|
combined_response += f"--- Analysis for Chunk {batch_idx + 1} ---\n{chunk_response}\n" |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
summary = summarize_findings(combined_response) |
|
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None |
|
if report_path: |
|
with open(report_path, "w", encoding="utf-8") as f: |
|
f.write(combined_response + "\n\n" + summary) |
|
threading.Thread(target=log_response_partial, args=(summary,)).start() |
|
|
|
yield history, report_path if report_path and os.path.exists(report_path) else None, summary |
|
|
|
except Exception as e: |
|
logger.error("Analysis error: %s", e) |
|
history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"}) |
|
threading.Thread(target=log_response_partial, args=(f"Error: {str(e)}",)).start() |
|
yield history, None, f"Error occurred during analysis: {str(e)}" |
|
|
|
send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary], _js="() => {return {streaming: true}}") |
|
msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary], _js="() => {return {streaming: true}}") |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
logger.info("Launching app...") |
|
llm, sampling_params = init_agent() |
|
demo = asyncio.run(create_ui(llm, sampling_params)) |
|
demo.queue(api_open=False).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
allowed_paths=[report_dir], |
|
share=False |
|
) |
|
finally: |
|
if torch.distributed.is_initialized(): |
|
torch.distributed.destroy_process_group() |