diff --git a/.env.template b/.env.template new file mode 100644 index 0000000000000000000000000000000000000000..286a87fc53c79069419c068e05f18a7b1de36c15 --- /dev/null +++ b/.env.template @@ -0,0 +1,41 @@ +HUGGINGFACE_TOKEN= TOKEN_GOES_HERE +OICE_NAME=af_nicole +SPEED=1.2 + +# LLM settings +LM_STUDIO_URL=http://localhost:1234/v1 +OLLAMA_URL = http://localhost:11434/api/chat +DEFAULT_SYSTEM_PROMPT=You are a friendly, helpful, and intelligent assistant. Begin your responses with phrases like 'Umm,' 'So,' or similar. Focus on the user query and reply directly to the user in the first person ('I'), responding promptly and naturally. Do not include any additional information or context in your responses. +MAX_TOKENS=512 +NUM_THREADS=2 +LLM_TEMPERATURE=0.9 +LLM_STREAM=true +LLM_RETRY_DELAY=0.5 +MAX_RETRIES=3 + +# Model names +VAD_MODEL=pyannote/segmentation-3.0 +WHISPER_MODEL=openai/whisper-tiny.en +LLM_MODEL=qwen2.5:0.5b-instruct-q8_0 +TTS_MODEL=kokoro.pth + +# VAD settings +VAD_MIN_DURATION_ON=0.1 +VAD_MIN_DURATION_OFF=0.1 + +# Audio settings +CHUNK=256 +FORMAT=pyaudio.paFloat32 +CHANNELS=1 +RATE=16000 +OUTPUT_SAMPLE_RATE=24000 +RECORD_DURATION=5 +SILENCE_THRESHOLD=0.01 +INTERRUPTION_THRESHOLD=0.01 +MAX_SILENCE_DURATION=1 +SPEECH_CHECK_TIMEOUT=0.1 +SPEECH_CHECK_THRESHOLD=0.02 +ROLLING_BUFFER_TIME=0.5 +TARGET_SIZE = 25 +PLAYBACK_DELAY = 0.001 +FIRST_SENTENCE_SIZE = 2 \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ac3e57d2cb5070e60542d6223675a5435a03d787 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +output/ +test/ +test.py +data/logs/ +examples/ +generated_audio/ +# C extensions +*.so +.vscode/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bb2740cbd36b782a7d8494ab69d29c24e0893c2d --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Abdullah Al Asif + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..919b107390bafd869877df08b9bbf2a56b95ce49 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,101 @@ ---- -license: mit ---- +# On Device Speech to Speech Conversational AI +![Kokoro-Conversational](assets/system_architecture.svg) + +This is realtime on-device speech-to-speech AI model. It used a series to tools to achieve that. It uses a combination of voice activity detection, speech recognition, language models, and text-to-speech synthesis to create a seamless and responsive conversational AI experience. The system is designed to run on-device, ensuring low latency and minimal data usage. + +

HOW TO RUN IT

