Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Gemini TTS API Server | |
A FastAPI-based REST API for Google's Gemini Text-to-Speech service | |
with concurrent request handling and audio format conversion. | |
""" | |
import os | |
import asyncio | |
import json | |
import base64 | |
import uuid | |
from datetime import datetime | |
from typing import Optional, List, Dict, Any | |
from io import BytesIO | |
import aiohttp | |
import aiofiles | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request | |
from fastapi.responses import FileResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
from pydub import AudioSegment | |
import uvicorn | |
# Pydantic models for request/response | |
class VoiceConfig(BaseModel): | |
voice_name: str = Field(default="Zephyr", description="Voice name (e.g., Zephyr, Puck)") | |
class SpeakerConfig(BaseModel): | |
speaker: str = Field(description="Speaker identifier") | |
voice_config: VoiceConfig | |
class TTSRequest(BaseModel): | |
text: str = Field(description="Text to convert to speech") | |
speakers: Optional[List[SpeakerConfig]] = Field( | |
default=None, | |
description="Multi-speaker configuration (optional)" | |
) | |
voice_name: Optional[str] = Field( | |
default="Zephyr", | |
description="Single voice name (used if speakers not provided)" | |
) | |
output_format: str = Field(default="wav", description="Output format: wav or mp3") | |
speed_factor: float = Field(default=1.0, description="Speed adjustment factor") | |
temperature: float = Field(default=1.0, description="Generation temperature") | |
class TTSResponse(BaseModel): | |
task_id: str | |
status: str | |
message: str | |
audio_url: Optional[str] = None | |
metadata: Optional[Dict[str, Any]] = None | |
class TaskStatus(BaseModel): | |
task_id: str | |
status: str | |
progress: Optional[str] = None | |
error: Optional[str] = None | |
result: Optional[Dict[str, Any]] = None | |
# Global task storage (in production, use Redis or database) | |
tasks: Dict[str, Dict[str, Any]] = {} | |
# FastAPI app initialization | |
app = FastAPI( | |
title="Gemini TTS API", | |
description="Text-to-Speech API using Google's Gemini model with concurrent request handling", | |
version="1.0.0" | |
) | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Configure appropriately for production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Configuration | |
def get_api_keys(): | |
"""Get API keys from environment variables""" | |
# Support multiple formats for API keys | |
api_keys = [] | |
# Single API key (backward compatibility) | |
single_key = os.getenv('GEMINI_API_KEY') | |
if single_key: | |
api_keys.append(single_key.strip()) | |
# Multiple API keys (comma-separated) | |
multi_keys = os.getenv('GEMINI_API_KEYS') | |
if multi_keys: | |
keys = [key.strip() for key in multi_keys.split(',') if key.strip()] | |
api_keys.extend(keys) | |
# Individual API keys (GEMINI_API_KEY_1, GEMINI_API_KEY_2, etc.) | |
i = 1 | |
while True: | |
key = os.getenv(f'GEMINI_API_KEY_{i}') | |
if not key: | |
break | |
api_keys.append(key.strip()) | |
i += 1 | |
# Remove duplicates while preserving order | |
seen = set() | |
unique_keys = [] | |
for key in api_keys: | |
if key not in seen: | |
seen.add(key) | |
unique_keys.append(key) | |
return unique_keys | |
GEMINI_API_KEYS = get_api_keys() | |
MODEL_ID = "gemini-2.5-flash-preview-tts" | |
GENERATE_CONTENT_API = "streamGenerateContent" | |
OUTPUT_DIR = "/tmp/audio_files" | |
MAX_CONCURRENT_REQUESTS = 10 | |
RATE_LIMIT_RETRY_DELAY = 60 # seconds to wait after rate limit | |
MAX_RETRIES_PER_KEY = 2 | |
# API key management | |
class APIKeyManager: | |
def __init__(self, api_keys: List[str]): | |
self.api_keys = api_keys | |
self.current_key_index = 0 | |
self.key_stats = {key: {"requests": 0, "failures": 0, "last_rate_limit": None} for key in api_keys} | |
self.lock = asyncio.Lock() | |
async def get_next_key(self) -> Optional[str]: | |
"""Get the next available API key""" | |
async with self.lock: | |
if not self.api_keys: | |
return None | |
# Try to find a key that's not rate limited | |
for _ in range(len(self.api_keys)): | |
key = self.api_keys[self.current_key_index] | |
stats = self.key_stats[key] | |
# Check if this key is currently rate limited | |
if stats["last_rate_limit"]: | |
time_since_limit = datetime.now().timestamp() - stats["last_rate_limit"] | |
if time_since_limit < RATE_LIMIT_RETRY_DELAY: | |
# Still rate limited, try next key | |
self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) | |
continue | |
else: | |
# Rate limit period has passed, reset | |
stats["last_rate_limit"] = None | |
# This key is available | |
stats["requests"] += 1 | |
return key | |
# All keys are rate limited, return the one with oldest rate limit | |
oldest_key = min( | |
self.api_keys, | |
key=lambda k: self.key_stats[k]["last_rate_limit"] or 0 | |
) | |
return oldest_key | |
async def mark_rate_limited(self, api_key: str): | |
"""Mark an API key as rate limited""" | |
async with self.lock: | |
if api_key in self.key_stats: | |
self.key_stats[api_key]["last_rate_limit"] = datetime.now().timestamp() | |
self.key_stats[api_key]["failures"] += 1 | |
async def mark_success(self, api_key: str): | |
"""Mark an API key as successful (reset failure count)""" | |
async with self.lock: | |
if api_key in self.key_stats: | |
self.key_stats[api_key]["failures"] = max(0, self.key_stats[api_key]["failures"] - 1) | |
def get_stats(self) -> dict: | |
"""Get statistics for all API keys""" | |
return { | |
"total_keys": len(self.api_keys), | |
"key_stats": self.key_stats.copy() | |
} | |
# Initialize API key manager | |
api_key_manager = APIKeyManager(GEMINI_API_KEYS) | |
# Ensure output directory exists | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Semaphore to limit concurrent requests | |
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) | |
async def convert_and_adjust_audio( | |
audio_data: bytes, | |
output_format: str = "wav", | |
speed_factor: float = 1.0 | |
) -> tuple[bytes, str]: | |
""" | |
Convert PCM audio data to specified format and adjust speed asynchronously | |
""" | |
def _convert(): | |
# Create AudioSegment from raw PCM data | |
audio = AudioSegment( | |
data=audio_data, | |
sample_width=2, # 16-bit = 2 bytes | |
frame_rate=24000, # 24kHz | |
channels=1 # mono | |
) | |
# Adjust speed by changing frame rate | |
if speed_factor != 1.0: | |
new_frame_rate = int(audio.frame_rate * speed_factor) | |
audio_speed_adjusted = audio._spawn( | |
audio.raw_data, | |
overrides={"frame_rate": new_frame_rate} | |
) | |
audio_speed_adjusted = audio_speed_adjusted.set_frame_rate(audio.frame_rate) | |
else: | |
audio_speed_adjusted = audio | |
# Export to desired format | |
buffer = BytesIO() | |
if output_format.lower() == "mp3": | |
audio_speed_adjusted.export(buffer, format="mp3", bitrate="128k") | |
return buffer.getvalue(), "mp3" | |
else: | |
audio_speed_adjusted.export(buffer, format="wav") | |
return buffer.getvalue(), "wav" | |
# Run audio processing in thread pool to avoid blocking | |
loop = asyncio.get_event_loop() | |
return await loop.run_in_executor(None, _convert) | |
async def generate_tts_audio( | |
task_id: str, | |
text: str, | |
speakers: Optional[List[SpeakerConfig]] = None, | |
voice_name: str = "Zephyr", | |
output_format: str = "wav", | |
speed_factor: float = 1.0, | |
temperature: float = 1.0 | |
): | |
""" | |
Generate TTS audio using Gemini API with multiple API key support and rate limit handling | |
""" | |
async with semaphore: # Limit concurrent requests | |
try: | |
# Update task status | |
tasks[task_id]["status"] = "processing" | |
tasks[task_id]["progress"] = "Preparing request" | |
# Prepare request data | |
if speakers: | |
# Multi-speaker configuration | |
speech_config = { | |
"multi_speaker_voice_config": { | |
"speaker_voice_configs": [ | |
{ | |
"speaker": speaker.speaker, | |
"voice_config": { | |
"prebuilt_voice_config": { | |
"voice_name": speaker.voice_config.voice_name | |
} | |
} | |
} | |
for speaker in speakers | |
] | |
} | |
} | |
else: | |
# Single voice configuration | |
speech_config = { | |
"voice_config": { | |
"prebuilt_voice_config": { | |
"voice_name": voice_name | |
} | |
} | |
} | |
request_data = { | |
"contents": [ | |
{ | |
"role": "user", | |
"parts": [{"text": text}] | |
} | |
], | |
"generationConfig": { | |
"responseModalities": ["audio"], | |
"temperature": temperature, | |
"speech_config": speech_config | |
} | |
} | |
# API endpoint | |
url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_ID}:{GENERATE_CONTENT_API}" | |
tasks[task_id]["progress"] = "Calling Gemini API" | |
# Try multiple API keys with rate limit handling | |
last_error = None | |
attempts = 0 | |
max_total_attempts = len(GEMINI_API_KEYS) * MAX_RETRIES_PER_KEY if GEMINI_API_KEYS else 1 | |
while attempts < max_total_attempts: | |
current_api_key = await api_key_manager.get_next_key() | |
if not current_api_key: | |
raise HTTPException(status_code=500, detail="No API keys available") | |
attempts += 1 | |
tasks[task_id]["progress"] = f"Attempting API call (attempt {attempts}/{max_total_attempts})" | |
try: | |
# Make async API request | |
async with aiohttp.ClientSession() as session: | |
async with session.post( | |
url, | |
headers={"Content-Type": "application/json"}, | |
params={"key": current_api_key}, | |
json=request_data, | |
timeout=aiohttp.ClientTimeout(total=120) # 2 minute timeout | |
) as response: | |
# Handle different HTTP status codes | |
if response.status == 200: | |
# Success! Mark key as successful and proceed | |
await api_key_manager.mark_success(current_api_key) | |
response_data = await response.json() | |
break | |
elif response.status == 429: # Rate limit exceeded | |
error_text = await response.text() | |
await api_key_manager.mark_rate_limited(current_api_key) | |
last_error = f"Rate limit exceeded for API key: {error_text}" | |
print(f"Rate limit hit for key ending in ...{current_api_key[-4:]}, trying next key") | |
continue | |
elif response.status in [403, 401]: # Auth errors | |
error_text = await response.text() | |
await api_key_manager.mark_rate_limited(current_api_key) # Temporarily disable this key | |
last_error = f"Authentication error: {error_text}" | |
print(f"Auth error for key ending in ...{current_api_key[-4:]}: {error_text}") | |
continue | |
else: # Other HTTP errors | |
error_text = await response.text() | |
last_error = f"HTTP {response.status}: {error_text}" | |
# Don't mark as rate limited for other errors, but still try next key | |
continue | |
except asyncio.TimeoutError: | |
last_error = "Request timeout" | |
continue | |
except aiohttp.ClientError as e: | |
last_error = f"Client error: {str(e)}" | |
continue | |
except Exception as e: | |
last_error = f"Unexpected error: {str(e)}" | |
continue | |
else: | |
# All attempts failed | |
raise HTTPException( | |
status_code=500, | |
detail=f"All API keys exhausted. Last error: {last_error}" | |
) | |
tasks[task_id]["progress"] = "Processing audio data" | |
# Extract audio data | |
if response_data and len(response_data) > 0: | |
candidates = response_data[0].get("candidates", []) | |
if not candidates: | |
raise HTTPException(status_code=500, detail="No candidates in response") | |
parts = candidates[0].get("content", {}).get("parts", []) | |
audio_data_b64 = None | |
for part in parts: | |
if "inlineData" in part: | |
audio_data_b64 = part["inlineData"].get("data", "") | |
break | |
if not audio_data_b64: | |
raise HTTPException(status_code=500, detail="No audio data found in response") | |
# Decode base64 audio data | |
audio_data = base64.b64decode(audio_data_b64) | |
tasks[task_id]["progress"] = "Converting audio format" | |
# Convert and adjust audio | |
converted_audio, file_ext = await convert_and_adjust_audio( | |
audio_data, output_format, speed_factor | |
) | |
# Generate filename | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"gemini_audio_{task_id}_{timestamp}.{file_ext}" | |
filepath = os.path.join(OUTPUT_DIR, filename) | |
# Save audio file | |
async with aiofiles.open(filepath, "wb") as f: | |
await f.write(converted_audio) | |
# Update task with results | |
tasks[task_id].update({ | |
"status": "completed", | |
"progress": "Completed", | |
"result": { | |
"filename": filename, | |
"filepath": filepath, | |
"format": output_format.upper(), | |
"speed_factor": speed_factor, | |
"original_size": len(audio_data), | |
"converted_size": len(converted_audio), | |
"audio_url": f"/audio/{filename}" | |
} | |
}) | |
except Exception as e: | |
tasks[task_id].update({ | |
"status": "failed", | |
"error": str(e) | |
}) | |
# API Endpoints | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"message": "Gemini TTS API Server", | |
"version": "1.0.0", | |
"endpoints": { | |
"POST /tts": "Generate TTS audio", | |
"GET /status/{task_id}": "Get task status", | |
"GET /audio/{filename}": "Download audio file", | |
"GET /tasks": "List all tasks" | |
} | |
} | |
async def create_tts_task( | |
request: TTSRequest, | |
background_tasks: BackgroundTasks | |
): | |
""" | |
Create a new TTS generation task | |
""" | |
if not GEMINI_API_KEYS: | |
raise HTTPException(status_code=500, detail="No GEMINI_API_KEYs configured. Please set GEMINI_API_KEY, GEMINI_API_KEYS, or GEMINI_API_KEY_1, GEMINI_API_KEY_2, etc.") | |
# Generate unique task ID | |
task_id = str(uuid.uuid4()) | |
# Initialize task | |
tasks[task_id] = { | |
"task_id": task_id, | |
"status": "queued", | |
"created_at": datetime.now().isoformat(), | |
"request": request.dict() | |
} | |
# Start background task | |
background_tasks.add_task( | |
generate_tts_audio, | |
task_id, | |
request.text, | |
request.speakers, | |
request.voice_name, | |
request.output_format, | |
request.speed_factor, | |
request.temperature | |
) | |
return TTSResponse( | |
task_id=task_id, | |
status="queued", | |
message="TTS generation task created successfully" | |
) | |
async def get_task_status(task_id: str): | |
""" | |
Get the status of a TTS generation task | |
""" | |
if task_id not in tasks: | |
raise HTTPException(status_code=404, detail="Task not found") | |
task = tasks[task_id] | |
return TaskStatus( | |
task_id=task_id, | |
status=task["status"], | |
progress=task.get("progress"), | |
error=task.get("error"), | |
result=task.get("result") | |
) | |
async def download_audio(filename: str): | |
""" | |
Download generated audio file | |
""" | |
filepath = os.path.join(OUTPUT_DIR, filename) | |
if not os.path.exists(filepath): | |
raise HTTPException(status_code=404, detail="Audio file not found") | |
return FileResponse( | |
filepath, | |
media_type="application/octet-stream", | |
filename=filename | |
) | |
async def list_tasks(): | |
""" | |
List all tasks with their current status | |
""" | |
return {"tasks": list(tasks.values())} | |
async def delete_task(task_id: str): | |
""" | |
Delete a task and its associated audio file | |
""" | |
if task_id not in tasks: | |
raise HTTPException(status_code=404, detail="Task not found") | |
task = tasks[task_id] | |
# Delete audio file if it exists | |
if task.get("result") and task["result"].get("filepath"): | |
filepath = task["result"]["filepath"] | |
if os.path.exists(filepath): | |
os.remove(filepath) | |
# Remove task from memory | |
del tasks[task_id] | |
return {"message": "Task deleted successfully"} | |
async def health_check(): | |
""" | |
Health check endpoint with API key status | |
""" | |
api_stats = api_key_manager.get_stats() | |
return { | |
"status": "healthy", | |
"timestamp": datetime.now().isoformat(), | |
"active_tasks": len([t for t in tasks.values() if t["status"] in ["queued", "processing"]]), | |
"total_tasks": len(tasks), | |
"api_keys": { | |
"total_configured": api_stats["total_keys"], | |
"available_keys": len([ | |
key for key, stats in api_stats["key_stats"].items() | |
if not stats["last_rate_limit"] or | |
(datetime.now().timestamp() - stats["last_rate_limit"]) > RATE_LIMIT_RETRY_DELAY | |
]), | |
"rate_limited_keys": len([ | |
key for key, stats in api_stats["key_stats"].items() | |
if stats["last_rate_limit"] and | |
(datetime.now().timestamp() - stats["last_rate_limit"]) <= RATE_LIMIT_RETRY_DELAY | |
]) | |
} | |
} | |
async def get_api_key_stats(): | |
""" | |
Get detailed statistics for all API keys | |
""" | |
stats = api_key_manager.get_stats() | |
# Mask API keys for security (show only last 4 characters) | |
masked_stats = {} | |
for key, data in stats["key_stats"].items(): | |
masked_key = f"***{key[-4:]}" if len(key) > 4 else "***" | |
masked_stats[masked_key] = { | |
**data, | |
"is_rate_limited": ( | |
data["last_rate_limit"] and | |
(datetime.now().timestamp() - data["last_rate_limit"]) <= RATE_LIMIT_RETRY_DELAY | |
) if data["last_rate_limit"] else False, | |
"time_until_available": max(0, RATE_LIMIT_RETRY_DELAY - ( | |
datetime.now().timestamp() - data["last_rate_limit"] | |
)) if data["last_rate_limit"] else 0 | |
} | |
return { | |
"total_keys": stats["total_keys"], | |
"key_statistics": masked_stats, | |
"rate_limit_settings": { | |
"retry_delay_seconds": RATE_LIMIT_RETRY_DELAY, | |
"max_retries_per_key": MAX_RETRIES_PER_KEY | |
} | |
} | |
# Cleanup old files periodically (you might want to implement this with a proper scheduler) | |
async def cleanup_old_files(): | |
""" | |
Clean up old audio files and completed tasks | |
""" | |
# This is a simple implementation - consider using APScheduler for production | |
pass | |
if __name__ == "__main__": | |
# Configuration for running the server | |
uvicorn.run( | |
"gemini_tts_api:app", | |
host="0.0.0.0", | |
port=8000, | |
reload=True, # Set to False in production | |
workers=1 # Use multiple workers in production with proper task storage | |
) |