Spaces:
Running
on
Zero
Running
on
Zero
import asyncio | |
from datetime import datetime | |
import os | |
import random | |
import time | |
from typing import Dict, List | |
import torch | |
import uvicorn | |
from sound_generator import generate_sound, generate_music | |
from fastapi import Depends, FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import FileResponse, HTMLResponse | |
from pydantic import BaseModel | |
# Create the FastAPI app with custom docs URL | |
app = FastAPI( | |
title="API de Sonidos Generativos", | |
description="API para generar sonidos y m煤sica basados en prompts", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc", | |
) | |
# Configuraci贸n de templates | |
templates = Jinja2Templates(directory="templates") | |
# Configuraci贸n de CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class AudioRequest(BaseModel): | |
prompt: str | |
class GPUQuotaConfig: | |
MAX_REQUEST_DURATION = 20 # segundos m谩ximos por solicitud | |
DAILY_QUOTA = 300 # 5 minutos en total (300 segundos) | |
class QuotaTracker: | |
def __init__(self): | |
self.users_quota: Dict[str, int] = {} | |
self.user_reset_times: Dict[str, datetime] = {} | |
self.current_user_index = 0 | |
self.registered_users: List[str] = [] | |
def register_user(self, user_id: str): | |
if user_id not in self.registered_users: | |
self.registered_users.append(user_id) | |
self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA | |
self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1) | |
def get_next_available_user(self): | |
# Verificar resets | |
for user_id in list(self.user_reset_times.keys()): | |
if datetime.now() > self.user_reset_times[user_id]: | |
self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA | |
self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1) | |
# Encontrar usuario con cuota | |
attempts = 0 | |
while attempts < len(self.registered_users): | |
self.current_user_index = (self.current_user_index + 1) % max(1, len(self.registered_users)) | |
current_user = self.registered_users[self.current_user_index] | |
if self.users_quota.get(current_user, 0) >= GPUQuotaConfig.MAX_REQUEST_DURATION: | |
return current_user | |
attempts += 1 | |
return None | |
def consume_quota(self, user_id: str, seconds: int): | |
if user_id in self.users_quota: | |
self.users_quota[user_id] = max(0, self.users_quota[user_id] - seconds) | |
return True | |
return False | |
def get_remaining_quota(self, user_id: str): | |
if user_id in self.users_quota: | |
# Verificar si se debe resetear | |
if datetime.now() > self.user_reset_times.get(user_id, datetime.max): | |
self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA | |
self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1) | |
return self.users_quota[user_id] | |
return 0 | |
def get_system_status(self): | |
return { | |
"registered_users": len(self.registered_users), | |
"users_with_quota": sum(1 for q in self.users_quota.values() if q >= GPUQuotaConfig.MAX_REQUEST_DURATION), | |
"total_available_seconds": sum(self.users_quota.values()) | |
} | |
# Inicializar sistema | |
quota_tracker = QuotaTracker() | |
# Registrar usuarios virtuales | |
for i in range(5): | |
quota_tracker.register_user(f"virtual_user_{i}") | |
# Sem谩foro para controlar acceso a GPU - solo una tarea a la vez | |
gpu_semaphore = asyncio.Semaphore(1) | |
# Middleware para asignar user_id | |
async def assign_user_id(request: Request, call_next): | |
if "user-id" not in request.headers: | |
request.state.user_id = f"anonymous_{random.randint(1000, 9999)}" | |
quota_tracker.register_user(request.state.user_id) | |
else: | |
request.state.user_id = request.headers["user-id"] | |
quota_tracker.register_user(request.state.user_id) | |
response = await call_next(request) | |
return response | |
async def get_user_id(request: Request): | |
return request.state.user_id | |
# Funci贸n para manejar la generaci贸n con control de GPU | |
async def process_with_gpu(generation_func, prompt, process_id): | |
start_time = time.time() | |
print(f"[{process_id}] Iniciando procesamiento GPU") | |
# Buscar usuario con cuota disponible | |
user_id = quota_tracker.get_next_available_user() | |
if not user_id: | |
raise HTTPException(status_code=429, detail="No hay cuota GPU disponible en el sistema") | |
quota_available = quota_tracker.get_remaining_quota(user_id) | |
print(f"[{process_id}] Usando cuota de usuario {user_id}: {quota_available}s disponibles") | |
# Verificar si hay suficiente cuota | |
if quota_available < GPUQuotaConfig.MAX_REQUEST_DURATION: | |
raise HTTPException(status_code=429, detail=f"Cuota GPU insuficiente ({quota_available}s disponibles)") | |
# Verificar que los modelos usen GPU si est谩 disponible | |
use_gpu = torch.cuda.is_available() | |
device = 'cuda' if use_gpu else 'cpu' | |
print(f"[{process_id}] Usando dispositivo: {device}") | |
try: | |
# Llamar a la funci贸n de generaci贸n con l铆mite de tiempo | |
audio_file_path = await asyncio.to_thread( | |
generation_func, prompt, device, user_id | |
) | |
# Liberar memoria GPU si se utiliz贸 | |
if use_gpu: | |
torch.cuda.empty_cache() | |
# Calcular tiempo real usado | |
elapsed_time = min(GPUQuotaConfig.MAX_REQUEST_DURATION, int(time.time() - start_time)) | |
# Consumir cuota | |
quota_tracker.consume_quota(user_id, elapsed_time) | |
print(f"[{process_id}] Procesamiento completado en {elapsed_time}s, cuota restante: {quota_tracker.get_remaining_quota(user_id)}s") | |
return audio_file_path | |
except Exception as e: | |
# Asegurar que liberamos memoria en caso de error | |
if use_gpu: | |
torch.cuda.empty_cache() | |
print(f"[{process_id}] Error: {str(e)}") | |
raise e | |
# Home page with API information | |
def home(request: Request): | |
return templates.TemplateResponse("home.html", {"request": request}) | |
# Prueba para verificar si la API funciona - la dejamos por ahora para debugging | |
def health_check(): | |
"""Endpoint para verificar que el servicio est谩 funcionando correctamente""" | |
return {"status": "ok", "service": "Sound Generation API"} | |
async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)): | |
try: | |
process_id = f"sound_{random.randint(1000, 9999)}" | |
# Usar sem谩foro para asegurar acceso exclusivo a GPU | |
async with gpu_semaphore: | |
audio_file_path = await process_with_gpu( | |
generate_sound, request.prompt, process_id | |
) | |
# Verifica si el archivo se ha generado correctamente | |
if not os.path.exists(audio_file_path): | |
raise HTTPException( | |
status_code=404, detail="Archivo de audio no encontrado." | |
) | |
# Regresar el archivo generado como una respuesta de descarga | |
return FileResponse( | |
audio_file_path, media_type="audio/wav", filename="generated_audio.wav" | |
) | |
except HTTPException as e: | |
# Reenviar excepciones HTTP | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_music_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)): | |
try: | |
process_id = f"music_{random.randint(1000, 9999)}" | |
# Usar sem谩foro para asegurar acceso exclusivo a GPU | |
async with gpu_semaphore: | |
audio_file_path = await process_with_gpu( | |
generate_music, request.prompt, process_id | |
) | |
# Verifica si el archivo se ha generado correctamente | |
if not os.path.exists(audio_file_path): | |
raise HTTPException( | |
status_code=404, detail="Archivo de audio no encontrado." | |
) | |
# Regresar el archivo generado como una respuesta de descarga | |
return FileResponse( | |
audio_file_path, media_type="audio/wav", filename="generated_audio.wav" | |
) | |
except HTTPException as e: | |
# Reenviar excepciones HTTP | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def quota_status_endpoint(user_id: str = Depends(get_user_id)): | |
user_quota = quota_tracker.get_remaining_quota(user_id) | |
system_status = quota_tracker.get_system_status() | |
return { | |
"user_id": user_id, | |
"quota_remaining": user_quota, | |
"reset_time": quota_tracker.user_reset_times.get(user_id, None), | |
"system_status": system_status, | |
"gpu_available": torch.cuda.is_available(), | |
"device_info": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU" | |
} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |