Spaces:
Sleeping
Sleeping
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig | |
from typing import Optional, Dict, Any, ClassVar | |
import logging | |
import os | |
import sys | |
import traceback | |
from functools import lru_cache | |
import gc | |
import asyncio | |
from fastapi import BackgroundTasks | |
import psutil | |
# Initialize FastAPI | |
app = FastAPI() | |
# Debugging logs | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Get HF token | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
logger.warning("No HF_TOKEN found in environment variables") | |
MODELS = { | |
"nidra-v1": "m1k3wn/nidra-v1", | |
"nidra-v2": "m1k3wn/nidra-v2" | |
} | |
DEFAULT_GENERATION_CONFIGS = { | |
"nidra-v1": { | |
"max_length": 300, | |
"min_length": 150, | |
"num_beams": 8, | |
"temperature": 0.55, | |
"do_sample": True, | |
"top_p": 0.95, | |
"repetition_penalty": 4.5, | |
"no_repeat_ngram_size": 4, | |
"early_stopping": True, | |
"length_penalty": 1.2, | |
}, | |
"nidra-v2": { | |
"max_length": 300, | |
"min_length": 150, | |
"num_beams": 8, | |
"temperature": 0.4, | |
"do_sample": True, | |
"top_p": 0.95, | |
"repetition_penalty": 3.5, | |
"no_repeat_ngram_size": 4, | |
"early_stopping": True, | |
"length_penalty": 1.2, | |
} | |
} | |
class ModelManager: | |
_instances: ClassVar[Dict[str, tuple]] = {} | |
async def get_model_and_tokenizer(cls, model_name: str): | |
if model_name not in cls._instances: | |
try: | |
model_path = MODELS[model_name] | |
logger.debug(f"Loading tokenizer and model from {model_path}") | |
tokenizer = T5Tokenizer.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
use_fast=True | |
) | |
model = T5ForConditionalGeneration.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
device_map='auto' | |
) | |
model.eval() | |
torch.set_num_threads(6) # Number of CPUs used | |
cls._instances[model_name] = (model, tokenizer) | |
except Exception as e: | |
logger.error(f"Error loading {model_name}: {str(e)}") | |
raise | |
return cls._instances[model_name] | |
class PredictionRequest(BaseModel): | |
inputs: str | |
model: str = "nidra-v1" | |
parameters: Optional[Dict[str, Any]] = None | |
class PredictionResponse(BaseModel): | |
generated_text: str | |
selected_model: str # Changed from model_used to avoid namespace conflict | |
# Memory debug endpoint | |
async def memory_usage(): | |
process = psutil.Process() | |
memory_info = process.memory_info() | |
return { | |
"memory_used_mb": memory_info.rss / 1024 / 1024, | |
"memory_percent": process.memory_percent(), | |
"cpu_percent": process.cpu_percent() | |
} | |
# Version check | |
async def version(): | |
return { | |
"python_version": sys.version, | |
"models_available": list(MODELS.keys()) | |
} | |
# Healthcheck endpoint | |
async def health(): | |
try: | |
logger.debug("Health check started") | |
logger.debug(f"HF_TOKEN present: {bool(HF_TOKEN)}") | |
logger.debug(f"Available models: {MODELS}") | |
result = await ModelManager.get_model_and_tokenizer("nidra-v1") | |
logger.debug("Model and tokenizer loaded successfully") | |
return { | |
"status": "healthy", | |
"loaded_models": list(ModelManager._instances.keys()) | |
} | |
except Exception as e: | |
error_msg = f"Health check failed: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
return { | |
"status": "unhealthy", | |
"error": str(e) | |
} | |
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks): | |
try: | |
if request.model not in MODELS: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid model. Available models: {list(MODELS.keys())}" | |
) | |
model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model) | |
# Add immediate cleanup of memory before generation | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy() | |
try: | |
model_generation_config = model.generation_config | |
generation_params.update({ | |
k: v for k, v in model_generation_config.to_dict().items() | |
if v is not None | |
}) | |
except Exception as config_load_error: | |
logger.warning(f"Using default generation config: {config_load_error}") | |
if request.parameters: | |
generation_params.update(request.parameters) | |
logger.debug(f"Final generation parameters: {generation_params}") | |
full_input = "Interpret this dream: " + request.inputs | |
inputs = tokenizer( | |
full_input, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True, | |
return_attention_mask=True | |
) | |
async def generate(): | |
try: | |
return model.generate( | |
**inputs, | |
**{k: v for k, v in generation_params.items() if k in [ | |
'max_length', 'min_length', 'do_sample', 'temperature', | |
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size', | |
'repetition_penalty', 'early_stopping' | |
]} | |
) | |
finally: | |
# Ensure cleanup happens even if generation fails | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
with torch.inference_mode(): | |
outputs = await asyncio.wait_for(generate(), timeout=45.0) # Reduced timeout | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
background_tasks.add_task(cleanup_memory) | |
return PredictionResponse( | |
generated_text=result, | |
selected_model=request.model | |
) | |
except asyncio.TimeoutError: | |
logger.error("Generation timed out") | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
raise HTTPException(status_code=504, detail="Generation timed out") | |
except Exception as e: | |
error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
raise HTTPException(status_code=500, detail=error_msg) | |
def cleanup_memory(): | |
try: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Force Python garbage collection | |
gc.collect(generation=2) | |
except Exception as e: | |
logger.error(f"Error in cleanup: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |