File size: 8,567 Bytes
833a580 dae38a2 1da2cfd dae38a2 1fa1ea5 e24be23 1fa1ea5 e24be23 65a2e99 f05e804 dae38a2 65a2e99 e24be23 1fa1ea5 1da2cfd 1fa1ea5 f05e804 1da2cfd dae38a2 1fa1ea5 dae38a2 1fa1ea5 dae38a2 1fa1ea5 1da2cfd 1fa1ea5 1da2cfd 1fa1ea5 1da2cfd 1fa1ea5 1da2cfd e24be23 1fa1ea5 dae38a2 e24be23 1fa1ea5 dae38a2 1da2cfd 1ebbef1 1da2cfd 722c891 1fa1ea5 dae38a2 1fa1ea5 dae38a2 1fa1ea5 dae38a2 1da2cfd 1fa1ea5 1da2cfd 1fa1ea5 1da2cfd 722c891 1ebbef1 1fa1ea5 1da2cfd 1fa1ea5 e24be23 1da2cfd e24be23 1fa1ea5 e24be23 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 65a2e99 1da2cfd 1ebbef1 1da2cfd e0fba37 b90a0eb 1fa1ea5 65a2e99 d14e134 1fa1ea5 d14e134 1fa1ea5 d14e134 1fa1ea5 65a2e99 1fa1ea5 89e3b93 1fa1ea5 d14e134 1fa1ea5 89e3b93 d14e134 1fa1ea5 89e3b93 dae38a2 1fa1ea5 1bb8be7 dae38a2 e24be23 1fa1ea5 e24be23 e778114 d14e134 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
#import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Thread
# Setup
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.join(current_dir, "src")
sys.path.insert(0, src_path)
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "txagent_models")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
file_cache_dir = os.path.join(base_dir, "cache")
report_dir = os.path.join(base_dir, "reports")
vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
os.makedirs(d, exist_ok=True)
# Hugging Face & Transformers cache
os.environ.update({
"HF_HOME": model_cache_dir,
"TRANSFORMERS_CACHE": model_cache_dir,
"VLLM_CACHE_DIR": vllm_cache_dir,
"TOKENIZERS_PARALLELISM": "false",
"CUDA_LAUNCH_BLOCKING": "1"
})
from txagent.txagent import TxAgent
MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
'allergies', 'summary', 'impression', 'findings', 'recommendations'}
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
def extract_priority_pages(file_path, max_pages=20):
try:
with pdfplumber.open(file_path) as pdf:
pages = []
for i, page in enumerate(pdf.pages[:3]):
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
text = page.extract_text() or ""
if any(re.search(rf'\b{kw}\b', text.lower()) for kw in MEDICAL_KEYWORDS):
pages.append(f"=== Page {i} ===\n{text.strip()}")
return "\n\n".join(pages)
except Exception as e:
return f"PDF processing error: {str(e)}"
def convert_file_to_json(file_path, file_type):
try:
h = file_hash(file_path)
cache_path = os.path.join(file_cache_dir, f"{h}.json")
if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
if file_type == "pdf":
text = extract_priority_pages(file_path)
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
Thread(target=full_pdf_processing, args=(file_path, h)).start()
elif file_type == "csv":
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
elif file_type in ["xls", "xlsx"]:
try:
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
except:
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
else:
return json.dumps({"error": f"Unsupported file type: {file_type}"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
return result
except Exception as e:
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
def full_pdf_processing(file_path, file_hash_value):
try:
cache_path = os.path.join(file_cache_dir, f"{file_hash_value}_full.json")
if os.path.exists(cache_path): return
with pdfplumber.open(file_path) as pdf:
full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out: out.write(full_text)
except Exception as e:
print("PDF processing error:", e)
def init_agent():
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100
)
agent.init_model()
return agent
# Lazy load agent only on first use
agent_container = {"agent": None}
def get_agent():
if agent_container["agent"] is None:
agent_container["agent"] = init_agent()
return agent_container["agent"]
def create_ui(get_agent_func):
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1><h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
send_btn = gr.Button("Analyze", variant="primary")
state = gr.State([])
download_output = gr.File(label="Download Report")
def analyze(message, history, conversation, files):
try:
extracted_data, file_hash_value = "", ""
if files:
with ThreadPoolExecutor(max_workers=4) as pool:
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
file_hash_value = file_hash(files[0].name)
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
1. List potential missed diagnoses
2. Flag any medication conflicts
3. Note incomplete assessments
4. Highlight abnormal results needing follow-up
Medical Records:\n{extracted_data[:15000]}
### Potential Oversights:\n"""
final_response = ""
for chunk in get_agent_func().run_gradio_chat(
message=prompt,
history=[],
temperature=0.2,
max_new_tokens=1024,
max_token=4096,
call_agent=False,
conversation=conversation
):
if isinstance(chunk, str):
final_response += chunk
elif isinstance(chunk, list):
final_response += "".join([c.content for c in chunk if hasattr(c, "content")])
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
if not cleaned:
cleaned = "No oversights found. Consider further review."
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value and os.path.exists(os.path.join(report_dir, f"{file_hash_value}_report.txt")) else None
yield updated_history, report_path
except Exception as e:
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}]
yield updated_history, None
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
return demo
if __name__ == "__main__":
print("Launching interface...")
ui = create_ui(get_agent)
ui.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=["/data/reports"],
share=False
)
|