LLMGameHub / src /audio /audio_generator.py
gsavin's picture
fix: add proper loading screen and fix initial music generation
c55ccab
raw
history blame
3.65 kB
import asyncio
from google import genai
from google.genai import types
from config import settings
import os
import tempfile
import wave
import numpy as np
import queue
import logging
import gradio as gr
logger = logging.getLogger(__name__)
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
audio_queue = queue.Queue(maxsize=1)
async def generate_music(request: gr.Request, music_tone: str, receive_audio):
async with (
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
asyncio.TaskGroup() as tg,
):
# Set up task to receive server messages.
tg.create_task(receive_audio(session))
# Send initial prompts and config
await session.set_weighted_prompts(
prompts=[
types.WeightedPrompt(text=music_tone, weight=1.0),
]
)
await session.set_music_generation_config(
config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
)
await session.play()
logger.info(f"Started music generation for session {request.session_hash}, music tone: {music_tone}")
sessions[request.session_hash] = session
async def change_music_tone(request: gr.Request, new_tone):
logger.info(f"Changing music tone to {new_tone}")
session = sessions.get(request.session_hash)
if not session:
logger.error(f"No session found for request {request.session_hash}")
return
await session.reset_context()
await session.set_weighted_prompts(
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
)
SAMPLE_RATE = 48000
async def receive_audio(session):
"""Process incoming audio from the music generation."""
while True:
try:
async for message in session.receive():
if message.server_content and message.server_content.audio_chunks:
audio_data = message.server_content.audio_chunks[0].data
audio_queue.put(audio_data)
await asyncio.sleep(10**-12)
except Exception as e:
logger.error(f"Error in receive_audio: {e}")
await asyncio.sleep(1)
sessions = {}
async def start_music_generation(request: gr.Request, music_tone: str):
"""Start the music generation in a separate thread."""
await generate_music(request, music_tone, receive_audio)
async def cleanup_music_session(request: gr.Request):
if request.session_hash in sessions:
logger.info(f"Cleaning up music session for session {request.session_hash}")
await sessions[request.session_hash].stop()
del sessions[request.session_hash]
current_audio_file = None
def update_audio():
"""Continuously stream audio from the queue."""
global current_audio_file
while True:
audio_data = audio_queue.get()
if isinstance(audio_data, bytes):
audio_array = np.frombuffer(audio_data, dtype=np.int16)
else:
audio_array = np.array(audio_data, dtype=np.int16)
temp_fd, temp_path = tempfile.mkstemp(suffix='.wav')
os.close(temp_fd)
# Write to WAV file
with wave.open(temp_path, 'wb') as wav_file:
wav_file.setnchannels(2) # Stereo
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(SAMPLE_RATE)
wav_file.writeframes(audio_array.tobytes())
if current_audio_file:
os.remove(current_audio_file)
current_audio_file = temp_path
yield temp_path