CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
782e103 verified
raw
history blame
7.09 kB
# ✅ Fully optimized app.py for Hugging Face Space with persistent 150GB storage
import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Thread
# Use /data for persistent HF storage
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "txagent_models")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
file_cache_dir = os.path.join(base_dir, "cache")
report_dir = os.path.join(base_dir, "reports")
vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(d, exist_ok=True)
# Set persistent HF + VLLM cache
os.environ.update({
"HF_HOME": model_cache_dir,
"TRANSFORMERS_CACHE": model_cache_dir,
"VLLM_CACHE_DIR": vllm_cache_dir,
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
# Force local loading only
LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}
def extract_priority_pages(file_path, max_pages=20):
try:
with pdfplumber.open(file_path) as pdf:
pages = []
for i, page in enumerate(pdf.pages[:3]):
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
text = page.extract_text() or ""
if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
pages.append(f"=== Page {i} ===\n{text.strip()}")
return "\n\n".join(pages)
except Exception as e:
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path, file_type):
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
Thread(target=full_pdf_processing, args=(file_path, h)).start()
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
elif file_type in ["xls", "xlsx"]:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
else:
return json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
return result
except Exception as e:
return json.dumps({"error": str(e)})
def full_pdf_processing(file_path, h):
try:
cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
if os.path.exists(cache_path): return
with pdfplumber.open(file_path) as pdf:
full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
except: pass
def init_agent():
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)
agent = TxAgent(
model_name=LOCAL_TXAGENT_PATH,
rag_model_name=LOCAL_RAG_PATH,
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100
)
agent.init_model()
return agent
agent_container = {"agent": None}
def get_agent():
if agent_container["agent"] is None:
agent_container["agent"] = init_agent()
return agent_container["agent"]
def create_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""<h1 style='text-align:center;'>🩺 Clinical Oversight Assistant</h1>""")
chatbot = gr.Chatbot(label="Analysis", height=600)
msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
send_btn = gr.Button("Analyze", variant="primary")
state = gr.State([])
def analyze(message, history, conversation, files):
try:
extracted, hval = "", ""
if files:
with ThreadPoolExecutor(max_workers=3) as pool:
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
hval = file_hash(files[0].name)
prompt = f"""Review these medical records and identify exactly what might have been missed:
1. Missed diagnoses
2. Medication conflicts
3. Incomplete assessments
4. Abnormal results needing follow-up
Medical Records:\n{extracted[:15000]}
"""
final_response = ""
for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
if isinstance(chunk, str): final_response += chunk
elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
updated_history = history + [[message, cleaned]]
return updated_history, None
except Exception as e:
return history + [[message, f"❌ Error: {str(e)}"]], None
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
return demo
if __name__ == "__main__":
ui = create_ui()
ui.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=["/data/reports"],
share=False
)