import asyncio from datetime import datetime import os import random import time from typing import Dict, List import torch import uvicorn from sound_generator import generate_sound, generate_music from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates from fastapi.responses import FileResponse, HTMLResponse from pydantic import BaseModel # Create the FastAPI app with custom docs URL app = FastAPI( title="API de Sonidos Generativos", description="API para generar sonidos y música basados en prompts", version="1.0.0", docs_url="/docs", redoc_url="/redoc", ) # Configuración de templates templates = Jinja2Templates(directory="templates") # Configuración de CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AudioRequest(BaseModel): prompt: str class GPUQuotaConfig: MAX_REQUEST_DURATION = 20 # segundos máximos por solicitud DAILY_QUOTA = 300 # 5 minutos en total (300 segundos) class QuotaTracker: def __init__(self): self.users_quota: Dict[str, int] = {} self.user_reset_times: Dict[str, datetime] = {} self.current_user_index = 0 self.registered_users: List[str] = [] def register_user(self, user_id: str): if user_id not in self.registered_users: self.registered_users.append(user_id) self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1) def get_next_available_user(self): # Verificar resets for user_id in list(self.user_reset_times.keys()): if datetime.now() > self.user_reset_times[user_id]: self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1) # Encontrar usuario con cuota attempts = 0 while attempts < len(self.registered_users): self.current_user_index = (self.current_user_index + 1) % max(1, len(self.registered_users)) current_user = self.registered_users[self.current_user_index] if self.users_quota.get(current_user, 0) >= GPUQuotaConfig.MAX_REQUEST_DURATION: return current_user attempts += 1 return None def consume_quota(self, user_id: str, seconds: int): if user_id in self.users_quota: self.users_quota[user_id] = max(0, self.users_quota[user_id] - seconds) return True return False def get_remaining_quota(self, user_id: str): if user_id in self.users_quota: # Verificar si se debe resetear if datetime.now() > self.user_reset_times.get(user_id, datetime.max): self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1) return self.users_quota[user_id] return 0 def get_system_status(self): return { "registered_users": len(self.registered_users), "users_with_quota": sum(1 for q in self.users_quota.values() if q >= GPUQuotaConfig.MAX_REQUEST_DURATION), "total_available_seconds": sum(self.users_quota.values()) } # Inicializar sistema quota_tracker = QuotaTracker() # Registrar usuarios virtuales for i in range(5): quota_tracker.register_user(f"virtual_user_{i}") # Semáforo para controlar acceso a GPU - solo una tarea a la vez gpu_semaphore = asyncio.Semaphore(1) # Middleware para asignar user_id @app.middleware("http") async def assign_user_id(request: Request, call_next): if "user-id" not in request.headers: request.state.user_id = f"anonymous_{random.randint(1000, 9999)}" quota_tracker.register_user(request.state.user_id) else: request.state.user_id = request.headers["user-id"] quota_tracker.register_user(request.state.user_id) response = await call_next(request) return response async def get_user_id(request: Request): return request.state.user_id # Función para manejar la generación con control de GPU async def process_with_gpu(generation_func, prompt, process_id): start_time = time.time() print(f"[{process_id}] Iniciando procesamiento GPU") # Buscar usuario con cuota disponible user_id = quota_tracker.get_next_available_user() if not user_id: raise HTTPException(status_code=429, detail="No hay cuota GPU disponible en el sistema") quota_available = quota_tracker.get_remaining_quota(user_id) print(f"[{process_id}] Usando cuota de usuario {user_id}: {quota_available}s disponibles") # Verificar si hay suficiente cuota if quota_available < GPUQuotaConfig.MAX_REQUEST_DURATION: raise HTTPException(status_code=429, detail=f"Cuota GPU insuficiente ({quota_available}s disponibles)") # Verificar que los modelos usen GPU si está disponible use_gpu = torch.cuda.is_available() device = 'cuda' if use_gpu else 'cpu' print(f"[{process_id}] Usando dispositivo: {device}") try: # Llamar a la función de generación con límite de tiempo audio_file_path = await asyncio.to_thread( generation_func, prompt, device, user_id ) # Liberar memoria GPU si se utilizó if use_gpu: torch.cuda.empty_cache() # Calcular tiempo real usado elapsed_time = min(GPUQuotaConfig.MAX_REQUEST_DURATION, int(time.time() - start_time)) # Consumir cuota quota_tracker.consume_quota(user_id, elapsed_time) print(f"[{process_id}] Procesamiento completado en {elapsed_time}s, cuota restante: {quota_tracker.get_remaining_quota(user_id)}s") return audio_file_path except Exception as e: # Asegurar que liberamos memoria en caso de error if use_gpu: torch.cuda.empty_cache() print(f"[{process_id}] Error: {str(e)}") raise e # Home page with API information @app.get("/", response_class=HTMLResponse) def home(request: Request): return templates.TemplateResponse("home.html", {"request": request}) # Prueba para verificar si la API funciona - la dejamos por ahora para debugging @app.get("/health") def health_check(): """Endpoint para verificar que el servicio está funcionando correctamente""" return {"status": "ok", "service": "Sound Generation API"} @app.post("/generate-sound/") async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)): try: process_id = f"sound_{random.randint(1000, 9999)}" # Usar semáforo para asegurar acceso exclusivo a GPU async with gpu_semaphore: audio_file_path = await process_with_gpu( generate_sound, request.prompt, process_id ) # Verifica si el archivo se ha generado correctamente if not os.path.exists(audio_file_path): raise HTTPException( status_code=404, detail="Archivo de audio no encontrado." ) # Regresar el archivo generado como una respuesta de descarga return FileResponse( audio_file_path, media_type="audio/wav", filename="generated_audio.wav" ) except HTTPException as e: # Reenviar excepciones HTTP raise e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate-music/") async def generate_music_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)): try: process_id = f"music_{random.randint(1000, 9999)}" # Usar semáforo para asegurar acceso exclusivo a GPU async with gpu_semaphore: audio_file_path = await process_with_gpu( generate_music, request.prompt, process_id ) # Verifica si el archivo se ha generado correctamente if not os.path.exists(audio_file_path): raise HTTPException( status_code=404, detail="Archivo de audio no encontrado." ) # Regresar el archivo generado como una respuesta de descarga return FileResponse( audio_file_path, media_type="audio/wav", filename="generated_audio.wav" ) except HTTPException as e: # Reenviar excepciones HTTP raise e except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/quota-status") async def quota_status_endpoint(user_id: str = Depends(get_user_id)): user_quota = quota_tracker.get_remaining_quota(user_id) system_status = quota_tracker.get_system_status() return { "user_id": user_id, "quota_remaining": user_quota, "reset_time": quota_tracker.user_reset_times.get(user_id, None), "system_status": system_status, "gpu_available": torch.cuda.is_available(), "device_info": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU" } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)