Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import re | |
| import logging | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from pydantic import BaseModel | |
| import markdown | |
| import PyPDF2 | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("TxAgentAPI") | |
| # Add src directory to Python path | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_path = os.path.abspath(os.path.join(current_dir, "src")) | |
| sys.path.insert(0, src_path) | |
| # Import TxAgent after setting up path | |
| try: | |
| from txagent.txagent import TxAgent | |
| except ImportError as e: | |
| logger.error(f"Failed to import TxAgent: {str(e)}") | |
| raise | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="TxAgent API", | |
| description="API for TxAgent medical document analysis", | |
| version="2.0.0" | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| temperature: float = 0.7 | |
| max_new_tokens: int = 512 | |
| history: Optional[List[Dict]] = None | |
| format: Optional[str] = "clean" # Options: raw, clean, structured, html | |
| # Response cleaning functions | |
| def clean_text_response(text: str) -> str: | |
| """Basic text cleaning""" | |
| text = re.sub(r'\n\s*\n', '\n\n', text) | |
| text = re.sub(r'[ ]+', ' ', text) | |
| text = text.replace("**", "").replace("__", "") | |
| return text.strip() | |
| def structure_medical_response(text: str) -> Dict: | |
| """Structure medical content into categories""" | |
| result = {"overview": "", "symptoms": [], "types": {}, "notes": ""} | |
| overview_end = text.find("Type 1 Diabetes:") | |
| result["overview"] = clean_text_response(text[:overview_end]) | |
| type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:]) | |
| current_type = None | |
| for section in type_sections: | |
| if section.startswith("Type "): | |
| current_type = section.replace(":", "").strip().lower() | |
| result["types"][current_type] = [] | |
| elif current_type: | |
| points = [p.strip() for p in section.split('\n') if p.strip()] | |
| result["types"][current_type] = points | |
| return result | |
| # Initialize agent at startup | |
| agent = None | |
| async def startup_event(): | |
| global agent | |
| try: | |
| logger.info("Initializing TxAgent...") | |
| # Get Hugging Face token from environment variable | |
| hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
| model_kwargs = {"token": hf_token} if hf_token else {} | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", # Corrected model name | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| tool_files_dict={}, | |
| enable_finish=True, | |
| enable_rag=False, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=100, | |
| **model_kwargs | |
| ) | |
| agent.init_model() | |
| logger.info("TxAgent initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize agent: {str(e)}") | |
| # Allow app to start, but endpoints will return errors if agent is None | |
| agent = None | |
| async def chat_endpoint(request: ChatRequest): | |
| """Handle chat conversations with formatting options""" | |
| if agent is None: | |
| logger.error("TxAgent not initialized") | |
| raise HTTPException(status_code=503, detail="TxAgent service unavailable") | |
| try: | |
| logger.info(f"Chat request received (format: {request.format})") | |
| raw_response = agent.chat( | |
| message=request.message, | |
| history=request.history, | |
| temperature=request.temperature, | |
| max_new_tokens=request.max_new_tokens | |
| ) | |
| formatted_response = { | |
| "raw": raw_response, | |
| "clean": clean_text_response(raw_response), | |
| "structured": structure_medical_response(raw_response), | |
| "html": markdown.markdown(raw_response) | |
| } | |
| response_content = formatted_response.get(request.format, formatted_response["clean"]) | |
| return JSONResponse({ | |
| "status": "success", | |
| "format": request.format, | |
| "response": response_content, | |
| "timestamp": datetime.now().isoformat(), | |
| "available_formats": list(formatted_response.keys()) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Chat error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_file(file: UploadFile = File(...)): | |
| """Handle file uploads and process with TxAgent""" | |
| if agent is None: | |
| logger.error("TxAgent not initialized") | |
| raise HTTPException(status_code=503, detail="TxAgent service unavailable") | |
| try: | |
| logger.info(f"File upload received: {file.filename}") | |
| content = "" | |
| if file.filename.endswith('.pdf'): | |
| pdf_reader = PyPDF2.PdfReader(file.file) | |
| for page in pdf_reader.pages: | |
| content += page.extract_text() or "" | |
| else: | |
| content = await file.read() | |
| content = content.decode('utf-8', errors='ignore') | |
| message = f"Analyze the following medical document content:\n\n{content[:10000]}" | |
| raw_response = agent.chat( | |
| message=message, | |
| history=[], | |
| temperature=0.7, | |
| max_new_tokens=512 | |
| ) | |
| formatted_response = { | |
| "raw": raw_response, | |
| "clean": clean_text_response(raw_response), | |
| "structured": structure_medical_response(raw_response), | |
| "html": markdown.markdown(raw_response) | |
| } | |
| response_content = formatted_response["clean"] | |
| return JSONResponse({ | |
| "status": "success", | |
| "format": "clean", | |
| "response": response_content, | |
| "timestamp": datetime.now().isoformat(), | |
| "available_formats": list(formatted_response.keys()) | |
| }) | |
| except Exception as e: | |
| logger.error(f"File upload error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| file.file.close() | |
| async def service_status(): | |
| """Check service status""" | |
| return { | |
| "status": "running" if agent else "failed", | |
| "version": "2.0.0", | |
| "model": agent.model_name if agent else "not loaded", | |
| "formats_available": ["raw", "clean", "structured", "html"], | |
| "timestamp": datetime.now().isoformat() | |
| } |