googletts / wrapper.py
daswer123's picture
Upload 3 files
56064c3 verified
import base64
import mimetypes
import os
import struct
from google import genai
from google.genai import types
import ffmpy
import datetime
class GeminiTTSWrapper:
def __init__(self, api_key=None):
"""Initialize the Gemini TTS wrapper with an API key."""
self.api_key = api_key
self.client = None
# Create output directory if it doesn't exist
os.makedirs("output", exist_ok=True)
if api_key:
self.set_api_key(api_key)
def set_api_key(self, api_key):
"""Set or update the API key and initialize the client."""
self.api_key = api_key
self.client = genai.Client(api_key=api_key)
return self
def generate_speech(self, text, model="gemini-2.5-pro-preview-tts", voice="Laomedeia",
instructions="", temperature=1.0, output_file=None,
convert_to_mp3=True):
"""
Generate speech from text using Gemini TTS models.
Args:
text (str): The text to convert to speech
model (str): Model to use (gemini-2.5-pro-preview-tts or gemini-2.5-flash-preview-tts)
voice (str): Prebuilt voice name to use
instructions (str): Optional instructions for controlling style, tone, accent, etc.
temperature (float): Sampling temperature (0.0 to 1.0)
output_file (str): Output filename (without extension)
convert_to_mp3 (bool): Whether to convert the output to MP3 format
Returns:
str: Path to the saved audio file
"""
if not self.client:
raise ValueError("API key not set. Call set_api_key() first.")
# Generate timestamp for filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
# If no output file specified, create one with timestamp
if output_file is None:
output_file = f"output/gemini_tts_{timestamp}"
elif not output_file.startswith("output/"):
output_file = f"output/{output_file}_{timestamp}"
# Prepare the content with instructions if provided
if instructions:
content_text = f"{instructions}:\n{text}"
else:
content_text = text
contents = [
types.Content(
role="user",
parts=[types.Part.from_text(text=content_text)],
),
]
generate_content_config = types.GenerateContentConfig(
temperature=temperature,
response_modalities=["audio"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=voice
)
)
),
)
file_path = None
for chunk in self.client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if (
chunk.candidates is None
or chunk.candidates[0].content is None
or chunk.candidates[0].content.parts is None
):
continue
if chunk.candidates[0].content.parts[0].inline_data:
inline_data = chunk.candidates[0].content.parts[0].inline_data
data_buffer = inline_data.data
file_extension = mimetypes.guess_extension(inline_data.mime_type)
if file_extension is None:
file_extension = ".wav"
data_buffer = self._convert_to_wav(inline_data.data, inline_data.mime_type)
wav_file_path = f"{output_file}{file_extension}"
self._save_binary_file(wav_file_path, data_buffer)
file_path = wav_file_path
# Convert to MP3 if requested
if convert_to_mp3:
mp3_file_path = f"{output_file}.mp3"
self._convert_to_mp3(wav_file_path, mp3_file_path)
file_path = mp3_file_path
else:
print(chunk.text)
return file_path
def _save_binary_file(self, file_name, data):
"""Save binary data to a file."""
with open(file_name, "wb") as f:
f.write(data)
return file_name
def _convert_to_wav(self, audio_data, mime_type):
"""Convert audio data to WAV format."""
parameters = self._parse_audio_mime_type(mime_type)
bits_per_sample = parameters["bits_per_sample"]
sample_rate = parameters["rate"]
num_channels = 1
data_size = len(audio_data)
bytes_per_sample = bits_per_sample // 8
block_align = num_channels * bytes_per_sample
byte_rate = sample_rate * block_align
chunk_size = 36 + data_size # 36 bytes for header fields before data chunk size
# http://soundfile.sapp.org/doc/WaveFormat/
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", # ChunkID
chunk_size, # ChunkSize (total file size - 8 bytes)
b"WAVE", # Format
b"fmt ", # Subchunk1ID
16, # Subchunk1Size (16 for PCM)
1, # AudioFormat (1 for PCM)
num_channels, # NumChannels
sample_rate, # SampleRate
byte_rate, # ByteRate
block_align, # BlockAlign
bits_per_sample, # BitsPerSample
b"data", # Subchunk2ID
data_size # Subchunk2Size (size of audio data)
)
return header + audio_data
def _convert_to_mp3(self, input_file, output_file):
"""Convert audio file to MP3 format using ffmpeg."""
try:
converter = ffmpy.FFmpeg(
inputs={input_file: None},
outputs={output_file: None}
)
converter.run()
return output_file
except Exception as e:
print(f"Error converting to MP3: {str(e)}")
return input_file
def _parse_audio_mime_type(self, mime_type):
"""Parse audio parameters from MIME type."""
bits_per_sample = 16
rate = 24000
# Extract rate from parameters
parts = mime_type.split(";")
for param in parts:
param = param.strip()
if param.lower().startswith("rate="):
try:
rate_str = param.split("=", 1)[1]
rate = int(rate_str)
except (ValueError, IndexError):
pass # Keep rate as default
elif param.startswith("audio/L"):
try:
bits_per_sample = int(param.split("L", 1)[1])
except (ValueError, IndexError):
pass # Keep bits_per_sample as default if conversion fails
return {"bits_per_sample": bits_per_sample, "rate": rate}
def list_available_voices(self):
"""Return a list of available voice options."""
return [
"Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", "Orus", "Aoede",
"Callirhoe", "Autonoe", "Enceladus", "Iapetus", "Umbriel", "Algieba",
"Despina", "Erinome", "Algenib", "Rasalgethi", "Laomedeia", "Achernar",
"Alnilam", "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
"Vindemiatrix", "Sadachbia", "Sadalthager", "Sulafat"
]