|
import sys |
|
import os |
|
import gradio as gr |
|
import hashlib |
|
import time |
|
import json |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import pandas as pd |
|
import pdfplumber |
|
|
|
|
|
os.environ.update({ |
|
"HF_HOME": "/data/hf_cache", |
|
"TOKENIZERS_PARALLELISM": "false" |
|
}) |
|
|
|
|
|
os.makedirs("/data/hf_cache", exist_ok=True) |
|
os.makedirs("/data/file_cache", exist_ok=True) |
|
os.makedirs("/data/reports", exist_ok=True) |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))) |
|
from txagent.txagent import TxAgent |
|
|
|
|
|
try: |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={"new_tool": "/data/tool_cache/new_tool.json"}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=8, |
|
seed=100 |
|
) |
|
agent.init_model() |
|
except Exception as e: |
|
print(f"Failed to initialize agent: {str(e)}") |
|
agent = None |
|
|
|
def file_hash(path: str) -> str: |
|
with open(path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def extract_text_from_pdf(file_path: str, max_pages: int = 10) -> str: |
|
try: |
|
with pdfplumber.open(file_path) as pdf: |
|
return "\n".join( |
|
f"Page {i+1}:\n{(page.extract_text() or '').strip()}\n" |
|
for i, page in enumerate(pdf.pages[:max_pages]) |
|
) |
|
except Exception as e: |
|
return f"PDF error: {str(e)}" |
|
|
|
def process_file(file_path: str, file_type: str) -> str: |
|
try: |
|
cache_path = f"/data/file_cache/{file_hash(file_path)}.json" |
|
if os.path.exists(cache_path): |
|
with open(cache_path, "r") as f: |
|
return f.read() |
|
|
|
if file_type == "pdf": |
|
content = extract_text_from_pdf(file_path) |
|
elif file_type == "csv": |
|
df = pd.read_csv(file_path, header=None, dtype=str, on_bad_lines="skip") |
|
content = df.fillna("").to_string() |
|
elif file_type in ["xls", "xlsx"]: |
|
df = pd.read_excel(file_path, header=None, dtype=str) |
|
content = df.fillna("").to_string() |
|
else: |
|
return json.dumps({"error": "Unsupported file type"}) |
|
|
|
result = json.dumps({"filename": os.path.basename(file_path), "content": content}) |
|
with open(cache_path, "w") as f: |
|
f.write(result) |
|
return result |
|
except Exception as e: |
|
return json.dumps({"error": str(e)}) |
|
|
|
def format_response(response: str) -> str: |
|
response = response.replace("[TOOL_CALLS]", "").strip() |
|
sections = { |
|
"1. **Missed Diagnoses**:": "π Missed Diagnoses", |
|
"2. **Medication Conflicts**:": "π Medication Conflicts", |
|
"3. **Incomplete Assessments**:": "π Incomplete Assessments", |
|
"4. **Abnormal Results Needing Follow-up**:": "β οΈ Abnormal Results" |
|
} |
|
for old, new in sections.items(): |
|
response = response.replace(old, f"\n### {new}\n") |
|
return response |
|
|
|
def analyze(message: str, history: list, files: list): |
|
if agent is None: |
|
yield history + [(message, "Agent initialization failed. Please try again later.")], None |
|
return |
|
|
|
history.append((message, None)) |
|
yield history, None |
|
|
|
try: |
|
extracted_data = "" |
|
if files: |
|
with ThreadPoolExecutor() as executor: |
|
futures = [executor.submit(process_file, f.name, f.name.split(".")[-1]) |
|
for f in files if hasattr(f, 'name')] |
|
extracted_data = "\n".join(f.result() for f in as_completed(futures)) |
|
|
|
prompt = f"""Review these medical records: |
|
{extracted_data[:10000]} |
|
|
|
Identify potential issues: |
|
1. Missed diagnoses |
|
2. Medication conflicts |
|
3. Incomplete assessments |
|
4. Abnormal results needing follow-up |
|
|
|
Analysis:""" |
|
|
|
response = "" |
|
for chunk in agent.run_gradio_chat( |
|
message=prompt, |
|
history=[], |
|
temperature=0.2, |
|
max_new_tokens=800 |
|
): |
|
if isinstance(chunk, str): |
|
response += chunk |
|
elif isinstance(chunk, list): |
|
response += "".join(getattr(c, 'content', '') for c in chunk) |
|
|
|
history[-1] = (message, format_response(response)) |
|
yield history, None |
|
|
|
history[-1] = (message, format_response(response)) |
|
yield history, None |
|
|
|
except Exception as e: |
|
history[-1] = (message, f"β Error: {str(e)}") |
|
yield history, None |
|
|
|
|
|
with gr.Blocks( |
|
title="Clinical Oversight Assistant", |
|
css=""" |
|
.gradio-container { |
|
max-width: 1000px; |
|
margin: auto; |
|
} |
|
.chatbot { |
|
min-height: 500px; |
|
} |
|
""" |
|
) as demo: |
|
gr.Markdown("# π©Ί Clinical Oversight Assistant") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
files = gr.File( |
|
label="Upload Medical Records", |
|
file_types=[".pdf", ".csv", ".xlsx"], |
|
file_count="multiple" |
|
) |
|
query = gr.Textbox( |
|
label="Your Query", |
|
placeholder="Ask about potential oversights..." |
|
) |
|
submit = gr.Button("Analyze", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="Analysis Results", |
|
show_copy_button=True |
|
) |
|
|
|
submit.click( |
|
analyze, |
|
inputs=[query, chatbot, files], |
|
outputs=[chatbot, gr.File(visible=False)] |
|
) |
|
query.submit( |
|
analyze, |
|
inputs=[query, chatbot, files], |
|
outputs=[chatbot, gr.File(visible=False)] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True |
|
) |