CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
4b24a59 verified
raw
history blame
7.01 kB
# ✅ Fully optimized app.py for Hugging Face Space with persistent 150GB storage
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Thread
# Use /data for persistent HF storage
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "txagent_models")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
file_cache_dir = os.path.join(base_dir, "cache")
report_dir = os.path.join(base_dir, "reports")
vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(d, exist_ok=True)
# Set persistent HF + VLLM cache
os.environ.update({
"HF_HOME": model_cache_dir,
"TRANSFORMERS_CACHE": model_cache_dir,
"VLLM_CACHE_DIR": vllm_cache_dir,
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
# Force local loading only
LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}
def extract_priority_pages(file_path, max_pages=20):
try:
with pdfplumber.open(file_path) as pdf:
pages = []
for i, page in enumerate(pdf.pages[:3]):
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
text = page.extract_text() or ""
if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
pages.append(f"=== Page {i} ===\n{text.strip()}")
return "\n\n".join(pages)
except Exception as e:
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path, file_type):
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
Thread(target=full_pdf_processing, args=(file_path, h)).start()
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
elif file_type in ["xls", "xlsx"]:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
else:
return json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
return result
except Exception as e:
return json.dumps({"error": str(e)})
def full_pdf_processing(file_path, h):
try:
cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
if os.path.exists(cache_path): return
with pdfplumber.open(file_path) as pdf:
full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
except: pass
def init_agent():
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)
agent = TxAgent(
model_name=LOCAL_TXAGENT_PATH,
rag_model_name=LOCAL_RAG_PATH,
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100
)
agent.init_model()
return agent
agent_container = {"agent": None}
def get_agent():
if agent_container["agent"] is None:
agent_container["agent"] = init_agent()
return agent_container["agent"]
def create_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""<h1 style='text-align:center;'>🩺 Clinical Oversight Assistant</h1>""")
chatbot = gr.Chatbot(label="Analysis", height=600)
msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
send_btn = gr.Button("Analyze", variant="primary")
state = gr.State([])
def analyze(message, history, conversation, files):
try:
extracted, hval = "", ""
if files:
with ThreadPoolExecutor(max_workers=3) as pool:
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
hval = file_hash(files[0].name)
prompt = f"""Review these medical records and identify exactly what might have been missed:
1. Missed diagnoses
2. Medication conflicts
3. Incomplete assessments
4. Abnormal results needing follow-up
Medical Records:\n{extracted[:15000]}
"""
final_response = ""
for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
if isinstance(chunk, str): final_response += chunk
elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
updated_history = history + [[message, cleaned]]
return updated_history, None
except Exception as e:
return history + [[message, f"❌ Error: {str(e)}"]], None
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
return demo
if __name__ == "__main__":
ui = create_ui()
ui.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=["/data/reports"],
share=False
)