|
import sys |
|
import os |
|
import pandas as pd |
|
import pdfplumber |
|
import json |
|
import gradio as gr |
|
from typing import List, Optional |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
import hashlib |
|
import shutil |
|
import time |
|
from functools import lru_cache |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
print(f"Adding to path: {src_path}") |
|
sys.path.insert(0, src_path) |
|
|
|
|
|
base_dir = "/data" |
|
model_cache_dir = os.path.join(base_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(base_dir, "tool_cache") |
|
file_cache_dir = os.path.join(base_dir, "cache") |
|
|
|
os.makedirs(model_cache_dir, exist_ok=True) |
|
os.makedirs(tool_cache_dir, exist_ok=True) |
|
os.makedirs(file_cache_dir, exist_ok=True) |
|
|
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
os.environ["HF_HOME"] = model_cache_dir |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
def sanitize_utf8(text: str) -> str: |
|
return text.encode("utf-8", "ignore").decode("utf-8") |
|
|
|
def file_hash(path: str) -> str: |
|
with open(path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
@lru_cache(maxsize=100) |
|
def get_cached_response(prompt: str, file_hash: str) -> Optional[str]: |
|
return None |
|
|
|
def convert_file_to_json(file_path: str, file_type: str) -> str: |
|
try: |
|
h = file_hash(file_path) |
|
cache_path = os.path.join(file_cache_dir, f"{h}.json") |
|
|
|
if os.path.exists(cache_path): |
|
return open(cache_path, "r", encoding="utf-8").read() |
|
|
|
if file_type == "csv": |
|
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip") |
|
elif file_type in ["xls", "xlsx"]: |
|
try: |
|
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str) |
|
except: |
|
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str) |
|
elif file_type == "pdf": |
|
with pdfplumber.open(file_path) as pdf: |
|
text = "\n".join([page.extract_text() or "" for page in pdf.pages]) |
|
result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()}) |
|
with open(cache_path, "w", encoding="utf-8") as f: |
|
f.write(result) |
|
return result |
|
else: |
|
return json.dumps({"error": f"Unsupported file type: {file_type}"}) |
|
|
|
if df is None or df.empty: |
|
return json.dumps({"warning": f"No data extracted from: {file_path}"}) |
|
|
|
df = df.fillna("") |
|
content = df.astype(str).values.tolist() |
|
result = json.dumps({"filename": os.path.basename(file_path), "rows": content}) |
|
with open(cache_path, "w", encoding="utf-8") as f: |
|
f.write(result) |
|
return result |
|
except Exception as e: |
|
return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"}) |
|
|
|
def convert_files_to_json_parallel(uploaded_files: list) -> str: |
|
extracted_text = [] |
|
with ThreadPoolExecutor(max_workers=4) as executor: |
|
futures = [] |
|
for file in uploaded_files: |
|
if not hasattr(file, 'name'): |
|
continue |
|
path = file.name |
|
ext = path.split(".")[-1].lower() |
|
futures.append(executor.submit(convert_file_to_json, path, ext)) |
|
|
|
for future in as_completed(futures): |
|
extracted_text.append(sanitize_utf8(future.result())) |
|
return "\n".join(extracted_text) |
|
|
|
def init_agent(): |
|
default_tool_path = os.path.abspath("data/new_tool.json") |
|
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
if not os.path.exists(target_tool_path): |
|
shutil.copy(default_tool_path, target_tool_path) |
|
|
|
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B" |
|
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B" |
|
|
|
agent = TxAgent( |
|
model_name=model_name, |
|
rag_model_name=rag_model_name, |
|
tool_files_dict={"new_tool": target_tool_path}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=8, |
|
seed=100, |
|
additional_default_tools=[] |
|
) |
|
agent.init_model() |
|
return agent |
|
|
|
def create_ui(agent: TxAgent): |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>") |
|
|
|
chatbot = gr.Chatbot(label="CPS Assistant", height=600, type="messages") |
|
file_upload = gr.File( |
|
label="Upload Medical File", |
|
file_types=[".pdf", ".txt", ".docx", ".jpg", ".png", ".csv", ".xls", ".xlsx"], |
|
file_count="multiple" |
|
) |
|
message_input = gr.Textbox(placeholder="Ask a biomedical question or just upload the files...", show_label=False) |
|
send_button = gr.Button("Send", variant="primary") |
|
conversation_state = gr.State([]) |
|
|
|
def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()): |
|
start_time = time.time() |
|
try: |
|
history.append({"role": "user", "content": message}) |
|
history.append({"role": "assistant", "content": "⏳ Processing your request..."}) |
|
yield history |
|
|
|
file_process_time = time.time() |
|
extracted_text = "" |
|
if uploaded_files and isinstance(uploaded_files, list): |
|
extracted_text = convert_files_to_json_parallel(uploaded_files) |
|
print(f"File processing took: {time.time() - file_process_time:.2f}s") |
|
|
|
context = ( |
|
"You are an expert clinical AI assistant. Review this patient's history, " |
|
"medications, and notes, and ONLY provide a final answer summarizing " |
|
"what the doctor might have missed." |
|
) |
|
chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]" |
|
|
|
model_start = time.time() |
|
generator = agent.run_gradio_chat( |
|
message=chunked_prompt, |
|
history=[], |
|
temperature=0.3, |
|
max_new_tokens=768, |
|
max_token=4096, |
|
call_agent=False, |
|
conversation=conversation, |
|
uploaded_files=uploaded_files, |
|
max_round=10 |
|
) |
|
|
|
final_response = [] |
|
for update in generator: |
|
if not update: |
|
continue |
|
if isinstance(update, str): |
|
final_response.append(update) |
|
elif isinstance(update, list): |
|
final_response.extend(msg.content for msg in update if hasattr(msg, 'content')) |
|
|
|
if len(final_response) % 3 == 0: |
|
history[-1] = {"role": "assistant", "content": "".join(final_response).strip()} |
|
yield history |
|
|
|
history[-1] = {"role": "assistant", "content": "".join(final_response).strip() or "❌ No response."} |
|
print(f"Model processing took: {time.time() - model_start:.2f}s") |
|
yield history |
|
|
|
except Exception as chat_error: |
|
print(f"Chat handling error: {chat_error}") |
|
history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."} |
|
yield history |
|
finally: |
|
print(f"Total request time: {time.time() - start_time:.2f}s") |
|
|
|
inputs = [message_input, chatbot, conversation_state, file_upload] |
|
send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot) |
|
message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot) |
|
|
|
gr.Examples([ |
|
["Upload your medical form and ask what the doctor might've missed."], |
|
["This patient was treated with antibiotics for UTI. What else should we check?"], |
|
["Is there anything abnormal in the attached blood work report?"] |
|
], inputs=message_input) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
print("Initializing agent...") |
|
agent = init_agent() |
|
|
|
print("Performing warm-up call...") |
|
try: |
|
warm_up = agent.run_gradio_chat( |
|
message="Warm up", |
|
history=[], |
|
temperature=0.1, |
|
max_new_tokens=10, |
|
max_token=100, |
|
call_agent=False, |
|
conversation=[] |
|
) |
|
for _ in warm_up: |
|
pass |
|
except: |
|
pass |
|
|
|
print("Launching interface...") |
|
demo = create_ui(agent) |
|
demo.queue().launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
share=True |
|
) |
|
|