+ +1. **Prerequisites:** + - Install Python 3.8+ (tested with 3.12) + - Install [eSpeak NG](https://github.com/espeak-ng/espeak-ng/releases/tag/1.52.0) (required for voice synthesis) + - Install Ollama from https://ollama.ai/ + +2. **Setup:** + - Clone the repository `git clone https://github.com/asiff00/On-Device-Speech-to-Speech-Conversational-AI.git` + - Run `git lfs pull` to download the models and voices + - Copy `.env.template` to `.env` + - Add your HuggingFace token to `.env` + - Twin other parameters there, if needed [Optional] + - Install requirements: `pip install -r requirements.txt` + - Add any missing packages if not already installed `pip install ` + + +4. **Run Ollama:** + - Start Ollama service + - Run: `ollama run qwen2.5:0.5b-instruct-q8_0` or any other model of your choice + +5. **Start Application:** + - Run: `python speech_to_speech.py` + - Wait for initialization (models loading) + - Start talking when you see "Voice Chat Bot Ready" + - Long press `Ctrl+C` to stop the application + + + +We basically put a few models together to work in a multi-threaded architecture, where each component operates independently but is integrated through a queue management system to ensure performance and responsiveness. + +## The flow works as follows: Loop (VAD -> Whisper -> LM -> TextChunker -> TTS) +To achieve that we use: +- **Voice Activity Detection**: Pyannote:pyannote/segmentation-3.0 +- **Speech Recognition**: Whisper:whisper-tiny.en (OpenAI) +- **Language Model**: LM Studio/Ollama with qwen2.5:0.5b-instruct-q8_0 +- **Voice Synthesis**: Kokoro:hexgrad/Kokoro-82M (Version 0.19, 16bit) + +We use custom text processing and queues to manage data, with separate queues for text and audio. This setup allows the system to handle heavy tasks without slowing down. We also use an interrupt mechanism allowing the user to interrupt the AI at any time. This makes the conversation feel more natural and responsive rather than just a generic TTS engine. + +## Demo Video: +A demo video is uploaded here. Either click on the thumbnail or click on the YouTube link: [https://youtu.be/x92FLnwf-nA](https://youtu.be/x92FLnwf-nA). + +[![On Device Speech to Speech AI Demo](https://img.youtube.com/vi/x92FLnwf-nA/0.jpg)](https://youtu.be/x92FLnwf-nA) + +## Performance: +![Timing Chart](assets/timing_chart.png) + +I ran this test on an AMD Ryzen 5600G, 16 GB, SSD, and No-GPU setup, achieving consistent ~2s latency. On average, it takes around 1.5s for the system to respond to a user query from the point the user says the last word. Although I haven't tested this on a GPU, I believe testing on a GPU would significantly improve performance and responsiveness. + +## How do we reduce latency? +### Priority based text chunking +We capitalize on the streaming output of the language model to reduce latency. Instead of waiting for the entire response to be generated, we process and deliver each chunk of text as soon as they become available, form phrases, and send it to the TTS engine queue. We play the audio as soon as it becomes available. This way, the user gets a very fast response, while the rest of the response is being generated. + +Our custom `TextChunker` analyzes incoming text streams from the language model and splits them into chunks suitable for the voice synthesizer. It uses a combination of sentence breaks (like periods, question marks, and exclamation points) and semantic breaks (like "and", "but", and "however") to determine the best places to split the text, ensuring natural-sounding speech output. + +The `TextChunker` maintains a set of break points: +- **Sentence breaks**: `.`, `!`, `?` (highest priority) +- **Semantic breaks** with priority levels: + - Level 4: `however`, `therefore`, `furthermore`, `moreover`, `nevertheless` + - Level 3: `while`, `although`, `unless`, `since` + - Level 2: `and`, `but`, `because`, `then` +- **Punctuation breaks**: `;` (4), `:` (4), `,` (3), `-` (2) + +When processing text, the `TextChunker` uses a priority-based system: +1. Looks for sentence-ending punctuation first (highest priority 5) +2. Checks for semantic break words with their associated priority levels +3. Falls back to punctuation marks with lower priorities +4. Splits at target word count if no natural breaks are found + +The text chunking method significantly reduces perceived latency by processing and delivering the first chunk of text as soon as it becomes available. Let's consider a hypothetical system where the language model generates responses at a certain rate. If we imagine a scenario where the model produces a response of N words at a rate of R words per second, waiting for the complete response would introduce a delay of N/R seconds before any audio is produced. With text chunking, the system can start processing the first M words as soon as they are ready (after M/R seconds), while the remaining words continue to be generated. This means the user hears the initial part of the response in just M/R seconds, while the rest streams in naturally. + +### Leading filler word LLM Prompting +We use a another little trick in the LLM prompt to speed up the system’s first response. We ask the LLM to start its reply with filler words like “umm,” “so,” or “well.” These words have a special role in language: they create natural pauses and breaks. Since these are single-word responses, they take only milliseconds to convert to audio. When we apply our chunking rules, the system splits the response at the filler word (e.g., “umm,”) and sends that tiny chunk to the TTS engine. This lets the bot play the audio for “umm” almost instantly, reducing perceived latency. The filler words act as natural “bridges” to mask processing delays. Even a short “umm” gives the illusion of a fluid conversation, while the system works on generating the rest of the response in the background. Longer chunks after the filler word might take more time to process, but the initial pause feels intentional and human-like. + +We have fallback plans for cases when the LLM fails to start its response with fillers. In those cases, we put hand breaks at 2 to 5 words, which comes with a cost of a bit of choppiness at the beginning but that feels less painful than the system taking a long time to give the first response. + +**In practice,** this approach can reduce perceived latency by up to 50-70%, depending on the length of the response and the speed of the language model. For example, in a typical conversation where responses average 15-20 words, our techniques can bring the initial response time down from 1.5-2 seconds to just `0.5-0.7` seconds, making the interaction feel much more natural and immediate. + + + +## Resources +This project utilizes the following resources: +* **Text-to-Speech Model:** [Kokoro](https://huggingface.co/hexgrad/Kokoro-82M) +* **Speech-to-Text Model:** [Whisper](https://huggingface.co/openai/whisper-tiny.en) +* **Voice Activity Detection Model:** [Pyannote](https://huggingface.co/pyannote/segmentation-3.0) +* **Large Language Model Server:** [Ollama](https://ollama.ai/) +* **Fallback Text-to-Speech Engine:** [eSpeak NG](https://github.com/espeak-ng/espeak-ng/releases/tag/1.52.0) + +## Acknowledgements +This project draws inspiration and guidance from the following articles and repositories, among others: +* [Realtime speech to speech conversation with MiniCPM-o](https://github.com/OpenBMB/MiniCPM-o) +* [A Comparative Guide to OpenAI and Ollama APIs](https://medium.com/@zakkyang/a-comparative-guide-to-openai-and-ollama-apis-with-cheathsheet-5aae6e515953) +* [Building Production-Ready TTS with Kokoro-82M](https://medium.com/@simeon.emanuilov/kokoro-82m-building-production-ready-tts-with-82m-parameters-unfoldai-98e36ff286b9) +* [Kokoro-82M: The Best TTS Model in Just 82 Million Parameters](https://medium.com/data-science-in-your-pocket/kokoro-82m-the-best-tts-model-in-just-82-million-parameters-512b4ba4f94c) +* [StyleTTS2 Model Implementation](https://github.com/yl4579/StyleTTS2/blob/main/models.py) diff --git a/assets/system_architecture.svg b/assets/system_architecture.svg new file mode 100644 index 0000000000000000000000000000000000000000..ba82ab9a49320935680d2deeb7754696383e886f --- /dev/null +++ b/assets/system_architecture.svg @@ -0,0 +1,4 @@ + + + +
System Architecture
🎤
User Voice
VAD
ASR
LLM
Chunker
Text
Queue
TTS
Audio
Queue
🔊
AI Voice
Interrupt
\ No newline at end of file diff --git a/assets/timing_chart.png b/assets/timing_chart.png new file mode 100644 index 0000000000000000000000000000000000000000..0ef7d8ab8d7af3915676754c0d019b3fd2018671 Binary files /dev/null and b/assets/timing_chart.png differ diff --git a/assets/video_demo.mov b/assets/video_demo.mov new file mode 100644 index 0000000000000000000000000000000000000000..3c05f97aa18d5bbdedbade19be4892ca45b7a614 --- /dev/null +++ b/assets/video_demo.mov @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4aa16650f035a094e65d759ac07e9050ccf22204f77816776b957cea203caf9c +size 11758861 diff --git a/data/models/kokoro.pth b/data/models/kokoro.pth new file mode 100644 index 0000000000000000000000000000000000000000..ca57d2b70255aaeaec33f4cc264b0ff395ef5f56 --- /dev/null +++ b/data/models/kokoro.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70cbf37f84610967f2ca72dadb95456fdd8b6c72cdd6dc7372c50f525889ff0c +size 163731194 diff --git a/data/voices/af.pt b/data/voices/af.pt new file mode 100644 index 0000000000000000000000000000000000000000..2e4bdbbe54f437d17c668b5f64faa16746759b88 --- /dev/null +++ b/data/voices/af.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fad4192fd8a840f925b0e3fc2be54e20531f91a9ac816a485b7992ca0bd83ebf +size 524355 diff --git a/data/voices/af_alloy.pt b/data/voices/af_alloy.pt new file mode 100644 index 0000000000000000000000000000000000000000..c0f72034c4e1369e4d17990442ffe35aeceb21f3 --- /dev/null +++ b/data/voices/af_alloy.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d877149dd8b348fbad12e5845b7e43d975390e9f3b68a811d1d86168bef5aa3 +size 523425 diff --git a/data/voices/af_aoede.pt b/data/voices/af_aoede.pt new file mode 100644 index 0000000000000000000000000000000000000000..047c8e54f2fac7b9175dd5a2f85ac45f1813a4e3 --- /dev/null +++ b/data/voices/af_aoede.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c03bd1a4c3716c2d8eaa3d50022f62d5c31cfbd6e15933a00b17fefe13841cc4 +size 523425 diff --git a/data/voices/af_bella.pt b/data/voices/af_bella.pt new file mode 100644 index 0000000000000000000000000000000000000000..0894c4dfa49b492b88026353659b9658e61c9218 --- /dev/null +++ b/data/voices/af_bella.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2828c6c2f94275ef3441a2edfcf48293298ee0f9b56ce70fb2e344345487b922 +size 524449 diff --git a/data/voices/af_bella_nicole.pt b/data/voices/af_bella_nicole.pt new file mode 100644 index 0000000000000000000000000000000000000000..11eaf0bd263ede34bfcea10a37a4172b70787bbb --- /dev/null +++ b/data/voices/af_bella_nicole.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d41525cea0e607c8c775adad8a81faa015d5ddafcbc66d9454c5c6aaef12137a +size 524623 diff --git a/data/voices/af_heart.pt b/data/voices/af_heart.pt new file mode 100644 index 0000000000000000000000000000000000000000..23a296174457c31b22b694f6e07e4e1b558122bf --- /dev/null +++ b/data/voices/af_heart.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ab5709b8ffab19bfd849cd11d98f75b60af7733253ad0d67b12382a102cb4ff +size 523425 diff --git a/data/voices/af_jessica.pt b/data/voices/af_jessica.pt new file mode 100644 index 0000000000000000000000000000000000000000..9740b43fd7ce43473fb5e673a16173136876ba04 --- /dev/null +++ b/data/voices/af_jessica.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdfdccb8cc975aa34ee6b89642963b0064237675de0e41a30ae64cc958dd4e87 +size 523435 diff --git a/data/voices/af_kore.pt b/data/voices/af_kore.pt new file mode 100644 index 0000000000000000000000000000000000000000..e5532ee773d500ac62cb2fdcc92e8847a867e8f8 --- /dev/null +++ b/data/voices/af_kore.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bfbc512321c3db49dff984ac675fa5ac7eaed5a96cc31104d3a9080e179d69d +size 523420 diff --git a/data/voices/af_nicole.pt b/data/voices/af_nicole.pt new file mode 100644 index 0000000000000000000000000000000000000000..e77e41c12c543ca0ae131801cfd194f45006032e --- /dev/null +++ b/data/voices/af_nicole.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9401802fb0b7080c324dec1a75d60f31d977ced600a99160e095dbc5a1172692 +size 524454 diff --git a/data/voices/af_nicole_sky.pt b/data/voices/af_nicole_sky.pt new file mode 100644 index 0000000000000000000000000000000000000000..105680a1b2bbcb92c1b2f193163c304e479b8860 --- /dev/null +++ b/data/voices/af_nicole_sky.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:587f36a3a2d9f295cd5538a923747be2fe398bbd81598896bac07bbdb7ff25b0 +size 524623 diff --git a/data/voices/af_nova.pt b/data/voices/af_nova.pt new file mode 100644 index 0000000000000000000000000000000000000000..4f781bca452e06e7acd8c5a5f80d73bf5abca5f7 --- /dev/null +++ b/data/voices/af_nova.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0233676ddc21908c37a1f102f6b88a59e4e5c1bd764983616eb9eda629dbcd2 +size 523420 diff --git a/data/voices/af_river.pt b/data/voices/af_river.pt new file mode 100644 index 0000000000000000000000000000000000000000..78a4f7b3aba66fa78e82916f73028a3f24d894b8 --- /dev/null +++ b/data/voices/af_river.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e149459bd9c084416b74756b9bd3418256a8b839088abb07d463730c369dab8f +size 523425 diff --git a/data/voices/af_sarah.pt b/data/voices/af_sarah.pt new file mode 100644 index 0000000000000000000000000000000000000000..ce3327f86256b40ab606671354391c77df21aba5 --- /dev/null +++ b/data/voices/af_sarah.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba7918c4ace6ace4221e7e01eb3a6d16596cba9729850551c758cd2ad3a4cd08 +size 524449 diff --git a/data/voices/af_sarah_nicole.pt b/data/voices/af_sarah_nicole.pt new file mode 100644 index 0000000000000000000000000000000000000000..c427a1278de0cfccc2b25144925c060888aa7589 --- /dev/null +++ b/data/voices/af_sarah_nicole.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa529793c4853a4107bb9857023a0ceb542466c664340ba0aeeb7c8570b2c51c +size 524623 diff --git a/data/voices/af_sky.pt b/data/voices/af_sky.pt new file mode 100644 index 0000000000000000000000000000000000000000..bf178060e7eaf614f93eee00534c20a0f948fbea --- /dev/null +++ b/data/voices/af_sky.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f16f1bb778de36a177ae4b0b6f1e59783d5f4d3bcecf752c3e1ee98299b335e +size 524375 diff --git a/data/voices/af_sky_adam.pt b/data/voices/af_sky_adam.pt new file mode 100644 index 0000000000000000000000000000000000000000..ae3b652eff0514d54c6fef62ddff311354814fd5 --- /dev/null +++ b/data/voices/af_sky_adam.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fa5978fab741ccd0d2a4992e34c85a7498f61062a665257a9d9b315dca327c3 +size 524464 diff --git a/data/voices/af_sky_emma.pt b/data/voices/af_sky_emma.pt new file mode 100644 index 0000000000000000000000000000000000000000..1e2b90b21ab79a4f07e5c8388d83da318c467fab --- /dev/null +++ b/data/voices/af_sky_emma.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfb3af5b8a0cbdd07d76fd201b572437ba2b048c03b65f2535a1f2810d01a99f +size 524464 diff --git a/data/voices/af_sky_emma_isabella.pt b/data/voices/af_sky_emma_isabella.pt new file mode 100644 index 0000000000000000000000000000000000000000..8ad84495bdbc8762099cfc01c650ec1b7610fdb5 --- /dev/null +++ b/data/voices/af_sky_emma_isabella.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12852daf302220b828a49a1d9089def6ff2b81fdab0a9ee500c66b0f37a2052f +size 524509 diff --git a/data/voices/am_adam.pt b/data/voices/am_adam.pt new file mode 100644 index 0000000000000000000000000000000000000000..1e812869b73ec2ebb4fec9893aeac470cf5e2e5a --- /dev/null +++ b/data/voices/am_adam.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1921528b400a553f66528c27899d95780918fe33b1ac7e2a871f6a0de475f176 +size 524444 diff --git a/data/voices/am_michael.pt b/data/voices/am_michael.pt new file mode 100644 index 0000000000000000000000000000000000000000..7acd7effc21d420948789cf186dee9ffb58ec557 --- /dev/null +++ b/data/voices/am_michael.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a255c9562c363103adc56c09b7daf837139d3bdaa8bd4dd74847ab1e3e8c28be +size 524459 diff --git a/data/voices/bf_alice.pt b/data/voices/bf_alice.pt new file mode 100644 index 0000000000000000000000000000000000000000..ae3a0a1db073ef34170583f2a55e8034bfa15932 --- /dev/null +++ b/data/voices/bf_alice.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d292651b6af6c0d81705c2580dcb4463fccc0ff7b8d618a471dbb4e45655b3f3 +size 523425 diff --git a/data/voices/bf_emma.pt b/data/voices/bf_emma.pt new file mode 100644 index 0000000000000000000000000000000000000000..012bef4d3425276e0e68a44c2ee86af7635b9eaa --- /dev/null +++ b/data/voices/bf_emma.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:992e6d8491b8926ef4a16205250e51a21d9924405a5d37e2db6e94adfd965c3b +size 524365 diff --git a/data/voices/bf_isabella.pt b/data/voices/bf_isabella.pt new file mode 100644 index 0000000000000000000000000000000000000000..67f0df78f0ab2dd1d9fd295284036d122d09b691 --- /dev/null +++ b/data/voices/bf_isabella.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0865a03931230100167f7a81d394b143c072efe2d7e4c4a87b5c54d6283f580 +size 524365 diff --git a/data/voices/bm_george.pt b/data/voices/bm_george.pt new file mode 100644 index 0000000000000000000000000000000000000000..5e058933596c8a0e6d125a08c3d26db547929d0c --- /dev/null +++ b/data/voices/bm_george.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d763dfe13e934357f4d8322b718787d79e32f2181e29ca0cf6aa637d8092b96 +size 524464 diff --git a/data/voices/bm_lewis.pt b/data/voices/bm_lewis.pt new file mode 100644 index 0000000000000000000000000000000000000000..ab94dac15af9ee3a67ef63629205a3b364629f75 --- /dev/null +++ b/data/voices/bm_lewis.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f70d9ea4d65f522f224628f06d86ea74279faae23bd7e765848a374aba916b76 +size 524449 diff --git a/data/voices/ef_dora.pt b/data/voices/ef_dora.pt new file mode 100644 index 0000000000000000000000000000000000000000..44cf5fc201c1f21711c24e46cc1053323959d9b5 --- /dev/null +++ b/data/voices/ef_dora.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9d69b0f8a2b87a345f269d89639f89dfbd1a6c9da0c498ae36dd34afcf35530 +size 523420 diff --git a/data/voices/if_sara.pt b/data/voices/if_sara.pt new file mode 100644 index 0000000000000000000000000000000000000000..b7a52593153d6711f7d6f245b9df7d4ad7eaf15d --- /dev/null +++ b/data/voices/if_sara.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c0b253b955fe32f1a1a86006aebe83d050ea95afd0e7be15182f087deedbf55 +size 523425 diff --git a/data/voices/jf_alpha.pt b/data/voices/jf_alpha.pt new file mode 100644 index 0000000000000000000000000000000000000000..90da6c92332481e222ddc5fa67c226f9ed4a7fdf --- /dev/null +++ b/data/voices/jf_alpha.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1bf4c9dc69e45ee46183b071f4db766349aac5592acbcfeaf051018048a5d787 +size 523425 diff --git a/data/voices/jf_gongitsune.pt b/data/voices/jf_gongitsune.pt new file mode 100644 index 0000000000000000000000000000000000000000..b806c5c2688b32b805ec266e29356f2d009ec7bf --- /dev/null +++ b/data/voices/jf_gongitsune.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b171917f18f351e65f2bf9657700cd6bfec4e65589c297525b9cf3c20105770 +size 523351 diff --git a/data/voices/pf_dora.pt b/data/voices/pf_dora.pt new file mode 100644 index 0000000000000000000000000000000000000000..887042c91adf283cb1a10b01f8ed9c50731e6cbb --- /dev/null +++ b/data/voices/pf_dora.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07e4ff987c5d5a8c3995efd15cc4f0db7c4c15e881b198d8ab7f67ecf51f5eb7 +size 523425 diff --git a/data/voices/zf_xiaoxiao.pt b/data/voices/zf_xiaoxiao.pt new file mode 100644 index 0000000000000000000000000000000000000000..009434acf1f3b1f6430a404af1612ba480b8344c --- /dev/null +++ b/data/voices/zf_xiaoxiao.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfaf6f2ded1ee56f1ff94fcd2b0e6cdf32e5b794bdc05b44e7439d44aef5887c +size 523440 diff --git a/data/voices/zf_xiaoyi.pt b/data/voices/zf_xiaoyi.pt new file mode 100644 index 0000000000000000000000000000000000000000..4eab21f6b047099939d8dcab0684409d76dcffbe --- /dev/null +++ b/data/voices/zf_xiaoyi.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5235dbaeef85a4c613bf78af9a88ff63c25bac5f26ba77e36186d8b7ebf05e2 +size 523430 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4ede041b989f47ee4b6da72715e7c4876cef799b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +phonemizer +torch +transformers +scipy +munch +sounddevice +python-multipart +soundfile +pydantic +requests +python-dotenv +numpy +pyaudio +pyannote.audio +torch_audiomentations +pydantic_settings \ No newline at end of file diff --git a/speech_to_speech.py b/speech_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..74cf63164cd769cdfb4c5aab7ab5dcb15547285d --- /dev/null +++ b/speech_to_speech.py @@ -0,0 +1,334 @@ +import msvcrt +import traceback +import time +import requests +import time +from transformers import WhisperProcessor, WhisperForConditionalGeneration +from src.utils.config import settings +from src.utils import ( + VoiceGenerator, + get_ai_response, + play_audio_with_interrupt, + init_vad_pipeline, + detect_speech_segments, + record_continuous_audio, + check_for_speech, + transcribe_audio, +) +from src.utils.audio_queue import AudioGenerationQueue +from src.utils.llm import parse_stream_chunk +import threading +from src.utils.text_chunker import TextChunker + +settings.setup_directories() +timing_info = { + "vad_start": None, + "transcription_start": None, + "llm_first_token": None, + "audio_queued": None, + "first_audio_play": None, + "playback_start": None, + "end": None, + "transcription_duration": None, +} + + +def process_input( + session: requests.Session, + user_input: str, + messages: list, + generator: VoiceGenerator, + speed: float, +) -> tuple[bool, None]: + """Processes user input, generates a response, and handles audio output. + + Args: + session (requests.Session): The requests session to use. + user_input (str): The user's input text. + messages (list): The list of messages to send to the LLM. + generator (VoiceGenerator): The voice generator object. + speed (float): The playback speed. + + Returns: + tuple[bool, None]: A tuple containing a boolean indicating if the process was interrupted and None. + """ + global timing_info + timing_info = {k: None for k in timing_info} + timing_info["vad_start"] = time.perf_counter() + + messages.append({"role": "user", "content": user_input}) + print("\nThinking...") + start_time = time.time() + try: + response_stream = get_ai_response( + session=session, + messages=messages, + llm_model=settings.LLM_MODEL, + llm_url=settings.OLLAMA_URL, + max_tokens=settings.MAX_TOKENS, + stream=True, + ) + + if not response_stream: + print("Failed to get AI response stream.") + return False, None + + audio_queue = AudioGenerationQueue(generator, speed) + audio_queue.start() + chunker = TextChunker() + complete_response = [] + + playback_thread = threading.Thread( + target=lambda: audio_playback_worker(audio_queue) + ) + playback_thread.daemon = True + playback_thread.start() + + for chunk in response_stream: + data = parse_stream_chunk(chunk) + if not data or "choices" not in data: + continue + + choice = data["choices"][0] + if "delta" in choice and "content" in choice["delta"]: + content = choice["delta"]["content"] + if content: + if not timing_info["llm_first_token"]: + timing_info["llm_first_token"] = time.perf_counter() + print(content, end="", flush=True) + chunker.current_text.append(content) + + text = "".join(chunker.current_text) + if chunker.should_process(text): + if not timing_info["audio_queued"]: + timing_info["audio_queued"] = time.perf_counter() + remaining = chunker.process(text, audio_queue) + chunker.current_text = [remaining] + complete_response.append(text[: len(text) - len(remaining)]) + + if choice.get("finish_reason") == "stop": + final_text = "".join(chunker.current_text).strip() + if final_text: + chunker.process(final_text, audio_queue) + complete_response.append(final_text) + break + + messages.append({"role": "assistant", "content": " ".join(complete_response)}) + print() + + time.sleep(0.1) + audio_queue.stop() + playback_thread.join() + + def playback_wrapper(): + timing_info["playback_start"] = time.perf_counter() + result = audio_playback_worker(audio_queue) + return result + + playback_thread = threading.Thread(target=playback_wrapper) + + timing_info["end"] = time.perf_counter() + print_timing_chart(timing_info) + return False, None + + except Exception as e: + print(f"\nError during streaming: {str(e)}") + if "audio_queue" in locals(): + audio_queue.stop() + return False, None + + +def audio_playback_worker(audio_queue) -> tuple[bool, None]: + """Manages audio playback in a separate thread, handling interruptions. + + Args: + audio_queue (AudioGenerationQueue): The audio queue object. + + Returns: + tuple[bool, None]: A tuple containing a boolean indicating if the playback was interrupted and the interrupt audio data. + """ + global timing_info + was_interrupted = False + interrupt_audio = None + + try: + while True: + speech_detected, audio_data = check_for_speech() + if speech_detected: + was_interrupted = True + interrupt_audio = audio_data + break + + audio_data, _ = audio_queue.get_next_audio() + if audio_data is not None: + if not timing_info["first_audio_play"]: + timing_info["first_audio_play"] = time.perf_counter() + + was_interrupted, interrupt_data = play_audio_with_interrupt(audio_data) + if was_interrupted: + interrupt_audio = interrupt_data + break + else: + time.sleep(settings.PLAYBACK_DELAY) + + if ( + not audio_queue.is_running + and audio_queue.sentence_queue.empty() + and audio_queue.audio_queue.empty() + ): + break + + except Exception as e: + print(f"Error in audio playback: {str(e)}") + + return was_interrupted, interrupt_audio + + +def main(): + """Main function to run the voice chat bot.""" + with requests.Session() as session: + try: + session = requests.Session() + generator = VoiceGenerator(settings.MODELS_DIR, settings.VOICES_DIR) + messages = [{"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}] + print("\nInitializing Whisper model...") + whisper_processor = WhisperProcessor.from_pretrained(settings.WHISPER_MODEL) + whisper_model = WhisperForConditionalGeneration.from_pretrained( + settings.WHISPER_MODEL + ) + print("\nInitializing Voice Activity Detection...") + vad_pipeline = init_vad_pipeline(settings.HUGGINGFACE_TOKEN) + print("\n=== Voice Chat Bot Initializing ===") + print("Device being used:", generator.device) + print("\nInitializing voice generator...") + result = generator.initialize(settings.TTS_MODEL, settings.VOICE_NAME) + print(result) + speed = settings.SPEED + try: + print("\nWarming up the LLM model...") + health = session.get("http://localhost:11434", timeout=3) + if health.status_code != 200: + print("Ollama not running! Start it first.") + return + response_stream = get_ai_response( + session=session, + messages=[ + {"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}, + {"role": "user", "content": "Hi!"}, + ], + llm_model=settings.LLM_MODEL, + llm_url=settings.OLLAMA_URL, + max_tokens=settings.MAX_TOKENS, + stream=False, + ) + if not response_stream: + print("Failed to initialized the AI model!") + return + except requests.RequestException as e: + print(f"Warmup failed: {str(e)}") + + print("\n\n=== Voice Chat Bot Ready ===") + print("The bot is now listening for speech.") + print("Just start speaking, and I'll respond automatically!") + print("You can interrupt me anytime by starting to speak.") + while True: + try: + if msvcrt.kbhit(): + user_input = input("\nYou (text): ").strip() + + if user_input.lower() == "quit": + print("Goodbye!") + break + + audio_data = record_continuous_audio() + if audio_data is not None: + speech_segments = detect_speech_segments( + vad_pipeline, audio_data + ) + + if speech_segments is not None: + print("\nTranscribing detected speech...") + timing_info["transcription_start"] = time.perf_counter() + + user_input = transcribe_audio( + whisper_processor, whisper_model, speech_segments + ) + + timing_info["transcription_duration"] = ( + time.perf_counter() - timing_info["transcription_start"] + ) + if user_input.strip(): + print(f"You (voice): {user_input}") + was_interrupted, speech_data = process_input( + session, user_input, messages, generator, speed + ) + if was_interrupted and speech_data is not None: + speech_segments = detect_speech_segments( + vad_pipeline, speech_data + ) + if speech_segments is not None: + print("\nTranscribing interrupted speech...") + user_input = transcribe_audio( + whisper_processor, + whisper_model, + speech_segments, + ) + if user_input.strip(): + print(f"You (voice): {user_input}") + process_input( + session, + user_input, + messages, + generator, + speed, + ) + else: + print("No clear speech detected, please try again.") + if session is not None: + session.headers.update({"Connection": "keep-alive"}) + if hasattr(session, "connection_pool"): + session.connection_pool.clear() + + except KeyboardInterrupt: + print("\nStopping...") + break + except Exception as e: + print(f"Error: {str(e)}") + continue + + except Exception as e: + print(f"Error: {str(e)}") + print("\nFull traceback:") + traceback.print_exc() + + +def print_timing_chart(metrics): + """Prints timing chart from global metrics""" + base_time = metrics["vad_start"] + events = [ + ("User stopped speaking", metrics["vad_start"]), + ("VAD started", metrics["vad_start"]), + ("Transcription started", metrics["transcription_start"]), + ("LLM first token", metrics["llm_first_token"]), + ("Audio queued", metrics["audio_queued"]), + ("First audio played", metrics["first_audio_play"]), + ("Playback started", metrics["playback_start"]), + ("End-to-end response", metrics["end"]), + ] + + print("\nTiming Chart:") + print(f"{'Event':<25} | {'Time (s)':>9} | {'Δ+':>6}") + print("-" * 45) + + prev_time = base_time + for name, t in events: + if t is None: + continue + elapsed = t - base_time + delta = t - prev_time + print(f"{name:<25} | {elapsed:9.2f} | {delta:6.2f}") + prev_time = t + + +if __name__ == "__main__": + main() diff --git a/src/config/config.json b/src/config/config.json new file mode 100644 index 0000000000000000000000000000000000000000..29e12f5e6f19d8b27dcdb2cd37e8b12fd89590c5 --- /dev/null +++ b/src/config/config.json @@ -0,0 +1,26 @@ +{ + "decoder": { + "type": "istftnet", + "upsample_kernel_sizes": [20, 12], + "upsample_rates": [10, 6], + "gen_istft_hop_size": 5, + "gen_istft_n_fft": 20, + "resblock_dilation_sizes": [ + [1, 3, 5], + [1, 3, 5], + [1, 3, 5] + ], + "resblock_kernel_sizes": [3, 7, 11], + "upsample_initial_channel": 512 + }, + "dim_in": 64, + "dropout": 0.2, + "hidden_dim": 512, + "max_conv_dim": 512, + "max_dur": 50, + "multispeaker": true, + "n_layer": 3, + "n_mels": 80, + "n_token": 178, + "style_dim": 128 +} \ No newline at end of file diff --git a/src/core/kokoro.py b/src/core/kokoro.py new file mode 100644 index 0000000000000000000000000000000000000000..ced375003a9cf65e138112b5c933051d8e957edb --- /dev/null +++ b/src/core/kokoro.py @@ -0,0 +1,156 @@ +import phonemizer +import os +import re +import torch +from dotenv import load_dotenv +load_dotenv() + +"""Initialize eSpeak environment variables. Must be called before any other imports.""" +os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll" +os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe" + +def split_num(num): + num = num.group() + if '.' in num: + return num + elif ':' in num: + h, m = [int(n) for n in num.split(':')] + if m == 0: + return f"{h} o'clock" + elif m < 10: + return f'{h} oh {m}' + return f'{h} {m}' + year = int(num[:4]) + if year < 1100 or year % 1000 < 10: + return num + left, right = num[:2], int(num[2:4]) + s = 's' if num.endswith('s') else '' + if 100 <= year % 1000 <= 999: + if right == 0: + return f'{left} hundred{s}' + elif right < 10: + return f'{left} oh {right}{s}' + return f'{left} {right}{s}' + +def flip_money(m): + m = m.group() + bill = 'dollar' if m[0] == '$' else 'pound' + if m[-1].isalpha(): + return f'{m[1:]} {bill}s' + elif '.' not in m: + s = '' if m[1:] == '1' else 's' + return f'{m[1:]} {bill}{s}' + b, c = m[1:].split('.') + s = '' if b == '1' else 's' + c = int(c.ljust(2, '0')) + coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence') + return f'{b} {bill}{s} and {c} {coins}' + +def point_num(num): + a, b = num.group().split('.') + return ' point '.join([a, ' '.join(b)]) + +def normalize_text(text): + text = text.replace(chr(8216), "'").replace(chr(8217), "'") + text = text.replace('«', chr(8220)).replace('»', chr(8221)) + text = text.replace(chr(8220), '"').replace(chr(8221), '"') + text = text.replace('(', '«').replace(')', '»') + for a, b in zip('、。!,:;?', ',.!,:;?'): + text = text.replace(a, b+' ') + text = re.sub(r'[^\S \n]', ' ', text) + text = re.sub(r' +', ' ', text) + text = re.sub(r'(?<=\n) +(?=\n)', '', text) + text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text) + text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text) + text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text) + text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text) + text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text) + text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text) + text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(? 510: + tokens = tokens[:510] + print('Truncated to 510 tokens') + ref_s = voicepack[len(tokens)] + out = forward(model, tokens, ref_s, speed) + ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens) + return out, ps diff --git a/src/models/istftnet.py b/src/models/istftnet.py new file mode 100644 index 0000000000000000000000000000000000000000..da29481368de41ce2a3ff9816c9bd3f11f3ab15e --- /dev/null +++ b/src/models/istftnet.py @@ -0,0 +1,523 @@ +# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py +from scipy.signal import get_window +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + +LRELU_SLOPE = 0.1 + +class AdaIN1d(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm1d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class AdaINResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64): + super(AdaINResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.adain1 = nn.ModuleList([ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ]) + + self.adain2 = nn.ModuleList([ + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + AdaIN1d(style_dim, channels), + ]) + + self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]) + self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]) + + + def forward(self, x, s): + for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2): + xt = n1(x, s) + xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D + xt = c1(xt) + xt = n2(xt, s) + xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + +class TorchSTFT(torch.nn.Module): + def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32)) + + def transform(self, input_data): + forward_transform = torch.stft( + input_data, + self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device), + return_complex=True) + + return torch.abs(forward_transform), torch.angle(forward_transform) + + def inverse(self, magnitude, phase): + inverse_transform = torch.istft( + magnitude * torch.exp(phase * 1j), + self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device)) + + return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: +# # for normal case + +# # To prevent torch.cumsum numerical overflow, +# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. +# # Buffer tmp_over_one_idx indicates the time step to add -1. +# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi +# tmp_over_one = torch.cumsum(rad_values, 1) % 1 +# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 +# cumsum_shift = torch.zeros_like(rad_values) +# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + +# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2), + scale_factor=1/self.upsample_scale, + mode="linear").transpose(1, 2) + +# tmp_over_one = torch.cumsum(rad_values, 1) % 1 +# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 +# cumsum_shift = torch.zeros_like(rad_values) +# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi + phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale, + scale_factor=self.upsample_scale, mode="linear").transpose(1, 2) + sines = torch.sin(phase) + + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, + device=f0.device) + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv +def padDiff(x): + return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) + + +class Generator(torch.nn.Module): + def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size): + super(Generator, self).__init__() + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + resblock = AdaINResBlock1 + + self.m_source = SourceModuleHnNSF( + sampling_rate=24000, + upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size, + harmonic_num=8, voiced_threshod=10) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size) + self.noise_convs = nn.ModuleList() + self.noise_res = nn.ModuleList() + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d, style_dim)) + + c_cur = upsample_initial_channel // (2 ** (i + 1)) + + if i + 1 < len(upsample_rates): # + stride_f0 = np.prod(upsample_rates[i + 1:]) + self.noise_convs.append(Conv1d( + gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) + self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim)) + else: + self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)) + self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim)) + + + self.post_n_fft = gen_istft_n_fft + self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) + self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft) + + + def forward(self, x, s, f0): + with torch.no_grad(): + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2).squeeze(1) + har_spec, har_phase = self.stft.transform(har_source) + har = torch.cat([har_spec, har_phase], dim=1) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x_source = self.noise_convs[i](har) + x_source = self.noise_res[i](x_source, s) + + x = self.ups[i](x) + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x, s) + else: + xs += self.resblocks[i*self.num_kernels+j](x, s) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) + return self.stft.inverse(spec, phase) + + def fw_phase(self, x, s): + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x, s) + else: + xs += self.resblocks[i*self.num_kernels+j](x, s) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.reflection_pad(x) + x = self.conv_post(x) + spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) + return spec, phase + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class AdainResBlk1d(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), + upsample='none', dropout_p=0.0): + super().__init__() + self.actv = actv + self.upsample_type = upsample + self.upsample = UpSample1d(upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + self.dropout = nn.Dropout(dropout_p) + + if upsample == 'none': + self.pool = nn.Identity() + else: + self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) + + + def _build_weights(self, dim_in, dim_out, style_dim): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) + self.norm1 = AdaIN1d(style_dim, dim_in) + self.norm2 = AdaIN1d(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + x = self.pool(x) + x = self.conv1(self.dropout(x)) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(self.dropout(x)) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / np.sqrt(2) + return out + +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + else: + return F.interpolate(x, scale_factor=2, mode='nearest') + +class Decoder(nn.Module): + def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80, + resblock_kernel_sizes = [3,7,11], + upsample_rates = [10, 6], + upsample_initial_channel=512, + resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]], + upsample_kernel_sizes=[20, 12], + gen_istft_n_fft=20, gen_istft_hop_size=5): + super().__init__() + + self.decode = nn.ModuleList() + + self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim) + + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim)) + self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True)) + + self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) + + self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) + + self.asr_res = nn.Sequential( + weight_norm(nn.Conv1d(512, 64, kernel_size=1)), + ) + + + self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, + upsample_initial_channel, resblock_dilation_sizes, + upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size) + + def forward(self, asr, F0_curve, N, s): + F0 = self.F0_conv(F0_curve.unsqueeze(1)) + N = self.N_conv(N.unsqueeze(1)) + + x = torch.cat([asr, F0, N], axis=1) + x = self.encode(x, s) + + asr_res = self.asr_res(asr) + + res = True + for block in self.decode: + if res: + x = torch.cat([x, asr_res, F0, N], axis=1) + x = block(x, s) + if block.upsample_type != "none": + res = False + + x = self.generator(x, s, F0_curve) + return x diff --git a/src/models/models.py b/src/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..068c61d7acdbadca35595fa4a8a0146e3c835e93 --- /dev/null +++ b/src/models/models.py @@ -0,0 +1,372 @@ +# https://github.com/yl4579/StyleTTS2/blob/main/models.py +from .istftnet import AdaIN1d, Decoder +from munch import Munch +from pathlib import Path +from .plbert import load_plbert +from torch.nn.utils import weight_norm, spectral_norm +import json +import numpy as np +import os +import os.path as osp +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + +class TextEncoder(nn.Module): + def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)): + super().__init__() + self.embedding = nn.Embedding(n_symbols, channels) + + padding = (kernel_size - 1) // 2 + self.cnn = nn.ModuleList() + for _ in range(depth): + self.cnn.append(nn.Sequential( + weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)), + LayerNorm(channels), + actv, + nn.Dropout(0.2), + )) + # self.cnn = nn.Sequential(*self.cnn) + + self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True) + + def forward(self, x, input_lengths, m): + x = self.embedding(x) # [B, T, emb] + x = x.transpose(1, 2) # [B, emb, T] + m = m.to(input_lengths.device).unsqueeze(1) + x.masked_fill_(m, 0.0) + + for c in self.cnn: + x = c(x) + x.masked_fill_(m, 0.0) + + x = x.transpose(1, 2) # [B, T, chn] + + input_lengths = input_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True, enforce_sorted=False) + + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence( + x, batch_first=True) + + x = x.transpose(-1, -2) + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) + + x_pad[:, :, :x.shape[-1]] = x + x = x_pad.to(x.device) + + x.masked_fill_(m, 0.0) + + return x + + def inference(self, x): + x = self.embedding(x) + x = x.transpose(1, 2) + x = self.cnn(x) + x = x.transpose(1, 2) + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + return x + + def length_to_mask(self, lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + + +class UpSample1d(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + else: + return F.interpolate(x, scale_factor=2, mode='nearest') + +class AdainResBlk1d(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), + upsample='none', dropout_p=0.0): + super().__init__() + self.actv = actv + self.upsample_type = upsample + self.upsample = UpSample1d(upsample) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + self.dropout = nn.Dropout(dropout_p) + + if upsample == 'none': + self.pool = nn.Identity() + else: + self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1)) + + + def _build_weights(self, dim_in, dim_out, style_dim): + self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1)) + self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1)) + self.norm1 = AdaIN1d(style_dim, dim_in) + self.norm2 = AdaIN1d(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + x = self.upsample(x) + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + x = self.pool(x) + x = self.conv1(self.dropout(x)) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(self.dropout(x)) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / np.sqrt(2) + return out + +class AdaLayerNorm(nn.Module): + def __init__(self, style_dim, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.fc = nn.Linear(style_dim, channels*2) + + def forward(self, x, s): + x = x.transpose(-1, -2) + x = x.transpose(1, -1) + + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1) + + + x = F.layer_norm(x, (self.channels,), eps=self.eps) + x = (1 + gamma) * x + beta + return x.transpose(1, -1).transpose(-1, -2) + +class ProsodyPredictor(nn.Module): + + def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1): + super().__init__() + + self.text_encoder = DurationEncoder(sty_dim=style_dim, + d_model=d_hid, + nlayers=nlayers, + dropout=dropout) + + self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) + self.duration_proj = LinearNorm(d_hid, max_dur) + + self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True) + self.F0 = nn.ModuleList() + self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) + self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) + self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) + + self.N = nn.ModuleList() + self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout)) + self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout)) + self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)) + + self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) + self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0) + + + def forward(self, texts, style, text_lengths, alignment, m): + d = self.text_encoder(texts, style, text_lengths, m) + + batch_size = d.shape[0] + text_size = d.shape[1] + + # predict duration + input_lengths = text_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + d, input_lengths, batch_first=True, enforce_sorted=False) + + m = m.to(text_lengths.device).unsqueeze(1) + + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + x, _ = nn.utils.rnn.pad_packed_sequence( + x, batch_first=True) + + x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]]) + + x_pad[:, :x.shape[1], :] = x + x = x_pad.to(x.device) + + duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training)) + + en = (d.transpose(-1, -2) @ alignment) + + return duration.squeeze(-1), en + + def F0Ntrain(self, x, s): + x, _ = self.shared(x.transpose(-1, -2)) + + F0 = x.transpose(-1, -2) + for block in self.F0: + F0 = block(F0, s) + F0 = self.F0_proj(F0) + + N = x.transpose(-1, -2) + for block in self.N: + N = block(N, s) + N = self.N_proj(N) + + return F0.squeeze(1), N.squeeze(1) + + def length_to_mask(self, lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +class DurationEncoder(nn.Module): + + def __init__(self, sty_dim, d_model, nlayers, dropout=0.1): + super().__init__() + self.lstms = nn.ModuleList() + for _ in range(nlayers): + self.lstms.append(nn.LSTM(d_model + sty_dim, + d_model // 2, + num_layers=1, + batch_first=True, + bidirectional=True, + dropout=dropout)) + self.lstms.append(AdaLayerNorm(sty_dim, d_model)) + + + self.dropout = dropout + self.d_model = d_model + self.sty_dim = sty_dim + + def forward(self, x, style, text_lengths, m): + masks = m.to(text_lengths.device) + + x = x.permute(2, 0, 1) + s = style.expand(x.shape[0], x.shape[1], -1) + x = torch.cat([x, s], axis=-1) + x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0) + + x = x.transpose(0, 1) + input_lengths = text_lengths.cpu().numpy() + x = x.transpose(-1, -2) + + for block in self.lstms: + if isinstance(block, AdaLayerNorm): + x = block(x.transpose(-1, -2), style).transpose(-1, -2) + x = torch.cat([x, s.permute(1, -1, 0)], axis=1) + x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0) + else: + x = x.transpose(-1, -2) + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True, enforce_sorted=False) + block.flatten_parameters() + x, _ = block(x) + x, _ = nn.utils.rnn.pad_packed_sequence( + x, batch_first=True) + x = F.dropout(x, p=self.dropout, training=self.training) + x = x.transpose(-1, -2) + + x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]]) + + x_pad[:, :, :x.shape[-1]] = x + x = x_pad.to(x.device) + + return x.transpose(-1, -2) + + def inference(self, x, style): + x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model) + style = style.expand(x.shape[0], x.shape[1], -1) + x = torch.cat([x, style], axis=-1) + src = self.pos_encoder(x) + output = self.transformer_encoder(src).transpose(0, 1) + return output + + def length_to_mask(self, lengths): + mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + +# https://github.com/yl4579/StyleTTS2/blob/main/utils.py +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d + +def build_model(path, device): + config = Path(__file__).parent.parent / 'config' / 'config.json' + assert config.exists(), f'Config path incorrect: config.json not found at {config}' + with open(config, 'r') as r: + args = recursive_munch(json.load(r)) + assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}' + decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, + resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, + upsample_rates = args.decoder.upsample_rates, + upsample_initial_channel=args.decoder.upsample_initial_channel, + resblock_dilation_sizes=args.decoder.resblock_dilation_sizes, + upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, + gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size) + text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token) + predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout) + bert = load_plbert() + bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim) + for parent in [bert, bert_encoder, predictor, decoder, text_encoder]: + for child in parent.children(): + if isinstance(child, nn.RNNBase): + child.flatten_parameters() + model = Munch( + bert=bert.to(device).eval(), + bert_encoder=bert_encoder.to(device).eval(), + predictor=predictor.to(device).eval(), + decoder=decoder.to(device).eval(), + text_encoder=text_encoder.to(device).eval(), + ) + for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items(): + assert key in model, key + try: + model[key].load_state_dict(state_dict) + except: + state_dict = {k[7:]: v for k, v in state_dict.items()} + model[key].load_state_dict(state_dict, strict=False) + return model diff --git a/src/models/plbert.py b/src/models/plbert.py new file mode 100644 index 0000000000000000000000000000000000000000..ef54f57bb8405abebfcd052bcb2be1249ce510bc --- /dev/null +++ b/src/models/plbert.py @@ -0,0 +1,15 @@ +# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py +from transformers import AlbertConfig, AlbertModel + +class CustomAlbert(AlbertModel): + def forward(self, *args, **kwargs): + # Call the original forward method + outputs = super().forward(*args, **kwargs) + # Only return the last_hidden_state + return outputs.last_hidden_state + +def load_plbert(): + plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1} + albert_base_configuration = AlbertConfig(**plbert_config) + bert = CustomAlbert(albert_base_configuration) + return bert diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98c576df9a33d3e03616b23fefefef40f381734c --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,35 @@ +from .audio import play_audio +from .voice import load_voice, quick_mix_voice, split_into_sentences +from .generator import VoiceGenerator +from .llm import filter_response, get_ai_response +from .audio_utils import save_audio_file, generate_and_play_sentences +from .commands import handle_commands +from .speech import ( + init_vad_pipeline, detect_speech_segments, record_audio, + record_continuous_audio, check_for_speech, play_audio_with_interrupt, + transcribe_audio +) +from .config import settings +from .text_chunker import TextChunker + +__all__ = [ + 'play_audio', + 'load_voice', + 'quick_mix_voice', + 'split_into_sentences', + 'VoiceGenerator', + 'filter_response', + 'get_ai_response', + 'save_audio_file', + 'generate_and_play_sentences', + 'handle_commands', + 'init_vad_pipeline', + 'detect_speech_segments', + 'record_audio', + 'record_continuous_audio', + 'check_for_speech', + 'play_audio_with_interrupt', + 'transcribe_audio', + 'settings', + 'TextChunker', +] \ No newline at end of file diff --git a/src/utils/audio.py b/src/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..588a748236e5cea6a0729606a3bc50a7b777bbbb --- /dev/null +++ b/src/utils/audio.py @@ -0,0 +1,42 @@ +import numpy as np +import sounddevice as sd +import time + + +def play_audio(audio_data: np.ndarray, sample_rate: int = 24000): + """ + Play audio directly using sounddevice. + + Args: + audio_data (np.ndarray): The audio data to play. + sample_rate (int, optional): The sample rate of the audio data. Defaults to 24000. + """ + try: + sd.play(audio_data, sample_rate) + sd.wait() + except Exception as e: + print(f"Error playing audio: {str(e)}") + + +def stream_audio_chunks( + audio_chunks: list, sample_rate: int = 24000, pause_duration: float = 0.2 +): + """ + Stream audio chunks one after another with a small pause between them. + + Args: + audio_chunks (list): A list of audio chunks to play. + sample_rate (int, optional): The sample rate of the audio data. Defaults to 24000. + pause_duration (float, optional): The duration of the pause between chunks in seconds. Defaults to 0.2. + """ + try: + for chunk in audio_chunks: + if len(chunk) == 0: + continue + sd.play(chunk, sample_rate) + sd.wait() + time.sleep(pause_duration) + except Exception as e: + print(f"Error streaming audio chunks: {str(e)}") + finally: + sd.stop() diff --git a/src/utils/audio_io.py b/src/utils/audio_io.py new file mode 100644 index 0000000000000000000000000000000000000000..2b6bcd56896bb31e7531f9738e22a629df2e48ca --- /dev/null +++ b/src/utils/audio_io.py @@ -0,0 +1,48 @@ +import numpy as np +import soundfile as sf +import sounddevice as sd +from datetime import datetime +from pathlib import Path +from typing import Tuple, Optional + + +def save_audio_file( + audio_data: np.ndarray, output_dir: Path, sample_rate: int = 24000 +) -> Path: + """ + Save audio data to a WAV file with a timestamp in the filename. + + Args: + audio_data (np.ndarray): The audio data to save. Can be a single array or a list of arrays. + output_dir (Path): The directory to save the audio file in. + sample_rate (int, optional): The sample rate of the audio data. Defaults to 24000. + + Returns: + Path: The path to the saved audio file. + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = output_dir / f"output_{timestamp}.wav" + + if isinstance(audio_data, list): + audio_data = np.concatenate(audio_data) + + sf.write(str(output_path), audio_data, sample_rate) + return output_path + + +def play_audio( + audio_data: np.ndarray, sample_rate: int = 24000 +) -> Tuple[bool, Optional[np.ndarray]]: + """ + Play audio data using sounddevice. + + Args: + audio_data (np.ndarray): The audio data to play. + sample_rate (int, optional): The sample rate of the audio data. Defaults to 24000. + + Returns: + Tuple[bool, Optional[np.ndarray]]: A tuple containing a boolean indicating if the playback was interrupted (always False here) and an optional numpy array representing the interrupted audio (always None here). + """ + sd.play(audio_data, sample_rate) + sd.wait() + return False, None diff --git a/src/utils/audio_queue.py b/src/utils/audio_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2fca44926b5f59532068530735a93e85e376f2 --- /dev/null +++ b/src/utils/audio_queue.py @@ -0,0 +1,180 @@ +from queue import Queue +import threading +import time +from typing import Optional, Tuple, List +import numpy as np +from pathlib import Path +import logging +from datetime import datetime +from .audio_io import save_audio_file + +logging.getLogger("phonemizer").setLevel(logging.ERROR) +logging.getLogger("speechbrain.utils.quirks").setLevel(logging.ERROR) +logging.basicConfig(format="%(message)s", level=logging.INFO) + + +class AudioGenerationQueue: + """ + A queue system for managing asynchronous audio generation from text input. + + This class implements a threaded queue system that handles text-to-audio generation + in a background thread. It provides functionality for adding sentences to be processed, + retrieving generated audio, and monitoring the generation process. + + Attributes: + generator: Audio generator instance used for text-to-speech conversion + speed (float): Speed multiplier for audio generation + output_dir (Path): Directory where generated audio files are saved + sentences_processed (int): Count of processed sentences + audio_generated (int): Count of successfully generated audio files + failed_sentences (list): List of tuples containing failed sentences and error messages + """ + + def __init__( + self, generator, speed: float = 1.0, output_dir: Optional[Path] = None + ): + """ + Initialize the audio generation queue system. + + Args: + generator: Audio generator instance for text-to-speech conversion + speed: Speed multiplier for audio generation (default: 1.0) + output_dir: Directory path for saving generated audio files (default: "generated_audio") + """ + self.generator = generator + self.speed = speed + self.lock = threading.Lock() + self.output_dir = output_dir or Path("generated_audio") + self.output_dir.mkdir(exist_ok=True) + self.sentence_queue = Queue() + self.audio_queue = Queue() + self.is_running = False + self.generation_thread = None + self.sentences_processed = 0 + self.audio_generated = 0 + self.failed_sentences = [] + + def start(self): + """ + Start the audio generation thread if not already running. + The thread will process sentences from the queue until stopped. + """ + if not self.is_running: + self.is_running = True + self.generation_thread = threading.Thread(target=self._generation_worker) + self.generation_thread.daemon = True + self.generation_thread.start() + + def stop(self): + """ + Stop the audio generation thread gracefully. + Waits for the current queue to be processed before stopping. + Outputs final processing statistics. + """ + if self.generation_thread: + while not self.sentence_queue.empty(): + time.sleep(0.1) + + time.sleep(0.5) + + self.is_running = False + self.generation_thread.join() + self.generation_thread = None + + logging.info( + f"\nAudio Generation Complete - Processed: {self.sentences_processed}, Generated: {self.audio_generated}, Failed: {len(self.failed_sentences)}" + ) + + def add_sentences(self, sentences: List[str]): + """ + Add a list of sentences to the generation queue. + + Args: + sentences: List of text strings to be converted to audio + """ + added_count = 0 + for sentence in sentences: + sentence = sentence.strip() + if sentence: + self.sentence_queue.put(sentence) + added_count += 1 + + if not self.is_running: + self.start() + + def get_next_audio(self) -> Tuple[Optional[np.ndarray], Optional[Path]]: + """ + Retrieve the next generated audio segment from the queue. + + Returns: + Tuple containing: + - numpy array of audio data (or None if queue is empty) + - Path object for the saved audio file (or None if queue is empty) + """ + try: + audio_data, output_path = self.audio_queue.get_nowait() + return audio_data, output_path + except: + return None, None + + def clear_queues(self): + """ + Clear both sentence and audio queues, removing all pending items. + Returns immediately without waiting for queue processing. + """ + sentences_cleared = 0 + audio_cleared = 0 + + while not self.sentence_queue.empty(): + try: + self.sentence_queue.get_nowait() + sentences_cleared += 1 + except: + pass + + while not self.audio_queue.empty(): + try: + self.audio_queue.get_nowait() + audio_cleared += 1 + except: + pass + + def _generation_worker(self): + """ + Internal worker method that runs in a separate thread. + Continuously processes sentences from the queue, generating audio + and handling any errors that occur during generation. + """ + while self.is_running or not self.sentence_queue.empty(): + try: + try: + sentence = self.sentence_queue.get_nowait() + self.sentences_processed += 1 + except: + if not self.is_running and self.sentence_queue.empty(): + break + time.sleep(0.01) + continue + + try: + audio_data, phonemes = self.generator.generate( + sentence, speed=self.speed + ) + + if audio_data is None or len(audio_data) == 0: + raise ValueError("Generated audio data is empty") + + output_path = save_audio_file(audio_data, self.output_dir) + self.audio_generated += 1 + + self.audio_queue.put((audio_data, output_path)) + + except Exception as e: + error_msg = str(e) + self.failed_sentences.append((sentence, error_msg)) + continue + + except Exception as e: + if not self.is_running and self.sentence_queue.empty(): + break + time.sleep(0.1) diff --git a/src/utils/audio_utils.py b/src/utils/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3250df42430ca2df14ec632bf653ca85809eb644 --- /dev/null +++ b/src/utils/audio_utils.py @@ -0,0 +1,83 @@ +import time +from pathlib import Path +from typing import List, Optional, Tuple, Callable +import numpy as np +from .audio_io import save_audio_file, play_audio +from .audio_queue import AudioGenerationQueue + + +def generate_and_play_sentences( + sentences: List[str], + generator, + speed: float = 1.0, + play_function: Callable = play_audio, + check_interrupt: Optional[Callable] = None, + output_dir: Optional[Path] = None, + sample_rate: Optional[int] = None, +) -> Tuple[bool, Optional[np.ndarray], List[Path]]: + """ + Generates and plays audio for each sentence with optional interruption checking. + + Args: + sentences (List[str]): A list of sentences to generate audio for. + generator: The audio generator object. + speed (float, optional): The speed of audio generation. Defaults to 1.0. + play_function (Callable, optional): The function to use for playing audio. Defaults to play_audio. + check_interrupt (Callable, optional): An optional function to check for interruptions. Defaults to None. + output_dir (Path, optional): The directory to save generated audio files. Defaults to None. + sample_rate (int, optional): The sample rate of the audio. Defaults to None. + + Returns: + Tuple[bool, Optional[np.ndarray], List[Path]]: A tuple containing: + - A boolean indicating if the process was interrupted. + - Optional audio data if the process was interrupted. + - A list of paths to the generated audio files. + """ + audio_queue = AudioGenerationQueue(generator, speed, output_dir) + audio_queue.start() + audio_queue.add_sentences(sentences) + + audio_files = [] + was_interrupted = False + interrupt_audio = None + + try: + while True: + if check_interrupt: + interrupted, audio_data = check_interrupt() + if interrupted: + was_interrupted = True + interrupt_audio = audio_data + break + + audio_data, output_path = audio_queue.get_next_audio() + + if audio_data is not None: + if output_path: + audio_files.append(output_path) + + if play_function: + try: + was_interrupted, interrupt_data = ( + play_function(audio_data, sample_rate) + if sample_rate + else play_function(audio_data) + ) + if was_interrupted: + interrupt_audio = interrupt_data + break + except Exception as e: + print(f"Error playing audio: {str(e)}") + continue + + if audio_queue.sentence_queue.empty() and audio_queue.audio_queue.empty(): + break + + time.sleep(0.01) + + except Exception as e: + print(f"Error in generate_and_play_sentences: {str(e)}") + finally: + audio_queue.stop() + + return was_interrupted, interrupt_audio, audio_files diff --git a/src/utils/commands.py b/src/utils/commands.py new file mode 100644 index 0000000000000000000000000000000000000000..3d375fc9126a57927b1e1c866c2df6161b07af44 --- /dev/null +++ b/src/utils/commands.py @@ -0,0 +1,97 @@ +import torch +from datetime import datetime +from pathlib import Path +from .voice import quick_mix_voice + + +def handle_commands(user_input, generator, speed, model_path=None): + """ + Handles bot commands to control the voice generator. + + Args: + user_input (str): The command input from the user. + generator: The voice generator object. + speed (float): The current speed of the generator. + model_path (str, optional): The path to the model. Defaults to None. + + Returns: + bool: True if a command was handled, False otherwise. + """ + if user_input.lower() == "quit": + print("Goodbye!") + return True + + if user_input.lower() == "voices": + voices = generator.list_available_voices() + print("\nAvailable voices:") + for voice in voices: + print(f"- {voice}") + return True + + if user_input.startswith("speed="): + try: + new_speed = float(user_input.split("=")[1]) + print(f"Speed set to {new_speed}") + return True + except: + print("Invalid speed value. Use format: speed=1.2") + return True + + if user_input.startswith("voice="): + try: + voice = user_input.split("=")[1] + if voice in generator.list_available_voices(): + generator.initialize(model_path or generator.model_path, voice) + print(f"Switched to voice: {voice}") + else: + print("Voice not found. Use 'voices' to list available voices.") + except Exception as e: + print(f"Error changing voice: {str(e)}") + return True + + if user_input.startswith("mix="): + try: + mix_input = user_input.split("=")[1] + voices_weights = mix_input.split(":") + voices = [v.strip() for v in voices_weights[0].split(",")] + + if len(voices_weights) > 1: + weights = [float(w.strip()) for w in voices_weights[1].split(",")] + else: + weights = [0.5, 0.5] + + if len(voices) != 2 or len(weights) != 2: + print( + "Mix command requires exactly two voices. Format: mix=voice1,voice2[:weight1,weight2]" + ) + return True + + available_voices = generator.list_available_voices() + if not all(voice in available_voices for voice in voices): + print( + "One or more voices not found. Use 'voices' to list available voices." + ) + return True + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_name = f"af_mixed_voice_{timestamp}" + + voice_tensors = [] + for voice_name in voices: + voice_path = Path(generator.voices_dir) / f"{voice_name}.pt" + voice = torch.load(voice_path, weights_only=True) + voice_tensors.append(voice) + + mixed = quick_mix_voice( + output_name, generator.voices_dir, *voice_tensors, weights=weights + ) + + generator.initialize(model_path or generator.model_path, output_name) + print( + f"Mixed voices: {voices[0]} ({weights[0]:.1f}) and {voices[1]} ({weights[1]:.1f})" + ) + except Exception as e: + print(f"Error mixing voices: {str(e)}") + return True + + return False diff --git a/src/utils/config.py b/src/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..89792add511ba9ceda769d919d9e47299d2a7c2c --- /dev/null +++ b/src/utils/config.py @@ -0,0 +1,111 @@ +from pathlib import Path +import os +from dotenv import load_dotenv + +load_dotenv() + + +def init_espeak(): + """Initialize eSpeak environment variables. Must be called before any other imports.""" + os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = ( + r"C:\Program Files\eSpeak NG\libespeak-ng.dll" + ) + os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe" + + +init_espeak() + +from pydantic_settings import BaseSettings +from pydantic import Field +from typing import Optional + + +class Settings(BaseSettings): + """Settings class to manage application configurations.""" + + BASE_DIR: Path = Path(__file__).parent.parent.parent + MODELS_DIR: Path = BASE_DIR / "data" / "models" + VOICES_DIR: Path = BASE_DIR / "data" / "voices" + OUTPUT_DIR: Path = BASE_DIR / "output" + RECORDINGS_DIR: Path = BASE_DIR / "recordings" + + ESPEAK_LIBRARY_PATH: str = r"C:\Program Files\eSpeak NG\libespeak-ng.dll" + ESPEAK_PATH: str = r"C:\Program Files\eSpeak NG\espeak-ng.exe" + + TTS_MODEL: str = Field(..., env="TTS_MODEL") + VOICE_NAME: str = Field(..., env="VOICE_NAME") + SPEED: float = Field(default=1.0, env="SPEED") + HUGGINGFACE_TOKEN: str = Field(..., env="HUGGINGFACE_TOKEN") + + LM_STUDIO_URL: str = Field(..., env="LM_STUDIO_URL") + OLLAMA_URL: str = Field(..., env="OLLAMA_URL") + DEFAULT_SYSTEM_PROMPT: str = Field(..., env="DEFAULT_SYSTEM_PROMPT") + LLM_MODEL: str = Field(..., env="LLM_MODEL") + NUM_THREADS: int = Field(default=2, env="NUM_THREADS") + MAX_TOKENS: int = Field(default=512, env="MAX_TOKENS") + LLM_TEMPERATURE: float = Field(default=0.7, env="LMM_TEMPERATURE") + LLM_STREAM: bool = Field(default=False, env="LLM_STREAM") + LLM_RETRY_DELAY: float = Field(default=0.5, env="LLM_RETRY_DELAY") + MAX_RETRIES: int = Field(default=3, env="MAX_RETRIES") + + WHISPER_MODEL: str = Field(default="openai/whisper-tiny.en", env="WHISPER_MODEL") + + VAD_MODEL: str = Field(default="pyannote/segmentation-3.0", env="VAD_MODEL") + VAD_MIN_DURATION_ON: float = Field(default=0.1, env="VAD_MIN_DURATION_ON") + VAD_MIN_DURATION_OFF: float = Field(default=0.1, env="VAD_MIN_DURATION_OFF") + + CHUNK: int = Field(default=1024, env="CHUNK") + FORMAT: str = Field(default="pyaudio.paFloat32", env="FORMAT") + CHANNELS: int = Field(default=1, env="CHANNELS") + RATE: int = Field(default=16000, env="RATE") + OUTPUT_SAMPLE_RATE: int = Field(default=24000, env="OUTPUT_SAMPLE_RATE") + RECORD_DURATION: int = Field(default=5, env="RECORD_DURATION") + SILENCE_THRESHOLD: float = Field(default=0.01, env="SILENCE_THRESHOLD") + INTERRUPTION_THRESHOLD: float = Field(default=0.02, env="INTERRUPTION_THRESHOLD") + MAX_SILENCE_DURATION: int = Field(default=1, env="MAX_SILENCE_DURATION") + SPEECH_CHECK_TIMEOUT: float = Field(default=0.1, env="SPEECH_CHECK_TIMEOUT") + SPEECH_CHECK_THRESHOLD: float = Field(default=0.02, env="SPEECH_CHECK_THRESHOLD") + ROLLING_BUFFER_TIME: float = Field(default=0.5, env="ROLLING_BUFFER_TIME") + TARGET_SIZE: int = Field(default=15, env="TARGET_SIZE") + FIRST_SENTENCE_SIZE: int = Field(default=3, env="FIRST_SENTENCE_SIZE") + PLAYBACK_DELAY: float = Field(default=0.005, env="PLAYBACK_DELAY") + + def setup_directories(self): + """Create necessary directories if they don't exist""" + self.MODELS_DIR.mkdir(parents=True, exist_ok=True) + self.VOICES_DIR.mkdir(parents=True, exist_ok=True) + self.OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + self.RECORDINGS_DIR.mkdir(parents=True, exist_ok=True) + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + + +settings = Settings() + + +def configure_logging(): + """Configure logging to suppress all logs""" + import logging + import warnings + + warnings.filterwarnings("ignore") + + logging.getLogger().setLevel(logging.ERROR) + + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("PIL").setLevel(logging.ERROR) + logging.getLogger("matplotlib").setLevel(logging.ERROR) + logging.getLogger("torch").setLevel(logging.ERROR) + logging.getLogger("tensorflow").setLevel(logging.ERROR) + logging.getLogger("whisper").setLevel(logging.ERROR) + logging.getLogger("transformers").setLevel(logging.ERROR) + logging.getLogger("pyannote").setLevel(logging.ERROR) + logging.getLogger("sounddevice").setLevel(logging.ERROR) + logging.getLogger("soundfile").setLevel(logging.ERROR) + logging.getLogger("uvicorn").setLevel(logging.ERROR) + logging.getLogger("fastapi").setLevel(logging.ERROR) + + +configure_logging() diff --git a/src/utils/generator.py b/src/utils/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..bef95d9831333bb93f2aad86b3c08013a2239e45 --- /dev/null +++ b/src/utils/generator.py @@ -0,0 +1,186 @@ +import torch +import numpy as np +from pathlib import Path +from src.models.models import build_model +from src.core.kokoro import generate +from .voice import split_into_sentences + + +class VoiceGenerator: + """ + A class to manage voice generation using a pre-trained model. + """ + + def __init__(self, models_dir, voices_dir): + """ + Initializes the VoiceGenerator with model and voice directories. + + Args: + models_dir (Path): Path to the directory containing model files. + voices_dir (Path): Path to the directory containing voice pack files. + """ + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = None + self.voicepack = None + self.voice_name = None + self.models_dir = models_dir + self.voices_dir = voices_dir + self._initialized = False + + def initialize(self, model_path, voice_name): + """ + Initializes the model and voice pack for audio generation. + + Args: + model_path (str): The filename of the model. + voice_name (str): The name of the voice pack. + + Returns: + str: A message indicating the voice has been loaded. + + Raises: + FileNotFoundError: If the model or voice pack file is not found. + """ + model_file = self.models_dir / model_path + if not model_file.exists(): + raise FileNotFoundError( + f"Model file not found at {model_file}. Please place the model file in the 'models' directory." + ) + + self.model = build_model(str(model_file), self.device) + self.voice_name = voice_name + + voice_path = self.voices_dir / f"{voice_name}.pt" + if not voice_path.exists(): + raise FileNotFoundError( + f"Voice pack not found at {voice_path}. Please place voice files in the 'data/voices' directory." + ) + + self.voicepack = torch.load(voice_path, weights_only=True).to(self.device) + self._initialized = True + return f"Loaded voice: {voice_name}" + + def list_available_voices(self): + """ + Lists all available voice packs in the voices directory. + + Returns: + list: A list of voice pack names (without the .pt extension). + """ + if not self.voices_dir.exists(): + return [] + return [f.stem for f in self.voices_dir.glob("*.pt")] + + def is_initialized(self): + """ + Checks if the generator is properly initialized. + + Returns: + bool: True if the model and voice pack are loaded, False otherwise. + """ + return ( + self._initialized and self.model is not None and self.voicepack is not None + ) + + def generate( + self, + text, + lang=None, + speed=1.0, + pause_duration=4000, + short_text_limit=200, + return_chunks=False, + ): + """ + Generates speech from the given text. + + Handles both short and long-form text by splitting long text into sentences. + + Args: + text (str): The text to generate speech from. + lang (str, optional): The language of the text. Defaults to None. + speed (float, optional): The speed of speech generation. Defaults to 1.0. + pause_duration (int, optional): The duration of pause between sentences in milliseconds. Defaults to 4000. + short_text_limit (int, optional): The character limit for considering text as short. Defaults to 200. + return_chunks (bool, optional): If True, returns a list of audio chunks instead of concatenated audio. Defaults to False. + + Returns: + tuple: A tuple containing the generated audio (numpy array or list of numpy arrays) and a list of phonemes. + + Raises: + RuntimeError: If the model is not initialized. + ValueError: If there is an error during audio generation. + """ + if not self.is_initialized(): + raise RuntimeError("Model not initialized. Call initialize() first.") + + if lang is None: + lang = self.voice_name[0] + + text = text.strip() + if not text: + return (None, []) if not return_chunks else ([], []) + + try: + if len(text) < short_text_limit: + try: + audio, phonemes = generate( + self.model, text, self.voicepack, lang=lang, speed=speed + ) + if audio is None or len(audio) == 0: + raise ValueError(f"Failed to generate audio for text: {text}") + return ( + (audio, phonemes) if not return_chunks else ([audio], phonemes) + ) + except Exception as e: + raise ValueError( + f"Error generating audio for text: {text}. Error: {str(e)}" + ) + + sentences = split_into_sentences(text) + if not sentences: + return (None, []) if not return_chunks else ([], []) + + audio_segments = [] + phonemes_list = [] + failed_sentences = [] + + for i, sentence in enumerate(sentences): + if not sentence.strip(): + continue + + try: + if audio_segments and not return_chunks: + audio_segments.append(np.zeros(pause_duration)) + + audio, phonemes = generate( + self.model, sentence, self.voicepack, lang=lang, speed=speed + ) + if audio is not None and len(audio) > 0: + audio_segments.append(audio) + phonemes_list.extend(phonemes) + else: + failed_sentences.append( + (i, sentence, "Generated audio is empty") + ) + except Exception as e: + failed_sentences.append((i, sentence, str(e))) + continue + + if failed_sentences: + error_msg = "\n".join( + [f"Sentence {i+1}: '{s}' - {e}" for i, s, e in failed_sentences] + ) + raise ValueError( + f"Failed to generate audio for some sentences:\n{error_msg}" + ) + + if not audio_segments: + return (None, []) if not return_chunks else ([], []) + + if return_chunks: + return audio_segments, phonemes_list + return np.concatenate(audio_segments), phonemes_list + + except Exception as e: + raise ValueError(f"Error in audio generation: {str(e)}") diff --git a/src/utils/llm.py b/src/utils/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..77471961e7d84bd5a413655b936b2d2a859c66b1 --- /dev/null +++ b/src/utils/llm.py @@ -0,0 +1,146 @@ +import re +import requests +import json +import time +from src.utils.config import settings + + +def filter_response(response: str) -> str: + """Removes markdown formatting and unicode characters from a string. + + Args: + response (str): The string to filter. + + Returns: + str: The filtered string. + """ + response = re.sub(r"\*\*|__|~~|`", "", response) + response = re.sub(r"[\U00010000-\U0010ffff]", "", response, flags=re.UNICODE) + return response + + +def warmup_llm(session: requests.Session, llm_model: str, llm_url: str): + """Sends a warmup request to the LLM server. + + Args: + session (requests.Session): The requests session to use. + llm_model (str): The name of the LLM model. + llm_url (str): The URL of the LLM server. + """ + try: + health = session.get("http://localhost:11434", timeout=3) + if health.status_code != 200: + print("Ollama not running! Start it first.") + return + + session.post( + llm_url, + json={ + "model": llm_model, + "messages": [{"role": "user", "content": "."}], + "context": [], + "options": {"num_ctx": 64}, + }, + timeout=5, + ) + + except requests.RequestException as e: + print(f"Warmup failed: {str(e)}") + return + + +def get_ai_response( + session: requests.Session, + messages: list, + llm_model: str, + llm_url: str, + max_tokens: int, + temperature: float = 0.7, + stream: bool = False, +): + """Sends a request to the LLM and returns a streaming iterator. + + Args: + session (requests.Session): The requests session to use. + messages (list): The list of messages to send to the LLM. + llm_model (str): The name of the LLM model. + llm_url (str): The URL of the LLM server. + max_tokens (int): The maximum number of tokens to generate. + temperature (float, optional): The temperature to use for generation. Defaults to 0.7. + stream (bool, optional): Whether to stream the response. Defaults to False. + + Returns: + iterator: An iterator over the streaming response. + """ + try: + response = session.post( + llm_url, + json={ + "model": llm_model, + "messages": messages, + "options": { + "num_ctx": settings.MAX_TOKENS * 2, + "num_thread": settings.NUM_THREADS, + }, + "stream": stream, + }, + timeout=3600, + stream=stream, + ) + response.raise_for_status() + + def streaming_iterator(): + """Iterates over the streaming response.""" + try: + for chunk in response.iter_content(chunk_size=512): + if chunk: + yield chunk + else: + yield b"\x00\x00" + except Exception as e: + print(f"\nError: {str(e)}") + yield b"\x00\x00" + + return streaming_iterator() + + except Exception as e: + print(f"\nError: {str(e)}") + + +def parse_stream_chunk(chunk: bytes) -> dict: + """Parses a chunk of data from the LLM stream. + + Args: + chunk (bytes): The chunk of data to parse. + + Returns: + dict: A dictionary containing the parsed data. + """ + if not chunk: + return {"keep_alive": True} + + try: + text = chunk.decode("utf-8").strip() + if text.startswith("data: "): + text = text[6:] + if text == "[DONE]": + return {"choices": [{"finish_reason": "stop", "delta": {}}]} + if text.startswith("{"): + data = json.loads(text) + content = "" + if "message" in data: + content = data["message"].get("content", "") + elif "choices" in data and data["choices"]: + choice = data["choices"][0] + content = choice.get("delta", {}).get("content", "") or choice.get( + "message", {} + ).get("content", "") + + if content: + return {"choices": [{"delta": {"content": filter_response(content)}}]} + return None + + except Exception as e: + if str(e) != "Expecting value: line 1 column 2 (char 1)": + print(f"Error parsing stream chunk: {str(e)}") + return None diff --git a/src/utils/speech.py b/src/utils/speech.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d7848b028067e37c295a394b1c806dbcc3f5c9 --- /dev/null +++ b/src/utils/speech.py @@ -0,0 +1,315 @@ +import pyaudio +import numpy as np +import torch +from torch.nn.functional import pad +import time +from queue import Queue +import sounddevice as sd +from .config import settings + +CHUNK = settings.CHUNK +FORMAT = pyaudio.paFloat32 +CHANNELS = settings.CHANNELS +RATE = settings.RATE +SILENCE_THRESHOLD = settings.SILENCE_THRESHOLD +SPEECH_CHECK_THRESHOLD = settings.SPEECH_CHECK_THRESHOLD +MAX_SILENCE_DURATION = settings.MAX_SILENCE_DURATION + + +def init_vad_pipeline(hf_token): + """Initializes the Voice Activity Detection pipeline. + + Args: + hf_token (str): Hugging Face API token. + + Returns: + pyannote.audio.pipelines.VoiceActivityDetection: VAD pipeline. + """ + from pyannote.audio import Model + from pyannote.audio.pipelines import VoiceActivityDetection + + model = Model.from_pretrained(settings.VAD_MODEL, use_auth_token=hf_token) + + pipeline = VoiceActivityDetection(segmentation=model) + + HYPER_PARAMETERS = { + "min_duration_on": settings.VAD_MIN_DURATION_ON, + "min_duration_off": settings.VAD_MIN_DURATION_OFF, + } + pipeline.instantiate(HYPER_PARAMETERS) + + return pipeline + + +def detect_speech_segments(pipeline, audio_data, sample_rate=None): + """Detects speech segments in audio using pyannote VAD. + + Args: + pipeline (pyannote.audio.pipelines.VoiceActivityDetection): VAD pipeline. + audio_data (np.ndarray or torch.Tensor): Audio data. + sample_rate (int, optional): Sample rate of the audio. Defaults to settings.RATE. + + Returns: + torch.Tensor or None: Concatenated speech segments as a torch tensor, or None if no speech is detected. + """ + if sample_rate is None: + sample_rate = settings.RATE + + if len(audio_data.shape) == 1: + audio_data = audio_data.reshape(1, -1) + + if not isinstance(audio_data, torch.Tensor): + audio_data = torch.from_numpy(audio_data) + + if audio_data.shape[1] < sample_rate: + padding_size = sample_rate - audio_data.shape[1] + audio_data = pad(audio_data, (0, padding_size)) + + vad = pipeline({"waveform": audio_data, "sample_rate": sample_rate}) + + speech_segments = [] + for speech in vad.get_timeline().support(): + start_sample = int(speech.start * sample_rate) + end_sample = int(speech.end * sample_rate) + if start_sample < audio_data.shape[1]: + end_sample = min(end_sample, audio_data.shape[1]) + segment = audio_data[0, start_sample:end_sample] + speech_segments.append(segment) + + if speech_segments: + return torch.cat(speech_segments) + return None + + +def record_audio(duration=None): + """Records audio for a specified duration. + + Args: + duration (int, optional): Recording duration in seconds. Defaults to settings.RECORD_DURATION. + + Returns: + np.ndarray: Recorded audio data as a numpy array. + """ + if duration is None: + duration = settings.RECORD_DURATION + + p = pyaudio.PyAudio() + + stream = p.open( + format=settings.FORMAT, + channels=settings.CHANNELS, + rate=settings.RATE, + input=True, + frames_per_buffer=settings.CHUNK, + ) + + print("\nRecording...") + frames = [] + + for i in range(0, int(settings.RATE / settings.CHUNK * duration)): + data = stream.read(settings.CHUNK) + frames.append(np.frombuffer(data, dtype=np.float32)) + + print("Done recording") + + stream.stop_stream() + stream.close() + p.terminate() + + audio_data = np.concatenate(frames, axis=0) + return audio_data + + +def record_continuous_audio(): + """Continuously monitors audio and detects speech segments. + + Returns: + np.ndarray or None: Recorded audio data as a numpy array, or None if no speech is detected. + """ + p = pyaudio.PyAudio() + + stream = p.open( + format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK + ) + + print("\nListening... (Press Ctrl+C to stop)") + frames = [] + buffer_frames = [] + buffer_size = int(RATE * 0.5 / CHUNK) + silence_frames = 0 + max_silence_frames = int(RATE / CHUNK * 1) + recording = False + + try: + while True: + data = stream.read(CHUNK, exception_on_overflow=False) + audio_chunk = np.frombuffer(data, dtype=np.float32) + + buffer_frames.append(audio_chunk) + if len(buffer_frames) > buffer_size: + buffer_frames.pop(0) + + audio_level = np.abs(np.concatenate(buffer_frames)).mean() + + if audio_level > SILENCE_THRESHOLD: + if not recording: + print("\nPotential speech detected...") + recording = True + frames.extend(buffer_frames) + frames.append(audio_chunk) + silence_frames = 0 + elif recording: + frames.append(audio_chunk) + silence_frames += 1 + + if silence_frames >= max_silence_frames: + print("Processing speech segment...") + break + + time.sleep(0.001) + + except KeyboardInterrupt: + pass + finally: + stream.stop_stream() + stream.close() + p.terminate() + + if frames: + return np.concatenate(frames) + return None + + +def check_for_speech(timeout=0.1): + """Checks if speech is detected in a non-blocking way. + + Args: + timeout (float, optional): Duration to check for speech in seconds. Defaults to 0.1. + + Returns: + tuple: A tuple containing a boolean indicating if speech was detected and the audio data as a numpy array, or (False, None) if no speech is detected. + """ + p = pyaudio.PyAudio() + + frames = [] + is_speech = False + + try: + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=CHUNK, + ) + + for _ in range(int(RATE * timeout / CHUNK)): + data = stream.read(CHUNK, exception_on_overflow=False) + audio_chunk = np.frombuffer(data, dtype=np.float32) + frames.append(audio_chunk) + + audio_level = np.abs(audio_chunk).mean() + if audio_level > SPEECH_CHECK_THRESHOLD: + is_speech = True + break + + finally: + stream.stop_stream() + stream.close() + p.terminate() + + if is_speech and frames: + return True, np.concatenate(frames) + return False, None + + +def play_audio_with_interrupt(audio_data, sample_rate=24000): + """Plays audio while monitoring for speech interruption. + + Args: + audio_data (np.ndarray): Audio data to play. + sample_rate (int, optional): Sample rate for playback. Defaults to 24000. + + Returns: + tuple: A tuple containing a boolean indicating if playback was interrupted and None, or (False, None) if playback completes without interruption. + """ + interrupt_queue = Queue() + + def input_callback(indata, frames, time, status): + """Callback for monitoring input audio.""" + if status: + print(f"Input status: {status}") + return + + audio_level = np.abs(indata[:, 0]).mean() + if audio_level > settings.INTERRUPTION_THRESHOLD: + interrupt_queue.put(True) + + def output_callback(outdata, frames, time, status): + """Callback for output audio.""" + if status: + print(f"Output status: {status}") + return + + if not interrupt_queue.empty(): + raise sd.CallbackStop() + + remaining = len(audio_data) - output_callback.position + if remaining == 0: + raise sd.CallbackStop() + valid_frames = min(remaining, frames) + outdata[:valid_frames, 0] = audio_data[ + output_callback.position : output_callback.position + valid_frames + ] + if valid_frames < frames: + outdata[valid_frames:] = 0 + output_callback.position += valid_frames + + output_callback.position = 0 + + try: + with sd.InputStream( + channels=1, callback=input_callback, samplerate=settings.RATE + ): + with sd.OutputStream( + channels=1, callback=output_callback, samplerate=sample_rate + ): + while output_callback.position < len(audio_data): + sd.sleep(100) + if not interrupt_queue.empty(): + return True, None + return False, None + except sd.CallbackStop: + return True, None + except Exception as e: + print(f"Error during playback: {str(e)}") + return False, None + + +def transcribe_audio(processor, model, audio_data, sampling_rate=None): + """Transcribes audio using Whisper. + + Args: + processor (transformers.WhisperProcessor): Whisper processor. + model (transformers.WhisperForConditionalGeneration): Whisper model. + audio_data (np.ndarray or torch.Tensor): Audio data to transcribe. + sampling_rate (int, optional): Sample rate of the audio. Defaults to settings.RATE. + + Returns: + str: Transcribed text. + """ + if sampling_rate is None: + sampling_rate = settings.RATE + + if audio_data is None: + return "" + + if isinstance(audio_data, torch.Tensor): + audio_data = audio_data.numpy() + + input_features = processor( + audio_data, sampling_rate=sampling_rate, return_tensors="pt" + ).input_features + predicted_ids = model.generate(input_features) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + return transcription[0] diff --git a/src/utils/text_chunker.py b/src/utils/text_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..55311cbfd52adc428e4acdaa38d43e47ae8bebd1 --- /dev/null +++ b/src/utils/text_chunker.py @@ -0,0 +1,120 @@ +from .config import settings + + +class TextChunker: + """A class to handle intelligent text chunking for voice generation.""" + + def __init__(self): + """Initialize the TextChunker with break points and priorities.""" + self.current_text = [] + self.found_first_sentence = False + self.semantic_breaks = { + "however": 4, + "therefore": 4, + "furthermore": 4, + "moreover": 4, + "nevertheless": 4, + "while": 3, + "although": 3, + "unless": 3, + "since": 3, + "and": 2, + "but": 2, + "because": 2, + "then": 2, + } + self.punctuation_priorities = { + ".": 5, + "!": 5, + "?": 5, + ";": 4, + ":": 4, + ",": 3, + "-": 2, + } + + def should_process(self, text: str) -> bool: + """Determines if text should be processed based on length or punctuation. + + Args: + text (str): The text to check. + + Returns: + bool: True if the text should be processed, False otherwise. + """ + if any(text.endswith(p) for p in self.punctuation_priorities): + return True + + words = text.split() + target = ( + settings.FIRST_SENTENCE_SIZE + if not self.found_first_sentence + else settings.TARGET_SIZE + ) + return len(words) >= target + + def find_break_point(self, words: list, target_size: int) -> int: + """Finds optimal break point in text. + + Args: + words (list): The list of words to find a break point in. + target_size (int): The target size of the chunk. + + Returns: + int: The index of the break point. + """ + if len(words) <= target_size: + return len(words) + + break_points = [] + + for i, word in enumerate(words[: target_size + 3]): + word_lower = word.lower() + + priority = self.semantic_breaks.get(word_lower, 0) + for punct, punct_priority in self.punctuation_priorities.items(): + if word.endswith(punct): + priority = max(priority, punct_priority) + + if priority > 0: + break_points.append((i, priority, -abs(i - target_size))) + + if not break_points: + return target_size + + break_points.sort(key=lambda x: (x[1], x[2]), reverse=True) + return break_points[0][0] + 1 + + def process(self, text: str, audio_queue) -> str: + """Process text chunk and return remaining text. + + Args: + text (str): The text to process. + audio_queue: The audio queue to add sentences to. + + Returns: + str: The remaining text after processing. + """ + if not text: + return "" + + words = text.split() + if not words: + return "" + + target_size = ( + settings.FIRST_SENTENCE_SIZE + if not self.found_first_sentence + else settings.TARGET_SIZE + ) + split_point = self.find_break_point(words, target_size) + + if split_point: + chunk = " ".join(words[:split_point]).strip() + if chunk and any(c.isalnum() for c in chunk): + chunk = chunk.rstrip(",") + audio_queue.add_sentences([chunk]) + self.found_first_sentence = True + return " ".join(words[split_point:]) if split_point < len(words) else "" + + return "" diff --git a/src/utils/voice.py b/src/utils/voice.py new file mode 100644 index 0000000000000000000000000000000000000000..11edb7db08f9983f631f01608e23915695c150aa --- /dev/null +++ b/src/utils/voice.py @@ -0,0 +1,229 @@ +import torch +from pathlib import Path +import json +import os + + +def load_config(): + """Loads configuration from config.json. + + Returns: + dict: The configuration loaded from the JSON file. + + Raises: + FileNotFoundError: If the config.json file is not found. + """ + config_path = Path(__file__).parent.parent / "config" / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found at {config_path}") + with open(config_path) as f: + return json.load(f) + + +def get_available_voices(voices_dir): + """Gets a list of available voice names without the .pt extension. + + Args: + voices_dir (str): The path to the directory containing voice files. + + Returns: + list: A list of voice names (strings). + """ + voices_dir = Path(voices_dir) + if not voices_dir.exists(): + return [] + return [f.stem for f in voices_dir.glob("*.pt")] + + +def validate_voice_name(voice_name, voices_dir): + """Validates that a voice name exists in the voices directory. + + Args: + voice_name (str): The name of the voice to validate. + voices_dir (str): The path to the directory containing voice files. + + Returns: + bool: True if the voice name is valid. + + Raises: + ValueError: If the voice name is not found in the voices directory. + """ + available_voices = get_available_voices(voices_dir) + if voice_name not in available_voices: + raise ValueError( + f"Voice '{voice_name}' not found. Available voices: {', '.join(available_voices)}" + ) + return True + + +def load_voice(voice_name, voices_dir): + """Loads a voice from the voices directory. + + Args: + voice_name (str): The name of the voice to load. + voices_dir (str): The path to the directory containing voice files. + + Returns: + torch.Tensor: The loaded voice as a torch tensor. + + Raises: + AssertionError: If the voices directory or voice file does not exist, or if the voice path is not a file. + RuntimeError: If there is an error loading the voice file or converting it to a tensor. + """ + voices_dir = Path(voices_dir) + assert voices_dir.exists(), f"Voices directory does not exist: {voices_dir}" + assert voices_dir.is_dir(), f"Voices path is not a directory: {voices_dir}" + + validate_voice_name(voice_name, voices_dir) + + voice_path = voices_dir / f"{voice_name}.pt" + assert voice_path.exists(), f"Voice file not found: {voice_path}" + assert voice_path.is_file(), f"Voice path is not a file: {voice_path}" + + try: + voice = torch.load(voice_path, weights_only=True) + except Exception as e: + raise RuntimeError(f"Error loading voice file {voice_path}: {str(e)}") + + if not isinstance(voice, torch.Tensor): + try: + voice = torch.tensor(voice) + except Exception as e: + raise RuntimeError(f"Could not convert voice to tensor: {str(e)}") + + return voice + + +def quick_mix_voice(output_name, voices_dir, *voices, weights=None): + """Mixes and saves voices with specified weights. + + Args: + output_name (str): The name of the output mixed voice file (without extension). + voices_dir (str): The path to the directory containing voice files. + *voices (torch.Tensor): Variable number of voice tensors to mix. + weights (list, optional): List of weights for each voice. Defaults to equal weights if None. + + Returns: + torch.Tensor: The mixed voice as a torch tensor. + + Raises: + ValueError: If no voices are provided, if the number of weights does not match the number of voices, or if the sum of weights is not positive. + AssertionError: If the voices directory does not exist or is not a directory. + """ + voices_dir = Path(voices_dir) + assert voices_dir.exists(), f"Voices directory does not exist: {voices_dir}" + assert voices_dir.is_dir(), f"Voices path is not a directory: {voices_dir}" + + if not voices: + raise ValueError("Must provide at least one voice") + + base_shape = voices[0].shape + for i, voice in enumerate(voices): + if not isinstance(voice, torch.Tensor): + raise ValueError(f"Voice {i} is not a tensor") + if voice.shape != base_shape: + raise ValueError( + f"Voice {i} has shape {voice.shape}, but expected {base_shape} (same as first voice)" + ) + + if weights is None: + weights = [1.0 / len(voices)] * len(voices) + else: + if len(weights) != len(voices): + raise ValueError( + f"Number of weights ({len(weights)}) must match number of voices ({len(voices)})" + ) + weights_sum = sum(weights) + if weights_sum <= 0: + raise ValueError("Sum of weights must be positive") + weights = [w / weights_sum for w in weights] + + device = voices[0].device + voices = [v.to(device) for v in voices] + + stacked = torch.stack(voices) + weights = torch.tensor(weights, device=device) + + mixed = torch.zeros_like(voices[0]) + for i, weight in enumerate(weights): + mixed += stacked[i] * weight + + output_path = voices_dir / f"{output_name}.pt" + torch.save(mixed, output_path) + print(f"Created mixed voice: {output_name}.pt") + return mixed + + +def split_into_sentences(text): + """Splits text into sentences using more robust rules. + + Args: + text (str): The input text to split. + + Returns: + list: A list of sentences (strings). + """ + import re + + text = text.strip() + if not text: + return [] + + abbreviations = { + "Mr.": "Mr", + "Mrs.": "Mrs", + "Dr.": "Dr", + "Ms.": "Ms", + "Prof.": "Prof", + "Sr.": "Sr", + "Jr.": "Jr", + "vs.": "vs", + "etc.": "etc", + "i.e.": "ie", + "e.g.": "eg", + "a.m.": "am", + "p.m.": "pm", + } + + for abbr, repl in abbreviations.items(): + text = text.replace(abbr, repl) + + sentences = [] + current = [] + + words = re.findall(r"\S+|\s+", text) + + for word in words: + current.append(word) + + if re.search(r"[.!?]+$", word): + if not re.match(r"^[A-Z][a-z]{1,2}$", word[:-1]): + sentence = "".join(current).strip() + if sentence: + sentences.append(sentence) + current = [] + continue + + if current: + sentence = "".join(current).strip() + if sentence: + sentences.append(sentence) + + for abbr, repl in abbreviations.items(): + sentences = [s.replace(repl, abbr) for s in sentences] + + sentences = [s.strip() for s in sentences if s.strip()] + + final_sentences = [] + for s in sentences: + if len(s) > 200: + parts = s.split(",") + parts = [p.strip() for p in parts if p.strip()] + if len(parts) > 1: + final_sentences.extend(parts) + else: + final_sentences.append(s) + else: + final_sentences.append(s) + + return final_sentences