pdarleyjr commited on
Commit
6c293ab
·
verified ·
1 Parent(s): c1e71cb

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -24,7 +24,7 @@ logger.add(
24
  # Initialize FastAPI app with metadata
25
  app = FastAPI(
26
  title="Clinical Report Generator API",
27
- description="Production API for generating clinical report summaries using Flan-T5",
28
  version="1.0.0",
29
  docs_url="/documentation", # Swagger UI
30
  redoc_url="/redoc" # ReDoc
@@ -40,6 +40,9 @@ app.add_middleware(
40
  max_age=3600, # Cache preflight requests
41
  )
42
 
 
 
 
43
  class ModelManager:
44
  def __init__(self):
45
  self.model = None
@@ -64,30 +67,30 @@ class ModelManager:
64
  if torch.cuda.is_available():
65
  logger.info(f"CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated")
66
 
67
- # Load tokenizer for Flan-T5-base
68
- logger.info("Initializing Flan-T5-base tokenizer...")
69
  self.tokenizer = T5Tokenizer.from_pretrained(
70
- "pdarleyjr/iplc-t5-clinical",
71
- use_fast=True, # Use fast tokenizer
72
  model_max_length=512
73
  )
74
- logger.success("Flan-T5-base tokenizer loaded successfully")
75
 
76
  # Load model configuration
77
  logger.info("Fetching model configuration...")
78
  config = AutoConfig.from_pretrained(
79
- "pdarleyjr/iplc-t5-clinical",
80
  trust_remote_code=False
81
  )
82
  logger.success("Model configuration loaded successfully")
83
 
84
- # Load the Flan-T5-base model
85
- logger.info("Loading Flan-T5-base model (this may take a few minutes)...")
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  logger.info(f"Using device: {device}")
88
 
89
  self.model = T5ForConditionalGeneration.from_pretrained(
90
- "pdarleyjr/iplc-t5-clinical",
91
  config=config,
92
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
93
  low_cpu_mem_usage=True
@@ -181,8 +184,8 @@ async def predict(request: PredictRequest) -> JSONResponse:
181
  with torch.no_grad(), model_manager.accelerator.autocast():
182
  outputs = model_manager.model.generate(
183
  input_ids,
184
- max_length=512, # Increased from 256 to allow for longer summaries
185
- num_beams=5, # Increased from 4 for more robust beam search
186
  no_repeat_ngram_size=3,
187
  length_penalty=2.0,
188
  early_stopping=True,
 
24
  # Initialize FastAPI app with metadata
25
  app = FastAPI(
26
  title="Clinical Report Generator API",
27
+ description="Production API for generating clinical report summaries using T5",
28
  version="1.0.0",
29
  docs_url="/documentation", # Swagger UI
30
  redoc_url="/redoc" # ReDoc
 
40
  max_age=3600, # Cache preflight requests
41
  )
42
 
43
+ # Model configuration
44
+ MODEL_ID = "pdarleyjr/iplc-t5-clinical"
45
+
46
  class ModelManager:
47
  def __init__(self):
48
  self.model = None
 
67
  if torch.cuda.is_available():
68
  logger.info(f"CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated")
69
 
70
+ # Load tokenizer
71
+ logger.info("Initializing tokenizer...")
72
  self.tokenizer = T5Tokenizer.from_pretrained(
73
+ MODEL_ID,
74
+ use_fast=True,
75
  model_max_length=512
76
  )
77
+ logger.success("Tokenizer loaded successfully")
78
 
79
  # Load model configuration
80
  logger.info("Fetching model configuration...")
81
  config = AutoConfig.from_pretrained(
82
+ MODEL_ID,
83
  trust_remote_code=False
84
  )
85
  logger.success("Model configuration loaded successfully")
86
 
87
+ # Load the model
88
+ logger.info("Loading model (this may take a few minutes)...")
89
  device = "cuda" if torch.cuda.is_available() else "cpu"
90
  logger.info(f"Using device: {device}")
91
 
92
  self.model = T5ForConditionalGeneration.from_pretrained(
93
+ MODEL_ID,
94
  config=config,
95
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
96
  low_cpu_mem_usage=True
 
184
  with torch.no_grad(), model_manager.accelerator.autocast():
185
  outputs = model_manager.model.generate(
186
  input_ids,
187
+ max_length=512, # Increased for longer summaries
188
+ num_beams=5, # Increased for better coherence
189
  no_repeat_ngram_size=3,
190
  length_penalty=2.0,
191
  early_stopping=True,