CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
1a611b9 verified
raw
history blame
8.37 kB
import sys
import os
import pandas as pd
import json
import gradio as gr
from typing import List, Tuple, Union, Generator, Dict, Any
import re
from datetime import datetime
import atexit
import torch.distributed as dist
import logging
# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# PyTorch cleanup
def cleanup():
if dist.is_initialized():
logger.info("Cleaning up PyTorch distributed process group")
dist.destroy_process_group()
atexit.register(cleanup)
# Directories
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(d, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
MAX_MODEL_TOKENS = 32768
MAX_CHUNK_TOKENS = 8192
MAX_NEW_TOKENS = 2048
PROMPT_OVERHEAD = 500
def clean_response(text: str) -> str:
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
return text.strip()
def estimate_tokens(text: str) -> int:
return len(text) // 3.5 + 1
def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
if isinstance(file_obj, dict) and 'name' in file_obj:
file_path = file_obj['name']
elif isinstance(file_obj, str):
file_path = file_obj
else:
raise ValueError("Unsupported file input type")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
xls = pd.ExcelFile(file_path)
all_text = []
for sheet in xls.sheet_names:
try:
df = xls.parse(sheet).astype(str).fillna("")
rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1)
sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()]
all_text.extend(sheet_text)
except Exception as e:
logger.warning(f"Failed to parse {sheet}: {e}")
return "\n".join(all_text)
def split_text_into_chunks(text: str) -> List[str]:
lines = text.split("\n")
chunks, current, current_tokens = [], [], 0
max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
for line in lines:
t = estimate_tokens(line)
if current_tokens + t > max_tokens:
chunks.append("\n".join(current))
current, current_tokens = [line], t
else:
current.append(line)
current_tokens += t
if current:
chunks.append("\n".join(current))
return chunks
def build_prompt_from_text(chunk: str) -> str:
return f"""
### Clinical Records Analysis
Please analyze these clinical notes and provide:
- Key diagnostic indicators
- Current medications and potential issues
- Recommended follow-up actions
- Any inconsistencies or concerns
---
{chunk}
---
Provide a structured response with clear medical reasoning.
"""
def validate_tool_file(tool_name: str, tool_path: str) -> bool:
try:
if not os.path.exists(tool_path):
logger.error(f"Missing tool file: {tool_path}")
return False
with open(tool_path, 'r') as f:
tool_data = json.load(f)
if isinstance(tool_data, list):
return all(isinstance(item, dict) and 'name' in item for item in tool_data)
elif isinstance(tool_data, dict):
if 'tools' in tool_data:
return all(isinstance(item, dict) and 'name' in item for item in tool_data['tools'])
return 'name' in tool_data
logger.error(f"Invalid format in tool: {tool_name}")
return False
except Exception as e:
logger.error(f"Error in {tool_name}: {e}")
return False
def init_agent() -> TxAgent:
new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(new_tool_path):
with open(new_tool_path, 'w') as f:
json.dump({
"name": "new_tool",
"description": "Default tool",
"tools": [{"name": "dummy_tool", "description": "test", "version": "1.0"}]
}, f)
tool_files = {
'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
'new_tool': new_tool_path
}
valid_tools = {k: v for k, v in tool_files.items() if validate_tool_file(k, v)}
if not valid_tools:
raise ValueError("No valid tool files")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=valid_tools,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
return agent
def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
accumulated = ""
if input_file is None:
yield "โŒ Upload an Excel file.", None, ""
return
try:
text = extract_text_from_excel(input_file)
chunks = split_text_into_chunks(text)
except Exception as e:
yield f"โŒ Error: {str(e)}", None, ""
return
for i, chunk in enumerate(chunks):
prompt = build_prompt_from_text(chunk)
result = ""
for out in agent.run_gradio_chat(
message=prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]
):
result += out if isinstance(out, str) else out.content
cleaned = clean_response(result)
accumulated += f"\n\n๐Ÿ“„ Part {i+1}:\n{cleaned}"
yield accumulated, None, ""
summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
summary = ""
for out in agent.run_gradio_chat(
message=summary_prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]
):
summary += out if isinstance(out, str) else out.content
final = clean_response(summary)
report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
with open(report_path, 'w') as f:
f.write(f"# Clinical Report\n\n{final}")
yield f"{accumulated}\n\n๐Ÿ“Š Final Summary:\n{final}", report_path, final
def create_ui(agent: TxAgent) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿฅ Clinical Records Analyzer")
with gr.Row():
file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Row():
with gr.Column(scale=2):
report_output = gr.Markdown()
with gr.Column(scale=1):
report_file = gr.File(label="Download", visible=False)
full_output = gr.State()
analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
return demo
if __name__ == "__main__":
try:
agent = init_agent()
demo = create_ui(agent)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
except Exception as e:
logger.error(f"App error: {e}", exc_info=True)
print(f"โŒ Application error: {e}", file=sys.stderr)
sys.exit(1)