CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
73810ec verified
raw
history blame
11.1 kB
import sys
import os
import pandas as pd
import json
import gradio as gr
from typing import List, Tuple, Union, Generator, BinaryIO, Dict, Any
import re
from datetime import datetime
import atexit
import torch.distributed as dist
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Cleanup for PyTorch distributed
def cleanup():
if dist.is_initialized():
logger.info("Cleaning up PyTorch distributed process group")
dist.destroy_process_group()
atexit.register(cleanup)
# Setup directories
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(d, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
MAX_MODEL_TOKENS = 32768
MAX_CHUNK_TOKENS = 8192
MAX_NEW_TOKENS = 2048
PROMPT_OVERHEAD = 500
def clean_response(text: str) -> str:
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
return text.strip()
def estimate_tokens(text: str) -> int:
return len(text) // 3.5 + 1
def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
"""Handle Gradio file upload object which is a dictionary with 'name' and other keys"""
all_text = []
try:
if isinstance(file_obj, dict) and 'name' in file_obj:
file_path = file_obj['name']
elif isinstance(file_obj, str):
file_path = file_obj
else:
raise ValueError("Unsupported file input type")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Temporary upload file not found at: {file_path}")
xls = pd.ExcelFile(file_path)
for sheet_name in xls.sheet_names:
try:
df = xls.parse(sheet_name).astype(str).fillna("")
rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1)
sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()]
all_text.extend(sheet_text)
except Exception as e:
logger.warning(f"Could not parse sheet {sheet_name}: {e}")
continue
return "\n".join(all_text)
except Exception as e:
raise ValueError(f"โŒ Error processing Excel file: {str(e)}")
def split_text_into_chunks(text: str) -> List[str]:
effective_max = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
for line in lines:
t = estimate_tokens(line)
if curr_tokens + t > effective_max:
if curr_chunk:
chunks.append("\n".join(curr_chunk))
curr_chunk, curr_tokens = [line], t
else:
curr_chunk.append(line)
curr_tokens += t
if curr_chunk:
chunks.append("\n".join(curr_chunk))
return chunks
def build_prompt_from_text(chunk: str) -> str:
return f"""
### Clinical Records Analysis
Please analyze these clinical notes and provide:
- Key diagnostic indicators
- Current medications and potential issues
- Recommended follow-up actions
- Any inconsistencies or concerns
---
{chunk}
---
Provide a structured response with clear medical reasoning.
"""
def validate_tool_file(tool_name: str, tool_path: str) -> None:
"""Validate the structure of a tool JSON file."""
try:
if not os.path.exists(tool_path):
raise FileNotFoundError(f"Tool file not found: {tool_path}")
with open(tool_path, 'r') as f:
tool_data = json.load(f)
logger.info(f"Contents of {tool_name} ({tool_path}): {tool_data}")
if isinstance(tool_data, list):
for item in tool_data:
if not isinstance(item, dict) or 'name' not in item:
raise ValueError(f"Invalid tool format in {tool_name}: each item must be a dict with a 'name' key, got {item}")
elif isinstance(tool_data, dict):
if 'tools' in tool_data:
if not isinstance(tool_data['tools'], list):
raise ValueError(f"'tools' field in {tool_name} must be a list, got {type(tool_data['tools'])}")
for item in tool_data['tools']:
if not isinstance(item, dict) or 'name' not in item:
raise ValueError(f"Invalid tool format in {tool_name}: each tool must be a dict with a 'name' key, got {item}")
else:
if 'name' not in tool_data:
raise ValueError(f"Invalid tool format in {tool_name}: dict must have a 'name' key or 'tools' field, got {tool_data}")
else:
raise ValueError(f"Invalid tool file {tool_name}: must be a list or dict, got {type(tool_data)}")
except Exception as e:
logger.error(f"Error validating tool file {tool_name} ({tool_path}): {str(e)}")
raise
def init_agent() -> TxAgent:
tool_path = os.path.join(tool_cache_dir, "new_tool.json")
logger.info(f"Checking for tool file at: {tool_path}")
# Create default tool file if it doesn't exist
if not os.path.exists(tool_path):
default_tool = {
"name": "new_tool",
"description": "Default tool configuration",
"version": "1.0",
"tools": [
{"name": "dummy_tool", "description": "Dummy tool for testing", "version": "1.0"}
]
}
logger.info(f"Creating default tool file at: {tool_path}")
with open(tool_path, 'w') as f:
json.dump(default_tool, f)
# Define tool files
tool_files_dict = {
'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
'new_tool': tool_path
}
# Validate all tool files
for tool_name, tool_path in tool_files_dict.items():
validate_tool_file(tool_name, tool_path)
# Initialize TxAgent
try:
logger.info(f"Initializing TxAgent with tool_files_dict: {tool_files_dict}")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=tool_files_dict,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
logger.info("TxAgent initialized, calling init_model")
agent.init_model()
logger.info("TxAgent model initialized successfully")
return agent
except Exception as e:
logger.error(f"Error initializing TxAgent: {str(e)}", exc_info=True)
raise
def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
accumulated_text = ""
try:
if input_file is None:
yield "โŒ Please upload a valid Excel file.", None, ""
return
try:
text = extract_text_from_excel(input_file)
chunks = split_text_into_chunks(text)
except Exception as e:
yield f"โŒ {str(e)}", None, ""
return
for i, chunk in enumerate(chunks):
prompt = build_prompt_from_text(chunk)
partial = ""
for res in agent.run_gradio_chat(
message=prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]
):
partial += res if isinstance(res, str) else res.content
cleaned = clean_response(partial)
accumulated_text += f"\n\n๐Ÿ“„ Analysis Part {i+1}:\n{cleaned}"
yield accumulated_text, None, ""
summary_prompt = f"Please summarize this analysis:\n\n{accumulated_text}"
final_report = ""
for res in agent.run_gradio_chat(
message=summary_prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]
):
final_report += res if isinstance(res, str) else res.content
cleaned = clean_response(final_report)
report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
with open(report_path, 'w') as f:
f.write(f"# Clinical Analysis Report\n\n{cleaned}")
yield f"{accumulated_text}\n\n๐Ÿ“Š Final Summary:\n{cleaned}", report_path, cleaned
except Exception as e:
logger.error(f"Processing error in stream_report: {str(e)}", exc_info=True)
yield f"โŒ Processing error: {str(e)}", None, ""
def create_ui(agent: TxAgent) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px !important}") as demo:
gr.Markdown("""# Clinical Records Analyzer""")
with gr.Row():
file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Row():
with gr.Column(scale=2):
report_output = gr.Markdown()
with gr.Column(scale=1):
report_file = gr.File(label="Download Report", visible=False)
full_output = gr.State()
analyze_btn.click(
fn=stream_report,
inputs=[file_upload, full_output],
outputs=[report_output, report_file, full_output]
)
return demo
if __name__ == "__main__":
try:
agent = init_agent()
demo = create_ui(agent)
logger.info("Launching Gradio UI")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Application error: {str(e)}", exc_info=True)
print(f"Application error: {str(e)}", file=sys.stderr)
sys.exit(1)