File size: 9,428 Bytes
f26b7a5
 
7814ee2
f26b7a5
 
 
 
7814ee2
 
 
f26b7a5
7814ee2
fa56145
 
7814ee2
 
149c25a
fa56145
 
 
 
 
 
 
 
f26b7a5
 
fa56145
7814ee2
149c25a
7814ee2
 
149c25a
7814ee2
149c25a
 
7814ee2
 
 
 
 
f26b7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75b5466
149c25a
137db14
 
 
 
7814ee2
75b5466
7814ee2
f26b7a5
7814ee2
f26b7a5
 
 
 
 
 
 
7814ee2
 
 
 
 
 
 
 
 
 
 
 
f26b7a5
 
 
7814ee2
 
 
 
f26b7a5
7814ee2
f26b7a5
 
 
 
 
 
 
7814ee2
f26b7a5
7814ee2
 
 
 
 
f26b7a5
7814ee2
 
 
 
f26b7a5
 
 
7814ee2
 
 
f26b7a5
 
 
 
 
 
 
 
 
 
 
 
 
75b5466
fa56145
75b5466
383458e
b40961b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import asyncio
from datetime import datetime
import os
import random
import time
from typing import Dict, List
import torch
import uvicorn

from sound_generator import generate_sound, generate_music
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse, HTMLResponse
from pydantic import BaseModel

# Create the FastAPI app with custom docs URL
app = FastAPI(
    title="API de Sonidos Generativos",
    description="API para generar sonidos y m煤sica basados en prompts",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc",
)


# Configuraci贸n de templates
templates = Jinja2Templates(directory="templates")

# Configuraci贸n de CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class AudioRequest(BaseModel):
    prompt: str

class GPUQuotaConfig:
    MAX_REQUEST_DURATION = 20  # segundos m谩ximos por solicitud
    DAILY_QUOTA = 300  # 5 minutos en total (300 segundos)
    
class QuotaTracker:
    def __init__(self):
        self.users_quota: Dict[str, int] = {}
        self.user_reset_times: Dict[str, datetime] = {}
        self.current_user_index = 0
        self.registered_users: List[str] = []
        
    def register_user(self, user_id: str):
        if user_id not in self.registered_users:
            self.registered_users.append(user_id)
            self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
            self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
            
    def get_next_available_user(self):
        # Verificar resets
        for user_id in list(self.user_reset_times.keys()):
            if datetime.now() > self.user_reset_times[user_id]:
                self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
                self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
        
        # Encontrar usuario con cuota
        attempts = 0
        while attempts < len(self.registered_users):
            self.current_user_index = (self.current_user_index + 1) % max(1, len(self.registered_users))
            current_user = self.registered_users[self.current_user_index]
            if self.users_quota.get(current_user, 0) >= GPUQuotaConfig.MAX_REQUEST_DURATION:
                return current_user
            attempts += 1
            
        return None
        
    def consume_quota(self, user_id: str, seconds: int):
        if user_id in self.users_quota:
            self.users_quota[user_id] = max(0, self.users_quota[user_id] - seconds)
            return True
        return False
    
    def get_remaining_quota(self, user_id: str):
        if user_id in self.users_quota:
            # Verificar si se debe resetear
            if datetime.now() > self.user_reset_times.get(user_id, datetime.max):
                self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
                self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
            return self.users_quota[user_id]
        return 0
    
    def get_system_status(self):
        return {
            "registered_users": len(self.registered_users),
            "users_with_quota": sum(1 for q in self.users_quota.values() if q >= GPUQuotaConfig.MAX_REQUEST_DURATION),
            "total_available_seconds": sum(self.users_quota.values())
        }

# Inicializar sistema
quota_tracker = QuotaTracker()

# Registrar usuarios virtuales
for i in range(5):
    quota_tracker.register_user(f"virtual_user_{i}")

# Sem谩foro para controlar acceso a GPU - solo una tarea a la vez
gpu_semaphore = asyncio.Semaphore(1)

# Middleware para asignar user_id
@app.middleware("http")
async def assign_user_id(request: Request, call_next):
    if "user-id" not in request.headers:
        request.state.user_id = f"anonymous_{random.randint(1000, 9999)}"
        quota_tracker.register_user(request.state.user_id)
    else:
        request.state.user_id = request.headers["user-id"]
        quota_tracker.register_user(request.state.user_id)
    
    response = await call_next(request)
    return response

