Curinha
Implement GPU quota management and user registration for sound generation
f26b7a5
raw
history blame
9.43 kB
import asyncio
from datetime import datetime
import os
import random
import time
from typing import Dict, List
import torch
import uvicorn
from sound_generator import generate_sound, generate_music
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse, HTMLResponse
from pydantic import BaseModel
# Create the FastAPI app with custom docs URL
app = FastAPI(
title="API de Sonidos Generativos",
description="API para generar sonidos y m煤sica basados en prompts",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
)
# Configuraci贸n de templates
templates = Jinja2Templates(directory="templates")
# Configuraci贸n de CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AudioRequest(BaseModel):
prompt: str
class GPUQuotaConfig:
MAX_REQUEST_DURATION = 20 # segundos m谩ximos por solicitud
DAILY_QUOTA = 300 # 5 minutos en total (300 segundos)
class QuotaTracker:
def __init__(self):
self.users_quota: Dict[str, int] = {}
self.user_reset_times: Dict[str, datetime] = {}
self.current_user_index = 0
self.registered_users: List[str] = []
def register_user(self, user_id: str):
if user_id not in self.registered_users:
self.registered_users.append(user_id)
self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
def get_next_available_user(self):
# Verificar resets
for user_id in list(self.user_reset_times.keys()):
if datetime.now() > self.user_reset_times[user_id]:
self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
# Encontrar usuario con cuota
attempts = 0
while attempts < len(self.registered_users):
self.current_user_index = (self.current_user_index + 1) % max(1, len(self.registered_users))
current_user = self.registered_users[self.current_user_index]
if self.users_quota.get(current_user, 0) >= GPUQuotaConfig.MAX_REQUEST_DURATION:
return current_user
attempts += 1
return None
def consume_quota(self, user_id: str, seconds: int):
if user_id in self.users_quota:
self.users_quota[user_id] = max(0, self.users_quota[user_id] - seconds)
return True
return False
def get_remaining_quota(self, user_id: str):
if user_id in self.users_quota:
# Verificar si se debe resetear
if datetime.now() > self.user_reset_times.get(user_id, datetime.max):
self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
return self.users_quota[user_id]
return 0
def get_system_status(self):
return {
"registered_users": len(self.registered_users),
"users_with_quota": sum(1 for q in self.users_quota.values() if q >= GPUQuotaConfig.MAX_REQUEST_DURATION),
"total_available_seconds": sum(self.users_quota.values())
}
# Inicializar sistema
quota_tracker = QuotaTracker()
# Registrar usuarios virtuales
for i in range(5):
quota_tracker.register_user(f"virtual_user_{i}")
# Sem谩foro para controlar acceso a GPU - solo una tarea a la vez
gpu_semaphore = asyncio.Semaphore(1)
# Middleware para asignar user_id
@app.middleware("http")
async def assign_user_id(request: Request, call_next):
if "user-id" not in request.headers:
request.state.user_id = f"anonymous_{random.randint(1000, 9999)}"
quota_tracker.register_user(request.state.user_id)
else:
request.state.user_id = request.headers["user-id"]
quota_tracker.register_user(request.state.user_id)
response = await call_next(request)
return response
async def get_user_id(request: Request):
return request.state.user_id
# Funci贸n para manejar la generaci贸n con control de GPU
async def process_with_gpu(generation_func, prompt, process_id):
start_time = time.time()
print(f"[{process_id}] Iniciando procesamiento GPU")
# Buscar usuario con cuota disponible
user_id = quota_tracker.get_next_available_user()
if not user_id:
raise HTTPException(status_code=429, detail="No hay cuota GPU disponible en el sistema")
quota_available = quota_tracker.get_remaining_quota(user_id)
print(f"[{process_id}] Usando cuota de usuario {user_id}: {quota_available}s disponibles")
# Verificar si hay suficiente cuota
if quota_available < GPUQuotaConfig.MAX_REQUEST_DURATION:
raise HTTPException(status_code=429, detail=f"Cuota GPU insuficiente ({quota_available}s disponibles)")
# Verificar que los modelos usen GPU si est谩 disponible
use_gpu = torch.cuda.is_available()
device = 'cuda' if use_gpu else 'cpu'
print(f"[{process_id}] Usando dispositivo: {device}")
try:
# Llamar a la funci贸n de generaci贸n con l铆mite de tiempo
audio_file_path = await asyncio.to_thread(
generation_func, prompt, device, user_id
)
# Liberar memoria GPU si se utiliz贸
if use_gpu:
torch.cuda.empty_cache()
# Calcular tiempo real usado
elapsed_time = min(GPUQuotaConfig.MAX_REQUEST_DURATION, int(time.time() - start_time))
# Consumir cuota
quota_tracker.consume_quota(user_id, elapsed_time)
print(f"[{process_id}] Procesamiento completado en {elapsed_time}s, cuota restante: {quota_tracker.get_remaining_quota(user_id)}s")
return audio_file_path
except Exception as e:
# Asegurar que liberamos memoria en caso de error
if use_gpu:
torch.cuda.empty_cache()
print(f"[{process_id}] Error: {str(e)}")
raise e
# Home page with API information
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse("home.html", {"request": request})
# Prueba para verificar si la API funciona - la dejamos por ahora para debugging
@app.get("/health")
def health_check():
"""Endpoint para verificar que el servicio est谩 funcionando correctamente"""
return {"status": "ok", "service": "Sound Generation API"}
@app.post("/generate-sound/")
async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
try:
process_id = f"sound_{random.randint(1000, 9999)}"
# Usar sem谩foro para asegurar acceso exclusivo a GPU
async with gpu_semaphore:
audio_file_path = await process_with_gpu(
generate_sound, request.prompt, process_id
)
# Verifica si el archivo se ha generado correctamente
if not os.path.exists(audio_file_path):
raise HTTPException(
status_code=404, detail="Archivo de audio no encontrado."
)
# Regresar el archivo generado como una respuesta de descarga
return FileResponse(
audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
)
except HTTPException as e:
# Reenviar excepciones HTTP
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate-music/")
async def generate_music_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
try:
process_id = f"music_{random.randint(1000, 9999)}"
# Usar sem谩foro para asegurar acceso exclusivo a GPU
async with gpu_semaphore:
audio_file_path = await process_with_gpu(
generate_music, request.prompt, process_id
)
# Verifica si el archivo se ha generado correctamente
if not os.path.exists(audio_file_path):
raise HTTPException(
status_code=404, detail="Archivo de audio no encontrado."
)
# Regresar el archivo generado como una respuesta de descarga
return FileResponse(
audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
)
except HTTPException as e:
# Reenviar excepciones HTTP
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/quota-status")
async def quota_status_endpoint(user_id: str = Depends(get_user_id)):
user_quota = quota_tracker.get_remaining_quota(user_id)
system_status = quota_tracker.get_system_status()
return {
"user_id": user_id,
"quota_remaining": user_quota,
"reset_time": quota_tracker.user_reset_times.get(user_id, None),
"system_status": system_status,
"gpu_available": torch.cuda.is_available(),
"device_info": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)