CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
9345354 verified
raw
history blame
12.1 kB
import sys
import os
import pandas as pd
import pdfplumber
import gradio as gr
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor, as_completed
import hashlib
import shutil
import re
import logging
import torch
import gc
from diskcache import Cache
from transformers import AutoTokenizer
from functools import lru_cache
from difflib import SequenceMatcher
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
MAX_TOKENS = 1800
BATCH_SIZE = 1
MAX_WORKERS = 2
CHUNK_SIZE = 5
MODEL_MAX_TOKENS = 131072
MAX_TEXT_LENGTH = 500000
MAX_ROWS_TO_PROCESS = 1000 # Limit for Excel/CSV rows
# Persistent directory setup
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
os.makedirs(report_dir, exist_ok=True)
os.environ.update({
"HF_HOME": model_cache_dir,
"TOKENIZERS_PARALLELISM": "false",
})
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
# Initialize cache
cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
@lru_cache(maxsize=1)
def get_tokenizer():
return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
def sanitize_utf8(text: str) -> str:
return text.encode("utf-8", "ignore").decode("utf-8")
def file_hash(path: str) -> str:
hash_md5 = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract_pdf_page(page, tokenizer, max_tokens=MAX_TOKENS) -> List[str]:
try:
text = page.extract_text() or ""
text = sanitize_utf8(text)
if len(text) > MAX_TEXT_LENGTH // 10:
text = text[:MAX_TEXT_LENGTH // 10]
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) > max_tokens:
chunks = []
current_chunk = []
current_length = 0
for token in tokens:
if current_length + 1 > max_tokens:
chunks.append(tokenizer.decode(current_chunk))
current_chunk = [token]
current_length = 1
else:
current_chunk.append(token)
current_length += 1
if current_chunk:
chunks.append(tokenizer.decode(current_chunk))
return chunks
return [text]
except Exception as e:
logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
return []
def extract_all_pages(file_path: str) -> List[str]:
try:
tokenizer = get_tokenizer()
with pdfplumber.open(file_path) as pdf:
total_pages = len(pdf.pages)
if total_pages == 0:
return ["PDF appears to be empty"]
results = []
for i in range(0, min(total_pages, 50)): # Limit to first 50 pages
try:
page = pdf.pages[i]
chunks = extract_pdf_page(page, tokenizer)
for chunk in chunks:
results.append(f"=== Page {i+1} ===\n{chunk}")
except Exception as e:
logger.warning(f"Error processing page {i+1}: {str(e)}")
continue
return results if results else ["Could not extract text from PDF"]
except Exception as e:
logger.error(f"PDF processing error: {e}")
return [f"PDF processing error: {str(e)}"]
def excel_to_json(file_path: str) -> List[Dict]:
engines = ['openpyxl', 'xlrd']
for engine in engines:
try:
with pd.ExcelFile(file_path, engine=engine) as excel_file:
sheets = excel_file.sheet_names
if not sheets:
return [{"error": "No sheets found"}]
results = []
for sheet_name in sheets[:3]: # Limit to first 3 sheets
try:
df = pd.read_excel(
excel_file,
sheet_name=sheet_name,
header=None,
dtype=str,
na_filter=False,
nrows=MAX_ROWS_TO_PROCESS # Limit rows
)
if not df.empty:
rows = df.head(MAX_ROWS_TO_PROCESS).values.tolist()
results.append({
"filename": os.path.basename(file_path),
"sheet": sheet_name,
"rows": rows,
"type": "excel"
})
except Exception as e:
logger.warning(f"Error processing sheet {sheet_name}: {str(e)}")
continue
return results if results else [{"error": "No readable data found"}]
except Exception as e:
logger.warning(f"Excel engine {engine} failed: {str(e)}")
continue
return [{"error": "Could not process Excel file with any engine"}]
def csv_to_json(file_path: str) -> List[Dict]:
try:
df = pd.read_csv(
file_path,
header=None,
dtype=str,
encoding_errors='replace',
on_bad_lines='skip',
nrows=MAX_ROWS_TO_PROCESS # Limit rows
)
if df.empty:
return [{"error": "CSV file is empty"}]
return [{
"filename": os.path.basename(file_path),
"rows": df.values.tolist(),
"type": "csv"
}]
except Exception as e:
logger.error(f"CSV processing error: {e}")
return [{"error": f"CSV processing error: {str(e)}"}]
def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
try:
logger.info(f"Processing {file_type} file: {os.path.basename(file_path)}")
if file_type == "pdf":
chunks = extract_all_pages(file_path)
return [{
"filename": os.path.basename(file_path),
"content": chunk,
"type": "pdf"
} for chunk in chunks]
elif file_type in ["xls", "xlsx"]:
return excel_to_json(file_path)
elif file_type == "csv":
return csv_to_json(file_path)
return [{"error": f"Unsupported file type: {file_type}"}]
except Exception as e:
logger.error(f"Error processing file: {e}")
return [{"error": f"Error processing file: {str(e)}"}]
def clean_response(text: str) -> str:
if not text:
return ""
patterns = [
(re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""),
(re.compile(r"\s+"), " "),
]
for pattern, repl in patterns:
text = pattern.sub(repl, text)
return text.strip()
@lru_cache(maxsize=1)
def init_agent():
logger.info("Initializing model...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": os.path.join(tool_cache_dir, "new_tool.json")},
force_finish=True,
enable_checker=False,
step_rag_num=4,
seed=100,
)
agent.init_model()
logger.info("Agent Ready")
return agent
def create_ui(agent):
PROMPT_TEMPLATE = """
Analyze this patient record excerpt for missed diagnoses (limit response to 500 tokens):
{chunk}
"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Analysis", height=500, type="messages")
msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
send_btn = gr.Button("Analyze", variant="primary")
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="single")
with gr.Column(scale=1):
final_summary = gr.Markdown("## Summary")
status = gr.Textbox(label="Status", interactive=False)
def analyze(message: str, history: List[Dict], file_obj) -> tuple:
try:
if not file_obj:
return history, "Please upload a file first", "No file uploaded"
file_path = file_obj.name
file_type = os.path.splitext(file_path)[-1].lower().replace(".", "")
history.append({"role": "user", "content": message})
# Process file
processed = process_file_cached(file_path, file_type)
if "error" in processed[0]:
history.append({"role": "assistant", "content": processed[0]["error"]})
return history, processed[0]["error"], "File processing failed"
# Prepare chunks
chunks = []
for item in processed:
if "content" in item:
chunks.append(item["content"])
elif "rows" in item:
rows_text = "\n".join([", ".join(map(str, row)) for row in item["rows"][:100]])
chunks.append(f"=== {item.get('sheet', 'Data')} ===\n{rows_text}")
if not chunks:
history.append({"role": "assistant", "content": "No processable content found."})
return history, "No processable content found", "Content extraction failed"
# Analyze each chunk
responses = []
for i, chunk in enumerate(chunks[:5]):
try:
prompt = PROMPT_TEMPLATE.format(chunk=chunk[:5000])
response = agent.run_quick_summary(prompt, 0.2, 256, 500)
cleaned = clean_response(response)
if cleaned:
responses.append(f"Analysis {i+1}:\n{cleaned}")
except Exception as e:
logger.warning(f"Error analyzing chunk {i+1}: {str(e)}")
continue
if not responses:
history.append({"role": "assistant", "content": "No valid analysis generated."})
return history, "No valid analysis generated", "Analysis failed"
summary = "\n\n".join(responses)
history.append({"role": "assistant", "content": summary})
return history, "Analysis completed", "Success"
except Exception as e:
logger.error(f"Analysis error: {e}")
history.append({"role": "assistant", "content": f"Error: {str(e)}"})
return history, f"Error: {str(e)}", "Failed"
finally:
torch.cuda.empty_cache()
gc.collect()
send_btn.click(
analyze,
inputs=[msg_input, chatbot, file_upload],
outputs=[chatbot, final_summary, status]
)
msg_input.submit(
analyze,
inputs=[msg_input, chatbot, file_upload],
outputs=[chatbot, final_summary, status]
)
return demo
if __name__ == "__main__":
try:
agent = init_agent()
demo = create_ui(agent)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
logger.error(f"Fatal error: {e}")
raise