CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
9ef8abc verified
raw
history blame
6.01 kB
import sys
import os
import gradio as gr
import hashlib
import time
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
import pdfplumber
# Set up environment
os.environ.update({
"HF_HOME": "/data/hf_cache",
"TOKENIZERS_PARALLELISM": "false"
})
# Create cache directories
os.makedirs("/data/hf_cache", exist_ok=True)
os.makedirs("/data/file_cache", exist_ok=True)
os.makedirs("/data/reports", exist_ok=True)
# Import TxAgent after setting up environment
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
# Initialize agent with error handling
try:
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": "/data/tool_cache/new_tool.json"},
force_finish=True,
enable_checker=True,
step_rag_num=8,
seed=100
)
agent.init_model()
except Exception as e:
print(f"Failed to initialize agent: {str(e)}")
agent = None
def file_hash(path: str) -> str:
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def extract_text_from_pdf(file_path: str, max_pages: int = 10) -> str:
try:
with pdfplumber.open(file_path) as pdf:
return "\n".join(
f"Page {i+1}:\n{(page.extract_text() or '').strip()}\n"
for i, page in enumerate(pdf.pages[:max_pages])
)
except Exception as e:
return f"PDF error: {str(e)}"
def process_file(file_path: str, file_type: str) -> str:
try:
cache_path = f"/data/file_cache/{file_hash(file_path)}.json"
if os.path.exists(cache_path):
with open(cache_path, "r") as f:
return f.read()
if file_type == "pdf":
content = extract_text_from_pdf(file_path)
elif file_type == "csv":
df = pd.read_csv(file_path, header=None, dtype=str, on_bad_lines="skip")
content = df.fillna("").to_string()
elif file_type in ["xls", "xlsx"]:
df = pd.read_excel(file_path, header=None, dtype=str)
content = df.fillna("").to_string()
else:
return json.dumps({"error": "Unsupported file type"})
result = json.dumps({"filename": os.path.basename(file_path), "content": content})
with open(cache_path, "w") as f:
f.write(result)
return result
except Exception as e:
return json.dumps({"error": str(e)})
def format_response(response: str) -> str:
response = response.replace("[TOOL_CALLS]", "").strip()
sections = {
"1. **Missed Diagnoses**:": "πŸ” Missed Diagnoses",
"2. **Medication Conflicts**:": "πŸ’Š Medication Conflicts",
"3. **Incomplete Assessments**:": "πŸ“‹ Incomplete Assessments",
"4. **Abnormal Results Needing Follow-up**:": "⚠️ Abnormal Results"
}
for old, new in sections.items():
response = response.replace(old, f"\n### {new}\n")
return response
def analyze(message: str, history: list, files: list):
if agent is None:
yield history + [(message, "Agent initialization failed. Please try again later.")], None
return
history.append((message, None))
yield history, None
try:
extracted_data = ""
if files:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_file, f.name, f.name.split(".")[-1])
for f in files if hasattr(f, 'name')]
extracted_data = "\n".join(f.result() for f in as_completed(futures))
prompt = f"""Review these medical records:
{extracted_data[:10000]}
Identify potential issues:
1. Missed diagnoses
2. Medication conflicts
3. Incomplete assessments
4. Abnormal results needing follow-up
Analysis:"""
response = ""
for chunk in agent.run_gradio_chat(
message=prompt,
history=[],
temperature=0.2,
max_new_tokens=800
):
if isinstance(chunk, str):
response += chunk
elif isinstance(chunk, list):
response += "".join(getattr(c, 'content', '') for c in chunk)
history[-1] = (message, format_response(response))
yield history, None
history[-1] = (message, format_response(response))
yield history, None
except Exception as e:
history[-1] = (message, f"❌ Error: {str(e)}")
yield history, None
# Create the interface
with gr.Blocks(
title="Clinical Oversight Assistant",
css="""
.gradio-container {
max-width: 1000px;
margin: auto;
}
.chatbot {
min-height: 500px;
}
"""
) as demo:
gr.Markdown("# 🩺 Clinical Oversight Assistant")
with gr.Row():
with gr.Column(scale=1):
files = gr.File(
label="Upload Medical Records",
file_types=[".pdf", ".csv", ".xlsx"],
file_count="multiple"
)
query = gr.Textbox(
label="Your Query",
placeholder="Ask about potential oversights..."
)
submit = gr.Button("Analyze", variant="primary")
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Analysis Results",
show_copy_button=True
)
submit.click(
analyze,
inputs=[query, chatbot, files],
outputs=[chatbot, gr.File(visible=False)]
)
query.submit(
analyze,
inputs=[query, chatbot, files],
outputs=[chatbot, gr.File(visible=False)]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)