File size: 3,650 Bytes
86b351a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55ccab
86b351a
 
 
 
 
 
 
 
 
 
c55ccab
86b351a
 
 
 
 
 
c55ccab
86b351a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55ccab
86b351a
c55ccab
86b351a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import asyncio
from google import genai
from google.genai import types
from config import settings
import os
import tempfile
import wave
import numpy as np
import queue
import logging
import gradio as gr

logger = logging.getLogger(__name__)

client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
audio_queue = queue.Queue(maxsize=1)

async def generate_music(request: gr.Request, music_tone: str, receive_audio):
      async with (
        client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
        asyncio.TaskGroup() as tg,
      ):
        # Set up task to receive server messages.
        tg.create_task(receive_audio(session))

        # Send initial prompts and config
        await session.set_weighted_prompts(
          prompts=[
            types.WeightedPrompt(text=music_tone, weight=1.0),
          ]
        )
        await session.set_music_generation_config(
          config=types.LiveMusicGenerationConfig(bpm=90, temperature=1.0)
        )
        await session.play()
        logger.info(f"Started music generation for session {request.session_hash}, music tone: {music_tone}")
        sessions[request.session_hash] = session
        
async def change_music_tone(request: gr.Request, new_tone):
    logger.info(f"Changing music tone to {new_tone}")
    session = sessions.get(request.session_hash)
    if not session:
        logger.error(f"No session found for request {request.session_hash}")
        return
    await session.reset_context()
    await session.set_weighted_prompts(
        prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
    )
        

SAMPLE_RATE = 48000

async def receive_audio(session):
    """Process incoming audio from the music generation."""
    while True:
        try:
            async for message in session.receive():
                if message.server_content and message.server_content.audio_chunks:
                    audio_data = message.server_content.audio_chunks[0].data
                    audio_queue.put(audio_data)
                await asyncio.sleep(10**-12)
        except Exception as e:
            logger.error(f"Error in receive_audio: {e}")
            await asyncio.sleep(1)

sessions = {}

async def start_music_generation(request: gr.Request, music_tone: str):
    """Start the music generation in a separate thread."""
    await generate_music(request, music_tone, receive_audio)
    
async def cleanup_music_session(request: gr.Request):
    if request.session_hash in sessions:
        logger.info(f"Cleaning up music session for session {request.session_hash}")
        await sessions[request.session_hash].stop()
        del sessions[request.session_hash]
    
current_audio_file = None

def update_audio():
    """Continuously stream audio from the queue."""
    global current_audio_file
    while True:
        audio_data = audio_queue.get()
        if isinstance(audio_data, bytes):
            audio_array = np.frombuffer(audio_data, dtype=np.int16)
        else:
            audio_array = np.array(audio_data, dtype=np.int16)
            
        temp_fd, temp_path = tempfile.mkstemp(suffix='.wav')
        os.close(temp_fd)
        # Write to WAV file
        with wave.open(temp_path, 'wb') as wav_file:
            wav_file.setnchannels(2)  # Stereo
            wav_file.setsampwidth(2)  # 16-bit
            wav_file.setframerate(SAMPLE_RATE)
            wav_file.writeframes(audio_array.tobytes())
            
        if current_audio_file:
            os.remove(current_audio_file)
            
        current_audio_file = temp_path
        yield temp_path