CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
63d0c23 verified
raw
history blame
13.6 kB
import sys
import os
import pandas as pd
import pdfplumber
import json
import gradio as gr
from typing import List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import psutil
import subprocess
from datetime import datetime
import tiktoken
# Persistent directory setup
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(directory, exist_ok=True)
# Environment variables
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Add src to path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
# Constants
MEDICAL_KEYWORDS = {
'diagnosis', 'assessment', 'plan', 'results', 'medications',
'allergies', 'summary', 'impression', 'findings', 'recommendations',
'conclusion', 'history', 'examination', 'progress', 'discharge'
}
TOKENIZER = "cl100k_base"
# Increase max model length to support larger contexts
MAX_MODEL_LEN = 4096
# Default chunk target tokens
TARGET_CHUNK_TOKENS = 1200
PROMPT_RESERVE = 100
MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
def log_system_usage(tag=""):
try:
cpu = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
print(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
capture_output=True, text=True
)
if result.returncode == 0:
used, total, util = result.stdout.strip().split(", ")
print(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
except Exception as e:
print(f"[{tag}] GPU/CPU monitor failed: {e}")
def sanitize_utf8(text: str) -> str:
return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path: str) -> str:
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def count_tokens(text: str) -> int:
encoding = tiktoken.get_encoding(TOKENIZER)
return len(encoding.encode(text))
def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
try:
text_chunks = []
total_pages = 0
total_tokens = 0
with pdfplumber.open(file_path) as pdf:
total_pages = len(pdf.pages)
for i, page in enumerate(pdf.pages):
page_text = page.extract_text() or ""
lower_text = page_text.lower()
header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n" if any(
re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS
) else f"\n=== Page {i+1} ===\n"
text_chunks.append(header + page_text.strip())
total_tokens += count_tokens(header) + count_tokens(page_text)
return "\n".join(text_chunks), total_pages, total_tokens
except Exception as e:
return f"PDF processing error: {str(e)}", 0, 0
def convert_file_to_json(file_path: str, file_type: str) -> str:
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path):
return open(cache_path, "r", encoding="utf-8").read()
if file_type == "pdf":
text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
result = json.dumps({
"filename": os.path.basename(file_path),
"content": text,
"total_pages": total_pages,
"total_tokens": total_tokens,
"status": "complete"
})
elif file_type == "csv":
chunks = []
for chunk in pd.read_csv(
file_path, encoding_errors="replace", header=None, dtype=str,
skip_blank_lines=False, on_bad_lines="skip", chunksize=1000
):
chunks.append(chunk.fillna("").astype(str).values.tolist())
content = [item for sub in chunks for item in sub]
result = json.dumps({
"filename": os.path.basename(file_path),
"rows": content,
"total_tokens": count_tokens(str(content))
})
elif file_type in ["xls", "xlsx"]:
try:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
except:
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
content = df.fillna("" ).astype(str).values.tolist()
result = json.dumps({
"filename": os.path.basename(file_path),
"rows": content,
"total_tokens": count_tokens(str(content))
})
else:
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f:
f.write(result)
return result
except Exception as e:
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
def clean_response(text: str) -> str:
text = sanitize_utf8(text)
patterns = [
r"\[TOOL_CALLS\].*",
r"\['get_[^\]]+\']\n?",
r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?",
r"To analyze the medical records for clinical oversights.*?\n"
]
for pat in patterns:
text = re.sub(pat, "", text, flags=re.DOTALL)
return re.sub(r"\n{3,}", "\n\n", text).strip()
def format_final_report(analysis_results: List[str], filename: str) -> str:
report = [
"COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS",
f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
f"File: {filename}",
"=" * 80
]
sections = {s: [] for s in [
"CRITICAL FINDINGS", "MISSED DIAGNOSES", "MEDICATION ISSUES",
"ASSESSMENT GAPS", "FOLLOW-UP RECOMMENDATIONS"
]}
for res in analysis_results:
for sec in sections:
m = re.search(
rf"{re.escape(sec)}:?\s*
(.+?)(?=
\*|
|$)",
res, re.IGNORECASE | re.DOTALL
)
if m:
content = m.group(1).strip()
if content and content not in sections[sec]:
sections[sec].append(content)
if sections["CRITICAL FINDINGS"]:
report.append("\n๐Ÿšจ **CRITICAL FINDINGS** ๐Ÿšจ")
report.extend(f"\n{c}" for c in sections["CRITICAL FINDINGS"])
for sec, conts in sections.items():
if sec != "CRITICAL FINDINGS" and conts:
report.append(f"\n**{sec}**")
report.extend(f"\n{c}" for c in conts)
if not any(sections.values()):
report.append("\nNo significant clinical oversights identified.")
report.append("\n" + "="*80)
report.append("END OF REPORT")
return "\n".join(report)
def split_content_by_tokens(content: str, max_tokens: int) -> List[str]:
paragraphs = re.split(r"\n\s*\n", content)
chunks, current, curr_toks = [], [], 0
for para in paragraphs:
toks = count_tokens(para)
if toks > max_tokens:
for sent in re.split(r'(?<=[.!?])\s+', para):
sent_toks = count_tokens(sent)
if curr_toks + sent_toks > max_tokens:
chunks.append("\n\n".join(current))
current, curr_toks = [sent], sent_toks
else:
current.append(sent)
curr_toks += sent_toks
elif curr_toks + toks > max_tokens:
chunks.append("\n\n".join(current))
current, curr_toks = [para], toks
else:
current.append(para)
curr_toks += toks
if current:
chunks.append("\n\n".join(current))
return chunks
def init_agent():
print("๐Ÿ” Initializing model...")
log_system_usage("Before Load")
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=2,
seed=100,
additional_default_tools=[]
)
agent.init_model()
log_system_usage("After Load")
print("โœ… Agent Ready")
return agent
def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
base_prompt = (
"Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
)
prompt_toks = count_tokens(base_prompt)
max_chunk_toks = MAX_MODEL_LEN - prompt_toks - PROMPT_RESERVE
chunks = split_content_by_tokens(content, max_chunk_toks)
results = []
for i, chunk in enumerate(chunks):
try:
prompt = base_prompt + chunk
response = ""
for out in agent.run_gradio_chat(
message=prompt,
history=[],
temperature=temperature,
max_new_tokens=300,
max_token=MAX_MODEL_LEN,
call_agent=False,
conversation=[]
):
if out:
if isinstance(out, list):
for m in out:
response += clean_response(m.content if hasattr(m, 'content') else str(m))
else:
response += clean_response(str(out))
if response:
results.append(response)
except Exception as e:
print(f"Error processing chunk {i}: {e}")
return format_final_report(results, filename)
def create_ui(agent):
with gr.Blocks(title="Clinical Oversight Assistant") as demo:
gr.Markdown("""
# ๐Ÿฉบ Clinical Oversight Assistant
Analyze medical records for potential oversights and generate comprehensive reports
""")
with gr.Row():
with gr.Column():
file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
msg_input = gr.Textbox(label="Analysis Focus (optional)")
temperature = gr.Slider(0.1, 1.0, value=0.3, label="Analysis Strictness")
send_btn = gr.Button("Analyze Documents", variant="primary")
clear_btn = gr.Button("Clear All")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column():
report_output = gr.Textbox(label="Report", lines=20, interactive=False)
data_preview = gr.Dataframe(headers=["File", "Snippet"], interactive=False)
download_output = gr.File(label="Download Report")
def analyze(files, msg, temp):
if not files:
yield "", None, "โš ๏ธ Please upload files.", None
return
yield "", None, "โณ Processing...", None
previews = []
contents = []
for f in files:
res = json.loads(sanitize_utf8(convert_file_to_json(f.name, os.path.splitext(f.name)[1][1:].lower())))
if "content" in res:
previews.append([res["filename"], res["content"][:200] + "..."])
contents.append(res["content"])
yield "", None, f"๐Ÿ” Analyzing {len(contents)} docs...", previews
combined = "\n".join(contents)
report = analyze_complete_document(combined, "+".join([os.path.basename(f.name) for f in files]), agent, temp)
file_hash_val = hashlib.md5(combined.encode()).hexdigest()
path = os.path.join(report_dir, f"{file_hash_val}_report.txt")
with open(path, "w", encoding="utf-8") as rd:
rd.write(report)
yield report, path, "โœ… Analysis complete!", previews
send_btn.click(analyze, [file_upload, msg_input, temperature], [report_output, download_output, status, data_preview])
clear_btn.click(lambda: (None, None, "", None), None, [report_output, download_output, status, data_preview])
return demo
if __name__ == "__main__":
print("๐Ÿš€ Launching app...")
try:
import tiktoken
except ImportError:
subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
agent = init_agent()
demo = create_ui(agent)
demo.queue(api_open=False, max_size=20).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
share=False,
allowed_paths=[report_dir]
)