async def get_user_id(request: Request):
    return request.state.user_id

# Funci贸n para manejar la generaci贸n con control de GPU
async def process_with_gpu(generation_func, prompt, process_id):
    start_time = time.time()
    print(f"[{process_id}] Iniciando procesamiento GPU")
    
    # Buscar usuario con cuota disponible
    user_id = quota_tracker.get_next_available_user()
    if not user_id:
        raise HTTPException(status_code=429, detail="No hay cuota GPU disponible en el sistema")
    
    quota_available = quota_tracker.get_remaining_quota(user_id)
    print(f"[{process_id}] Usando cuota de usuario {user_id}: {quota_available}s disponibles")
    
    # Verificar si hay suficiente cuota
    if quota_available < GPUQuotaConfig.MAX_REQUEST_DURATION:
        raise HTTPException(status_code=429, detail=f"Cuota GPU insuficiente ({quota_available}s disponibles)")
    
    # Verificar que los modelos usen GPU si est谩 disponible
    use_gpu = torch.cuda.is_available()
    device = 'cuda' if use_gpu else 'cpu'
    print(f"[{process_id}] Usando dispositivo: {device}")
    
    try:
        # Llamar a la funci贸n de generaci贸n con l铆mite de tiempo
        audio_file_path = await asyncio.to_thread(
            generation_func, prompt, device, user_id
        )
        
        # Liberar memoria GPU si se utiliz贸
        if use_gpu:
            torch.cuda.empty_cache()
        
        # Calcular tiempo real usado
        elapsed_time = min(GPUQuotaConfig.MAX_REQUEST_DURATION, int(time.time() - start_time))
        
        # Consumir cuota
        quota_tracker.consume_quota(user_id, elapsed_time)
        print(f"[{process_id}] Procesamiento completado en {elapsed_time}s, cuota restante: {quota_tracker.get_remaining_quota(user_id)}s")
        
        return audio_file_path
    
    except Exception as e:
        # Asegurar que liberamos memoria en caso de error
        if use_gpu:
            torch.cuda.empty_cache()
        print(f"[{process_id}] Error: {str(e)}")
        raise e


# Home page with API information
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
    return templates.TemplateResponse("home.html", {"request": request})

# Prueba para verificar si la API funciona - la dejamos por ahora para debugging
@app.get("/health")
def health_check():
    """Endpoint para verificar que el servicio est谩 funcionando correctamente"""
    return {"status": "ok", "service": "Sound Generation API"}


@app.post("/generate-sound/")
async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
    try:
        process_id = f"sound_{random.randint(1000, 9999)}"
        
        # Usar sem谩foro para asegurar acceso exclusivo a GPU
        async with gpu_semaphore:
            audio_file_path = await process_with_gpu(
                generate_sound, request.prompt, process_id
            )

        # Verifica si el archivo se ha generado correctamente
        if not os.path.exists(audio_file_path):
            raise HTTPException(
                status_code=404, detail="Archivo de audio no encontrado."
            )

        # Regresar el archivo generado como una respuesta de descarga
        return FileResponse(
            audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
        )

    except HTTPException as e:
        # Reenviar excepciones HTTP
        raise e
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate-music/")
async def generate_music_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
    try:
        process_id = f"music_{random.randint(1000, 9999)}"
        
        # Usar sem谩foro para asegurar acceso exclusivo a GPU
        async with gpu_semaphore:
            audio_file_path = await process_with_gpu(
                generate_music, request.prompt, process_id
            )

        # Verifica si el archivo se ha generado correctamente
        if not os.path.exists(audio_file_path):
            raise HTTPException(
                status_code=404, detail="Archivo de audio no encontrado."
            )

        # Regresar el archivo generado como una respuesta de descarga
        return FileResponse(
            audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
        )

    except HTTPException as e:
        # Reenviar excepciones HTTP
        raise e
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/quota-status")
async def quota_status_endpoint(user_id: str = Depends(get_user_id)):
    user_quota = quota_tracker.get_remaining_quota(user_id)
    system_status = quota_tracker.get_system_status()
    
    return {
        "user_id": user_id,
        "quota_remaining": user_quota,
        "reset_time": quota_tracker.user_reset_times.get(user_id, None),
        "system_status": system_status,
        "gpu_available": torch.cuda.is_available(),
        "device_info": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
    }



if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)