Spaces:
Sleeping
Sleeping
File size: 6,059 Bytes
2095fff 78a09b4 5c94eeb f7ed1d0 2095fff 7394c77 a4b1bdb fab4412 f7ed1d0 a4b1bdb f7ed1d0 2580a1e 2095fff f7ed1d0 2095fff 2580a1e 7394c77 f7ed1d0 7394c77 2095fff 5c94eeb f7ed1d0 2095fff 4347c84 f7ed1d0 4347c84 2095fff f7ed1d0 2095fff 2580a1e f7ed1d0 2580a1e f7ed1d0 2580a1e 2095fff f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 2095fff 5fc0c7a 9ab0a9a f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 5c94eeb f7ed1d0 78a09b4 5fc0c7a f7ed1d0 2095fff f7ed1d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
from functools import lru_cache
# Initialize FastAPI
app = FastAPI()
# Set up logging with more detailed formatting
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
logger.warning("No HF_TOKEN found in environment variables")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
DEFAULT_GENERATION_CONFIGS = {
"nidra-v1": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.55,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 4.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
},
"nidra-v2": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.4,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 3.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
}
}
class ModelManager:
_instances: ClassVar[Dict[str, tuple]] = {}
@classmethod
def get_model_and_tokenizer(cls, model_name: str):
if model_name not in cls._instances:
try:
model_path = MODELS[model_name]
logger.info(f"Loading tokenizer for {model_name}")
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
return_special_tokens_mask=True
)
logger.info(f"Loading model {model_name}")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
device_map="auto" # This will handle GPU if available
)
cls._instances[model_name] = (model, tokenizer)
logger.info(f"Successfully loaded {model_name}")
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to load model {model_name}: {str(e)}"
)
return cls._instances[model_name]
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
parameters: Optional[Dict[str, Any]] = None
class PredictionResponse(BaseModel):
generated_text: str
model_used: str
@app.get("/version")
async def version():
return {
"python_version": sys.version,
"models_available": list(MODELS.keys())
}
@app.get("/health")
async def health():
# More comprehensive health check
try:
# Try to load at least one model to verify functionality
ModelManager.get_model_and_tokenizer("nidra-v1")
return {
"status": "healthy",
"loaded_models": list(ModelManager._instances.keys())
}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
# Validate model
if request.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Available models: {list(MODELS.keys())}"
)
# Get cached model and tokenizer
model, tokenizer = ModelManager.get_model_and_tokenizer(request.model)
# Get generation parameters
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
# Try to load model's saved generation config
try:
model_generation_config = model.generation_config
generation_params.update({
k: v for k, v in model_generation_config.to_dict().items()
if v is not None
})
except Exception as config_load_error:
logger.warning(f"Using default generation config: {config_load_error}")
# Override with request-specific parameters
if request.parameters:
generation_params.update(request.parameters)
logger.debug(f"Final generation parameters: {generation_params}")
# Prepare input
full_input = "Interpret this dream: " + request.inputs
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(model.device) # Ensure inputs are on same device as model
# Generate
outputs = model.generate(
**inputs,
**{k: v for k, v in generation_params.items() if k in [
'max_length', 'min_length', 'do_sample', 'temperature',
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
'repetition_penalty', 'early_stopping'
]}
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return PredictionResponse(
generated_text=result,
model_used=request.model
)
except Exception as e:
error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |