Curinha commited on
Commit
f26b7a5
1 Parent(s): b40961b

Implement GPU quota management and user registration for sound generation

Browse files
Files changed (2) hide show
  1. app.py +182 -17
  2. sound_generator.py +7 -4
app.py CHANGED
@@ -1,8 +1,14 @@
 
 
1
  import os
 
 
 
 
2
  import uvicorn
3
 
4
  from sound_generator import generate_sound, generate_music
5
- from fastapi import FastAPI, HTTPException, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.templating import Jinja2Templates
8
  from fastapi.responses import FileResponse, HTMLResponse
@@ -17,7 +23,8 @@ app = FastAPI(
17
  redoc_url="/redoc",
18
  )
19
 
20
- # Cargar las plantillas desde la carpeta "templates"
 
21
  templates = Jinja2Templates(directory="templates")
22
 
23
  # Configuraci贸n de CORS
@@ -29,11 +36,145 @@ app.add_middleware(
29
  allow_headers=["*"],
30
  )
31
 
32
-
33
- # Define a Pydantic model to handle the input prompt
34
  class AudioRequest(BaseModel):
35
  prompt: str
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Prueba para verificar si la API funciona - la dejamos por ahora para debugging
39
  @app.get("/health")
@@ -43,10 +184,15 @@ def health_check():
43
 
44
 
45
  @app.post("/generate-sound/")
46
- async def generate_sound_endpoint(request: AudioRequest):
47
  try:
48
- # Llamada a la funci贸n para generar el sonido
49
- audio_file_path = generate_sound(request.prompt)
 
 
 
 
 
50
 
51
  # Verifica si el archivo se ha generado correctamente
52
  if not os.path.exists(audio_file_path):
@@ -59,35 +205,54 @@ async def generate_sound_endpoint(request: AudioRequest):
59
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
60
  )
61
 
 
 
 
62
  except Exception as e:
63
  raise HTTPException(status_code=500, detail=str(e))
64
 
65
-
66
  @app.post("/generate-music/")
67
- async def generate_music_endpoint(request: AudioRequest):
68
  try:
69
- # Call the synchronous generate_music function
70
- audio_file_path = generate_music(request.prompt)
 
 
 
 
 
71
 
72
- # Verifies if the file has been generated correctly
73
  if not os.path.exists(audio_file_path):
74
  raise HTTPException(
75
  status_code=404, detail="Archivo de audio no encontrado."
76
  )
77
 
78
- # Return the generated file as a download response
79
  return FileResponse(
80
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
81
  )
82
 
 
 
 
83
  except Exception as e:
84
  raise HTTPException(status_code=500, detail=str(e))
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- @app.get("/", response_class=HTMLResponse)
88
- def home(request: Request):
89
- """P谩gina de inicio con informaci贸n de la API"""
90
- return templates.TemplateResponse("home.html", {"request": request})
91
 
92
 
93
  if __name__ == "__main__":
 
1
+ import asyncio
2
+ from datetime import datetime
3
  import os
4
+ import random
5
+ import time
6
+ from typing import Dict, List
7
+ import torch
8
  import uvicorn
9
 
10
  from sound_generator import generate_sound, generate_music
11
+ from fastapi import Depends, FastAPI, HTTPException, Request
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from fastapi.templating import Jinja2Templates
14
  from fastapi.responses import FileResponse, HTMLResponse
 
23
  redoc_url="/redoc",
24
  )
25
 
26
+
27
+ # Configuraci贸n de templates
28
  templates = Jinja2Templates(directory="templates")
29
 
30
  # Configuraci贸n de CORS
 
36
  allow_headers=["*"],
37
  )
38
 
 
 
39
  class AudioRequest(BaseModel):
40
  prompt: str
41
 
42
+ class GPUQuotaConfig:
43
+ MAX_REQUEST_DURATION = 20 # segundos m谩ximos por solicitud
44
+ DAILY_QUOTA = 300 # 5 minutos en total (300 segundos)
45
+
46
+ class QuotaTracker:
47
+ def __init__(self):
48
+ self.users_quota: Dict[str, int] = {}
49
+ self.user_reset_times: Dict[str, datetime] = {}
50
+ self.current_user_index = 0
51
+ self.registered_users: List[str] = []
52
+
53
+ def register_user(self, user_id: str):
54
+ if user_id not in self.registered_users:
55
+ self.registered_users.append(user_id)
56
+ self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
57
+ self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
58
+
59
+ def get_next_available_user(self):
60
+ # Verificar resets
61
+ for user_id in list(self.user_reset_times.keys()):
62
+ if datetime.now() > self.user_reset_times[user_id]:
63
+ self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
64
+ self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
65
+
66
+ # Encontrar usuario con cuota
67
+ attempts = 0
68
+ while attempts < len(self.registered_users):
69
+ self.current_user_index = (self.current_user_index + 1) % max(1, len(self.registered_users))
70
+ current_user = self.registered_users[self.current_user_index]
71
+ if self.users_quota.get(current_user, 0) >= GPUQuotaConfig.MAX_REQUEST_DURATION:
72
+ return current_user
73
+ attempts += 1
74
+
75
+ return None
76
+
77
+ def consume_quota(self, user_id: str, seconds: int):
78
+ if user_id in self.users_quota:
79
+ self.users_quota[user_id] = max(0, self.users_quota[user_id] - seconds)
80
+ return True
81
+ return False
82
+
83
+ def get_remaining_quota(self, user_id: str):
84
+ if user_id in self.users_quota:
85
+ # Verificar si se debe resetear
86
+ if datetime.now() > self.user_reset_times.get(user_id, datetime.max):
87
+ self.users_quota[user_id] = GPUQuotaConfig.DAILY_QUOTA
88
+ self.user_reset_times[user_id] = datetime.now() + datetime.timedelta(days=1)
89
+ return self.users_quota[user_id]
90
+ return 0
91
+
92
+ def get_system_status(self):
93
+ return {
94
+ "registered_users": len(self.registered_users),
95
+ "users_with_quota": sum(1 for q in self.users_quota.values() if q >= GPUQuotaConfig.MAX_REQUEST_DURATION),
96
+ "total_available_seconds": sum(self.users_quota.values())
97
+ }
98
+
99
+ # Inicializar sistema
100
+ quota_tracker = QuotaTracker()
101
+
102
+ # Registrar usuarios virtuales
103
+ for i in range(5):
104
+ quota_tracker.register_user(f"virtual_user_{i}")
105
+
106
+ # Sem谩foro para controlar acceso a GPU - solo una tarea a la vez
107
+ gpu_semaphore = asyncio.Semaphore(1)
108
+
109
+ # Middleware para asignar user_id
110
+ @app.middleware("http")
111
+ async def assign_user_id(request: Request, call_next):
112
+ if "user-id" not in request.headers:
113
+ request.state.user_id = f"anonymous_{random.randint(1000, 9999)}"
114
+ quota_tracker.register_user(request.state.user_id)
115
+ else:
116
+ request.state.user_id = request.headers["user-id"]
117
+ quota_tracker.register_user(request.state.user_id)
118
+
119
+ response = await call_next(request)
120
+ return response
121
+
122
+ async def get_user_id(request: Request):
123
+ return request.state.user_id
124
+
125
+ # Funci贸n para manejar la generaci贸n con control de GPU
126
+ async def process_with_gpu(generation_func, prompt, process_id):
127
+ start_time = time.time()
128
+ print(f"[{process_id}] Iniciando procesamiento GPU")
129
+
130
+ # Buscar usuario con cuota disponible
131
+ user_id = quota_tracker.get_next_available_user()
132
+ if not user_id:
133
+ raise HTTPException(status_code=429, detail="No hay cuota GPU disponible en el sistema")
134
+
135
+ quota_available = quota_tracker.get_remaining_quota(user_id)
136
+ print(f"[{process_id}] Usando cuota de usuario {user_id}: {quota_available}s disponibles")
137
+
138
+ # Verificar si hay suficiente cuota
139
+ if quota_available < GPUQuotaConfig.MAX_REQUEST_DURATION:
140
+ raise HTTPException(status_code=429, detail=f"Cuota GPU insuficiente ({quota_available}s disponibles)")
141
+
142
+ # Verificar que los modelos usen GPU si est谩 disponible
143
+ use_gpu = torch.cuda.is_available()
144
+ device = 'cuda' if use_gpu else 'cpu'
145
+ print(f"[{process_id}] Usando dispositivo: {device}")
146
+
147
+ try:
148
+ # Llamar a la funci贸n de generaci贸n con l铆mite de tiempo
149
+ audio_file_path = await asyncio.to_thread(
150
+ generation_func, prompt, device, user_id
151
+ )
152
+
153
+ # Liberar memoria GPU si se utiliz贸
154
+ if use_gpu:
155
+ torch.cuda.empty_cache()
156
+
157
+ # Calcular tiempo real usado
158
+ elapsed_time = min(GPUQuotaConfig.MAX_REQUEST_DURATION, int(time.time() - start_time))
159
+
160
+ # Consumir cuota
161
+ quota_tracker.consume_quota(user_id, elapsed_time)
162
+ print(f"[{process_id}] Procesamiento completado en {elapsed_time}s, cuota restante: {quota_tracker.get_remaining_quota(user_id)}s")
163
+
164
+ return audio_file_path
165
+
166
+ except Exception as e:
167
+ # Asegurar que liberamos memoria en caso de error
168
+ if use_gpu:
169
+ torch.cuda.empty_cache()
170
+ print(f"[{process_id}] Error: {str(e)}")
171
+ raise e
172
+
173
+
174
+ # Home page with API information
175
+ @app.get("/", response_class=HTMLResponse)
176
+ def home(request: Request):
177
+ return templates.TemplateResponse("home.html", {"request": request})
178
 
179
  # Prueba para verificar si la API funciona - la dejamos por ahora para debugging
180
  @app.get("/health")
 
184
 
185
 
186
  @app.post("/generate-sound/")
187
+ async def generate_sound_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
188
  try:
