tts-api / test.py
Avinyaa
test
6a83fff
raw
history blame
4.71 kB
import os
import torch
import torchaudio
import subprocess
# Set environment variables for CPU-only usage
os.environ['COQUI_TOS_AGREED'] = '1'
os.environ['NUMBA_DISABLE_JIT'] = '1'
os.environ['FORCE_CPU'] = 'true'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# Fix PyTorch weights_only issue for XTTS
import torch.serialization
from TTS.tts.configs.xtts_config import XttsConfig
torch.serialization.add_safe_globals([XttsConfig])
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
print("Testing XTTS C3PO voice cloning...")
# C3PO model path
model_path = "XTTS-v2_C3PO/"
config_path = "XTTS-v2_C3PO/config.json"
# Check if model files exist, if not download them
if not os.path.exists(config_path):
print("C3PO model not found locally, downloading...")
try:
subprocess.run([
"git", "clone",
"https://huggingface.co/Borcherding/XTTS-v2_C3PO",
"XTTS-v2_C3PO"
], check=True)
print("C3PO model downloaded successfully")
except subprocess.CalledProcessError as e:
print(f"Failed to download C3PO model: {e}")
exit(1)
# Load configuration
config = XttsConfig()
config.load_json(config_path)
# Initialize and load model
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_path=os.path.join(model_path, "model.pth"),
vocab_path=os.path.join(model_path, "vocab.json"),
eval=True,
)
device = "cpu" # Force CPU usage
print(f"C3PO model loaded on {device} (forced CPU mode)")
# Text to convert to speech
text = "Hello there! I am C-3PO, human-cyborg relations. How may I assist you today?"
# Look for reference audio in the C3PO model directory
reference_audio_path = None
for file in os.listdir(model_path):
if file.endswith(('.wav', '.mp3', '.m4a')):
reference_audio_path = os.path.join(model_path, file)
print(f"Found C3PO reference audio: {file}")
break
# If no reference audio found, create a simple test reference
if reference_audio_path is None:
print("No reference audio found in C3PO model, creating test reference...")
reference_audio_path = "test_reference.wav"
# Generate a simple sine wave as placeholder
import numpy as np
sample_rate = 24000
duration = 3 # seconds
frequency = 440 # Hz
t = np.linspace(0, duration, int(sample_rate * duration))
audio_data = 0.3 * np.sin(2 * np.pi * frequency * t)
# Save as WAV
torchaudio.save(reference_audio_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
print(f"Test reference audio created: {reference_audio_path}")
try:
# Generate conditioning latents
print("Processing reference audio...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
audio_path=reference_audio_path,
gpt_cond_len=30,
gpt_cond_chunk_len=4,
max_ref_length=60
)
# Generate speech
print("Generating C3PO speech...")
out = model.inference(
text,
"en", # language
gpt_cond_latent,
speaker_embedding,
repetition_penalty=5.0,
temperature=0.75,
)
# Save output
output_path = "c3po_test_output.wav"
torchaudio.save(output_path, torch.tensor(out["wav"]).unsqueeze(0), 24000)
print(f"C3PO speech generated successfully! Saved as: {output_path}")
# Test multilingual capabilities
print("\nTesting multilingual C3PO...")
multilingual_tests = [
("es", "Hola, soy C-3PO. Domino más de seis millones de formas de comunicación."),
("fr", "Bonjour, je suis C-3PO. Je maîtrise plus de six millions de formes de communication."),
("de", "Hallo, ich bin C-3PO. Ich beherrsche über sechs Millionen Kommunikationsformen."),
]
for lang, test_text in multilingual_tests:
print(f"Generating {lang.upper()} speech...")
out = model.inference(
test_text,
lang,
gpt_cond_latent,
speaker_embedding,
repetition_penalty=5.0,
temperature=0.75,
)
output_path = f"c3po_test_{lang}.wav"
torchaudio.save(output_path, torch.tensor(out["wav"]).unsqueeze(0), 24000)
print(f"C3PO {lang.upper()} speech saved as: {output_path}")
except Exception as e:
print(f"Error during speech generation: {e}")
import traceback
traceback.print_exc()
print("XTTS C3PO test completed!")
print("\nGenerated files:")
for file in os.listdir("."):
if file.startswith("c3po_test") and file.endswith(".wav"):
print(f" - {file}")