tts-api / test_kokoro_install.py
Avinyaa
u
1b567fa
raw
history blame
2.48 kB
#!/usr/bin/env python3
"""
Simple test script to verify Kokoro TTS installation and functionality.
"""
import os
# Set basic environment variables
os.environ['NUMBA_DISABLE_JIT'] = '1'
def test_kokoro_import():
"""Test if Kokoro can be imported"""
try:
from kokoro import KPipeline
import soundfile as sf
import torch
print("βœ… All required packages imported successfully!")
return True
except ImportError as e:
print(f"❌ Import error: {e}")
return False
def test_kokoro_pipeline():
"""Test if Kokoro pipeline can be initialized"""
try:
from kokoro import KPipeline
pipeline = KPipeline(lang_code='a')
print("βœ… Kokoro pipeline initialized successfully!")
return True
except Exception as e:
print(f"❌ Pipeline initialization error: {e}")
return False
def test_kokoro_generation():
"""Test if Kokoro can generate speech"""
try:
from kokoro import KPipeline
import soundfile as sf
pipeline = KPipeline(lang_code='a')
text = "Hello, this is a test of Kokoro TTS."
generator = pipeline(text, voice='af_heart')
for i, (gs, ps, audio) in enumerate(generator):
print(f"βœ… Generated audio segment {i}: gs={gs}, ps={ps}")
# Save test audio
sf.write('test_kokoro.wav', audio, 24000)
print("βœ… Test audio saved as 'test_kokoro.wav'")
break # Just test the first segment
return True
except Exception as e:
print(f"❌ Speech generation error: {e}")
return False
def main():
"""Run all tests"""
print("🎀 Testing Kokoro TTS Installation")
print("=" * 40)
tests = [
("Import Test", test_kokoro_import),
("Pipeline Test", test_kokoro_pipeline),
("Generation Test", test_kokoro_generation)
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\nπŸ” Running {test_name}...")
if test_func():
passed += 1
else:
print(f"❌ {test_name} failed!")
print(f"\nπŸ“Š Results: {passed}/{total} tests passed")
if passed == total:
print("πŸŽ‰ All tests passed! Kokoro TTS is ready to use.")
else:
print("⚠️ Some tests failed. Please check the installation.")
if __name__ == "__main__":
main()