Spaces:
Running
Running
import asyncio | |
from google import genai | |
from google.genai import types | |
from config import settings | |
import os | |
import tempfile | |
import wave | |
import numpy as np | |
import queue | |
import logging | |
import gradio as gr | |
logger = logging.getLogger(__name__) | |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'}) | |
audio_queue = queue.Queue(maxsize=1) | |
async def generate_music(request: gr.Request, music_tone: str, receive_audio): | |
async with ( | |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session, | |
asyncio.TaskGroup() as tg, | |
): | |
# Set up task to receive server messages. | |
tg.create_task(receive_audio(session)) | |
# Send initial prompts and config | |
await session.set_weighted_prompts( | |
prompts=[ | |
types.WeightedPrompt(text=music_tone, weight=1.0), | |
] | |
) | |
await session.set_music_generation_config( | |
config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0) | |
) | |
await session.play() | |
logger.info(f"Started music generation for session {request.session_hash}, music tone: {music_tone}") | |
sessions[request.session_hash] = session | |
async def change_music_tone(request: gr.Request, new_tone): | |
logger.info(f"Changing music tone to {new_tone}") | |
session = sessions.get(request.session_hash) | |
if not session: | |
logger.error(f"No session found for request {request.session_hash}") | |
return | |
await session.reset_context() | |
await session.set_weighted_prompts( | |
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)] | |
) | |
SAMPLE_RATE = 48000 | |
async def receive_audio(session): | |
"""Process incoming audio from the music generation.""" | |
while True: | |
try: | |
async for message in session.receive(): | |
if message.server_content and message.server_content.audio_chunks: | |
audio_data = message.server_content.audio_chunks[0].data | |
audio_queue.put(audio_data) | |
await asyncio.sleep(10**-12) | |
except Exception as e: | |
logger.error(f"Error in receive_audio: {e}") | |
await asyncio.sleep(1) | |
sessions = {} | |
async def start_music_generation(request: gr.Request, music_tone: str): | |
"""Start the music generation in a separate thread.""" | |
await generate_music(request, music_tone, receive_audio) | |
async def cleanup_music_session(request: gr.Request): | |
if request.session_hash in sessions: | |
logger.info(f"Cleaning up music session for session {request.session_hash}") | |
await sessions[request.session_hash].stop() | |
del sessions[request.session_hash] | |
current_audio_file = None | |
def update_audio(): | |
"""Continuously stream audio from the queue.""" | |
global current_audio_file | |
while True: | |
audio_data = audio_queue.get() | |
if isinstance(audio_data, bytes): | |
audio_array = np.frombuffer(audio_data, dtype=np.int16) | |
else: | |
audio_array = np.array(audio_data, dtype=np.int16) | |
temp_fd, temp_path = tempfile.mkstemp(suffix='.wav') | |
os.close(temp_fd) | |
# Write to WAV file | |
with wave.open(temp_path, 'wb') as wav_file: | |
wav_file.setnchannels(2) # Stereo | |
wav_file.setsampwidth(2) # 16-bit | |
wav_file.setframerate(SAMPLE_RATE) | |
wav_file.writeframes(audio_array.tobytes()) | |
if current_audio_file: | |
os.remove(current_audio_file) | |
current_audio_file = temp_path | |
yield temp_path |