File size: 7,998 Bytes
56064c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import base64
import mimetypes
import os
import struct
from google import genai
from google.genai import types
import ffmpy
import datetime
class GeminiTTSWrapper:
def __init__(self, api_key=None):
"""Initialize the Gemini TTS wrapper with an API key."""
self.api_key = api_key
self.client = None
# Create output directory if it doesn't exist
os.makedirs("output", exist_ok=True)
if api_key:
self.set_api_key(api_key)
def set_api_key(self, api_key):
"""Set or update the API key and initialize the client."""
self.api_key = api_key
self.client = genai.Client(api_key=api_key)
return self
def generate_speech(self, text, model="gemini-2.5-pro-preview-tts", voice="Laomedeia",
instructions="", temperature=1.0, output_file=None,
convert_to_mp3=True):
"""
Generate speech from text using Gemini TTS models.
Args:
text (str): The text to convert to speech
model (str): Model to use (gemini-2.5-pro-preview-tts or gemini-2.5-flash-preview-tts)
voice (str): Prebuilt voice name to use
instructions (str): Optional instructions for controlling style, tone, accent, etc.
temperature (float): Sampling temperature (0.0 to 1.0)
output_file (str): Output filename (without extension)
convert_to_mp3 (bool): Whether to convert the output to MP3 format
Returns:
str: Path to the saved audio file
"""
if not self.client:
raise ValueError("API key not set. Call set_api_key() first.")
# Generate timestamp for filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# If no output file specified, create one with timestamp
if output_file is None:
output_file = f"output/gemini_tts_{timestamp}"
elif not output_file.startswith("output/"):
output_file = f"output/{output_file}_{timestamp}"
# Prepare the content with instructions if provided
if instructions:
content_text = f"{instructions}:\n{text}"
else:
content_text = text
contents = [
types.Content(
role="user",
parts=[types.Part.from_text(text=content_text)],
),
]
generate_content_config = types.GenerateContentConfig(
temperature=temperature,
response_modalities=["audio"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=voice
)
)
),
)
file_path = None
for chunk in self.client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if (
chunk.candidates is None
or chunk.candidates[0].content is None
or chunk.candidates[0].content.parts is None
):
continue
if chunk.candidates[0].content.parts[0].inline_data:
inline_data = chunk.candidates[0].content.parts[0].inline_data
data_buffer = inline_data.data
file_extension = mimetypes.guess_extension(inline_data.mime_type)
if file_extension is None:
file_extension = ".wav"
data_buffer = self._convert_to_wav(inline_data.data, inline_data.mime_type)
wav_file_path = f"{output_file}{file_extension}"
self._save_binary_file(wav_file_path, data_buffer)
file_path = wav_file_path
# Convert to MP3 if requested
if convert_to_mp3:
mp3_file_path = f"{output_file}.mp3"
self._convert_to_mp3(wav_file_path, mp3_file_path)
file_path = mp3_file_path
else:
print(chunk.text)
return file_path
def _save_binary_file(self, file_name, data):
"""Save binary data to a file."""
with open(file_name, "wb") as f:
f.write(data)
return file_name
def _convert_to_wav(self, audio_data, mime_type):
"""Convert audio data to WAV format."""
parameters = self._parse_audio_mime_type(mime_type)
bits_per_sample = parameters["bits_per_sample"]
sample_rate = parameters["rate"]
num_channels = 1
data_size = len(audio_data)
bytes_per_sample = bits_per_sample // 8
block_align = num_channels * bytes_per_sample
byte_rate = sample_rate * block_align
chunk_size = 36 + data_size # 36 bytes for header fields before data chunk size
# http://soundfile.sapp.org/doc/WaveFormat/
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", # ChunkID
chunk_size, # ChunkSize (total file size - 8 bytes)
b"WAVE", # Format
b"fmt ", # Subchunk1ID
16, # Subchunk1Size (16 for PCM)
1, # AudioFormat (1 for PCM)
num_channels, # NumChannels
sample_rate, # SampleRate
byte_rate, # ByteRate
block_align, # BlockAlign
bits_per_sample, # BitsPerSample
b"data", # Subchunk2ID
data_size # Subchunk2Size (size of audio data)
)
return header + audio_data
def _convert_to_mp3(self, input_file, output_file):
"""Convert audio file to MP3 format using ffmpeg."""
try:
converter = ffmpy.FFmpeg(
inputs={input_file: None},
outputs={output_file: None}
)
converter.run()
return output_file
except Exception as e:
print(f"Error converting to MP3: {str(e)}")
return input_file
def _parse_audio_mime_type(self, mime_type):
"""Parse audio parameters from MIME type."""
bits_per_sample = 16
rate = 24000
# Extract rate from parameters
parts = mime_type.split(";")
for param in parts:
param = param.strip()
if param.lower().startswith("rate="):
try:
rate_str = param.split("=", 1)[1]
rate = int(rate_str)
except (ValueError, IndexError):
pass # Keep rate as default
elif param.startswith("audio/L"):
try:
bits_per_sample = int(param.split("L", 1)[1])
except (ValueError, IndexError):
pass # Keep bits_per_sample as default if conversion fails
return {"bits_per_sample": bits_per_sample, "rate": rate}
def list_available_voices(self):
"""Return a list of available voice options."""
return [
"Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", "Orus", "Aoede",
"Callirhoe", "Autonoe", "Enceladus", "Iapetus", "Umbriel", "Algieba",
"Despina", "Erinome", "Algenib", "Rasalgethi", "Laomedeia", "Achernar",
"Alnilam", "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
"Vindemiatrix", "Sadachbia", "Sadalthager", "Sulafat"
]
|