CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
833a580 verified
raw
history blame
8.57 kB
#import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Thread
# Setup
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.join(current_dir, "src")
sys.path.insert(0, src_path)
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "txagent_models")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
file_cache_dir = os.path.join(base_dir, "cache")
report_dir = os.path.join(base_dir, "reports")
vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(d, exist_ok=True)
# Hugging Face & Transformers cache
os.environ.update({
"HF_HOME": model_cache_dir,
"TRANSFORMERS_CACHE": model_cache_dir,
"VLLM_CACHE_DIR": vllm_cache_dir,
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
from txagent.txagent import TxAgent
MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
'allergies', 'summary', 'impression', 'findings', 'recommendations'}
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
def extract_priority_pages(file_path, max_pages=20):
try:
with pdfplumber.open(file_path) as pdf:
pages = []
for i, page in enumerate(pdf.pages[:3]):
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
text = page.extract_text() or ""
if any(re.search(rf'\b{kw}\b', text.lower()) for kw in MEDICAL_KEYWORDS):
pages.append(f"=== Page {i} ===\n{text.strip()}")
return "\n\n".join(pages)
except Exception as e:
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path, file_type):
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
Thread(target=full_pdf_processing, args=(file_path, h)).start()
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
elif file_type in ["xls", "xlsx"]:
try:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
except:
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
else:
return json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
return result
except Exception as e:
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
def full_pdf_processing(file_path, file_hash_value):
try:
cache_path = os.path.join(file_cache_dir, f"{file_hash_value}_full.json")
if os.path.exists(cache_path): return
with pdfplumber.open(file_path) as pdf:
full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out: out.write(full_text)
except Exception as e:
print("PDF processing error:", e)
def init_agent():
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100
)
agent.init_model()
return agent
# Lazy load agent only on first use
agent_container = {"agent": None}
def get_agent():
if agent_container["agent"] is None:
agent_container["agent"] = init_agent()
return agent_container["agent"]
def create_ui(get_agent_func):
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1><h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
send_btn = gr.Button("Analyze", variant="primary")
state = gr.State([])
download_output = gr.File(label="Download Report")
def analyze(message, history, conversation, files):
try:
extracted_data, file_hash_value = "", ""
if files:
with ThreadPoolExecutor(max_workers=4) as pool:
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
file_hash_value = file_hash(files[0].name)
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
1. List potential missed diagnoses
2. Flag any medication conflicts
3. Note incomplete assessments
4. Highlight abnormal results needing follow-up
Medical Records:\n{extracted_data[:15000]}
### Potential Oversights:\n"""
final_response = ""
for chunk in get_agent_func().run_gradio_chat(
message=prompt,
history=[],
temperature=0.2,
max_new_tokens=1024,
max_token=4096,
call_agent=False,
conversation=conversation
):
if isinstance(chunk, str):
final_response += chunk
elif isinstance(chunk, list):
final_response += "".join([c.content for c in chunk if hasattr(c, "content")])
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
if not cleaned:
cleaned = "No oversights found. Consider further review."
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value and os.path.exists(os.path.join(report_dir, f"{file_hash_value}_report.txt")) else None
yield updated_history, report_path
except Exception as e:
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}]
yield updated_history, None
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
return demo
if __name__ == "__main__":
print("Launching interface...")
ui = create_ui(get_agent)
ui.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=["/data/reports"],
share=False
)