189
+ process_id = f"sound_{random.randint(1000, 9999)}"
190
+
191
+ # Usar sem谩foro para asegurar acceso exclusivo a GPU
192
+ async with gpu_semaphore:
193
+ audio_file_path = await process_with_gpu(
194
+ generate_sound, request.prompt, process_id
195
+ )
196
 
197
  # Verifica si el archivo se ha generado correctamente
198
  if not os.path.exists(audio_file_path):
 
205
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
206
  )
207
 
208
+ except HTTPException as e:
209
+ # Reenviar excepciones HTTP
210
+ raise e
211
  except Exception as e:
212
  raise HTTPException(status_code=500, detail=str(e))
213
 
 
214
  @app.post("/generate-music/")
215
+ async def generate_music_endpoint(request: AudioRequest, user_id: str = Depends(get_user_id)):
216
  try:
217
+ process_id = f"music_{random.randint(1000, 9999)}"
218
+
219
+ # Usar sem谩foro para asegurar acceso exclusivo a GPU
220
+ async with gpu_semaphore:
221
+ audio_file_path = await process_with_gpu(
222
+ generate_music, request.prompt, process_id
223
+ )
224
 
225
+ # Verifica si el archivo se ha generado correctamente
226
  if not os.path.exists(audio_file_path):
227
  raise HTTPException(
228
  status_code=404, detail="Archivo de audio no encontrado."
229
  )
230
 
231
+ # Regresar el archivo generado como una respuesta de descarga
232
  return FileResponse(
233
  audio_file_path, media_type="audio/wav", filename="generated_audio.wav"
234
  )
235
 
236
+ except HTTPException as e:
237
+ # Reenviar excepciones HTTP
238
+ raise e
239
  except Exception as e:
240
  raise HTTPException(status_code=500, detail=str(e))
241
 
242
+ @app.get("/quota-status")
243
+ async def quota_status_endpoint(user_id: str = Depends(get_user_id)):
244
+ user_quota = quota_tracker.get_remaining_quota(user_id)
245
+ system_status = quota_tracker.get_system_status()
246
+
247
+ return {
248
+ "user_id": user_id,
249
+ "quota_remaining": user_quota,
250
+ "reset_time": quota_tracker.user_reset_times.get(user_id, None),
251
+ "system_status": system_status,
252
+ "gpu_available": torch.cuda.is_available(),
253
+ "device_info": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
254
+ }
255
 
 
 
 
 
256
 
257
 
258
  if __name__ == "__main__":
sound_generator.py CHANGED
@@ -16,8 +16,9 @@ music_model = MusicGen.get_pretrained('facebook/musicgen-small')
16
  sound_model.set_generation_params(duration=5)
17
  music_model.set_generation_params(duration=5)
18
 
19
- @spaces.GPU
20
- def generate_sound(prompt: str):
 
21
  """
22
  Generate sound using Audiocraft based on the given prompt.
23
 
@@ -27,6 +28,7 @@ def generate_sound(prompt: str):
27
  Returns:
28
  - str: The path to the saved audio file.
29
  """
 
30
  descriptions = [prompt]
31
  timestamp = str(time.time()).replace(".", "")
32
  wav = sound_model.generate(descriptions) # Generate audio
@@ -36,8 +38,8 @@ def generate_sound(prompt: str):
36
 
37
  return f"{output_path}.wav"
38
 
39
- @spaces.GPU
40
- def generate_music(prompt: str):
41
  """
42
  Generate music using Audiocraft based on the given prompt.
43
 
@@ -47,6 +49,7 @@ def generate_music(prompt: str):
47
  Returns:
48
  - str: The path to the saved audio file.
49
  """
 
50
  descriptions = [prompt]
51
  timestamp = str(time.time()).replace(".", "")
52
  wav = music_model.generate(descriptions) # Generate music
 
16
  sound_model.set_generation_params(duration=5)
17
  music_model.set_generation_params(duration=5)
18
 
19
+
20
+ @spaces.GPU(duration=20)
21
+ def generate_sound(prompt: str, user_id: str):
22
  """
23
  Generate sound using Audiocraft based on the given prompt.
24
 
 
28
  Returns:
29
  - str: The path to the saved audio file.
30
  """
31
+ print(f"Generando sonido para prompt: '{prompt}' en dispositivo {device} (usuario: {user_id})")
32
  descriptions = [prompt]
33
  timestamp = str(time.time()).replace(".", "")
34
  wav = sound_model.generate(descriptions) # Generate audio
 
38
 
39
  return f"{output_path}.wav"
40
 
41
+ @spaces.GPU(duration=20)
42
+ def generate_music(prompt: str, user_id: str):
43
  """
44
  Generate music using Audiocraft based on the given prompt.
45
 
 
49
  Returns:
50
  - str: The path to the saved audio file.
51
  """
52
+ print(f"Generando sonido para prompt: '{prompt}' en dispositivo {device} (usuario: {user_id})")
53
  descriptions = [prompt]
54
  timestamp = str(time.time()).replace(".", "")
55
  wav = music_model.generate(descriptions) # Generate music