File size: 7,005 Bytes
cf5094d dae38a2 1da2cfd dae38a2 cf5094d e24be23 65a2e99 f05e804 dae38a2 65a2e99 e24be23 cf5094d 1da2cfd 1fa1ea5 f05e804 1da2cfd dae38a2 cf5094d dae38a2 cf5094d dae38a2 1fa1ea5 cf5094d dae38a2 1fa1ea5 1da2cfd 1fa1ea5 1da2cfd 1fa1ea5 1da2cfd 1fa1ea5 cf5094d 1fa1ea5 1da2cfd e24be23 1fa1ea5 dae38a2 e24be23 1fa1ea5 dae38a2 1da2cfd 1ebbef1 1da2cfd cf5094d dae38a2 cf5094d dae38a2 1fa1ea5 dae38a2 cf5094d 1da2cfd cf5094d 1da2cfd cf5094d 1fa1ea5 1da2cfd cf5094d e24be23 cf5094d e24be23 cf5094d e24be23 1fa1ea5 e24be23 1fa1ea5 cf5094d d14e134 cf5094d 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 cf5094d 1fa1ea5 cf5094d 1fa1ea5 cf5094d 1da2cfd cf5094d b90a0eb cf5094d 1fa1ea5 cf5094d 1fa1ea5 782e103 cf5094d d14e134 782e103 1bb8be7 cf5094d dae38a2 e24be23 cf5094d 1fa1ea5 e24be23 e778114 782e103 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# ✅ Fully optimized app.py for Hugging Face Space with persistent 150GB storage
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Thread
# Use /data for persistent HF storage
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "txagent_models")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
file_cache_dir = os.path.join(base_dir, "cache")
report_dir = os.path.join(base_dir, "reports")
vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(d, exist_ok=True)
# Set persistent HF + VLLM cache
os.environ.update({
"HF_HOME": model_cache_dir,
"TRANSFORMERS_CACHE": model_cache_dir,
"VLLM_CACHE_DIR": vllm_cache_dir,
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
# Force local loading only
LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}
def extract_priority_pages(file_path, max_pages=20):
try:
with pdfplumber.open(file_path) as pdf:
pages = []
for i, page in enumerate(pdf.pages[:3]):
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
text = page.extract_text() or ""
if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
pages.append(f"=== Page {i} ===\n{text.strip()}")
return "\n\n".join(pages)
except Exception as e:
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path, file_type):
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
Thread(target=full_pdf_processing, args=(file_path, h)).start()
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
elif file_type in ["xls", "xlsx"]:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
else:
return json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
return result
except Exception as e:
return json.dumps({"error": str(e)})
def full_pdf_processing(file_path, h):
try:
cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
if os.path.exists(cache_path): return
with pdfplumber.open(file_path) as pdf:
full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
except: pass
def init_agent():
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)
agent = TxAgent(
model_name=LOCAL_TXAGENT_PATH,
rag_model_name=LOCAL_RAG_PATH,
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100
)
agent.init_model()
return agent
agent_container = {"agent": None}
def get_agent():
if agent_container["agent"] is None:
agent_container["agent"] = init_agent()
return agent_container["agent"]
def create_ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""<h1 style='text-align:center;'>🩺 Clinical Oversight Assistant</h1>""")
chatbot = gr.Chatbot(label="Analysis", height=600)
msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
send_btn = gr.Button("Analyze", variant="primary")
state = gr.State([])
def analyze(message, history, conversation, files):
try:
extracted, hval = "", ""
if files:
with ThreadPoolExecutor(max_workers=3) as pool:
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
hval = file_hash(files[0].name)
prompt = f"""Review these medical records and identify exactly what might have been missed:
1. Missed diagnoses
2. Medication conflicts
3. Incomplete assessments
4. Abnormal results needing follow-up
Medical Records:\n{extracted[:15000]}
"""
final_response = ""
for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
if isinstance(chunk, str): final_response += chunk
elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
updated_history = history + [[message, cleaned]]
return updated_history, None
except Exception as e:
return history + [[message, f"❌ Error: {str(e)}"]], None
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
return demo
if __name__ == "__main__":
ui = create_ui()
ui.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=["/data/reports"],
share=False
)
|