Spaces:
Sleeping
Sleeping
# Import configuration first to setup environment | |
import app_config | |
import os | |
import sys | |
import io | |
import subprocess | |
import uuid | |
import time | |
import torch | |
import torchaudio | |
import tempfile | |
import logging | |
from typing import Optional | |
# Fix PyTorch weights_only issue for XTTS | |
import torch.serialization | |
from TTS.tts.configs.xtts_config import XttsConfig | |
torch.serialization.add_safe_globals([XttsConfig]) | |
# Set environment variables | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
os.environ["NUMBA_DISABLE_JIT"] = "1" | |
# Force CPU usage if specified | |
if os.environ.get("FORCE_CPU", "false").lower() == "true": | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
import langid | |
from scipy.io.wavfile import write | |
from pydub import AudioSegment | |
from TTS.api import TTS | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from TTS.utils.generic_utils import get_user_data_dir | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="XTTS C3PO API", description="Text-to-Speech API using XTTS-v2 C3PO model", version="1.0.0") | |
class TTSRequest(BaseModel): | |
text: str | |
language: str = "en" | |
voice_cleanup: bool = False | |
no_lang_auto_detect: bool = False | |
class XTTSService: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Use the C3PO model path | |
self.model_path = "XTTS-v2_C3PO/" | |
self.config_path = "XTTS-v2_C3PO/config.json" | |
# Check if model files exist, if not download them | |
if not os.path.exists(self.config_path): | |
logger.info("C3PO model not found locally, downloading...") | |
self._download_c3po_model() | |
# Load configuration | |
config = XttsConfig() | |
config.load_json(self.config_path) | |
# Initialize and load model | |
self.model = Xtts.init_from_config(config) | |
self.model.load_checkpoint( | |
config, | |
checkpoint_path=os.path.join(self.model_path, "model.pth"), | |
vocab_path=os.path.join(self.model_path, "vocab.json"), | |
eval=True, | |
) | |
if self.device == "cuda": | |
self.model.cuda() | |
self.supported_languages = config.languages | |
logger.info(f"XTTS C3PO model loaded successfully. Supported languages: {self.supported_languages}") | |
# Set default reference audio (C3PO voice) | |
self.default_reference = os.path.join(self.model_path, "reference.wav") | |
if not os.path.exists(self.default_reference): | |
# Look for any reference audio in the model directory | |
for file in os.listdir(self.model_path): | |
if file.endswith(('.wav', '.mp3', '.m4a')): | |
self.default_reference = os.path.join(self.model_path, file) | |
break | |
else: | |
self.default_reference = None | |
if self.default_reference: | |
logger.info(f"Default C3PO reference audio: {self.default_reference}") | |
else: | |
logger.warning("No default reference audio found in C3PO model directory") | |
def _download_c3po_model(self): | |
"""Download the C3PO model from Hugging Face""" | |
try: | |
logger.info("Downloading C3PO model from Hugging Face...") | |
subprocess.run([ | |
"git", "clone", | |
"https://huggingface.co/Borcherding/XTTS-v2_C3PO", | |
"XTTS-v2_C3PO" | |
], check=True) | |
logger.info("C3PO model downloaded successfully") | |
except subprocess.CalledProcessError as e: | |
logger.error(f"Failed to download C3PO model: {e}") | |
raise HTTPException(status_code=500, detail="Failed to download C3PO model") | |
def generate_speech(self, text: str, speaker_wav_path: str = None, language: str = "en", | |
voice_cleanup: bool = False, no_lang_auto_detect: bool = False) -> str: | |
"""Generate speech and return the path to the output file""" | |
try: | |
# Use default C3PO voice if no speaker file provided | |
if speaker_wav_path is None: | |
if self.default_reference is None: | |
raise HTTPException(status_code=400, detail="No reference audio available. Please upload a speaker file.") | |
speaker_wav_path = self.default_reference | |
logger.info("Using default C3PO voice") | |
# Validate language | |
if language not in self.supported_languages: | |
raise HTTPException(status_code=400, detail=f"Language '{language}' not supported. Supported: {self.supported_languages}") | |
# Language detection for longer texts | |
if len(text) > 15 and not no_lang_auto_detect: | |
language_predicted = langid.classify(text)[0].strip() | |
if language_predicted == "zh": | |
language_predicted = "zh-cn" | |
if language_predicted != language: | |
logger.warning(f"Detected language: {language_predicted}, chosen: {language}") | |
# Text length validation | |
if len(text) < 2: | |
raise HTTPException(status_code=400, detail="Text too short, please provide longer text") | |
if len(text) > 500: # Increased limit for API | |
raise HTTPException(status_code=400, detail="Text too long, maximum 500 characters") | |
# Voice cleanup if requested | |
processed_speaker_wav = speaker_wav_path | |
if voice_cleanup: | |
processed_speaker_wav = self._cleanup_audio(speaker_wav_path) | |
# Generate conditioning latents | |
try: | |
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( | |
audio_path=processed_speaker_wav, | |
gpt_cond_len=30, | |
gpt_cond_chunk_len=4, | |
max_ref_length=60 | |
) | |
except Exception as e: | |
logger.error(f"Speaker encoding error: {e}") | |
raise HTTPException(status_code=400, detail="Error processing reference audio. Please check the audio file.") | |
# Generate speech | |
logger.info("Generating speech...") | |
start_time = time.time() | |
out = self.model.inference( | |
text, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
repetition_penalty=5.0, | |
temperature=0.75, | |
) | |
inference_time = time.time() - start_time | |
logger.info(f"Speech generation completed in {inference_time:.2f} seconds") | |
# Save output | |
output_filename = f"xtts_c3po_output_{uuid.uuid4().hex}.wav" | |
output_path = os.path.join(tempfile.gettempdir(), output_filename) | |
torchaudio.save(output_path, torch.tensor(out["wav"]).unsqueeze(0), 24000) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error generating speech: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}") | |
def _cleanup_audio(self, audio_path: str) -> str: | |
"""Apply audio cleanup filters""" | |
try: | |
output_path = audio_path + "_cleaned.wav" | |
# Basic audio cleanup using ffmpeg-python or similar | |
# For now, just return the original path | |
# You can implement more sophisticated cleanup here | |
return audio_path | |
except Exception as e: | |
logger.warning(f"Audio cleanup failed: {e}, using original audio") | |
return audio_path | |
# Initialize XTTS service | |
logger.info("Initializing XTTS C3PO service...") | |
tts_service = XTTSService() | |
async def root(): | |
return {"message": "XTTS C3PO API is running", "status": "healthy", "model": "C3PO"} | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"device": tts_service.device, | |
"model": "XTTS-v2 C3PO", | |
"supported_languages": tts_service.supported_languages, | |
"default_voice": "C3PO" if tts_service.default_reference else "None" | |
} | |
async def get_languages(): | |
"""Get list of supported languages""" | |
return {"languages": tts_service.supported_languages} | |
async def text_to_speech( | |
text: str = Form(...), | |
language: str = Form("en"), | |
voice_cleanup: bool = Form(False), | |
no_lang_auto_detect: bool = Form(False), | |
speaker_file: UploadFile = File(None) | |
): | |
""" | |
Convert text to speech using XTTS C3PO voice cloning | |
- **text**: The text to convert to speech (max 500 characters) | |
- **language**: Language code (default: "en") | |
- **voice_cleanup**: Apply audio cleanup to reference voice | |
- **no_lang_auto_detect**: Disable automatic language detection | |
- **speaker_file**: Reference speaker audio file (optional, uses C3PO voice if not provided) | |
""" | |
if not text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
speaker_temp_path = None | |
try: | |
# Handle speaker file if provided | |
if speaker_file is not None: | |
# Validate file type | |
if not speaker_file.content_type.startswith('audio/'): | |
raise HTTPException(status_code=400, detail="Speaker file must be an audio file") | |
# Save uploaded speaker file temporarily | |
speaker_temp_path = os.path.join(tempfile.gettempdir(), f"speaker_{uuid.uuid4().hex}.wav") | |
with open(speaker_temp_path, "wb") as buffer: | |
content = await speaker_file.read() | |
buffer.write(content) | |
# Generate speech (will use C3PO voice if no speaker file provided) | |
output_path = tts_service.generate_speech( | |
text, | |
speaker_temp_path, | |
language, | |
voice_cleanup, | |
no_lang_auto_detect | |
) | |
# Clean up temporary speaker file | |
if speaker_temp_path and os.path.exists(speaker_temp_path): | |
try: | |
os.remove(speaker_temp_path) | |
except: | |
pass | |
# Return the generated audio file | |
voice_type = "custom" if speaker_file else "c3po" | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"xtts_{voice_type}_output_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
# Clean up files in case of error | |
if speaker_temp_path and os.path.exists(speaker_temp_path): | |
try: | |
os.remove(speaker_temp_path) | |
except: | |
pass | |
logger.error(f"Error in TTS endpoint: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def text_to_speech_json( | |
request: TTSRequest, | |
speaker_file: UploadFile = File(None) | |
): | |
""" | |
Convert text to speech using JSON request body | |
- **request**: TTSRequest containing text, language, and options | |
- **speaker_file**: Reference speaker audio file (optional, uses C3PO voice if not provided) | |
""" | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
speaker_temp_path = None | |
try: | |
# Handle speaker file if provided | |
if speaker_file is not None: | |
# Validate file type | |
if not speaker_file.content_type.startswith('audio/'): | |
raise HTTPException(status_code=400, detail="Speaker file must be an audio file") | |
# Save uploaded speaker file temporarily | |
speaker_temp_path = os.path.join(tempfile.gettempdir(), f"speaker_{uuid.uuid4().hex}.wav") | |
with open(speaker_temp_path, "wb") as buffer: | |
content = await speaker_file.read() | |
buffer.write(content) | |
# Generate speech | |
output_path = tts_service.generate_speech( | |
request.text, | |
speaker_temp_path, | |
request.language, | |
request.voice_cleanup, | |
request.no_lang_auto_detect | |
) | |
# Clean up temporary speaker file | |
if speaker_temp_path and os.path.exists(speaker_temp_path): | |
try: | |
os.remove(speaker_temp_path) | |
except: | |
pass | |
# Return the generated audio file | |
voice_type = "custom" if speaker_file else "c3po" | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"xtts_{voice_type}_{request.language}_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
# Clean up files in case of error | |
if speaker_temp_path and os.path.exists(speaker_temp_path): | |
try: | |
os.remove(speaker_temp_path) | |
except: | |
pass | |
logger.error(f"Error in TTS JSON endpoint: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def text_to_speech_c3po_only( | |
text: str = Form(...), | |
language: str = Form("en"), | |
no_lang_auto_detect: bool = Form(False) | |
): | |
""" | |
Convert text to speech using C3PO voice only (no file upload needed) | |
- **text**: The text to convert to speech (max 500 characters) | |
- **language**: Language code (default: "en") | |
- **no_lang_auto_detect**: Disable automatic language detection | |
""" | |
if not text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
try: | |
# Generate speech using C3PO voice | |
output_path = tts_service.generate_speech( | |
text, | |
None, # Use default C3PO voice | |
language, | |
False, # No voice cleanup needed for default voice | |
no_lang_auto_detect | |
) | |
# Return the generated audio file | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"c3po_voice_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
logger.error(f"Error in C3PO TTS endpoint: {e}") | |
if isinstance(e, HTTPException): | |
raise e | |
raise HTTPException(status_code=500, detail=str(e)) |