pdarleyjr commited on
Commit
7d96c7a
·
verified ·
1 Parent(s): fd8c861

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, status
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from pydantic import BaseModel
5
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoConfig
6
+ import torch
7
+ import os
8
+ import sys
9
+ import traceback
10
+ from typing import Optional, Dict, Any
11
+ from accelerate import Accelerator
12
+ import time
13
+ import psutil
14
+ from loguru import logger
15
+
16
+ # Configure production logging to stderr
17
+ logger.remove() # Remove default handler
18
+ logger.add(
19
+ sys.stderr,
20
+ level="INFO",
21
+ format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
22
+ )
23
+
24
+ # Initialize FastAPI app with metadata
25
+ app = FastAPI(
26
+ title="Clinical Report Generator API",
27
+ description="Production API for generating clinical report summaries using Flan-T5",
28
+ version="1.0.0",
29
+ docs_url="/documentation", # Swagger UI
30
+ redoc_url="/redoc" # ReDoc
31
+ )
32
+
33
+ # Configure CORS for production
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["https://pdarleyjr.github.io"], # GitHub Pages domain
37
+ allow_credentials=True,
38
+ allow_methods=["POST", "GET"], # Restrict to needed methods
39
+ allow_headers=["*"],
40
+ max_age=3600, # Cache preflight requests
41
+ )
42
+
43
+ class ModelManager:
44
+ def __init__(self):
45
+ self.model = None
46
+ self.tokenizer = None
47
+ self.accelerator = Accelerator()
48
+ self.last_load_time = None
49
+ self.load_lock = False
50
+
51
+ async def load_model(self) -> bool:
52
+ """Load model and tokenizer with proper error handling and logging"""
53
+ if self.load_lock:
54
+ logger.warning("Model load already in progress")
55
+ return False
56
+
57
+ try:
58
+ self.load_lock = True
59
+ logger.info("Starting model and tokenizer loading process...")
60
+
61
+ # Log system resources
62
+ memory = psutil.virtual_memory()
63
+ logger.info(f"System memory: {memory.percent}% used, {memory.available / (1024*1024*1024):.2f}GB available")
64
+ if torch.cuda.is_available():
65
+ logger.info(f"CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated")
66
+
67
+ # Load tokenizer for Flan-T5-base
68
+ logger.info("Initializing Flan-T5-base tokenizer...")
69
+ self.tokenizer = T5Tokenizer.from_pretrained(
70
+ "pdarleyjr/iplc-t5-clinical",
71
+ use_fast=True, # Use fast tokenizer
72
+ model_max_length=512
73
+ )
74
+ logger.success("Flan-T5-base tokenizer loaded successfully")
75
+
76
+ # Load model configuration
77
+ logger.info("Fetching model configuration...")
78
+ config = AutoConfig.from_pretrained(
79
+ "google/flan-t5-base",
80
+ trust_remote_code=False
81
+ )
82
+ logger.success("Model configuration loaded successfully")
83
+
84
+ # Load the Flan-T5-base model
85
+ logger.info("Loading Flan-T5-base model (this may take a few minutes)...")
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ logger.info(f"Using device: {device}")
88
+
89
+ self.model = T5ForConditionalGeneration.from_pretrained(
90
+ "pdarleyjr/iplc-t5-clinical",
91
+ config=config,
92
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
93
+ low_cpu_mem_usage=True
94
+ ).to(device)
95
+ logger.success("Model loaded successfully")
96
+
97
+ # Prepare model with accelerator
98
+ self.model = self.accelerator.prepare_model(self.model)
99
+ logger.success("Model prepared with accelerator")
100
+
101
+ # Log final memory usage
102
+ memory = psutil.virtual_memory()
103
+ logger.info(f"Final memory usage: {memory.percent}% used, {memory.available / (1024*1024*1024):.2f}GB available")
104
+ if torch.cuda.is_available():
105
+ logger.info(f"Final CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated")
106
+
107
+ self.last_load_time = time.time()
108
+ return True
109
+
110
+ except Exception as e:
111
+ logger.exception("Error loading model")
112
+ self.model = None
113
+ self.tokenizer = None
114
+ return False
115
+
116
+ finally:
117
+ self.load_lock = False
118
+
119
+ def is_loaded(self) -> bool:
120
+ """Check if model and tokenizer are loaded"""
121
+ return self.model is not None and self.tokenizer is not None
122
+
123
+ def get_load_time(self) -> Optional[float]:
124
+ """Get the last successful load time"""
125
+ return self.last_load_time
126
+
127
+ # Initialize model manager
128
+ model_manager = ModelManager()
129
+
130
+ class PredictRequest(BaseModel):
131
+ """Request model for prediction endpoint"""
132
+ text: str
133
+
134
+ class Config:
135
+ schema_extra = {
136
+ "example": {
137
+ "text": "evaluation type: initial. primary diagnosis: F84.0. severity: mild. primary language: english"
138
+ }
139
+ }
140
+
141
+ @app.post("/predict",
142
+ response_model=Dict[str, Any],
143
+ status_code=status.HTTP_200_OK,
144
+ responses={
145
+ 500: {"description": "Internal server error"},
146
+ 503: {"description": "Service unavailable - model loading"}
147
+ })
148
+ async def predict(request: PredictRequest) -> JSONResponse:
149
+ """Generate a clinical report summary"""
150
+ start_time = time.time()
151
+
152
+ try:
153
+ # Check if model needs to be loaded
154
+ if not model_manager.is_loaded():
155
+ logger.warning("Model not loaded, attempting to load...")
156
+ success = await model_manager.load_model()
157
+ if not success:
158
+ return JSONResponse(
159
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
160
+ content={
161
+ "success": False,
162
+ "error": "Model is initializing. Please try again in a few moments."
163
+ }
164
+ )
165
+
166
+ # Prepare input text
167
+ input_text = "summarize: " + request.text
168
+ input_ids = model_manager.tokenizer.encode(
169
+ input_text,
170
+ return_tensors="pt",
171
+ max_length=512,
172
+ truncation=True,
173
+ padding=True
174
+ )
175
+
176
+ # Generate summary with error handling
177
+ try:
178
+ device = next(model_manager.model.parameters()).device
179
+ input_ids = input_ids.to(device)
180
+
181
+ with torch.no_grad(), model_manager.accelerator.autocast():
182
+ outputs = model_manager.model.generate(
183
+ input_ids,
184
+ max_length=512, # Increased from 256 to allow for longer summaries
185
+ num_beams=5, # Increased from 4 for more robust beam search
186
+ no_repeat_ngram_size=3,
187
+ length_penalty=2.0,
188
+ early_stopping=True,
189
+ pad_token_id=model_manager.tokenizer.pad_token_id,
190
+ eos_token_id=model_manager.tokenizer.eos_token_id,
191
+ temperature=0.7 # Added for more natural generation
192
+ )
193
+
194
+ summary = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
195
+
196
+ # Log performance metrics
197
+ process_time = time.time() - start_time
198
+ logger.info(f"Summary generated in {process_time:.2f} seconds")
199
+
200
+ return JSONResponse(
201
+ content={
202
+ "success": True,
203
+ "data": summary,
204
+ "error": None,
205
+ "metrics": {
206
+ "process_time": process_time
207
+ }
208
+ }
209
+ )
210
+
211
+ except torch.cuda.OutOfMemoryError:
212
+ logger.error("CUDA out of memory error - clearing cache and reducing batch size")
213
+ if torch.cuda.is_available():
214
+ torch.cuda.empty_cache()
215
+ logger.info(f"CUDA memory after cleanup: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated")
216
+ return JSONResponse(
217
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
218
+ content={
219
+ "success": False,
220
+ "error": "Server is currently overloaded. Please try again later."
221
+ }
222
+ )
223
+
224
+ except Exception as e:
225
+ logger.exception("Error in predict endpoint")
226
+ return JSONResponse(
227
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
228
+ content={
229
+ "success": False,
230
+ "error": "An unexpected error occurred. Please try again later."
231
+ }
232
+ )
233
+
234
+ @app.get("/health",
235
+ response_model=Dict[str, Any],
236
+ status_code=status.HTTP_200_OK)
237
+ async def health_check() -> JSONResponse:
238
+ """Check API and model health status"""
239
+ try:
240
+ is_loaded = model_manager.is_loaded()
241
+ load_time = model_manager.get_load_time()
242
+
243
+ return JSONResponse(
244
+ content={
245
+ "status": "healthy",
246
+ "model_loaded": is_loaded,
247
+ "last_load_time": load_time,
248
+ "version": "1.0.0",
249
+ "gpu_available": torch.cuda.is_available(),
250
+ "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
251
+ }
252
+ )
253
+ except Exception as e:
254
+ logger.error(f"Error in health check: {str(e)}")
255
+ return JSONResponse(
256
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
257
+ content={
258
+ "status": "unhealthy",
259
+ "error": str(e)
260
+ }
261
+ )
262
+
263
+ @app.on_event("startup")
264
+ async def startup_event() -> None:
265
+ """Initialize model on startup"""
266
+ logger.info("Starting application in production mode...")
267
+ logger.info(f"System resources - CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%")
268
+ if torch.cuda.is_available():
269
+ logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
270
+ await model_manager.load_model()
271
+
272
+ @app.on_event("shutdown")
273
+ async def shutdown_event() -> None:
274
+ """Clean up resources on shutdown"""
275
+ logger.info("Initiating graceful shutdown...")
276
+ # Clear CUDA cache and log final stats
277
+ if torch.cuda.is_available():
278
+ logger.info(f"Final CUDA memory before cleanup: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB")
279
+ torch.cuda.empty_cache()
280
+ logger.info("CUDA cache cleared")
281
+ logger.info(f"Final system stats - CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%")
282
+ logger.success("Application shutdown complete")
283
+
284
+ # Run the server
285
+ if __name__ == "__main__":
286
+ import uvicorn
287
+ uvicorn.run(app, host="0.0.0.0", port=7860)