CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
12ddaba verified
raw
history blame
22.7 kB
import sys
import os
import pandas as pd
import pdfplumber
import json
import gradio as gr
from typing import List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import psutil
import subprocess
from datetime import datetime
import tiktoken
# Persistent directory setup
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(directory, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
# Constants
MEDICAL_KEYWORDS = {
'diagnosis', 'assessment', 'plan', 'results', 'medications',
'allergies', 'summary', 'impression', 'findings', 'recommendations',
'conclusion', 'history', 'examination', 'progress', 'discharge'
}
TOKENIZER = "cl100k_base"
MAX_MODEL_LEN = 2048
TARGET_CHUNK_TOKENS = 1000 # Reduced from 1200 to be more conservative
PROMPT_RESERVE = 400 # Increased buffer for prompt + response
MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
def log_system_usage(tag=""):
try:
cpu = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
print(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
capture_output=True, text=True
)
if result.returncode == 0:
used, total, util = result.stdout.strip().split(", ")
print(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
except Exception as e:
print(f"[{tag}] GPU/CPU monitor failed: {e}")
def sanitize_utf8(text: str) -> str:
return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path: str) -> str:
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def count_tokens(text: str) -> int:
encoding = tiktoken.get_encoding(TOKENIZER)
return len(encoding.encode(text))
def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
try:
text_chunks = []
total_pages = 0
total_tokens = 0
with pdfplumber.open(file_path) as pdf:
total_pages = len(pdf.pages)
for i, page in enumerate(pdf.pages):
page_text = page.extract_text() or ""
lower_text = page_text.lower()
if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
section_header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n"
text_chunks.append(section_header + page_text.strip())
total_tokens += count_tokens(section_header)
else:
text_chunks.append(f"\n=== Page {i+1} ===\n{page_text.strip()}")
total_tokens += count_tokens(page_text)
return "\n".join(text_chunks), total_pages, total_tokens
except Exception as e:
return f"PDF processing error: {str(e)}", 0, 0
def convert_file_to_json(file_path: str, file_type: str) -> str:
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
return f.read()
if file_type == "pdf":
text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
result = json.dumps({
"filename": os.path.basename(file_path),
"content": text,
"total_pages": total_pages,
"total_tokens": total_tokens,
"status": "complete"
})
elif file_type == "csv":
chunks = []
for chunk in pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
chunks.append(chunk.fillna("").astype(str).values.tolist())
content = [item for sublist in chunks for item in sublist]
result = json.dumps({
"filename": os.path.basename(file_path),
"rows": content,
"total_tokens": count_tokens(str(content))
})
elif file_type in ["xls", "xlsx"]:
try:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
except Exception:
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
content = df.fillna("").astype(str).values.tolist()
result = json.dumps({
"filename": os.path.basename(file_path),
"rows": content,
"total_tokens": count_tokens(str(content))
})
else:
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f:
f.write(result)
return result
except Exception as e:
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
def clean_response(text: str) -> str:
text = sanitize_utf8(text)
text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
text = re.sub(r"\['get_[^\]]+\']\n?", "", text)
text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
text = re.sub(r"To analyze the medical records for clinical oversights.*?begin by reviewing.*?\n", "", text, flags=re.DOTALL)
text = re.sub(r"\n{3,}", "\n\n", text).strip()
return text
def format_final_report(analysis_results: List[str], filename: str) -> str:
report = []
report.append(f"COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS")
report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report.append(f"File: {filename}")
report.append("=" * 80)
sections = {
"CRITICAL FINDINGS": [],
"MISSED DIAGNOSES": [],
"MEDICATION ISSUES": [],
"ASSESSMENT GAPS": [],
"FOLLOW-UP RECOMMENDATIONS": []
}
for result in analysis_results:
for section in sections:
section_match = re.search(
rf"{re.escape(section)}:?\s*\n([^*]+?)(?=\n\*|\n\n|$)",
result,
re.IGNORECASE | re.DOTALL
)
if section_match:
content = section_match.group(1).strip()
if content and content not in sections[section]:
sections[section].append(content)
if sections["CRITICAL FINDINGS"]:
report.append("\n๐Ÿšจ **CRITICAL FINDINGS** ๐Ÿšจ")
for content in sections["CRITICAL FINDINGS"]:
report.append(f"\n{content}")
for section, contents in sections.items():
if section != "CRITICAL FINDINGS" and contents:
report.append(f"\n**{section.upper()}**")
for content in contents:
report.append(f"\n{content}")
if not any(sections.values()):
report.append("\nNo significant clinical oversights identified.")
report.append("\n" + "=" * 80)
report.append("END OF REPORT")
return "\n".join(report)
def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS) -> List[str]:
"""More conservative splitting that ensures we stay well under token limits"""
paragraphs = re.split(r"\n\s*\n", content)
chunks = []
current_chunk = []
current_tokens = 0
for para in paragraphs:
para_tokens = count_tokens(para)
# If paragraph is too big, split into sentences
if para_tokens > max_tokens * 0.8: # Don't allow paragraphs that take up most of the chunk
sentences = re.split(r'(?<=[.!?])\s+', para)
for sent in sentences:
sent_tokens = count_tokens(sent)
if current_tokens + sent_tokens > max_tokens * 0.9: # Leave 10% buffer
if current_chunk: # Only add if we have content
chunks.append("\n\n".join(current_chunk))
current_chunk = []
current_tokens = 0
# If single sentence is too long, split into words
if sent_tokens > max_tokens * 0.8:
words = sent.split()
for word in words:
word_tokens = count_tokens(word)
if current_tokens + word_tokens > max_tokens * 0.9:
if current_chunk:
chunks.append("\n\n".join(current_chunk))
current_chunk = []
current_tokens = 0
current_chunk.append(word)
current_tokens += word_tokens
else:
current_chunk.append(sent)
current_tokens += sent_tokens
else:
current_chunk.append(sent)
current_tokens += sent_tokens
elif current_tokens + para_tokens > max_tokens * 0.9:
chunks.append("\n\n".join(current_chunk))
current_chunk = [para]
current_tokens = para_tokens
else:
current_chunk.append(para)
current_tokens += para_tokens
if current_chunk:
chunks.append("\n\n".join(current_chunk))
return chunks
def init_agent():
print("๐Ÿ” Initializing model...")
log_system_usage("Before Load")
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=2,
seed=100,
additional_default_tools=[],
)
agent.init_model()
log_system_usage("After Load")
print("โœ… Agent Ready")
return agent
def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
"""Analyze complete document with strict token management"""
chunks = split_content_by_tokens(content)
analysis_results = []
for i, chunk in enumerate(chunks):
try:
# Minimal prompt template
base_prompt = """Analyze this medical content for:
1. Critical findings needing immediate attention
2. Potential missed diagnoses
3. Medication issues
4. Assessment gaps
5. Follow-up recommendations
Content:\n"""
# Calculate available space
prompt_tokens = count_tokens(base_prompt)
max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 300 # 300 tokens for response
# Ensure chunk fits
chunk_tokens = count_tokens(chunk)
if chunk_tokens > max_content_tokens:
# If still too big after splitting, truncate
encoding = tiktoken.get_encoding(TOKENIZER)
tokens = encoding.encode(chunk)
chunk = encoding.decode(tokens[:max_content_tokens])
print(f"Warning: Truncated chunk {i} from {chunk_tokens} to {max_content_tokens} tokens")
prompt = base_prompt + chunk
# Final verification
total_tokens = count_tokens(prompt)
if total_tokens > MAX_MODEL_LEN - 200:
encoding = tiktoken.get_encoding(TOKENIZER)
tokens = encoding.encode(prompt)
prompt = encoding.decode(tokens[:MAX_MODEL_LEN - 200])
print(f"Warning: Truncated final prompt from {total_tokens} tokens")
response = ""
for output in agent.run_gradio_chat(
message=prompt,
history=[],
temperature=temperature,
max_new_tokens=200, # Conservative response length
max_token=MAX_MODEL_LEN,
call_agent=False,
conversation=[],
):
if output:
if isinstance(output, list):
for m in output:
if hasattr(m, 'content'):
response += clean_response(m.content)
elif isinstance(output, str):
response += clean_response(output)
if response:
analysis_results.append(response)
except Exception as e:
print(f"Error processing chunk {i}: {str(e)}")
continue
return format_final_report(analysis_results, filename)
def create_ui(agent):
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
spacing_size="md",
radius_size="md"
),
title="Clinical Oversight Assistant",
css="""
.report-box {
border: 1px solid #e0e0e0;
border-radius: 8px;
padding: 16px;
background-color: #f9f9f9;
}
.file-upload {
background-color: #f5f7fa;
padding: 16px;
border-radius: 8px;
}
.analysis-btn {
width: 100%;
}
.critical-finding {
color: #d32f2f;
font-weight: bold;
}
.dataframe-container {
height: 600px;
overflow-y: auto;
}
"""
) as demo:
gr.Markdown("""
<div style='text-align: center; margin-bottom: 20px;'>
<h1 style='color: #2b3a67; margin-bottom: 8px;'>๐Ÿฉบ Clinical Oversight Assistant</h1>
<p style='color: #5a6a8a; font-size: 16px;'>
Analyze medical records for potential oversights and generate comprehensive reports
</p>
</div>
""")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=400):
with gr.Group(elem_classes="file-upload"):
file_upload = gr.File(
file_types=[".pdf", ".csv", ".xls", ".xlsx"],
file_count="multiple",
label="Upload Medical Records",
elem_id="file-upload"
)
with gr.Row():
clear_btn = gr.Button("Clear All", size="sm")
send_btn = gr.Button(
"Analyze Documents",
variant="primary",
elem_classes="analysis-btn"
)
with gr.Accordion("Additional Options", open=False):
msg_input = gr.Textbox(
placeholder="Enter specific focus areas or questions...",
label="Analysis Focus",
lines=3
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
label="Analysis Strictness"
)
status = gr.Textbox(
label="Processing Status",
interactive=False,
visible=True
)
with gr.Column(scale=2, min_width=600):
with gr.Tabs():
with gr.TabItem("Analysis Report", id="report"):
report_output = gr.Textbox(
label="Clinical Oversight Findings",
lines=25,
max_lines=50,
interactive=False,
elem_classes="report-box"
)
with gr.TabItem("Raw Data Preview", id="preview"):
with gr.Column(elem_classes="dataframe-container"):
data_preview = gr.Dataframe(
headers=["Page", "Content"],
datatype=["str", "str"],
interactive=False
)
with gr.Row():
download_output = gr.File(
label="Download Full Report",
visible=True,
interactive=False
)
gr.Button("Save to EHR", visible=False)
def analyze(files: List, message: str, temp: float):
if not files:
return (
{"value": "", "visible": True},
None,
{"value": "โš ๏ธ Please upload at least one file to analyze.", "visible": True},
{"value": None, "visible": True}
)
yield (
{"value": "", "visible": True},
None,
{"value": "โณ Processing documents...", "visible": True},
{"value": None, "visible": True}
)
file_contents = []
filenames = []
preview_data = []
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for f in files:
file_path = f.name
futures.append(executor.submit(
convert_file_to_json,
file_path,
os.path.splitext(file_path)[1][1:].lower()
))
filenames.append(os.path.basename(file_path))
results = []
for future in as_completed(futures):
result = sanitize_utf8(future.result())
try:
data = json.loads(result)
results.append(data)
if "content" in data:
preview_data.append([data["filename"], data["content"][:500] + "..."])
except Exception as e:
print(f"Error processing result: {e}")
continue
yield (
{"value": "", "visible": True},
None,
{"value": f"๐Ÿ” Analyzing {len(files)} documents...", "visible": True},
{"value": preview_data[:20], "visible": True}
)
try:
combined_content = "\n".join([
item.get("content", "") if isinstance(item, dict) and "content" in item
else str(item.get("rows", "")) if isinstance(item, dict)
else str(item)
for item in results
])
full_report = analyze_complete_document(
combined_content,
" + ".join(filenames),
agent,
temperature=temp
)
file_hash_value = hashlib.md5(combined_content.encode()).hexdigest()
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
with open(report_path, "w", encoding="utf-8") as f:
f.write(full_report)
yield (
{"value": full_report, "visible": True},
report_path if os.path.exists(report_path) else None,
{"value": "โœ… Analysis complete!", "visible": True},
{"value": preview_data[:20], "visible": True}
)
except Exception as e:
error_msg = f"โŒ Error during analysis: {str(e)}"
print(error_msg)
yield (
{"value": "", "visible": True},
None,
{"value": error_msg, "visible": True},
{"value": None, "visible": True}
)
send_btn.click(
fn=analyze,
inputs=[file_upload, msg_input, temperature],
outputs=[report_output, download_output, status, data_preview],
api_name="analyze"
)
clear_btn.click(
fn=lambda: (
None,
None,
"",
None,
{"value": 0.3},
{"value": ""}
),
inputs=None,
outputs=[file_upload, download_output, status, data_preview, temperature, msg_input]
)
return demo
if __name__ == "__main__":
print("๐Ÿš€ Launching app...")
try:
import tiktoken
except ImportError:
print("Installing tiktoken...")
subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
agent = init_agent()
demo = create_ui(agent)
demo.queue(
api_open=False,
max_size=20
).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=[report_dir],
share=False
)