fullstuckdev commited on
Commit
f6b6cd4
·
1 Parent(s): e7ceaff

path swagger

Browse files
Files changed (1) hide show
  1. app.py +97 -25
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import os
2
  from fastapi import FastAPI, HTTPException, BackgroundTasks
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import logging
 
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO)
@@ -13,7 +15,30 @@ logger = logging.getLogger(__name__)
13
  os.makedirs("/app/cache", exist_ok=True)
14
  os.environ['TRANSFORMERS_CACHE'] = "/app/cache"
15
 
16
- app = FastAPI(title="Medical LLaMA API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Add CORS middleware
19
  app.add_middleware(
@@ -24,34 +49,81 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # Check GPU availability
28
- def check_gpu():
29
- if torch.cuda.is_available():
30
- logger.info(f"GPU available: {torch.cuda.get_device_name(0)}")
31
- return True
32
- logger.warning("No GPU available, using CPU")
33
- return False
 
 
 
 
 
 
 
 
 
34
 
35
- # Initialize model with proper device
36
- def init_model():
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
- device = "cuda" if check_gpu() else "cpu"
39
- model_path = os.getenv("MODEL_PATH", "./model/medical_llama_3b")
40
-
41
- logger.info(f"Loading model from {model_path}")
42
- tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/app/cache")
43
- model = AutoModelForCausalLM.from_pretrained(
44
- model_path,
45
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
46
- device_map="auto",
47
- cache_dir="/app/cache"
48
- )
49
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
- logger.error(f"Error loading model: {str(e)}")
52
- raise
53
 
54
- # Rest of your existing code...
 
 
 
 
 
 
 
 
 
 
55
 
56
  @app.on_event("startup")
57
  async def startup_event():
 
1
  import os
2
  from fastapi import FastAPI, HTTPException, BackgroundTasks
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import logging
8
+ from typing import List, Optional
9
 
10
  # Setup logging
11
  logging.basicConfig(level=logging.INFO)
 
15
  os.makedirs("/app/cache", exist_ok=True)
16
  os.environ['TRANSFORMERS_CACHE'] = "/app/cache"
17
 
18
+ # Pydantic models for request/response
19
+ class GenerateRequest(BaseModel):
20
+ text: str
21
+ max_length: Optional[int] = 512
22
+ temperature: Optional[float] = 0.7
23
+ num_return_sequences: Optional[int] = 1
24
+
25
+ class GenerateResponse(BaseModel):
26
+ generated_text: List[str]
27
+
28
+ class HealthResponse(BaseModel):
29
+ status: str
30
+ model_loaded: bool
31
+ gpu_available: bool
32
+ device: str
33
+
34
+ # Initialize FastAPI app
35
+ app = FastAPI(
36
+ title="Medical LLaMA API",
37
+ description="API for medical text generation using fine-tuned LLaMA model",
38
+ version="1.0.0",
39
+ docs_url="/docs",
40
+ redoc_url="/redoc"
41
+ )
42
 
43
  # Add CORS middleware
44
  app.add_middleware(
 
49
  allow_headers=["*"],
50
  )
51
 
52
+ # Global variables for model and tokenizer
53
+ model = None
54
+ tokenizer = None
55
+
56
+ @app.get("/", response_model=HealthResponse, tags=["Health"])
57
+ async def root():
58
+ """
59
+ Root endpoint to check API health and model status
60
+ """
61
+ device = "cuda" if torch.cuda.is_available() else "cpu"
62
+ return HealthResponse(
63
+ status="online",
64
+ model_loaded=model is not None,
65
+ gpu_available=torch.cuda.is_available(),
66
+ device=device
67
+ )
68
 
69
+ @app.post("/generate", response_model=GenerateResponse, tags=["Generation"])
70
+ async def generate_text(request: GenerateRequest):
71
+ """
72
+ Generate medical text based on input prompt
73
+
74
+ Parameters:
75
+ - text: Input text prompt
76
+ - max_length: Maximum length of generated text
77
+ - temperature: Sampling temperature (0.0 to 1.0)
78
+ - num_return_sequences: Number of sequences to generate
79
+
80
+ Returns:
81
+ - List of generated text sequences
82
+ """
83
  try:
84
+ if model is None or tokenizer is None:
85
+ raise HTTPException(status_code=500, detail="Model not loaded")
86
+
87
+ inputs = tokenizer(
88
+ request.text,
89
+ return_tensors="pt",
90
+ padding=True,
91
+ truncation=True,
92
+ max_length=request.max_length
93
+ ).to(model.device)
94
+
95
+ with torch.no_grad():
96
+ generated_ids = model.generate(
97
+ inputs.input_ids,
98
+ max_length=request.max_length,
99
+ num_return_sequences=request.num_return_sequences,
100
+ temperature=request.temperature,
101
+ pad_token_id=tokenizer.pad_token_id,
102
+ eos_token_id=tokenizer.eos_token_id,
103
+ )
104
+
105
+ generated_texts = [
106
+ tokenizer.decode(g, skip_special_tokens=True)
107
+ for g in generated_ids
108
+ ]
109
+
110
+ return GenerateResponse(generated_text=generated_texts)
111
+
112
  except Exception as e:
113
+ logger.error(f"Generation error: {str(e)}")
114
+ raise HTTPException(status_code=500, detail=str(e))
115
 
116
+ @app.get("/health", tags=["Health"])
117
+ async def health_check():
118
+ """
119
+ Check the health status of the API and model
120
+ """
121
+ return {
122
+ "status": "healthy",
123
+ "model_loaded": model is not None,
124
+ "gpu_available": torch.cuda.is_available(),
125
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
126
+ }
127
 
128
  @app.on_event("startup")
129
  async def startup_event():