diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..efcfce8c4da378d9d1b8e586c57c7bf121a96909
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,5 @@
+results
+data
+*.filelist
+/data_server/target
+checkpoints
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000000000000000000000000000000000000..3e1f3474665193bfedb7a83cb63fb3374c20ec60
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,25 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: "[BUG]"
+labels: bug
+assignees: ''
+
+---
+
+Feel free to ask any kind of questions in the issues page, but please use English since other users may find your questions valuable.
+
+**Describe the bug**
+A clear and concise description of what the bug is.
+
+**To Reproduce**
+Steps to reproduce the behavior:
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+**Screenshots / log**
+If applicable, add screenshots / logs to help explain your problem.
+
+**Additional context**
+Add any other context about the problem here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..3c0f91a53074afc7b04b6928de4c7d73bb5744f7
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,22 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: "[Feature]"
+labels: enhancement
+assignees: ''
+
+---
+
+Feel free to ask any kind of questions in the issues page, but please use English since other users may find your questions valuable.
+
+**Is your feature request related to a problem? Please describe.**
+A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
+
+**Describe the solution you'd like**
+A clear and concise description of what you want to happen.
+
+**Describe alternatives you've considered**
+A clear and concise description of any alternative solutions or features you've considered.
+
+**Additional context**
+Add any other context or screenshots about the feature request here.
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..91c6c22a1806d27511bd8cfc8c31cb3deb4379aa
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,7 @@
+**Is this PR adding new feature or fix a BUG?**
+
+Add feature / Fix BUG.
+
+**Is this pull request related to any issue? If yes, please link the issue.**
+
+#xxx
diff --git a/.github/workflows/build-windows-package.yml b/.github/workflows/build-windows-package.yml
new file mode 100644
index 0000000000000000000000000000000000000000..869c8e3393981beceb52ed770f612dd732590763
--- /dev/null
+++ b/.github/workflows/build-windows-package.yml
@@ -0,0 +1,51 @@
+name: build-windows-package
+
+on:
+ push:
+ branches:
+ - main
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Remove unnecessary files
+ run: |
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+ - uses: actions/setup-python@v5
+ with:
+ python-version: 3.12
+ - uses: actions/checkout@v4
+ with:
+ path: ./fish-speech
+ - name: Setup Hugging Face CLI
+ run: pip3 install huggingface-hub
+ - name: Download Windows Binaries
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ if [[ "${{ github.actor }}" = "Leng Yue" ]] || [[ "${{ github.actor }}" = "AnyaCoder" ]] || [[ "${{ github.actor }}" = "pre-commit-ci[bot]" ]]; then
+ ls -la
+ else
+ echo "Author is not Leng Yue nor AnyaCoder. No upload performed."
+ fi
+ - uses: actions/upload-artifact@v4
+ with:
+ name: fish-speech-main-${{ github.run_id }}
+ path: ./fish-speech
+
+ - name: Upload to Hugging Face
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ if [ "${{ github.actor }}" = "AnyaCoder" ]; then
+ echo "Author is AnyaCoder. Performing the zipping && upload."
+ zip -qr fish-speech-main-${{ github.run_id }}.zip ./fish-speech
+ huggingface-cli upload SpicyqSama007/fish-speech-packed ./fish-speech-main-${{ github.run_id }}.zip fish-speech-main-${{ github.run_id }}.zip
+ else
+ echo "Author is not AnyaCoder. No upload performed."
+ fi
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..910f0284df724462981fa23ed369b738669b141f
--- /dev/null
+++ b/.github/workflows/docs.yml
@@ -0,0 +1,30 @@
+name: docs
+on:
+ push:
+ branches:
+ - main
+
+permissions:
+ contents: write
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Configure Git Credentials
+ run: |
+ git config user.name github-actions[bot]
+ git config user.email 41898282+github-actions[bot]@users.noreply.github.com
+ - uses: actions/setup-python@v5
+ with:
+ python-version: 3.x
+ - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
+ - uses: actions/cache@v4
+ with:
+ key: mkdocs-material-${{ env.cache_id }}
+ path: .cache
+ restore-keys: |
+ mkdocs-material-
+ - run: pip install -r docs/requirements.txt
+ - run: mkdocs gh-deploy --force
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..88341de1d5bd01bb2c6d5b80646048b87c8c766c
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,29 @@
+.pgx.*
+.pdm-python
+/fish_speech.egg-info
+__pycache__
+/results
+/data
+/*.test.sh
+*.filelist
+filelists
+/fish_speech/text/cmudict_cache.pickle
+/checkpoints
+/.vscode
+/data_server/target
+/*.npy
+/*.wav
+/*.mp3
+/results
+/data
+/.idea
+ffmpeg.exe
+ffprobe.exe
+asr-label*
+/.cache
+/fishenv
+/.locale
+/demo-audios
+ref_data*
+/example
+/faster_whisper
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4d28faa409fbfee789598dcadf448b666cf3242f
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,32 @@
+ci:
+ autoupdate_schedule: monthly
+
+repos:
+ - repo: https://github.com/pycqa/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+ args: [--profile=black]
+
+ - repo: https://github.com/psf/black
+ rev: 24.4.2
+ hooks:
+ - id: black
+
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.3.0
+ hooks:
+ - id: codespell
+ files: ^.*\.(py|md|rst|yml)$
+ args: [-L=fro]
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.6.0
+ hooks:
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-json
+ - id: mixed-line-ending
+ args: ['--fix=lf']
+ - id: check-added-large-files
+ args: ['--maxkb=5000']
diff --git a/.project-root b/.project-root
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..015eb5de8569951255b2d66c251ee20fe9153ace
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,19 @@
+# Read the Docs configuration file for MkDocs projects
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Set the version of Python and other tools you might need
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.12"
+
+mkdocs:
+ configuration: mkdocs.yml
+
+# Optionally declare the Python requirements required to build your docs
+python:
+ install:
+ - requirements: docs/requirements.txt
diff --git a/API_FLAGS.txt b/API_FLAGS.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4e8b6b72dcbff55c5dd3e886f3a81b7222d085b5
--- /dev/null
+++ b/API_FLAGS.txt
@@ -0,0 +1,6 @@
+# --infer
+# --api
+--listen 0.0.0.0:8080 \
+--llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+--decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+--decoder-config-name firefly_gan_vq
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..cbe5ad1670406e4402217edfb82d2c56af7e8631
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,437 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. BY-NC-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution, NonCommercial, and ShareAlike.
+
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ k. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ l. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ m. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ n. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ b. ShareAlike.
+
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-NC-SA Compatible License.
+
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/README.zh.md b/README.zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..881ada54ae31881b2812050ea9e1d1988db36bcf
--- /dev/null
+++ b/README.zh.md
@@ -0,0 +1,74 @@
+# Fish Speech
+
+
+
+此代码库及模型根据 CC-BY-NC-SA-4.0 许可证发布。请参阅 [LICENSE](LICENSE) 了解更多细节.
+
+## 免责声明
+
+我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+
+## 在线 DEMO
+
+[Fish Audio](https://fish.audio)
+
+## 快速开始本地推理
+
+[inference.ipynb](/inference.ipynb)
+
+## 视频
+
+#### 1.2 介绍: https://www.bilibili.com/video/BV1wz421B71D
+
+#### 1.1 技术介绍: https://www.bilibili.com/video/BV1zJ4m1K7cj
+
+## 文档
+
+- [English](https://speech.fish.audio/en/)
+- [中文](https://speech.fish.audio/)
+- [日本語](https://speech.fish.audio/ja/)
+
+## 例子
+
+- [English](https://speech.fish.audio/en/samples/)
+- [中文](https://speech.fish.audio/samples/)
+- [日本語](https://speech.fish.audio/ja/samples/)
+
+## 鸣谢
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
+
+## 赞助
+
+
+
diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3054037de5fd4931b22be279d5c8d505be950519
--- /dev/null
+++ b/docker-compose.dev.yml
@@ -0,0 +1,16 @@
+version: '3.8'
+
+services:
+ fish-speech:
+ build: .
+ container_name: fish-speech
+ volumes:
+ - ./:/exp
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [gpu]
+ command: tail -f /dev/null
diff --git a/dockerfile b/dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..e4688420dc49ba5cfc0317c799c83d079b1ee2d1
--- /dev/null
+++ b/dockerfile
@@ -0,0 +1,24 @@
+FROM python:3.10.14-bookworm
+
+# Install system dependencies
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && apt-get install -y git curl build-essential ffmpeg libsm6 libxext6 libjpeg-dev \
+ zlib1g-dev aria2 zsh openssh-server sudo protobuf-compiler cmake libsox-dev && \
+ apt-get clean && rm -rf /var/lib/apt/lists/*
+
+# Install oh-my-zsh so your terminal looks nice
+RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended
+
+# Set zsh as default shell
+RUN chsh -s /usr/bin/zsh
+ENV SHELL=/usr/bin/zsh
+
+# Setup torchaudio
+RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+
+# Project Env
+WORKDIR /exp
+COPY . .
+RUN pip3 install -e .
+
+CMD /bin/zsh
diff --git a/docs/CNAME b/docs/CNAME
new file mode 100644
index 0000000000000000000000000000000000000000..d506fb8b394fa80f3d329ab8450dfc102e839bd1
--- /dev/null
+++ b/docs/CNAME
@@ -0,0 +1 @@
+speech.fish.audio
diff --git a/docs/assets/figs/VS_1.jpg b/docs/assets/figs/VS_1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..41a3f69992edcbbaa85a21695bdc33ff81dc10d6
Binary files /dev/null and b/docs/assets/figs/VS_1.jpg differ
diff --git a/docs/assets/figs/diagram.png b/docs/assets/figs/diagram.png
new file mode 100644
index 0000000000000000000000000000000000000000..254b669c293428926e8d28d47471536d6eb76357
Binary files /dev/null and b/docs/assets/figs/diagram.png differ
diff --git a/docs/en/finetune.md b/docs/en/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..b76e1df24f3d6f51a812eaee477ee57f71e314c6
--- /dev/null
+++ b/docs/en/finetune.md
@@ -0,0 +1,125 @@
+# Fine-tuning
+
+Obviously, when you opened this page, you were not satisfied with the performance of the few-shot pre-trained model. You want to fine-tune a model to improve its performance on your dataset.
+
+In current version, you only need to finetune the 'LLAMA' part.
+
+## Fine-tuning LLAMA
+### 1. Prepare the dataset
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+You need to convert your dataset into the above format and place it under `data`. The audio file can have the extensions `.mp3`, `.wav`, or `.flac`, and the annotation file should have the extensions `.lab`.
+
+!!! warning
+ It's recommended to apply loudness normalization to the dataset. You can use [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) to do this.
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. Batch extraction of semantic tokens
+
+Make sure you have downloaded the VQGAN weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+You can then run the following command to extract semantic tokens:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+!!! note
+ You can adjust `--num-workers` and `--batch-size` to increase extraction speed, but please make sure not to exceed your GPU memory limit.
+ For the VITS format, you can specify a file list using `--filelist xxx.list`.
+
+This command will create `.npy` files in the `data` directory, as shown below:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. Pack the dataset into protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+After the command finishes executing, you should see the `quantized-dataset-ft.protos` file in the `data` directory.
+
+### 4. Finally, fine-tuning with LoRA
+
+Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+Finally, you can start the fine-tuning by running the following command:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ You can modify the training parameters such as `batch_size`, `gradient_accumulation_steps`, etc. to fit your GPU memory by modifying `fish_speech/configs/text2semantic_finetune.yaml`.
+
+!!! note
+ For Windows users, you can use `trainer.strategy.process_group_backend=gloo` to avoid `nccl` issues.
+
+After training is complete, you can refer to the [inference](inference.md) section, and use `--speaker SPK1` to generate speech.
+
+!!! info
+ By default, the model will only learn the speaker's speech patterns and not the timbre. You still need to use prompts to ensure timbre stability.
+ If you want to learn the timbre, you can increase the number of training steps, but this may lead to overfitting.
+
+After training, you need to convert the LoRA weights to regular weights before performing inference.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.2-sft \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.2-sft-yth-lora/
+```
+!!! note
+ You may also try other checkpoints. We suggest using the earliest checkpoint that meets your requirements, as they often perform better on out-of-distribution (OOD) data.
diff --git a/docs/en/index.md b/docs/en/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..a720565acac59b0f6dedd63758a17bcbe75bfd7b
--- /dev/null
+++ b/docs/en/index.md
@@ -0,0 +1,128 @@
+# Introduction
+
+
+
+!!! warning
+ We assume no responsibility for any illegal use of the codebase. Please refer to the local laws regarding DMCA (Digital Millennium Copyright Act) and other relevant laws in your area.
+ This codebase is released under the `BSD-3-Clause` license, and all models are released under the CC-BY-NC-SA-4.0 license.
+
+
+
+
+
+## Requirements
+
+- GPU Memory: 4GB (for inference), 8GB (for fine-tuning)
+- System: Linux, Windows
+
+## Windows Setup
+
+Windows professional users may consider WSL2 or Docker to run the codebase.
+
+Non-professional Windows users can consider the following methods to run the codebase without a Linux environment (with model compilation capabilities aka `torch.compile`):
+
+
+ Unzip the project package.
+ Click install_env.bat
to install the environment.
+
+ You can decide whether to use a mirror site for downloads by editing the USE_MIRROR
item in install_env.bat
.
+ USE_MIRROR=false
downloads the latest stable version of torch
from the original site. USE_MIRROR=true
downloads the latest version of torch
from a mirror site. The default is true
.
+ You can decide whether to enable the compiled environment download by editing the INSTALL_TYPE
item in install_env.bat
.
+ INSTALL_TYPE=preview
downloads the preview version with the compiled environment. INSTALL_TYPE=stable
downloads the stable version without the compiled environment.
+
+
+ If step 2 has USE_MIRROR=preview
, execute this step (optional, for activating the compiled model environment):
+
+ Download the LLVM compiler using the following links:
+
+
+ Download and install the Microsoft Visual C++ Redistributable package to resolve potential .dll missing issues.
+
+
+ Download and install Visual Studio Community Edition to obtain MSVC++ build tools, resolving LLVM header file dependencies.
+
+ Visual Studio Download
+ After installing Visual Studio Installer, download Visual Studio Community 2022.
+ Click the Modify
button as shown below, find the Desktop development with C++
option, and check it for download.
+
+
+
+
+
+ Install CUDA Toolkit 12
+
+
+ Double-click start.bat
to enter the Fish-Speech training inference configuration WebUI page.
+
+
+ (Optional) Double-click run_cmd.bat
to enter the conda/python command line environment of this project.
+
+
+## Linux Setup
+
+```bash
+# Create a python 3.10 virtual environment, you can also use virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# Install pytorch
+pip3 install torch torchvision torchaudio
+
+# Install fish-speech
+pip3 install -e .
+
+# (Ubuntu / Debian User) Install sox
+apt install libsox-dev
+```
+
+## Changelog
+
+- 2024/07/02: Updated Fish-Speech to 1.2 version, remove VITS Decoder, and greatly enhanced zero-shot ability.
+- 2024/05/10: Updated Fish-Speech to 1.1 version, implement VITS decoder to reduce WER and improve timbre similarity.
+- 2024/04/22: Finished Fish-Speech 1.0 version, significantly modified VQGAN and LLAMA models.
+- 2023/12/28: Added `lora` fine-tuning support.
+- 2023/12/27: Add `gradient checkpointing`, `causual sampling`, and `flash-attn` support.
+- 2023/12/19: Updated webui and HTTP API.
+- 2023/12/18: Updated fine-tuning documentation and related examples.
+- 2023/12/17: Updated `text2semantic` model, supporting phoneme-free mode.
+- 2023/12/13: Beta version released, includes VQGAN model and a language model based on LLAMA (phoneme support only).
+
+## Acknowledgements
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/en/inference.md b/docs/en/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..6cd84833391eb099ee97d0f533e29ee4d4a58ab6
--- /dev/null
+++ b/docs/en/inference.md
@@ -0,0 +1,153 @@
+# Inference
+
+Inference support command line, HTTP API and web UI.
+
+!!! note
+ Overall, reasoning consists of several parts:
+
+ 1. Encode a given ~10 seconds of voice using VQGAN.
+ 2. Input the encoded semantic tokens and the corresponding text into the language model as an example.
+ 3. Given a new piece of text, let the model generate the corresponding semantic tokens.
+ 4. Input the generated semantic tokens into VITS / VQGAN to decode and generate the corresponding voice.
+
+## Command Line Inference
+
+Download the required `vqgan` and `llama` models from our Hugging Face repository.
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+### 1. Generate prompt from voice:
+
+!!! note
+ If you plan to let the model randomly choose a voice timbre, you can skip this step.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+You should get a `fake.npy` file.
+
+### 2. Generate semantic tokens from text:
+
+```bash
+python tools/llama/generate.py \
+ --text "The text you want to convert" \
+ --prompt-text "Your reference text" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --num-samples 2 \
+ --compile
+```
+
+This command will create a `codes_N` file in the working directory, where N is an integer starting from 0.
+
+!!! note
+ You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second).
+ Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.
+
+!!! info
+ For GPUs that do not support bf16, you may need to use the `--half` parameter.
+
+### 3. Generate vocals from semantic tokens:
+
+#### VQGAN Decoder
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+## HTTP API Inference
+
+We provide a HTTP API for inference. You can use the following command to start the server:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+If you want to speed up inference, you can add the --compile parameter.
+
+After that, you can view and test the API at http://127.0.0.1:8080/.
+
+Below is an example of sending a request using `tools/post_api.py`.
+
+```bash
+python -m tools.post_api \
+ --text "Text to be input" \
+ --reference_audio "Path to reference audio" \
+ --reference_text "Text content of the reference audio" \
+ --streaming True
+```
+
+The above command indicates synthesizing the desired audio according to the reference audio information and returning it in a streaming manner.
+
+If you need to randomly select reference audio based on `{SPEAKER}` and `{EMOTION}`, configure it according to the following steps:
+
+### 1. Create a `ref_data` folder in the root directory of the project.
+
+### 2. Create a directory structure similar to the following within the `ref_data` folder.
+
+```
+.
+├── SPEAKER1
+│ ├──EMOTION1
+│ │ ├── 21.15-26.44.lab
+│ │ ├── 21.15-26.44.wav
+│ │ ├── 27.51-29.98.lab
+│ │ ├── 27.51-29.98.wav
+│ │ ├── 30.1-32.71.lab
+│ │ └── 30.1-32.71.flac
+│ └──EMOTION2
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPEAKER2
+ └─── EMOTION3
+ ├── 30.1-32.71.lab
+ └── 30.1-32.71.mp3
+```
+
+That is, first place `{SPEAKER}` folders in `ref_data`, then place `{EMOTION}` folders under each speaker, and place any number of `audio-text pairs` under each emotion folder.
+
+### 3. Enter the following command in the virtual environment
+
+```bash
+python tools/gen_ref.py
+
+```
+
+### 4. Call the API.
+
+```bash
+python -m tools.post_api \
+ --text "Text to be input" \
+ --speaker "${SPEAKER1}" \
+ --emotion "${EMOTION1}" \
+ --streaming True
+```
+
+The above example is for testing purposes only.
+
+## WebUI Inference
+
+You can start the WebUI using the following command:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+!!! note
+ You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.
+
+Enjoy!
diff --git a/docs/en/samples.md b/docs/en/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..a079c0c3e29ff7a7e1bf1b1c9143903cb3394457
--- /dev/null
+++ b/docs/en/samples.md
@@ -0,0 +1,223 @@
+# Samples
+
+v1.2 samples are available on [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/).
+
+The following samples are from the v1.1 model.
+
+## Chinese Sentence 1
+```
+人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Nahida (Genshin Impact)
+
+
+
+
+ Zhongli (Genshin Impact)
+
+
+
+
+ Furina (Genshin Impact)
+
+
+
+
+ Random Speaker 1
+ -
+
+
+
+ Random Speaker 2
+ -
+
+
+
+
+
+
+## Chinese Sentence 2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Nahida (Genshin Impact)
+
+
+
+
+ Random Speaker
+ -
+
+
+
+
+
+
+## Chinese Sentence 3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Random Speaker
+ -
+
+
+
+
+
+## English Sentence 1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Random Speaker 1
+ -
+
+
+
+ Random Speaker 2
+ -
+
+
+
+
+
+## English Sentence 2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Random Speaker
+ -
+
+
+
+
+
+## Japanese Sentence 1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Random Speaker 1
+ -
+
+
+
+ Random Speaker 2
+ -
+
+
+
+
+
+## Japanese Sentence 2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ Speaker
+ Input Audio
+ Synthesized Audio
+
+
+
+
+ Random Speaker
+ -
+
+
+
+
diff --git a/docs/ja/finetune.md b/docs/ja/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..1549528536decd02f0729a6898260bff04050767
--- /dev/null
+++ b/docs/ja/finetune.md
@@ -0,0 +1,125 @@
+# 微調整
+
+明らかに、このページを開いたとき、few-shot 事前トレーニングモデルのパフォーマンスに満足していなかったことでしょう。データセット上でのパフォーマンスを向上させるためにモデルを微調整したいと考えています。
+
+現在のバージョンでは、「LLAMA」部分のみを微調整する必要があります。
+
+## LLAMAの微調整
+### 1. データセットの準備
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+データセットを上記の形式に変換し、「data」ディレクトリに配置する必要があります。音声ファイルの拡張子は「.mp3」、「.wav」、または「.flac」にすることができ、注釈ファイルの拡張子は「.lab」にする必要があります。
+
+!!! warning
+ データセットにラウドネス正規化を適用することをお勧めします。これを行うには、[fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) を使用できます。
+
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+
+### 2. セマンティックトークンのバッチ抽出
+
+VQGANの重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+次に、次のコマンドを実行してセマンティックトークンを抽出できます。
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+!!! note
+ `--num-workers` と `--batch-size` を調整して抽出速度を上げることができますが、GPUメモリの制限を超えないようにしてください。
+ VITS形式の場合、`--filelist xxx.list` を使用してファイルリストを指定できます。
+
+このコマンドは、`data`ディレクトリに`.npy`ファイルを作成します。以下のように表示されます。
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. データセットをprotobufにパックする
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+コマンドの実行が完了すると、`data`ディレクトリに`quantized-dataset-ft.protos`ファイルが表示されます。
+
+### 4. 最後に、LoRAを使用して微調整する
+
+同様に、`LLAMA`の重みをダウンロードしたことを確認してください。まだダウンロードしていない場合は、次のコマンドを実行してください。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+最後に、次のコマンドを実行して微調整を開始できます。
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ `fish_speech/configs/text2semantic_finetune.yaml` を変更して、`batch_size`、`gradient_accumulation_steps` などのトレーニングパラメータを変更し、GPUメモリに適合させることができます。
+
+!!! note
+ Windowsユーザーの場合、`trainer.strategy.process_group_backend=gloo` を使用して `nccl` の問題を回避できます。
+
+トレーニングが完了したら、[推論](inference.md)セクションを参照し、`--speaker SPK1` を使用して音声を生成します。
+
+!!! info
+ デフォルトでは、モデルは話者の発話パターンのみを学習し、音色は学習しません。音色の安定性を確保するためにプロンプトを使用する必要があります。
+ 音色を学習したい場合は、トレーニングステップ数を増やすことができますが、これにより過学習が発生する可能性があります。
+
+トレーニングが完了したら、推論を行う前にLoRAの重みを通常の重みに変換する必要があります。
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.2-sft \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.2-sft-yth-lora/
+```
+!!! note
+ 他のチェックポイントを試すこともできます。要件を満たす最も早いチェックポイントを使用することをお勧めします。これらは通常、分布外(OOD)データでより良いパフォーマンスを発揮します。
diff --git a/docs/ja/index.md b/docs/ja/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..1a04730d7d8d138ea7111cda7145642f2daf05da
--- /dev/null
+++ b/docs/ja/index.md
@@ -0,0 +1,128 @@
+# イントロダクション
+
+
+
+!!! warning
+ 私たちは、コードベースの違法な使用について一切の責任を負いません。お住まいの地域の DMCA(デジタルミレニアム著作権法)およびその他の関連法については、現地の法律を参照してください。
+ このコードベースは `BSD-3-Clause` ライセンスの下でリリースされており、すべてのモデルは CC-BY-NC-SA-4.0 ライセンスの下でリリースされています。
+
+
+
+
+
+## 要件
+
+- GPU メモリ: 4GB(推論用)、8GB(微調整用)
+- システム: Linux、Windows
+
+## Windows セットアップ
+
+Windows のプロユーザーは、コードベースを実行するために WSL2 または Docker を検討することができます。
+
+非プロの Windows ユーザーは、Linux 環境なしでコードベースを実行するために以下の方法を検討することができます(モデルコンパイル機能付き、つまり `torch.compile`):
+
+
+ プロジェクトパッケージを解凍します。
+ install_env.bat
をクリックして環境をインストールします。
+
+ install_env.bat
のUSE_MIRROR
項目を編集して、ミラーサイトを使用するかどうかを決定できます。
+ USE_MIRROR=false
は、最新の安定版torch
をオリジナルサイトからダウンロードします。USE_MIRROR=true
は、最新のtorch
をミラーサイトからダウンロードします。デフォルトはtrue
です。
+ install_env.bat
のINSTALL_TYPE
項目を編集して、コンパイル環境のダウンロードを有効にするかどうかを決定できます。
+ INSTALL_TYPE=preview
は、コンパイル環境付きのプレビュー版をダウンロードします。INSTALL_TYPE=stable
は、コンパイル環境なしの安定版をダウンロードします。
+
+
+ ステップ2でUSE_MIRROR=preview
の場合、このステップを実行します(オプション、コンパイルモデル環境を有効にするため):
+
+ 以下のリンクを使用してLLVMコンパイラをダウンロードします:
+
+
+ Microsoft Visual C++ 再頒布可能パッケージをダウンロードしてインストールし、潜在的な.dllの欠落問題を解決します。
+
+
+ Visual Studio Community Editionをダウンロードしてインストールし、MSVC++ビルドツールを取得し、LLVMのヘッダーファイル依存関係を解決します。
+
+ Visual Studio ダウンロード
+ Visual Studio Installerをインストールした後、Visual Studio Community 2022をダウンロードします。
+ 以下の図のようにModify
ボタンをクリックし、Desktop development with C++
オプションを見つけてチェックしてダウンロードします。
+
+
+
+
+
+ インストール CUDA Toolkit 12
+
+
+ start.bat
をダブルクリックして、Fish-Speechトレーニング推論設定WebUIページに入ります。
+
+
+ (オプション)run_cmd.bat
をダブルクリックして、このプロジェクトのconda/pythonコマンドライン環境に入ります。
+
+
+## Linux セットアップ
+
+```bash
+# python 3.10仮想環境を作成します。virtualenvも使用できます。
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# pytorchをインストールします。
+pip3 install torch torchvision torchaudio
+
+# fish-speechをインストールします。
+pip3 install -e .
+
+# (Ubuntu / Debianユーザー) soxをインストールします。
+apt install libsox-dev
+```
+
+## 変更履歴
+
+- 2024/07/02: Fish-Speech を 1.2 バージョンに更新し、VITS デコーダーを削除し、ゼロショット能力を大幅に強化しました。
+- 2024/05/10: Fish-Speech を 1.1 バージョンに更新し、VITS デコーダーを実装して WER を減少させ、音色の類似性を向上させました。
+- 2024/04/22: Fish-Speech 1.0 バージョンを完成させ、VQGAN および LLAMA モデルを大幅に修正しました。
+- 2023/12/28: `lora`微調整サポートを追加しました。
+- 2023/12/27: `gradient checkpointing`、`causual sampling`、および`flash-attn`サポートを追加しました。
+- 2023/12/19: webui および HTTP API を更新しました。
+- 2023/12/18: 微調整ドキュメントおよび関連例を更新しました。
+- 2023/12/17: `text2semantic`モデルを更新し、音素フリーモードをサポートしました。
+- 2023/12/13: ベータ版をリリースし、VQGAN モデルおよび LLAMA に基づく言語モデル(音素のみサポート)を含みます。
+
+## 謝辞
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/ja/inference.md b/docs/ja/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..308061979c42ab566086a7b3d7a01c47431a390b
--- /dev/null
+++ b/docs/ja/inference.md
@@ -0,0 +1,157 @@
+# 推論
+
+推論は、コマンドライン、HTTP API、および Web UI をサポートしています。
+
+!!! note
+ 全体として、推論は次のいくつかの部分で構成されています:
+
+ 1. VQGANを使用して、与えられた約10秒の音声をエンコードします。
+ 2. エンコードされたセマンティックトークンと対応するテキストを例として言語モデルに入力します。
+ 3. 新しいテキストが与えられた場合、モデルに対応するセマンティックトークンを生成させます。
+ 4. 生成されたセマンティックトークンをVITS / VQGANに入力してデコードし、対応する音声を生成します。
+
+## コマンドライン推論
+
+必要な`vqgan`および`llama`モデルを Hugging Face リポジトリからダウンロードします。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+### 1. 音声からプロンプトを生成する:
+
+!!! note
+ モデルにランダムに音声の音色を選ばせる場合、このステップをスキップできます。
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+`fake.npy`ファイルが生成されるはずです。
+
+### 2. テキストからセマンティックトークンを生成する:
+
+```bash
+python tools/llama/generate.py \
+ --text "変換したいテキスト" \
+ --prompt-text "参照テキスト" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --num-samples 2 \
+ --compile
+```
+
+このコマンドは、作業ディレクトリに`codes_N`ファイルを作成します。ここで、N は 0 から始まる整数です。
+
+!!! note
+ `--compile`を使用して CUDA カーネルを融合し、より高速な推論を実現することができます(約 30 トークン/秒 -> 約 500 トークン/秒)。
+ それに対応して、加速を使用しない場合は、`--compile`パラメータをコメントアウトできます。
+
+!!! info
+ bf16 をサポートしていない GPU の場合、`--half`パラメータを使用する必要があるかもしれません。
+
+### 3. セマンティックトークンから音声を生成する:
+
+#### VQGAN デコーダー
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+## HTTP API 推論
+
+推論のための HTTP API を提供しています。次のコマンドを使用してサーバーを起動できます:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+推論を高速化したい場合は、--compile パラメータを追加できます。
+
+その後、`http://127.0.0.1:8080/`で API を表示およびテストできます。
+
+以下は、`tools/post_api.py` を使用してリクエストを送信する例です。
+
+```bash
+python -m tools.post_api \
+ --text "入力するテキスト" \
+ --reference_audio "参照音声へのパス" \
+ --reference_text "参照音声テキスト" \
+ --streaming True
+```
+
+上記のコマンドは、参照音声の情報に基づいて必要な音声を合成し、ストリーミング方式で返すことを示しています。
+
+`{SPEAKER}`と`{EMOTION}`に基づいて参照音声をランダムに選択する必要がある場合は、以下の手順に従って設定します:
+
+### 1. プロジェクトのルートディレクトリに`ref_data`フォルダを作成します。
+
+### 2. `ref_data`フォルダ内に次のような構造のディレクトリを作成します。
+
+```
+.
+├── SPEAKER1
+│ ├──EMOTION1
+│ │ ├── 21.15-26.44.lab
+│ │ ├── 21.15-26.44.wav
+│ │ ├── 27.51-29.98.lab
+│ │ ├── 27.51-29.98.wav
+│ │ ├── 30.1-32.71.lab
+│ │ └── 30.1-32.71.flac
+│ └──EMOTION2
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPEAKER2
+ └─── EMOTION3
+ ├── 30.1-32.71.lab
+ └── 30.1-32.71.mp3
+
+```
+
+つまり、まず`ref_data`に`{SPEAKER}`フォルダを配置し、各スピーカーの下に`{EMOTION}`フォルダを配置し、各感情フォルダの下に任意の数の音声-テキストペアを配置します
+
+### 3. 仮想環境で以下のコマンドを入力します.
+
+```bash
+python tools/gen_ref.py
+
+```
+
+参照ディレクトリを生成します。
+
+### 4. API を呼び出します。
+
+```bash
+python -m tools.post_api \
+ --text "入力するテキスト" \
+ --speaker "${SPEAKER1}" \
+ --emotion "${EMOTION1}" \
+ --streaming True
+
+```
+
+上記の例はテスト目的のみです。
+
+## WebUI 推論
+
+次のコマンドを使用して WebUI を起動できます:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+!!! note
+ Gradio 環境変数(`GRADIO_SHARE`、`GRADIO_SERVER_PORT`、`GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。
+
+お楽しみください!
diff --git a/docs/ja/samples.md b/docs/ja/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..f94e83d02d7f1d670cbeb75499c1abcfef71bac1
--- /dev/null
+++ b/docs/ja/samples.md
@@ -0,0 +1,223 @@
+# サンプル
+
+v1.2のサンプルは[Bilibili](https://www.bilibili.com/video/BV1wz421B71D/)で利用可能です。
+
+以下のサンプルはv1.1モデルからのものです。
+
+## 中国語の文1
+```
+人間灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ナヒーダ (原神)
+
+
+
+
+ 鍾離 (原神)
+
+
+
+
+ フリナ (原神)
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+
+## 中国語の文2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ナヒーダ (原神)
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+
+## 中国語の文3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+## 英語の文1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+## 英語の文2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
+
+## 日本語の文1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者1
+ -
+
+
+
+ ランダム話者2
+ -
+
+
+
+
+
+## 日本語の文2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ 話者
+ 入力音声
+ 合成音声
+
+
+
+
+ ランダム話者
+ -
+
+
+
+
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6e145dbea1b9b26b2bddd7500e3f270b3eb0009
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,3 @@
+mkdocs-material
+mkdocs-static-i18n[material]
+mkdocs[i18n]
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
new file mode 100644
index 0000000000000000000000000000000000000000..a88af87b3cdbfd2d6b05f39877d5821bb7ebe119
--- /dev/null
+++ b/docs/stylesheets/extra.css
@@ -0,0 +1,3 @@
+.md-grid {
+ max-width: 1440px;
+}
diff --git a/docs/zh/finetune.md b/docs/zh/finetune.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d47e6022e910738b74ddcc40ba115dab44cff9b
--- /dev/null
+++ b/docs/zh/finetune.md
@@ -0,0 +1,136 @@
+# 微调
+
+显然, 当你打开这个页面的时候, 你已经对预训练模型 zero-shot 的效果不算满意. 你想要微调一个模型, 使得它在你的数据集上表现更好.
+
+在目前版本,你只需要微调'LLAMA'部分即可.
+
+## LLAMA 微调
+### 1. 准备数据集
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPK2
+ ├── 38.79-40.85.lab
+ └── 38.79-40.85.mp3
+```
+
+你需要将数据集转为以上格式, 并放到 `data` 下, 音频后缀可以为 `.mp3`, `.wav` 或 `.flac`, 标注文件后缀建议为 `.lab`.
+
+!!! warning
+ 建议先对数据集进行响度匹配, 你可以使用 [fish-audio-preprocess](https://github.com/fishaudio/audio-preprocess) 来完成这一步骤.
+ ```bash
+ fap loudness-norm data-raw data --clean
+ ```
+
+### 2. 批量提取语义 token
+
+确保你已经下载了 vqgan 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+随后可运行以下命令来提取语义 token:
+
+```bash
+python tools/vqgan/extract_vq.py data \
+ --num-workers 1 --batch-size 16 \
+ --config-name "firefly_gan_vq" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+!!! note
+ 你可以调整 `--num-workers` 和 `--batch-size` 来提高提取速度, 但是请注意不要超过你的显存限制.
+
+该命令会在 `data` 目录下创建 `.npy` 文件, 如下所示:
+
+```
+.
+├── SPK1
+│ ├── 21.15-26.44.lab
+│ ├── 21.15-26.44.mp3
+│ ├── 21.15-26.44.npy
+│ ├── 27.51-29.98.lab
+│ ├── 27.51-29.98.mp3
+│ ├── 27.51-29.98.npy
+│ ├── 30.1-32.71.lab
+│ ├── 30.1-32.71.mp3
+│ └── 30.1-32.71.npy
+└── SPK2
+ ├── 38.79-40.85.lab
+ ├── 38.79-40.85.mp3
+ └── 38.79-40.85.npy
+```
+
+### 3. 打包数据集为 protobuf
+
+```bash
+python tools/llama/build_dataset.py \
+ --input "data" \
+ --output "data/protos" \
+ --text-extension .lab \
+ --num-workers 16
+```
+
+命令执行完毕后, 你应该能在 `data` 目录下看到 `protos` 文件.
+
+
+### 4. 最后, 使用 LoRA 进行微调
+
+同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+对于中国大陆用户, 可使用 mirror 下载.
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+最后, 你可以运行以下命令来启动微调:
+
+```bash
+python fish_speech/train.py --config-name text2semantic_finetune \
+ project=$project \
+ +lora@model.model.lora_config=r_8_alpha_16
+```
+
+!!! note
+ 你可以通过修改 `fish_speech/configs/text2semantic_finetune.yaml` 来修改训练参数如 `batch_size`, `gradient_accumulation_steps` 等, 来适应你的显存.
+
+!!! note
+ 对于 Windows 用户, 你可以使用 `trainer.strategy.process_group_backend=gloo` 来避免 `nccl` 的问题.
+
+训练结束后, 你可以参考 [推理](inference.md) 部分, 并携带 `--speaker SPK1` 参数来测试你的模型.
+
+!!! info
+ 默认配置下, 基本只会学到说话人的发音方式, 而不包含音色, 你依然需要使用 prompt 来保证音色的稳定性.
+ 如果你想要学到音色, 请将训练步数调大, 但这有可能会导致过拟合.
+
+训练完成后, 你需要先将 loRA 的权重转为普通权重, 然后再进行推理.
+
+```bash
+python tools/llama/merge_lora.py \
+ --lora-config r_8_alpha_16 \
+ --base-weight checkpoints/fish-speech-1.2-sft \
+ --lora-weight results/$project/checkpoints/step_000000010.ckpt \
+ --output checkpoints/fish-speech-1.2-sft-yth-lora/
+```
+
+!!! note
+ 你也可以尝试其他的 checkpoint, 我们建议你使用最早的满足你要求的 checkpoint, 他们通常在 OOD 上表现更好.
diff --git a/docs/zh/index.md b/docs/zh/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..8698a273f2feb0c22f7aaf2f2408b831294c0b4c
--- /dev/null
+++ b/docs/zh/index.md
@@ -0,0 +1,118 @@
+# 介绍
+
+
+
+!!! warning
+ 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
+ 此代码库根据 `BSD-3-Clause` 许可证发布, 所有模型根据 CC-BY-NC-SA-4.0 许可证发布.
+
+
+
+
+
+## 要求
+
+- GPU 内存: 4GB (用于推理), 8GB (用于微调)
+- 系统: Linux, Windows
+
+## Windows 配置
+
+Windows 专业用户可以考虑 WSL2 或 docker 来运行代码库。
+
+Windows 非专业用户可考虑以下为免 Linux 环境的基础运行方法(附带模型编译功能,即 `torch.compile`):
+
+
+1. 解压项目压缩包。
+2. 点击 `install_env.bat` 安装环境。
+ - 可以通过编辑 `install_env.bat` 的 `USE_MIRROR` 项来决定是否使用镜像站下载。
+ - `USE_MIRROR=false` 使用原始站下载最新稳定版 `torch` 环境。`USE_MIRROR=true` 为从镜像站下载最新 `torch` 环境。默认为 `true`。
+ - 可以通过编辑 `install_env.bat` 的 `INSTALL_TYPE` 项来决定是否启用可编译环境下载。
+ - `INSTALL_TYPE=preview` 下载开发版编译环境。`INSTALL_TYPE=stable` 下载稳定版不带编译环境。
+3. 若第2步 `INSTALL_TYPE=preview` 则执行这一步(可跳过,此步为激活编译模型环境)
+ 1. 使用如下链接下载 LLVM 编译器。
+ - [LLVM-17.0.6(原站站点下载)](https://huggingface.co/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - [LLVM-17.0.6(镜像站点下载)](https://hf-mirror.com/fishaudio/fish-speech-1/resolve/main/LLVM-17.0.6-win64.exe?download=true)
+ - 下载完 `LLVM-17.0.6-win64.exe` 后,双击进行安装,选择合适的安装位置,最重要的是勾选 `Add Path to Current User` 添加环境变量。
+ - 确认安装完成。
+ 2. 下载安装 Microsoft Visual C++ 可再发行程序包,解决潜在 .dll 丢失问题。
+ - [MSVC++ 14.40.33810.0 下载](https://aka.ms/vs/17/release/vc_redist.x64.exe)
+ 3. 下载安装 Visual Studio 社区版以获取 MSVC++ 编译工具, 解决 LLVM 的头文件依赖问题。
+ - [Visual Studio 下载](https://visualstudio.microsoft.com/zh-hans/downloads/)
+ - 安装好Visual Studio Installer之后,下载Visual Studio Community 2022
+ - 如下图点击`修改`按钮,找到`使用C++的桌面开发`项,勾选下载
+ 4. 下载安装 [CUDA Toolkit 12](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Windows&target_arch=x86_64)
+4. 双击 `start.bat` 打开训练推理WebUI管理界面. 如有需要,可照下列提示修改`API_FLAGS`.
+
+!!! info "可选"
+
+ 想启动 推理 WebUI 界面?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式:
+ ```
+ --infer
+ # --api
+ # --listen ...
+ ...
+ ```
+
+!!! info "可选"
+
+ 想启动 API 服务器?编辑项目根目录下的 `API_FLAGS.txt`, 前三行修改成如下格式:
+ ```
+ # --infer
+ --api
+ --listen ...
+ ...
+ ```
+
+!!! info "可选"
+
+ 双击 `run_cmd.bat` 进入本项目的 conda/python 命令行环境
+
+
+## Linux 配置
+
+```bash
+# 创建一个 python 3.10 虚拟环境, 你也可以用 virtualenv
+conda create -n fish-speech python=3.10
+conda activate fish-speech
+
+# 安装 pytorch
+pip3 install torch torchvision torchaudio
+
+# 安装 fish-speech
+pip3 install -e .
+
+# (Ubuntu / Debian 用户) 安装 sox
+apt install libsox-dev
+```
+
+## 更新日志
+
+- 2024/07/02: 更新了 Fish-Speech 到 1.2 版本,移除 VITS Decoder,同时极大幅度提升 zero-shot 能力.
+- 2024/05/10: 更新了 Fish-Speech 到 1.1 版本,引入了 VITS Decoder 来降低口胡和提高音色相似度.
+- 2024/04/22: 完成了 Fish-Speech 1.0 版本, 大幅修改了 VQGAN 和 LLAMA 模型.
+- 2023/12/28: 添加了 `lora` 微调支持.
+- 2023/12/27: 添加了 `gradient checkpointing`, `causual sampling` 和 `flash-attn` 支持.
+- 2023/12/19: 更新了 Webui 和 HTTP API.
+- 2023/12/18: 更新了微调文档和相关例子.
+- 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式.
+- 2023/12/13: 测试版发布, 包含 VQGAN 模型和一个基于 LLAMA 的语言模型 (只支持音素).
+
+## 致谢
+
+- [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2)
+- [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
+- [GPT VITS](https://github.com/innnky/gpt-vits)
+- [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)
diff --git a/docs/zh/inference.md b/docs/zh/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..80df182161fb4eb8dc0a617ee4d66f6b6ecc1c92
--- /dev/null
+++ b/docs/zh/inference.md
@@ -0,0 +1,164 @@
+# 推理
+
+推理支持命令行, http api, 以及 webui 三种方式.
+
+!!! note
+ 总的来说, 推理分为几个部分:
+
+ 1. 给定一段 ~10 秒的语音, 将它用 VQGAN 编码.
+ 2. 将编码后的语义 token 和对应文本输入语言模型作为例子.
+ 3. 给定一段新文本, 让模型生成对应的语义 token.
+ 4. 将生成的语义 token 输入 VQGAN 解码, 生成对应的语音.
+
+## 命令行推理
+
+从我们的 huggingface 仓库下载所需的 `vqgan` 和 `llama` 模型。
+
+```bash
+huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+对于中国大陆用户,可使用 mirror 下载。
+
+```bash
+HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft
+```
+
+### 1. 从语音生成 prompt:
+
+!!! note
+ 如果你打算让模型随机选择音色, 你可以跳过这一步.
+
+```bash
+python tools/vqgan/inference.py \
+ -i "paimon.wav" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+你应该能得到一个 `fake.npy` 文件.
+
+### 2. 从文本生成语义 token:
+
+```bash
+python tools/llama/generate.py \
+ --text "要转换的文本" \
+ --prompt-text "你的参考文本" \
+ --prompt-tokens "fake.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --num-samples 2 \
+ --compile
+```
+
+该命令会在工作目录下创建 `codes_N` 文件, 其中 N 是从 0 开始的整数.
+
+!!! note
+ 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).
+ 对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.
+
+!!! info
+ 对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.
+
+### 3. 从语义 token 生成人声:
+
+#### VQGAN 解码
+
+```bash
+python tools/vqgan/inference.py \
+ -i "codes_0.npy" \
+ --checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+```
+
+## HTTP API 推理
+
+运行以下命令来启动 HTTP 服务:
+
+```bash
+python -m tools.api \
+ --listen 0.0.0.0:8080 \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+如果你想要加速推理,可以加上`--compile`参数。
+
+推荐中国大陆用户运行以下命令来启动 HTTP 服务:
+```bash
+HF_ENDPOINT=https://hf-mirror.com python -m ...(同上)
+```
+
+随后, 你可以在 `http://127.0.0.1:8080/` 中查看并测试 API.
+
+下面是使用`tools/post_api.py`发送请求的示例。
+
+```bash
+python -m tools.post_api \
+ --text "要输入的文本" \
+ --reference_audio "参考音频路径" \
+ --reference_text "参考音频的文本内容" \
+ --streaming True
+```
+
+上面的命令表示按照参考音频的信息,合成所需的音频并流式返回.
+
+如果需要通过`{说话人}`和`{情绪}`随机选择参考音频,那么就根据下列步骤配置:
+
+### 1. 在项目根目录创建`ref_data`文件夹.
+
+### 2. 在`ref_data`文件夹内创建类似如下结构的目录.
+
+```
+.
+├── SPEAKER1
+│ ├──EMOTION1
+│ │ ├── 21.15-26.44.lab
+│ │ ├── 21.15-26.44.wav
+│ │ ├── 27.51-29.98.lab
+│ │ ├── 27.51-29.98.wav
+│ │ ├── 30.1-32.71.lab
+│ │ └── 30.1-32.71.flac
+│ └──EMOTION2
+│ ├── 30.1-32.71.lab
+│ └── 30.1-32.71.mp3
+└── SPEAKER2
+ └─── EMOTION3
+ ├── 30.1-32.71.lab
+ └── 30.1-32.71.mp3
+```
+
+也就是`ref_data`里先放`{说话人}`文件夹, 每个说话人下再放`{情绪}`文件夹, 每个情绪文件夹下放任意个`音频-文本对`。
+
+### 3. 在虚拟环境里输入
+
+```bash
+python tools/gen_ref.py
+```
+
+生成参考目录.
+
+### 4. 调用 api.
+
+```bash
+python -m tools.post_api \
+ --text "要输入的文本" \
+ --speaker "说话人1" \
+ --emotion "情绪1" \
+ --streaming True
+```
+
+以上示例仅供测试.
+
+## WebUI 推理
+
+你可以使用以下命令来启动 WebUI:
+
+```bash
+python -m tools.webui \
+ --llama-checkpoint-path "checkpoints/fish-speech-1.2-sft" \
+ --decoder-checkpoint-path "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth" \
+ --decoder-config-name firefly_gan_vq
+```
+
+!!! note
+ 你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.
+
+祝大家玩得开心!
diff --git a/docs/zh/samples.md b/docs/zh/samples.md
new file mode 100644
index 0000000000000000000000000000000000000000..b4d0fab1d801ce6c55916e7a6f0a261ec4373849
--- /dev/null
+++ b/docs/zh/samples.md
@@ -0,0 +1,223 @@
+# 例子
+
+v1.2 的样本可以在 [Bilibili](https://www.bilibili.com/video/BV1wz421B71D/) 观看。
+
+以下样本来自 v1.1 版本的模型。
+
+## 中文句子 1
+```
+人间灯火倒映湖中,她的渴望让静水泛起涟漪。若代价只是孤独,那就让这份愿望肆意流淌。
+流入她所注视的世间,也流入她如湖水般澄澈的目光。
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 纳西妲 (原神)
+
+
+
+
+ 钟离 (原神)
+
+
+
+
+ 芙宁娜 (原神)
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+
+## 中文句子 2
+```
+你们这个是什么群啊,你们这是害人不浅啊你们这个群!谁是群主,出来!真的太过分了。你们搞这个群干什么?
+我儿子每一科的成绩都不过那个平均分呐,他现在初二,你叫我儿子怎么办啊?他现在还不到高中啊?
+你们害死我儿子了!快点出来你这个群主!再这样我去报警了啊!我跟你们说你们这一帮人啊,一天到晚啊,
+搞这些什么游戏啊,动漫啊,会害死你们的,你们没有前途我跟你说。你们这九百多个人,好好学习不好吗?
+一天到晚在上网。有什么意思啊?麻烦你重视一下你们的生活的目标啊?有一点学习目标行不行?一天到晚上网是不是人啊?
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 纳西妲 (原神)
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+
+## 中文句子 3
+```
+大家好,我是 Fish Audio 开发的开源文本转语音模型。经过十五万小时的数据训练,
+我已经能够熟练掌握中文、日语和英语,我的语言处理能力接近人类水平,声音表现形式丰富多变。
+作为一个仅有亿级参数的模型,我相信社区成员能够在个人设备上轻松运行和微调,让我成为您的私人语音助手。
+```
+
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+## 英文句子 1
+
+```
+In the realm of advanced technology, the evolution of artificial intelligence stands as a
+monumental achievement. This dynamic field, constantly pushing the boundaries of what
+machines can do, has seen rapid growth and innovation. From deciphering complex data
+patterns to driving cars autonomously, AI's applications are vast and diverse.
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+## 英文句子 2
+```
+Hello everyone, I am an open-source text-to-speech model developed by
+Fish Audio. After training with 150,000 hours of data, I have become proficient
+in Chinese, Japanese, and English, and my language processing abilities
+are close to human level. My voice is capable of a wide range of expressions.
+As a model with only hundreds of millions of parameters, I believe community
+members can easily run and fine-tune me on their personal devices, allowing
+me to serve as your personal voice assistant.
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
+
+## 日文句子 1
+
+```
+先進技術の領域において、人工知能の進化は画期的な成果として立っています。常に機械ができることの限界を
+押し広げているこのダイナミックな分野は、急速な成長と革新を見せています。複雑なデータパターンの解読か
+ら自動運転車の操縦まで、AIの応用は広範囲に及びます。
+```
+
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人 1
+ -
+
+
+
+ 随机说话人 2
+ -
+
+
+
+
+
+## 日文句子 2
+```
+皆さん、こんにちは。私はフィッシュオーディオによって開発されたオープンソースのテ
+キストから音声への変換モデルです。15万時間のデータトレーニングを経て、
+中国語、日本語、英語を熟知しており、言語処理能力は人間に近いレベルです。
+声の表現も多彩で豊かです。数億のパラメータを持つこのモデルは、コミュニティ
+のメンバーが個人のデバイスで簡単に実行し、微調整することができると
+信じています。これにより、私を個人の音声アシスタントとして活用できます。
+```
+
+
+
+
+ 说话人
+ 输入音频
+ 合成音频
+
+
+
+
+ 随机说话人
+ -
+
+
+
+
diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3
--- /dev/null
+++ b/fish_speech/callbacks/__init__.py
@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]
diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a
--- /dev/null
+++ b/fish_speech/callbacks/grad_norm.py
@@ -0,0 +1,113 @@
+from typing import Optional, Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor, nn
+from torch.utils._foreach_utils import (
+ _group_tensors_by_device_and_dtype,
+ _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+ parameters: Union[Tensor, list[Tensor]],
+ norm_type: float = 2.0,
+) -> float:
+ """
+ Returns the norm of the gradients of the given parameters.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ norm_type (float): type of the used p-norm.
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """ # noqa: E501
+
+ if isinstance(parameters, Tensor):
+ parameters = [parameters]
+
+ grads = [p.grad for p in parameters if p.grad is not None]
+ if len(grads) == 0:
+ return None
+
+ first_device = grads[0].device
+ grouped_grads: dict[
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
+ ] = _group_tensors_by_device_and_dtype(
+ [[g.detach() for g in grads]]
+ ) # type: ignore[assignment]
+
+ norms = []
+ for (device, _), ([grads], _) in grouped_grads.items():
+ if _has_foreach_support(grads, device=device):
+ norms.extend(torch._foreach_norm(grads, norm_type))
+ else:
+ norms.extend([torch.norm(g, norm_type) for g in grads])
+
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+ """
+ Callback that computes the gradient norm of the model parameters.
+ """
+
+ def __init__(
+ self,
+ norm_type: float = 2.0,
+ logging_interval: str = "step",
+ sub_module: Optional[Union[str, list[str]]] = None,
+ ) -> None:
+ """
+ Args:
+ norm_type (float): type of the used p-norm.
+ logging_interval (str): "step" or "epoch".
+ """
+ super().__init__()
+
+ self.norm_type = norm_type
+ self.logging_interval = logging_interval
+ self.sub_module = sub_module
+
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+ """
+ Computes the gradient norm of the model parameters and logs it to the logger.
+
+ Args:
+ trainer (Trainer): The trainer object
+ model (LightningModule): The current lightningModule
+ """
+
+ lightning_model = model
+
+ if self.sub_module is None:
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
+
+ sub_modules = self.sub_module
+ if isinstance(sub_modules, str):
+ sub_modules = [sub_modules]
+
+ for sub_module in sub_modules:
+ self.log_sub_module_grad_norm(
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
+ )
+
+ def log_sub_module_grad_norm(
+ self, lightning_model: LightningModule, model: nn.Module, path: str
+ ) -> None:
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+ if grad_norm_val is None:
+ return
+
+ on_step = self.logging_interval == "step"
+ lightning_model.log(
+ f"train{path}/grad_norm",
+ grad_norm_val,
+ on_step=on_step,
+ on_epoch=not on_step,
+ )
diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75
--- /dev/null
+++ b/fish_speech/configs/base.yaml
@@ -0,0 +1,87 @@
+# Base configuration for training a model
+paths:
+ run_dir: results/${project}
+ ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+ run:
+ dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+ _target_: lightning.pytorch.trainer.Trainer
+
+ default_root_dir: ${paths.run_dir}
+ accelerator: gpu
+ num_nodes: 1
+ devices: auto
+ strategy:
+ _target_: lightning.pytorch.strategies.DDPStrategy
+ process_group_backend: nccl # This should be override when training on windows
+
+ precision: bf16-mixed
+
+ # disable validation by epoch end
+ check_val_every_n_epoch: null
+ val_check_interval: 5000
+ max_steps: 100_000
+
+ # Use torch.backends.cudnn.benchmark to speed up training
+ benchmark: true
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: ${paths.ckpt_dir}
+ filename: "step_{step:09d}"
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save 5 latest checkpoints
+ monitor: step # use step to monitor checkpoints
+ mode: max # save the latest checkpoint with the highest global_step
+ every_n_epochs: null # don't save checkpoints by epoch end
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
+ auto_insert_metric_name: false
+
+ model_summary:
+ _target_: lightning.pytorch.callbacks.ModelSummary
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+ learning_rate_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: step
+ log_momentum: false
+
+ grad_norm_monitor:
+ _target_: fish_speech.callbacks.GradNormMonitor
+ norm_type: 2
+ logging_interval: step
+
+# Logger
+logger:
+ tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.run_dir}/tensorboard/"
+ name: null
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+
+ # wandb:
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # # name: "" # name of the run (normally generated by wandb)
+ # save_dir: "${paths.run_dir}"
+ # offline: False
+ # id: null # pass correct id to resume experiment!
+ # anonymous: null # enable anonymous logging
+ # project: "fish-speech"
+ # log_model: False # upload lightning ckpts
+ # prefix: "" # a string to put at the beginning of metric keys
+ # # entity: "" # set to name of your wandb team
+ # group: ""
+ # tags: ["vq", "hq", "finetune"]
+ # job_type: ""
+
+# Loop
+train: true
+test: false
diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7417623b03149560d3f8d6eb73c50b78ea5a84d5
--- /dev/null
+++ b/fish_speech/configs/firefly_gan_vq.yaml
@@ -0,0 +1,34 @@
+_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
+spec_transform:
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+ sample_rate: 44100
+ n_mels: 160
+ n_fft: 2048
+ hop_length: 512
+ win_length: 2048
+backbone:
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
+ input_channels: 160
+ depths: [3, 3, 9, 3]
+ dims: [128, 256, 384, 512]
+ drop_path_rate: 0.2
+ kernel_size: 7
+head:
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
+ hop_length: 512
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
+ resblock_kernel_sizes: [3, 7, 11]
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+ num_mels: 512
+ upsample_initial_channel: 512
+ use_template: false
+ pre_conv_kernel_size: 13
+ post_conv_kernel_size: 13
+quantizer:
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+ input_dim: 512
+ n_groups: 4
+ n_codebooks: 1
+ levels: [8, 5, 5, 5]
+ downsample_factor: [2]
diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9
--- /dev/null
+++ b/fish_speech/configs/lora/r_8_alpha_16.yaml
@@ -0,0 +1,4 @@
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
+r: 8
+lora_alpha: 16
+lora_dropout: 0.01
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1bf8fd6b6d8f99f8c5071c12a01ca8af55ec4070
--- /dev/null
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -0,0 +1,83 @@
+defaults:
+ - base
+ - _self_
+
+project: text2semantic_finetune_dual_ar
+max_length: 4096
+pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft
+
+# Lightning Trainer
+trainer:
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ gradient_clip_algorithm: "norm"
+ max_steps: 1000
+ precision: bf16-true
+ limit_val_batches: 10
+ val_check_interval: 100
+
+# Dataset Configuration
+tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
+
+# Dataset Configuration
+train_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+val_dataset:
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ proto_files:
+ - data/protos
+ tokenizer: ${tokenizer}
+ causal: true
+ max_length: ${max_length}
+ use_speaker: false
+ interactive_prob: 0.7
+
+data:
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
+ train_dataset: ${train_dataset}
+ val_dataset: ${val_dataset}
+ num_workers: 4
+ batch_size: 8
+ tokenizer: ${tokenizer}
+ max_length: ${max_length}
+
+# Model Configuration
+model:
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
+ model:
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
+ path: ${pretrained_ckpt_path}
+ load_weights: true
+ max_length: ${max_length}
+ lora_config: null
+
+ optimizer:
+ _target_: torch.optim.AdamW
+ _partial_: true
+ lr: 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.95]
+ eps: 1e-5
+
+ lr_scheduler:
+ _target_: torch.optim.lr_scheduler.LambdaLR
+ _partial_: true
+ lr_lambda:
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
+ _partial_: true
+ num_warmup_steps: 10
+
+# Callbacks
+callbacks:
+ model_checkpoint:
+ every_n_train_steps: ${trainer.val_check_interval}
diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9ca0ef9181754eda7e6b49e01abeafbe07fb00f
--- /dev/null
+++ b/fish_speech/conversation.py
@@ -0,0 +1,2 @@
+SEMANTIC_TOKEN = "<|semantic|>"
+CODEBOOK_PAD_TOKEN_ID = 0
diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa
--- /dev/null
+++ b/fish_speech/datasets/concat_repeat.py
@@ -0,0 +1,53 @@
+import bisect
+import random
+from typing import Iterable
+
+from torch.utils.data import Dataset, IterableDataset
+
+
+class ConcatRepeatDataset(Dataset):
+ datasets: list[Dataset]
+ cumulative_sizes: list[int]
+ repeats: list[int]
+
+ @staticmethod
+ def cumsum(sequence, repeats):
+ r, s = [], 0
+ for dataset, repeat in zip(sequence, repeats):
+ l = len(dataset) * repeat
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
+ super().__init__()
+
+ self.datasets = list(datasets)
+ self.repeats = repeats
+
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
+ assert len(self.datasets) == len(
+ repeats
+ ), "datasets and repeats should have the same length"
+
+ for d in self.datasets:
+ assert not isinstance(
+ d, IterableDataset
+ ), "ConcatRepeatDataset does not support IterableDataset"
+
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+ dataset = self.datasets[dataset_idx]
+
+ return dataset[sample_idx % len(dataset)]
diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto
new file mode 100644
index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379
--- /dev/null
+++ b/fish_speech/datasets/protos/text-data.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package text_data;
+
+message Semantics {
+ repeated uint32 values = 1;
+}
+
+message Sentence {
+ repeated string texts = 1;
+ repeated Semantics semantics = 3;
+}
+
+message TextData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence sentences = 4;
+}
+
+message SampledData {
+ string source = 1;
+ string name = 2;
+ repeated Sentence samples = 3;
+}
diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_pb2.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: text-data.proto
+# Protobuf Python Version: 4.25.1
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals["_SEMANTICS"]._serialized_start = 30
+ _globals["_SEMANTICS"]._serialized_end = 57
+ _globals["_SENTENCE"]._serialized_start = 59
+ _globals["_SENTENCE"]._serialized_end = 125
+ _globals["_TEXTDATA"]._serialized_start = 127
+ _globals["_TEXTDATA"]._serialized_end = 207
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
+# @@protoc_insertion_point(module_scope)
diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107
--- /dev/null
+++ b/fish_speech/datasets/protos/text_data_stream.py
@@ -0,0 +1,36 @@
+import struct
+
+from .text_data_pb2 import TextData
+
+
+def read_pb_stream(f):
+ while True:
+ buf = f.read(4)
+ if len(buf) == 0:
+ break
+ size = struct.unpack("I", buf)[0]
+ buf = f.read(size)
+ text_data = TextData()
+ text_data.ParseFromString(buf)
+ yield text_data
+
+
+def write_pb_stream(f, text_data):
+ buf = text_data.SerializeToString()
+ f.write(struct.pack("I", len(buf)))
+ f.write(buf)
+
+
+def pack_pb_stream(text_data):
+ buf = text_data.SerializeToString()
+ return struct.pack("I", len(buf)) + buf
+
+
+def split_pb_stream(f):
+ while True:
+ head = f.read(4)
+ if len(head) == 0:
+ break
+ size = struct.unpack("I", head)[0]
+ buf = f.read(size)
+ yield head + buf
diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418
--- /dev/null
+++ b/fish_speech/datasets/semantic.py
@@ -0,0 +1,496 @@
+import random
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from random import Random
+from typing import Optional, Union
+
+import numpy as np
+import pyarrow.parquet as pq
+import torch
+import torch.nn.functional as F
+from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
+from torch.distributed import get_rank, get_world_size, is_initialized
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.datasets.protos.text_data_pb2 import SampledData
+from fish_speech.datasets.protos.text_data_stream import read_pb_stream
+from fish_speech.text.clean import clean_text
+from fish_speech.utils import RankedLogger
+from fish_speech.utils.braceexpand import braceexpand
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def split_by_rank_worker(files):
+ # We need to know the total number of devices
+ # to split the data properly
+
+ total_devices = 1
+ if is_initialized():
+ total_devices = get_world_size()
+
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ total_devices *= worker_info.num_workers
+
+ if len(files) < total_devices:
+ # Repeat the files N times to match the number of devices
+ files = files * (total_devices // len(files) + 1)
+
+ # DDP
+ if is_initialized():
+ files = files[get_rank() :: get_world_size()]
+
+ # Split by worker
+ if worker_info is not None:
+ files = files[worker_info.id :: worker_info.num_workers]
+
+ return files
+
+
+class AutoTextSemanticInstructionDataset(IterableDataset):
+ """
+ Auto Augment Dataset by Speaker
+
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+ 2. Automatically normalize the text
+
+ For interactive mode, we use the following format (multiple sequences):
+ [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+ For non-interactive mode, we use the following format (one long sequence):
+ [INST] text [/INST] ...
+ """
+
+ def __init__(
+ self,
+ proto_files: list[str],
+ seed: int = 42,
+ interactive_prob: float = 0.5,
+ max_length: int = 1024,
+ tokenizer: AutoTokenizer = None,
+ use_speaker: bool | float = True,
+ causal: bool = True,
+ num_codebooks: Optional[int] = None,
+ skip_text_prob: float = 0.0,
+ ):
+ """
+ Args:
+ proto_files: proto buf files if using local data
+ seed: random seed
+ interactive_prob: probability to use interactive mode
+ max_length: max length of the text
+ tokenizer: tokenizer
+ use_speaker: include speaker information in the prompt
+ causal: use causal sampling when using local data, disable will lead to random sampling
+ num_codebooks: number of codebooks, if None, it will be automatically detected
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+ """
+
+ super().__init__()
+
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+ self.seed = seed
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+ self.interactive_prob = interactive_prob
+ self.use_speaker = use_speaker
+ self.proto_files = proto_files
+ self.causal = causal
+ self.num_codebooks = num_codebooks
+ self.skip_text_prob = skip_text_prob
+
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
+ self.groups = None
+
+ def init_mock_data_server(self):
+ if self.groups is not None:
+ return
+
+ # Expand the proto files
+ expanded_proto_files = []
+ for filename in self.proto_files:
+ for i in braceexpand(filename):
+ i = Path(i)
+ if i.is_file():
+ expanded_proto_files.append(i)
+ elif i.is_dir():
+ expanded_proto_files.extend(i.rglob("*.proto"))
+ expanded_proto_files.extend(i.rglob("*.protos"))
+ else:
+ raise ValueError(f"{i} is not a file or directory")
+
+ expanded_proto_files = sorted(expanded_proto_files)
+ Random(self.seed).shuffle(expanded_proto_files)
+
+ self.groups = []
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
+ log.info(
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+ )
+
+ count = 0
+ for filename in shard_proto_files:
+ with open(filename, "rb") as f:
+ for text_data in read_pb_stream(f):
+ self.groups.append(text_data)
+ count += 1
+
+ log.info(f"Read total {count} groups of data")
+
+ # Shuffle the lines
+ Random(self.seed).shuffle(self.groups)
+ self.group_weights = [len(i.sentences) for i in self.groups]
+
+ def __iter__(self):
+ while True:
+ yield self.augment()
+
+ def tokenize_sentence(self, sentence: str):
+ sentence = clean_text(sentence)
+ tokens = self.tokenizer.encode(
+ f"{sentence}",
+ max_length=10**6,
+ add_special_tokens=False,
+ truncation=False,
+ )
+ return sentence, len(tokens)
+
+ def sample_data(self):
+ if self.groups is None:
+ self.init_mock_data_server()
+
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
+ num_samples = self.max_length // 20
+
+ # choice group based on their number of samples
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
+
+ if self.causal:
+ # Sample in order
+ if num_samples >= len(group.sentences):
+ samples = group.sentences
+ else:
+ begin = random.randint(0, len(group.sentences) - num_samples)
+ samples = group.sentences[begin : begin + num_samples]
+ else:
+ samples = random.choices(
+ group.sentences, k=min(num_samples, len(group.sentences))
+ )
+
+ return SampledData(
+ source=group.source,
+ name=group.name,
+ samples=samples,
+ )
+
+ def augment(self):
+ final_text, final_semantic = [], []
+ response = self.sample_data()
+ if len(response.samples) == 0:
+ # Invalid group
+ return None
+
+ samples = list(response.samples)
+ idx = 0
+ use_interactive = random.random() < self.interactive_prob
+
+ if use_interactive is False:
+ # Random sample based on speaker using a truncated normal distribution
+ a = torch.tensor([0], dtype=torch.float32)
+ torch.nn.init.trunc_normal_(
+ a,
+ mean=self.max_length // 2,
+ std=self.max_length // 4,
+ a=10,
+ b=self.max_length,
+ )
+ remaining_tokens = a.long().item() - 4
+ else:
+ remaining_tokens = self.max_length
+
+ # Use speaker
+ if isinstance(self.use_speaker, float):
+ use_speaker = random.random() < self.use_speaker
+ else:
+ use_speaker = self.use_speaker
+
+ all_tokens, all_labels = [], []
+ while remaining_tokens > 0 and len(samples) > 0:
+ sentence = samples.pop(0)
+
+ text = random.choice(sentence.texts)
+ text, length = self.tokenize_sentence(text)
+ remaining_tokens -= length + len(sentence.semantics[0].values)
+
+ if use_interactive is False:
+ final_text.append(text)
+ final_semantic.append(sentence.semantics)
+ else:
+ # For interactive mode, we only apply speaker for the first sentence
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+ tokens, labels = self.pack_sentences(
+ sentences=[text],
+ semantics=[sentence.semantics],
+ speaker=response.name if use_speaker else None,
+ skip_text=random.random() < self.skip_text_prob,
+ )
+
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ idx += 1
+
+ if use_interactive is False:
+ tokens, labels = self.pack_sentences(
+ final_text,
+ semantics=final_semantic,
+ speaker=response.name if use_speaker else None,
+ )
+ all_tokens.append(tokens)
+ all_labels.append(labels)
+
+ tokens = torch.cat(all_tokens, dim=1)
+ labels = torch.cat(all_labels, dim=1)
+
+ # Verify that the length is correct
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+ data = {"tokens": tokens, "labels": labels}
+
+ return data
+
+ def pack_sentences(
+ self,
+ sentences: list[str],
+ semantics: list,
+ speaker: Optional[str] = None,
+ skip_text: bool = False,
+ ):
+ if speaker is None:
+ speaker = "assistant"
+
+ cated_sentences = " ".join(sentences)
+ if skip_text:
+ cated_sentences = "<|skip_text|>"
+
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
+ final_text = final_text + f"<|im_start|>{speaker}\n"
+
+ encoded = self.tokenizer.encode(
+ final_text,
+ add_special_tokens=False,
+ truncation=False,
+ max_length=10**6,
+ )
+ semantic_length = sum([len(i[0].values) for i in semantics])
+ prompt_length = len(encoded)
+ num_codebooks = (
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+ )
+
+ # Pack the tokens and semantics (add and to semantic tokens)
+ tokens = (
+ encoded
+ + [self.semantic_token_id] * semantic_length
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
+ )
+
+ # Codebook bos/padding: 0, eos: 1
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
+ for segment in semantics:
+ for book_idx, book in zip(range(num_codebooks), segment):
+ for j in book.values:
+ codes[book_idx].append(int(j) + 1)
+
+ for book in codes:
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
+
+ tokens = [tokens] + codes
+
+ tokens = torch.tensor(tokens, dtype=torch.long)
+ labels = tokens.clone()
+
+ if skip_text:
+ # If text is not provided, the sentence is used for condition only, all labels are -100
+ torch.fill_(labels, -100)
+ return tokens, labels
+
+ # Mask out the tokens for semantic, predict semantic tokens only
+ # Since we don't mask out the input tokens, the language modeling still works
+ labels[1:, :prompt_length] = -100
+
+ tokens = tokens[:, :-1]
+ labels = labels[:, 1:]
+
+ # Verify the padding is correct, and the last token is eos
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+ return tokens, labels
+
+
+@dataclass
+class TextDataCollator:
+ tokenizer: AutoTokenizer
+ max_length: int = 1024
+
+ def __call__(self, examples):
+ if "negative_tokens" in examples:
+ positive_examples = []
+ negative_examples = []
+
+ for i in examples:
+ positive_examples.append(
+ {
+ "tokens": i["tokens"],
+ "labels": i["labels"],
+ }
+ )
+ negative_examples.append(
+ {
+ "tokens": i["negative_tokens"],
+ "labels": i["negative_labels"],
+ }
+ )
+
+ examples = positive_examples + negative_examples
+
+ return self.batchify(examples)
+
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
+ tokens, attention_masks, labels = [], [], []
+
+ # Calculate the max length
+ max_tokens_length = 0
+ for example in examples:
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+ max_tokens_length = min(max_tokens_length, self.max_length)
+
+ for example in examples:
+ _tokens = example[tokens_key][:, :max_tokens_length]
+ _labels = example[labels_key][:, :max_tokens_length]
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
+ tokens_length = _tokens.size(1)
+ _attention_mask[:tokens_length] = False
+
+ assert tokens_length == _labels.size(
+ 1
+ ), f"{tokens_length} != {_labels.size(1)}"
+
+ if tokens_length < max_tokens_length:
+ _tokens = F.pad(
+ _tokens,
+ (0, max_tokens_length - tokens_length),
+ value=self.tokenizer.eos_token_id,
+ )
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
+ _labels = F.pad(
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
+ )
+
+ tokens.append(_tokens)
+ attention_masks.append(_attention_mask)
+ labels.append(_labels)
+
+ tokens = torch.stack(tokens, dim=0)
+ attention_masks = torch.stack(attention_masks, dim=0)
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "inputs": tokens,
+ "attention_masks": attention_masks,
+ "labels": labels,
+ }
+
+
+class InterleaveDataset(IterableDataset):
+ def __init__(
+ self,
+ datasets: list[IterableDataset],
+ probabilities: list[float],
+ seed: int = 42,
+ ):
+ super().__init__()
+
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.seed = seed
+
+ def __iter__(self):
+ rng = np.random.default_rng(self.seed)
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+ while True:
+ # Random choice one
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+ dataset_iterator = dataset_iterators[dataset_idx]
+
+ try:
+ yield next(dataset_iterator)
+ except StopIteration:
+ # Exhausted, create a new iterator
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+ yield next(dataset_iterators[dataset_idx])
+
+
+class SemanticDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ batch_size: int = 32,
+ tokenizer: AutoTokenizer = None,
+ max_length: int = 1024,
+ num_workers: int = 4,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.tokenizer = tokenizer
+ self.max_length = max_length
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ from tqdm import tqdm
+
+ ds = AutoTextSemanticInstructionDataset(
+ ["data/protos"],
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
+ use_speaker=False,
+ interactive_prob=1.0,
+ skip_text_prob=0.5,
+ )
+
+ for i in ds:
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
+ # i["labels"][0][i["labels"][0] == -100] = 0
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
+ break
diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e
--- /dev/null
+++ b/fish_speech/datasets/vqgan.py
@@ -0,0 +1,147 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+ def __init__(
+ self,
+ filelist: str,
+ sample_rate: int = 32000,
+ hop_length: int = 640,
+ slice_frames: Optional[int] = None,
+ ):
+ super().__init__()
+
+ filelist = Path(filelist)
+ root = filelist.parent
+
+ self.files = [
+ root / line.strip()
+ for line in filelist.read_text(encoding="utf-8").splitlines()
+ if line.strip()
+ ]
+ self.sample_rate = sample_rate
+ self.hop_length = hop_length
+ self.slice_frames = slice_frames
+
+ def __len__(self):
+ return len(self.files)
+
+ def get_item(self, idx):
+ file = self.files[idx]
+
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+ # Slice audio and features
+ if (
+ self.slice_frames is not None
+ and audio.shape[0] > self.slice_frames * self.hop_length
+ ):
+ start = np.random.randint(
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
+ )
+ audio = audio[start : start + self.slice_frames * self.hop_length]
+
+ if len(audio) == 0:
+ return None
+
+ max_value = np.abs(audio).max()
+ if max_value > 1.0:
+ audio = audio / max_value
+
+ return {
+ "audio": torch.from_numpy(audio),
+ }
+
+ def __getitem__(self, idx):
+ try:
+ return self.get_item(idx)
+ except Exception as e:
+ import traceback
+
+ traceback.print_exc()
+ logger.error(f"Error loading {self.files[idx]}: {e}")
+ return None
+
+
+@dataclass
+class VQGANCollator:
+ def __call__(self, batch):
+ batch = [x for x in batch if x is not None]
+
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+ audio_maxlen = audio_lengths.max()
+
+ # Rounds up to nearest multiple of 2 (audio_lengths)
+ audios = []
+ for x in batch:
+ audios.append(
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+ )
+
+ return {
+ "audios": torch.stack(audios),
+ "audio_lengths": audio_lengths,
+ }
+
+
+class VQGANDataModule(LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: VQGANDataset,
+ val_dataset: VQGANDataset,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ val_batch_size: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.batch_size = batch_size
+ self.val_batch_size = val_batch_size or batch_size
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ shuffle=True,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.val_batch_size,
+ collate_fn=VQGANCollator(),
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ )
+
+
+if __name__ == "__main__":
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+ dataloader = DataLoader(
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+ )
+
+ for batch in dataloader:
+ print(batch["audios"].shape)
+ print(batch["features"].shape)
+ print(batch["audio_lengths"])
+ print(batch["feature_lengths"])
+ break
diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2
--- /dev/null
+++ b/fish_speech/i18n/README.md
@@ -0,0 +1,27 @@
+## i18n Folder Attribution
+
+The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
+
+### fish_speech/i18n/core.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
+
+**Initial commit:**
+add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
+
+**Initial author:**
+[@L4Ph](https://github.com/L4Ph)
+
+### fish_speech/i18n/scan.py
+
+**Related code from RVC:**
+[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
+
+**Initial commit:**
+File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
+
+**Initial author:**
+[@towzeur](https://github.com/towzeur)
+
+We appreciate the contributions of the RVC project and its authors.
diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246
--- /dev/null
+++ b/fish_speech/i18n/__init__.py
@@ -0,0 +1,3 @@
+from .core import i18n
+
+__all__ = ["i18n"]
diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd
--- /dev/null
+++ b/fish_speech/i18n/core.py
@@ -0,0 +1,40 @@
+import json
+import locale
+from pathlib import Path
+
+I18N_FILE_PATH = Path(__file__).parent / "locale"
+DEFAULT_LANGUAGE = "en_US"
+
+
+def load_language_list(language):
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
+ language_list = json.load(f)
+
+ return language_list
+
+
+class I18nAuto:
+ def __init__(self):
+ i18n_file = Path(".locale")
+
+ if i18n_file.exists():
+ with open(i18n_file, "r", encoding="utf-8") as f:
+ language = f.read().strip()
+ else:
+ # getlocale can't identify the system's language ((None, None))
+ language = locale.getdefaultlocale()[0]
+
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
+ language = DEFAULT_LANGUAGE
+
+ self.language = language
+ self.language_map = load_language_list(language)
+
+ def __call__(self, key):
+ return self.language_map.get(key, key)
+
+ def __repr__(self):
+ return "Use Language: " + self.language
+
+
+i18n = I18nAuto()
diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json
new file mode 100644
index 0000000000000000000000000000000000000000..cf6ad6ca1e5e284abb78c3f5b00418836eae4310
--- /dev/null
+++ b/fish_speech/i18n/locale/en_US.json
@@ -0,0 +1,122 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
+ "Add to Processing Area": "Add to Processing Area",
+ "Added path successfully!": "Added path successfully!",
+ "Advanced Config": "Advanced Config",
+ "Base LLAMA Model": "Base LLAMA Model",
+ "Batch Inference": "Batch Inference",
+ "Batch Size": "Batch Size",
+ "Changing with the Model Path": "Changing with the Model Path",
+ "Chinese": "Chinese",
+ "Compile Model": "Compile Model",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
+ "Copy": "Copy",
+ "Data Preprocessing": "Data Preprocessing",
+ "Data Preprocessing Path": "Data Preprocessing Path",
+ "Data Source": "Data Source",
+ "Decoder Model Config": "Decoder Model Config",
+ "Decoder Model Path": "Decoder Model Path",
+ "Disabled": "Disabled",
+ "Enable Reference Audio": "Enable Reference Audio",
+ "English": "English",
+ "Error Message": "Error Message",
+ "File Preprocessing": "File Preprocessing",
+ "Generate": "Generate",
+ "Generated Audio": "Generated Audio",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
+ "Infer interface is closed": "Infer interface is closed",
+ "Inference Configuration": "Inference Configuration",
+ "Inference Server Configuration": "Inference Server Configuration",
+ "Inference Server Error": "Inference Server Error",
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
+ "Initial Learning Rate": "Initial Learning Rate",
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
+ "Input Text": "Input Text",
+ "Invalid path: {}": "Invalid path: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
+ "Japanese": "Japanese",
+ "LLAMA Configuration": "LLAMA Configuration",
+ "LLAMA Model Config": "LLAMA Model Config",
+ "LLAMA Model Path": "LLAMA Model Path",
+ "Labeling Device": "Labeling Device",
+ "LoRA Model to be merged": "LoRA Model to be merged",
+ "Maximum Audio Duration": "Maximum Audio Duration",
+ "Maximum Length per Sample": "Maximum Length per Sample",
+ "Maximum Training Steps": "Maximum Training Steps",
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
+ "Merge": "Merge",
+ "Merge LoRA": "Merge LoRA",
+ "Merge successfully": "Merge successfully",
+ "Minimum Audio Duration": "Minimum Audio Duration",
+ "Model Output Path": "Model Output Path",
+ "Model Size": "Model Size",
+ "Move": "Move",
+ "Move files successfully": "Move files successfully",
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
+ "No selected options": "No selected options",
+ "Number of Workers": "Number of Workers",
+ "Open Inference Server": "Open Inference Server",
+ "Open Labeler WebUI": "Open Labeler WebUI",
+ "Open Tensorboard": "Open Tensorboard",
+ "Opened labeler in browser": "Opened labeler in browser",
+ "Optional Label Language": "Optional Label Language",
+ "Optional online ver": "Optional online ver",
+ "Output Path": "Output Path",
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
+ "Precision": "Precision",
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
+ "Put your text here.": "Put your text here.",
+ "Reference Audio": "Reference Audio",
+ "Reference Text": "Reference Text",
+ "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.",
+ "Remove Selected Data": "Remove Selected Data",
+ "Removed path successfully!": "Removed path successfully!",
+ "Repetition Penalty": "Repetition Penalty",
+ "Save model every n steps": "Save model every n steps",
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
+ "Select VITS ckpt": "Select VITS ckpt",
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
+ "Select source file processing method": "Select source file processing method",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
+ "Selected: {}": "Selected: {}",
+ "Speaker": "Speaker",
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
+ "Start Training": "Start Training",
+ "Streaming Audio": "Streaming Audio",
+ "Streaming Generate": "Streaming Generate",
+ "Tensorboard Host": "Tensorboard Host",
+ "Tensorboard Log Path": "Tensorboard Log Path",
+ "Tensorboard Port": "Tensorboard Port",
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
+ "Training Configuration": "Training Configuration",
+ "Training Error": "Training Error",
+ "Training stopped": "Training stopped",
+ "Type name of the speaker": "Type name of the speaker",
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
+ "Use LoRA": "Use LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
+ "Use filelist": "Use filelist",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
+ "VITS Configuration": "VITS Configuration",
+ "VQGAN Configuration": "VQGAN Configuration",
+ "Validation Batch Size": "Validation Batch Size",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
+ "WebUI Host": "WebUI Host",
+ "WebUI Port": "WebUI Port",
+ "Whisper Model": "Whisper Model",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
+ "latest": "latest",
+ "new": "new",
+ "Realtime Transform Text": "Realtime Transform Text",
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
+ "Text Normalization": "Text Normalization"
+}
diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json
new file mode 100644
index 0000000000000000000000000000000000000000..1ea59882138f776e3c710aefe70c652906eeb4b4
--- /dev/null
+++ b/fish_speech/i18n/locale/es_ES.json
@@ -0,0 +1,122 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
+ "Advanced Config": "Configuración Avanzada",
+ "Base LLAMA Model": "Modelo Base LLAMA",
+ "Batch Inference": "Inferencia por Lote",
+ "Batch Size": "Tamaño del Lote",
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
+ "Chinese": "Chino",
+ "Compile Model": "Compilar Modelo",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
+ "Copy": "Copiar",
+ "Data Preprocessing": "Preprocesamiento de Datos",
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
+ "Data Source": "Fuente de Datos",
+ "Decoder Model Config": "Configuración del modelo decodificador",
+ "Decoder Model Path": "Ruta del modelo decodificador",
+ "Disabled": "Desactivado",
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
+ "English": "Inglés",
+ "Error Message": "Mensaje de Error",
+ "File Preprocessing": "Preprocesamiento de Archivos",
+ "Generate": "Generar",
+ "Generated Audio": "Audio Generado",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
+ "Inference Configuration": "Configuración de Inferencia",
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
+ "Inference Server Error": "Error del Servidor de Inferencia",
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
+ "Input Text": "Texto de Entrada",
+ "Invalid path: {}": "Ruta inválida: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
+ "Japanese": "Japonés",
+ "LLAMA Configuration": "Configuración de LLAMA",
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
+ "Labeling Device": "Dispositivo de Etiquetado",
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
+ "Maximum Audio Duration": "Duración máxima de audio",
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
+ "Merge": "Fusionar",
+ "Merge LoRA": "Fusionar LoRA",
+ "Merge successfully": "Fusionado exitosamente",
+ "Minimum Audio Duration": "Duración mínima de audio",
+ "Model Output Path": "Ruta de Salida del Modelo",
+ "Model Size": "Tamaño del Modelo",
+ "Move": "Mover",
+ "Move files successfully": "Archivos movidos exitosamente",
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
+ "No selected options": "No hay opciones seleccionadas",
+ "Number of Workers": "Número de Trabajadores",
+ "Open Inference Server": "Abrir Servidor de Inferencia",
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
+ "Open Tensorboard": "Abrir Tensorboard",
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
+ "Optional online ver": "Ver en línea opcional",
+ "Output Path": "Ruta de Salida",
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
+ "Precision": "Precisión",
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
+ "Put your text here.": "Ponga su texto aquí.",
+ "Reference Audio": "Audio de Referencia",
+ "Reference Text": "Texto de Referencia",
+ "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado se publica bajo la Licencia BSD-3-Clause, y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
+ "Repetition Penalty": "Penalización por Repetición",
+ "Save model every n steps": "Guardar modelo cada n pasos",
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
+ "Selected: {}": "Seleccionado: {}",
+ "Speaker": "Hablante",
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
+ "Start Training": "Iniciar Entrenamiento",
+ "Streaming Audio": "transmisión de audio",
+ "Streaming Generate": "síntesis en flujo",
+ "Tensorboard Host": "Host de Tensorboard",
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
+ "Tensorboard Port": "Puerto de Tensorboard",
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
+ "Training Configuration": "Configuración de Entrenamiento",
+ "Training Error": "Error de Entrenamiento",
+ "Training stopped": "Entrenamiento detenido",
+ "Type name of the speaker": "Escriba el nombre del hablante",
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
+ "Use LoRA": "Usar LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
+ "Use filelist": "Usar lista de archivos",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
+ "VITS Configuration": "Configuración de VITS",
+ "VQGAN Configuration": "Configuración de VQGAN",
+ "Validation Batch Size": "Tamaño del Lote de Validación",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
+ "WebUI Host": "Host de WebUI",
+ "WebUI Port": "Puerto de WebUI",
+ "Whisper Model": "Modelo Whisper",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
+ "latest": "más reciente",
+ "new": "nuevo",
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
+ "Text Normalization": "Normalización de Texto"
+}
diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json
new file mode 100644
index 0000000000000000000000000000000000000000..e7817eb0c559676e3d67ffa3488b177c25119e11
--- /dev/null
+++ b/fish_speech/i18n/locale/ja_JP.json
@@ -0,0 +1,123 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
+ "Accumulate Gradient Batches": "勾配バッチの累積",
+ "Add to Processing Area": "処理エリアに追加",
+ "Added path successfully!": "パスの追加に成功しました!",
+ "Advanced Config": "詳細設定",
+ "Base LLAMA Model": "基本LLAMAモデル",
+ "Batch Inference": "バッチ推論",
+ "Batch Size": "バッチサイズ",
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
+ "Chinese": "中国語",
+ "Compile Model": "モデルのコンパイル",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
+ "Copy": "コピー",
+ "Data Preprocessing": "データ前処理",
+ "Data Preprocessing Path": "データ前処理パス",
+ "Data Source": "データソース",
+ "Decoder Model Config": "デコーダーモデルの構成",
+ "Decoder Model Path": "デコーダーモデルのパス",
+ "Disabled": "無効",
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
+ "English": "英語",
+ "Error Message": "エラーメッセージ",
+ "File Preprocessing": "文書前处理",
+ "Generate": "生成",
+ "Generated Audio": "生成されたオーディオ",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
+ "Infer interface is closed": "推論インターフェースが閉じられています",
+ "Inference Configuration": "推論設定",
+ "Inference Server Configuration": "推論サーバー設定",
+ "Inference Server Error": "推論サーバーエラー",
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
+ "Initial Learning Rate": "初期学習率",
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
+ "Input Text": "入力テキスト",
+ "Invalid path: {}": "無効なパス: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
+ "Japanese": "日本語",
+ "LLAMA Configuration": "LLAMA設定",
+ "LLAMA Model Config": "LLAMAモデル設定",
+ "LLAMA Model Path": "LLAMAモデルパス",
+ "Labeling Device": "ラベリングデバイス",
+ "LoRA Model to be merged": "マージするLoRAモデル",
+ "Maximum Audio Duration": "最大オーディオの長さ",
+ "Maximum Length per Sample": "サンプルあたりの最大長",
+ "Maximum Training Steps": "最大トレーニングステップ数",
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
+ "Merge": "マージ",
+ "Merge LoRA": "LoRAのマージ",
+ "Merge successfully": "マージに成功しました",
+ "Minimum Audio Duration": "最小オーディオの長さ",
+ "Model Output Path": "モデル出力パス",
+ "Model Size": "モデルサイズ",
+ "Move": "移動",
+ "Move files successfully": "ファイルの移動に成功しました",
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
+ "No selected options": "選択されたオプションはありません",
+ "Number of Workers": "ワーカー数",
+ "Open Inference Server": "推論サーバーを開く",
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
+ "Open Tensorboard": "Tensorboardを開く",
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
+ "Optional Label Language": "オプションのラベル言語",
+ "Optional online ver": "オプションのオンラインバージョン",
+ "Output Path": "出力パス",
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
+ "Put your text here.": "ここにテキストを入力してください。",
+ "Reference Audio": "リファレンスオーディオ",
+ "Reference Text": "リファレンステキスト",
+ "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "関連コードはBSD-3-Clauseライセンスの下でリリースされ、重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
+ "Remove Selected Data": "選択したデータを削除",
+ "Removed path successfully!": "パスの削除に成功しました!",
+ "Repetition Penalty": "反復ペナルティ",
+ "Save model every n steps": "nステップごとにモデルを保存",
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
+ "Select VITS ckpt": "VITS チェックポイントを選択",
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
+ "Select source file processing method": "ソースファイルの処理方法を選択",
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
+ "Selected: {}": "選択済み: {}",
+ "Speaker": "話者",
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
+ "Start Training": "トレーニング開始",
+ "Streaming Audio": "ストリーミングオーディオ",
+ "Streaming Generate": "ストリーミング合成",
+ "Tensorboard Host": "Tensorboardホスト",
+ "Tensorboard Log Path": "Tensorboardログパス",
+ "Tensorboard Port": "Tensorboardポート",
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
+ "Training Configuration": "トレーニング設定",
+ "Training Error": "トレーニングエラー",
+ "Training stopped": "トレーニングが停止しました",
+ "Type name of the speaker": "話者の名前を入力",
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
+ "Use LoRA": "LoRAを使用",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
+ "Use filelist": "ファイルリストを使用",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
+ "VITS Configuration": "VITS の構成",
+ "VQGAN Configuration": "VQGAN の構成",
+ "Validation Batch Size": "検証バッチサイズ",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
+ "WebUI Host": "WebUIホスト",
+ "WebUI Port": "WebUIポート",
+ "Whisper Model": "Whisperモデル",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
+ "latest": "最新",
+ "new": "新規",
+ "Realtime Transform Text": "リアルタイム変換テキスト",
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
+ "Text Normalization": "テキスト正規化"
+
+}
diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json
new file mode 100644
index 0000000000000000000000000000000000000000..da81eef1cf154e96a92e34b44a3e3fa89f68386a
--- /dev/null
+++ b/fish_speech/i18n/locale/zh_CN.json
@@ -0,0 +1,122 @@
+{
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
+ "Accumulate Gradient Batches": "梯度累积批次",
+ "Add to Processing Area": "加入处理区",
+ "Added path successfully!": "添加路径成功!",
+ "Advanced Config": "高级参数",
+ "Base LLAMA Model": "基础 LLAMA 模型",
+ "Batch Inference": "批量推理",
+ "Batch Size": "批次大小",
+ "Changing with the Model Path": "随模型路径变化",
+ "Chinese": "中文",
+ "Compile Model": "编译模型",
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
+ "Copy": "复制",
+ "Data Preprocessing": "数据预处理",
+ "Data Preprocessing Path": "数据预处理路径",
+ "Data Source": "数据源",
+ "Decoder Model Config": "解码器模型配置",
+ "Decoder Model Path": "解码器模型路径",
+ "Disabled": "禁用",
+ "Enable Reference Audio": "启用参考音频",
+ "English": "英文",
+ "Error Message": "错误信息",
+ "File Preprocessing": "文件预处理",
+ "Generate": "生成",
+ "Generated Audio": "音频",
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
+ "Infer interface is closed": "推理界面已关闭",
+ "Inference Configuration": "推理配置",
+ "Inference Server Configuration": "推理服务器配置",
+ "Inference Server Error": "推理服务器错误",
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
+ "Initial Learning Rate": "初始学习率",
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
+ "Input Text": "输入文本",
+ "Invalid path: {}": "无效路径: {}",
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
+ "Japanese": "日文",
+ "LLAMA Configuration": "LLAMA 配置",
+ "LLAMA Model Config": "LLAMA 模型配置",
+ "LLAMA Model Path": "LLAMA 模型路径",
+ "Labeling Device": "标注加速设备",
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
+ "Maximum Audio Duration": "最大音频时长",
+ "Maximum Length per Sample": "每个样本的最大长度",
+ "Maximum Training Steps": "最大训练步数",
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
+ "Merge": "合并",
+ "Merge LoRA": "合并 LoRA",
+ "Merge successfully": "合并成功",
+ "Minimum Audio Duration": "最小音频时长",
+ "Model Output Path": "模型输出路径",
+ "Model Size": "模型规模",
+ "Move": "移动",
+ "Move files successfully": "移动文件成功",
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
+ "No selected options": "没有选择的选项",
+ "Number of Workers": "数据加载进程数",
+ "Open Inference Server": "打开推理服务器",
+ "Open Labeler WebUI": "打开标注工具",
+ "Open Tensorboard": "打开 Tensorboard",
+ "Opened labeler in browser": "在浏览器中打开标注工具",
+ "Optional Label Language": "[可选] 标注语言",
+ "Optional online ver": "[可选] 使用在线版",
+ "Output Path": "输出路径",
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
+ "Precision": "精度",
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
+ "Put your text here.": "在此处输入文本.",
+ "Reference Audio": "参考音频",
+ "Reference Text": "参考文本",
+ "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.",
+ "Remove Selected Data": "移除选中数据",
+ "Removed path successfully!": "移除路径成功!",
+ "Repetition Penalty": "重复惩罚",
+ "Save model every n steps": "每 n 步保存模型",
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
+ "Select VITS ckpt": "选择 VITS 检查点",
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
+ "Select source file processing method": "选择源文件处理方法",
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
+ "Selected: {}": "已选择: {}",
+ "Speaker": "说话人",
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
+ "Start Training": "开始训练",
+ "Streaming Audio": "流式音频",
+ "Streaming Generate": "流式合成",
+ "Tensorboard Host": "Tensorboard 监听地址",
+ "Tensorboard Log Path": "Tensorboard 日志路径",
+ "Tensorboard Port": "Tensorboard 端口",
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
+ "Training Configuration": "训练配置",
+ "Training Error": "训练错误",
+ "Training stopped": "训练已停止",
+ "Type name of the speaker": "输入说话人的名称",
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
+ "Use LoRA": "使用 LoRA",
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
+ "Use filelist": "使用文件列表",
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
+ "VITS Configuration": "VITS 配置",
+ "VQGAN Configuration": "VQGAN 配置",
+ "Validation Batch Size": "验证批次大小",
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
+ "WebUI Host": "WebUI 监听地址",
+ "WebUI Port": "WebUI 端口",
+ "Whisper Model": "Whisper 模型",
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
+ "latest": "最近的检查点",
+ "new": "创建新的检查点",
+ "Realtime Transform Text": "实时规范化文本",
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
+ "Text Normalization": "文本规范化"
+}
diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1
--- /dev/null
+++ b/fish_speech/i18n/scan.py
@@ -0,0 +1,122 @@
+import ast
+import glob
+import json
+from collections import OrderedDict
+from pathlib import Path
+
+from loguru import logger
+
+from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
+
+
+def extract_i18n_strings(node):
+ i18n_strings = []
+
+ if (
+ isinstance(node, ast.Call)
+ and isinstance(node.func, ast.Name)
+ and node.func.id == "i18n"
+ ):
+ for arg in node.args:
+ if isinstance(arg, ast.Str):
+ i18n_strings.append(arg.s)
+
+ for child_node in ast.iter_child_nodes(node):
+ i18n_strings.extend(extract_i18n_strings(child_node))
+
+ return i18n_strings
+
+
+# scan the directory for all .py files (recursively)
+# for each file, parse the code into an AST
+# for each AST, extract the i18n strings
+
+strings = []
+folders = ["fish_speech", "tools"]
+# for filename in glob.iglob("**/*.py", recursive=True):
+for folder in folders:
+ for f in Path(folder).rglob("*.py"):
+ code = f.read_text(encoding="utf-8")
+ if "i18n(" in code:
+ tree = ast.parse(code)
+ i18n_strings = extract_i18n_strings(tree)
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
+ strings.extend(i18n_strings)
+
+code_keys = set(strings)
+logger.info(f"Total unique: {len(code_keys)}")
+
+
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+standard_keys = set(standard_data.keys())
+
+# Define the standard file name
+unused_keys = standard_keys - code_keys
+logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
+for unused_key in unused_keys:
+ logger.info(f"\t{unused_key}")
+
+missing_keys = code_keys - standard_keys
+logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
+for missing_key in missing_keys:
+ logger.info(f"\t{missing_key}")
+
+code_keys_dict = OrderedDict()
+for s in strings:
+ code_keys_dict[s] = s
+
+# write back
+with open(standard_file, "w", encoding="utf-8") as f:
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+logger.info(f"Updated {standard_file}")
+
+
+# Define the standard file name
+standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
+
+# Find all JSON files in the directory
+dir_path = I18N_FILE_PATH
+languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
+
+# Load the standard file
+with open(standard_file, "r", encoding="utf-8") as f:
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
+
+# Loop through each language file
+for lang_file in languages:
+ # Load the language file
+ with open(lang_file, "r", encoding="utf-8") as f:
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
+
+ # Find the difference between the language file and the standard file
+ diff = set(standard_data.keys()) - set(lang_data.keys())
+
+ miss = set(lang_data.keys()) - set(standard_data.keys())
+
+ # Add any missing keys to the language file
+ for key in diff:
+ lang_data[key] = "#!" + key
+ logger.info(f"Added missing key: {key} to {lang_file}")
+
+ # Del any extra keys to the language file
+ for key in miss:
+ del lang_data[key]
+ logger.info(f"Del extra key: {key} from {lang_file}")
+
+ # Sort the keys of the language file to match the order of the standard file
+ lang_data = OrderedDict(
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
+ )
+
+ # Save the updated language file
+ with open(lang_file, "w", encoding="utf-8") as f:
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
+ f.write("\n")
+
+ logger.info(f"Updated {lang_file}")
+
+logger.info("Done")
diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0
--- /dev/null
+++ b/fish_speech/models/text2semantic/lit_module.py
@@ -0,0 +1,202 @@
+from typing import Any, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+
+import fish_speech.utils as utils
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.models.text2semantic.llama import NaiveTransformer
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+class TextToSemantic(L.LightningModule):
+ def __init__(
+ self,
+ model: NaiveTransformer,
+ optimizer: Any,
+ lr_scheduler: Any,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ def forward(self, x):
+ return self.model(x)
+
+ def on_save_checkpoint(self, checkpoint):
+ # Save only LoRA parameters
+ state_dict = checkpoint["state_dict"]
+ use_lora = any("lora" in name for name in state_dict.keys())
+ if not use_lora:
+ return
+
+ for name in list(state_dict.keys()):
+ if "lora" not in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ # Get weight decay parameters
+ weight_decay_parameters, other_parameters = [], []
+ for name, param in self.named_parameters():
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
+ other_parameters.append(param)
+ else:
+ weight_decay_parameters.append(param)
+
+ optimizer = self.optimizer_builder(
+ [
+ {"params": weight_decay_parameters},
+ {"params": other_parameters, "weight_decay": 0.0},
+ ]
+ )
+
+ # Print the parameters and their weight decay
+ for i in optimizer.param_groups:
+ log.info(
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
+ )
+
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ },
+ }
+
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
+ def get_batch_logps(
+ self,
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ assert logits.shape[:-1] == labels.shape
+
+ labels = labels.clone()
+ loss_mask = labels != -100
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == -100] = 0
+
+ per_token_logps = torch.gather(
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def _step(self, batch, batch_idx, stage: str):
+ is_train = stage == "train"
+
+ if is_train:
+ # Key part to make lora work
+ # Otherwise the parameters are merged, which lead to incorrect gradients
+ self.model.train()
+
+ # Do positive and negative samples in the same batch to speed up training
+ labels = batch["labels"]
+ outputs = self.model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.view(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ )
+
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.view(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ )
+
+ loss = base_loss + semantic_loss
+
+ self.log(
+ f"{stage}/loss",
+ loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/base_loss",
+ base_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ self.log(
+ f"{stage}/semantic_loss",
+ semantic_loss,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=False,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ # Top-5 accuracy
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
+ self.log(
+ f"{stage}/top_5_accuracy",
+ accuracy,
+ on_step=is_train,
+ on_epoch=not is_train,
+ prog_bar=True,
+ logger=True,
+ sync_dist=not is_train,
+ )
+
+ return loss
+
+ def get_accuracy(self, logits, labels):
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
+ if mask.sum() == 0:
+ return torch.tensor(0.0, device=logits.device)
+
+ _, indices = logits.topk(5, dim=-1)
+ correct = indices.eq(labels.unsqueeze(-1))
+ correct[~mask] = 0
+ correct = correct.sum()
+ accuracy = correct / mask.sum()
+
+ return accuracy
+
+ def training_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "train")
+
+ def validation_step(self, batch, batch_idx):
+ return self._step(batch, batch_idx, "val")
diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b5cd276c0c382a3334c45ca9bf74ea1c8a142d5
--- /dev/null
+++ b/fish_speech/models/text2semantic/llama.py
@@ -0,0 +1,752 @@
+import json
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from loguru import logger
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.attention import SDPBackend, sdpa_kernel
+from torch.utils.checkpoint import checkpoint
+from transformers import AutoTokenizer
+
+from fish_speech.conversation import SEMANTIC_TOKEN
+from fish_speech.utils import RankedLogger
+
+from .lora import LoraConfig, setup_lora
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class BaseModelArgs:
+ model_type: str = "base"
+
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 2048
+ dropout: float = 0.0
+ tie_word_embeddings: bool = True
+ attention_qkv_bias: bool = False
+
+ # Codebook configs
+ codebook_size: int = 160
+ num_codebooks: int = 4
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = True
+
+ # Initialize the model
+ initializer_range: float = 0.02
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+ @staticmethod
+ def from_pretrained(path: str):
+ path = Path(path)
+
+ if path.is_dir():
+ path = path / "config.json"
+
+ with open(path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ match data["model_type"]:
+ case "naive":
+ cls = NaiveModelArgs
+ case "dual_ar":
+ cls = DualARModelArgs
+ case _:
+ raise ValueError(f"Unknown model type: {data['model_type']}")
+
+ return cls(**data)
+
+ def save(self, path: str):
+ with open(path, "w") as f:
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ model_type: str = "naive"
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+ model_type: str = "dual_ar"
+ n_fast_layer: int = 4
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ codebook_logits: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(
+ self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.tokenizer = tokenizer
+
+ self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size,
+ config.dim,
+ )
+ self.codebook_embeddings = nn.Embedding(
+ config.codebook_size * config.num_codebooks,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ if self.config.tie_word_embeddings is False:
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.dim // config.n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ if init_weights:
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def embed(self, x: Tensor) -> Tensor:
+ vocab_embeds = [self.embeddings(x[:, 0])]
+ for i in range(self.config.num_codebooks):
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
+ emb[x[:, 0] != self.semantic_token_id] = 0
+ vocab_embeds.append(emb)
+
+ x = torch.stack(vocab_embeds, dim=3)
+ x = x.sum(dim=3)
+
+ return x
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> BaseTransformerForwardResult:
+ seq_len = inp.size(2)
+
+ # Here we want to merge the embeddings of the codebooks
+ x = self.embed(inp)
+
+ freqs_cis = self.freqs_cis[:seq_len]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ mask = None
+ if key_padding_mask is not None:
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ return_all: bool = False,
+ ) -> BaseTransformerForwardResult:
+ # This is used for generation, optimized for torch compile
+ assert (
+ self.max_seq_len != -1 and self.max_batch_size != -1
+ ), "Please call setup_caches before forward_generate"
+
+ x = self.embed(x)
+
+ mask = self.causal_mask[
+ None, None, input_pos, : self.max_seq_len
+ ] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+ # If prefill, we only calculate the logits of last token
+ if x.size(1) > 1 and not return_all:
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @staticmethod
+ def from_pretrained(
+ path: str,
+ load_weights: bool = False,
+ max_length: int | None = None,
+ lora_config: LoraConfig | None = None,
+ rope_base: int | None = None,
+ ) -> "BaseTransformer":
+ config = BaseModelArgs.from_pretrained(str(path))
+ if max_length is not None:
+ config.max_seq_len = max_length
+ log.info(f"Override max_seq_len to {max_length}")
+
+ if rope_base is not None:
+ config.rope_base = rope_base
+ log.info(f"Override rope_base to {rope_base}")
+
+ match config.model_type:
+ case "naive":
+ model_cls = NaiveTransformer
+ case "dual_ar":
+ model_cls = DualARTransformer
+ case _:
+ raise ValueError(f"Unknown model type: {config.model_type}")
+
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
+ log.info(f"Loading model from {path}, config: {config}")
+ model = model_cls(config, tokenizer=tokenizer)
+
+ if lora_config is not None:
+ setup_lora(model, lora_config)
+ log.info(f"LoRA setup: {lora_config}")
+
+ if load_weights is False:
+ log.info("Randomly initialized model")
+ else:
+
+ if "int8" in str(Path(path)):
+ logger.info("Using int8 weight-only quantization!")
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
+
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
+ model = simple_quantizer.convert_for_runtime()
+
+ if "int4" in str(Path(path)):
+ logger.info("Using int4 quantization!")
+ path_comps = path.name.split("-")
+ assert path_comps[-2].startswith("g")
+ groupsize = int(path_comps[-2][1:])
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
+
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+ model = simple_quantizer.convert_for_runtime()
+
+ weights = torch.load(
+ Path(path) / "model.pth", map_location="cpu", mmap=True
+ )
+ err = model.load_state_dict(weights, strict=False, assign=True)
+ log.info(f"Loaded weights with error: {err}")
+
+ return model
+
+ def save_pretrained(self, path: str, drop_lora: bool = False):
+ path = Path(path)
+ path.mkdir(parents=True, exist_ok=True)
+
+ self.config.save(path / "config.json")
+ state_dict = self.state_dict()
+
+ if drop_lora:
+ for key in list(state_dict.keys()):
+ if "lora" not in key:
+ continue
+
+ state_dict.pop(key)
+ log.info(f"Drop LoRA parameter: {key}")
+
+ torch.save(state_dict, path / "model.pth")
+ self.tokenizer.save_pretrained(path)
+
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.codebook_output = nn.Linear(
+ config.dim,
+ config.codebook_size * config.num_codebooks,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+ token_logits = result.logits
+ x = result.hidden_states
+
+ # Codebook
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
+ codebook_logits = rearrange(
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ result = super().forward(
+ inp=inp,
+ key_padding_mask=key_padding_mask,
+ )
+ return self.decode(result)
+
+ def forward_generate(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> TransformerForwardResult:
+ result = super().forward_generate(x, input_pos)
+ return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
+
+ # Fast transformer
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
+
+ # The equivalent bs is so large that sdpa doesn't work
+ self.fast_layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+ )
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.fast_output = nn.Linear(
+ config.dim,
+ config.codebook_size,
+ bias=False,
+ )
+
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+ ):
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+ head_dim = self.config.dim // self.config.n_head
+
+ # Fast transformer
+ # The max seq len here is the number of codebooks
+ for b in self.fast_layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ self.config.num_codebooks,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ )
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(inp, key_padding_mask)
+ token_logits = parent_result.logits
+ x = parent_result.hidden_states
+
+ # Fast transformer
+ fast_seq_len = self.config.num_codebooks
+ fast_mask = self.causal_mask[
+ None, None, :fast_seq_len, :fast_seq_len
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
+
+ # Drop the last token and rotate left
+ codebooks = inp[:, 1:-1, 1:]
+ codebooks = F.pad(codebooks, (0, 1), value=0)
+ codebook_embeddings = self.fast_embeddings(codebooks)
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
+ b, s = x.size(0), x.size(2)
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
+
+ # Remove padded part
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
+ codebook_mask = (codebooks == 0).all(dim=-1)
+
+ if torch.all(codebook_mask):
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
+ codebook_mask[:8] = False
+
+ x_bs, x_len = x.size(0), x.size(1)
+ x = x[~codebook_mask]
+
+ for layer in self.fast_layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
+ else:
+ x = layer(x, fast_freqs_cis, fast_mask)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x)
+ codebook_logits = self.fast_output(fast_out)
+
+ # Re-pad the codebook_logits
+ buffer = torch.zeros(
+ x_bs,
+ x_len,
+ codebook_logits.size(-1),
+ device=codebook_logits.device,
+ dtype=codebook_logits.dtype,
+ )
+ buffer[~codebook_mask] = codebook_logits
+ codebook_logits = buffer
+
+ assert codebook_logits.shape[1] == self.config.num_codebooks
+ codebook_logits = rearrange(
+ codebook_logits,
+ "(b s) n d -> b s n d",
+ b=b,
+ s=s,
+ n=self.config.num_codebooks,
+ )
+
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ codebook_logits=codebook_logits,
+ )
+
+ def forward_generate_fast(
+ self, x: Tensor, input_pos: Optional[Tensor] = None
+ ) -> Tensor:
+ # Fast transformer
+ x = x.view(1, 1, -1)
+
+ fast_mask = self.causal_mask[
+ None, None, input_pos, : self.config.num_codebooks
+ ] # (B, N, Q, K)
+ fast_freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.fast_layers:
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
+
+ # unflatten the batch and num_codebooks
+ fast_out = self.fast_norm(x) # only take the last token
+ codebook_logits = self.fast_output(fast_out)
+
+ return codebook_logits
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
+ )
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self._register_load_state_dict_pre_hook(self.load_hook)
+
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ if mask is None:
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=True,
+ # No third party attn_mask here to use flash_attention
+ )
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938
--- /dev/null
+++ b/fish_speech/models/text2semantic/lora.py
@@ -0,0 +1,92 @@
+from dataclasses import dataclass
+
+import loralib as lora
+
+
+@dataclass
+class LoraConfig:
+ r: int
+ lora_alpha: float
+ lora_dropout: float = 0.0
+
+
+def setup_lora(model, lora_config):
+ # Replace the embedding layer with a LoRA layer
+ model.embeddings = lora.Embedding(
+ num_embeddings=model.embeddings.num_embeddings,
+ embedding_dim=model.embeddings.embedding_dim,
+ padding_idx=model.embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ model.codebook_embeddings = lora.Embedding(
+ num_embeddings=model.codebook_embeddings.num_embeddings,
+ embedding_dim=model.codebook_embeddings.embedding_dim,
+ padding_idx=model.codebook_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Replace output layer with a LoRA layer
+ linears = [(model, "output")]
+
+ # Replace all linear layers with LoRA layers
+ for layer in model.layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ if hasattr(model, "fast_layers"):
+ model.fast_embeddings = lora.Embedding(
+ num_embeddings=model.fast_embeddings.num_embeddings,
+ embedding_dim=model.fast_embeddings.embedding_dim,
+ padding_idx=model.fast_embeddings.padding_idx,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ )
+
+ # Dual-AR model
+ linears.append((model, "fast_output"))
+
+ for layer in model.fast_layers:
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+ linears.extend(
+ [
+ (layer.feed_forward, "w1"),
+ (layer.feed_forward, "w2"),
+ (layer.feed_forward, "w3"),
+ ]
+ )
+
+ for module, layer in linears:
+ updated_linear = lora.Linear(
+ in_features=getattr(module, layer).in_features,
+ out_features=getattr(module, layer).out_features,
+ bias=getattr(module, layer).bias,
+ r=lora_config.r,
+ lora_alpha=lora_config.lora_alpha,
+ lora_dropout=lora_config.lora_dropout,
+ )
+ setattr(module, layer, updated_linear)
+
+ # Mark only the LoRA layers as trainable
+ lora.mark_only_lora_as_trainable(model, bias="none")
+
+
+def get_merged_state_dict(model):
+ # This line will merge the state dict of the model and the LoRA parameters
+ model.eval()
+
+ # Then we need to remove the LoRA parameters from the state dict
+ state_dict = model.state_dict()
+ for name in list(state_dict.keys()):
+ if "lora" in name:
+ state_dict.pop(name)
+
+ return state_dict
diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..401c6df468c7aa51be1ecaa71ac71513958ae055
--- /dev/null
+++ b/fish_speech/models/vqgan/__init__.py
@@ -0,0 +1,3 @@
+from .lit_module import VQGAN
+
+__all__ = ["VQGAN"]
diff --git a/fish_speech/models/vqgan/lit_module.py b/fish_speech/models/vqgan/lit_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd0733ba748ab69bb539eb6b596b36a365ac460f
--- /dev/null
+++ b/fish_speech/models/vqgan/lit_module.py
@@ -0,0 +1,442 @@
+import itertools
+import math
+from typing import Any, Callable
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+import wandb
+from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
+from matplotlib import pyplot as plt
+from torch import nn
+
+from fish_speech.models.vqgan.modules.discriminator import Discriminator
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
+from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
+
+
+class VQGAN(L.LightningModule):
+ def __init__(
+ self,
+ optimizer: Callable,
+ lr_scheduler: Callable,
+ encoder: WaveNet,
+ quantizer: nn.Module,
+ decoder: WaveNet,
+ discriminator: Discriminator,
+ vocoder: nn.Module,
+ encode_mel_transform: nn.Module,
+ gt_mel_transform: nn.Module,
+ weight_adv: float = 1.0,
+ weight_vq: float = 1.0,
+ weight_mel: float = 1.0,
+ sampling_rate: int = 44100,
+ freeze_encoder: bool = False,
+ ):
+ super().__init__()
+
+ # Model parameters
+ self.optimizer_builder = optimizer
+ self.lr_scheduler_builder = lr_scheduler
+
+ # Modules
+ self.encoder = encoder
+ self.quantizer = quantizer
+ self.decoder = decoder
+ self.vocoder = vocoder
+ self.discriminator = discriminator
+ self.encode_mel_transform = encode_mel_transform
+ self.gt_mel_transform = gt_mel_transform
+
+ # A simple linear layer to project quality to condition channels
+ self.quality_projection = nn.Linear(1, 768)
+
+ # Freeze vocoder
+ for param in self.vocoder.parameters():
+ param.requires_grad = False
+
+ # Loss weights
+ self.weight_adv = weight_adv
+ self.weight_vq = weight_vq
+ self.weight_mel = weight_mel
+
+ # Other parameters
+ self.sampling_rate = sampling_rate
+
+ # Disable strict loading
+ self.strict_loading = False
+
+ # If encoder is frozen
+ if freeze_encoder:
+ for param in self.encoder.parameters():
+ param.requires_grad = False
+
+ for param in self.quantizer.parameters():
+ param.requires_grad = False
+
+ self.automatic_optimization = False
+
+ def on_save_checkpoint(self, checkpoint):
+ # Do not save vocoder
+ state_dict = checkpoint["state_dict"]
+ for name in list(state_dict.keys()):
+ if "vocoder" in name:
+ state_dict.pop(name)
+
+ def configure_optimizers(self):
+ optimizer_generator = self.optimizer_builder(
+ itertools.chain(
+ self.encoder.parameters(),
+ self.quantizer.parameters(),
+ self.decoder.parameters(),
+ self.quality_projection.parameters(),
+ )
+ )
+ optimizer_discriminator = self.optimizer_builder(
+ self.discriminator.parameters()
+ )
+
+ lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
+ lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
+
+ return (
+ {
+ "optimizer": optimizer_generator,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler_generator,
+ "interval": "step",
+ "name": "optimizer/generator",
+ },
+ },
+ {
+ "optimizer": optimizer_discriminator,
+ "lr_scheduler": {
+ "scheduler": lr_scheduler_discriminator,
+ "interval": "step",
+ "name": "optimizer/discriminator",
+ },
+ },
+ )
+
+ def training_step(self, batch, batch_idx):
+ optim_g, optim_d = self.optimizers()
+
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+ audios = audios.float()
+ audios = audios[:, None, :]
+
+ with torch.no_grad():
+ encoded_mels = self.encode_mel_transform(audios)
+ gt_mels = self.gt_mel_transform(audios)
+ quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
+ quality = quality.unsqueeze(-1)
+
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ gt_mels = gt_mels * mel_masks_float_conv
+ encoded_mels = encoded_mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
+
+ # Quantize
+ vq_result = self.quantizer(encoded_features)
+ loss_vq = getattr("vq_result", "loss", 0.0)
+ vq_recon_features = vq_result.z * mel_masks_float_conv
+ vq_recon_features = (
+ vq_recon_features + self.quality_projection(quality)[:, :, None]
+ )
+
+ # VQ Decode
+ gen_mel = (
+ self.decoder(
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+ condition=vq_recon_features,
+ )
+ * mel_masks_float_conv
+ )
+
+ # Discriminator
+ real_logits = self.discriminator(gt_mels)
+ fake_logits = self.discriminator(gen_mel.detach())
+ d_mask = F.interpolate(
+ mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
+ )
+
+ loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
+ loss_fake = avg_with_mask(fake_logits**2, d_mask)
+
+ loss_d = loss_real + loss_fake
+
+ self.log(
+ "train/discriminator/loss",
+ loss_d,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=True,
+ logger=True,
+ )
+
+ # Discriminator backward
+ optim_d.zero_grad()
+ self.manual_backward(loss_d)
+ self.clip_gradients(
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+ )
+ optim_d.step()
+
+ # Mel Loss, applying l1, using a weighted sum
+ mel_distance = (
+ gen_mel - gt_mels
+ ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
+ loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
+ loss_mel_mid_freq = avg_with_mask(
+ mel_distance[:, 40:70, :], mel_masks_float_conv
+ )
+ loss_mel_high_freq = avg_with_mask(
+ mel_distance[:, 70:, :], mel_masks_float_conv
+ )
+ loss_mel = (
+ loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
+ )
+
+ # Adversarial Loss
+ fake_logits = self.discriminator(gen_mel)
+ loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
+
+ # Total loss
+ loss = (
+ self.weight_vq * loss_vq
+ + self.weight_mel * loss_mel
+ + self.weight_adv * loss_adv
+ )
+
+ # Log losses
+ self.log(
+ "train/generator/loss",
+ loss,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=True,
+ logger=True,
+ )
+ self.log(
+ "train/generator/loss_vq",
+ loss_vq,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=False,
+ logger=True,
+ )
+ self.log(
+ "train/generator/loss_mel",
+ loss_mel,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=False,
+ logger=True,
+ )
+ self.log(
+ "train/generator/loss_adv",
+ loss_adv,
+ on_step=True,
+ on_epoch=False,
+ prog_bar=False,
+ logger=True,
+ )
+
+ # Generator backward
+ optim_g.zero_grad()
+ self.manual_backward(loss)
+ self.clip_gradients(
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+ )
+ optim_g.step()
+
+ scheduler_g, scheduler_d = self.lr_schedulers()
+ scheduler_g.step()
+ scheduler_d.step()
+
+ def validation_step(self, batch: Any, batch_idx: int):
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+ audios = audios.float()
+ audios = audios[:, None, :]
+
+ encoded_mels = self.encode_mel_transform(audios)
+ gt_mels = self.gt_mel_transform(audios)
+
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ gt_mels = gt_mels * mel_masks_float_conv
+ encoded_mels = encoded_mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
+
+ # Quantize
+ vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
+ vq_recon_features = (
+ vq_recon_features
+ + self.quality_projection(
+ torch.ones(
+ vq_recon_features.shape[0], 1, device=vq_recon_features.device
+ )
+ * 2
+ )[:, :, None]
+ )
+
+ # VQ Decode
+ gen_aux_mels = (
+ self.decoder(
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+ condition=vq_recon_features,
+ )
+ * mel_masks_float_conv
+ )
+ loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
+
+ self.log(
+ "val/loss_mel",
+ loss_mel,
+ on_step=False,
+ on_epoch=True,
+ prog_bar=False,
+ logger=True,
+ sync_dist=True,
+ )
+
+ recon_audios = self.vocoder(gt_mels)
+ gen_aux_audios = self.vocoder(gen_aux_mels)
+
+ # only log the first batch
+ if batch_idx != 0:
+ return
+
+ for idx, (
+ gt_mel,
+ gen_aux_mel,
+ audio,
+ gen_aux_audio,
+ recon_audio,
+ audio_len,
+ ) in enumerate(
+ zip(
+ gt_mels,
+ gen_aux_mels,
+ audios.cpu().float(),
+ gen_aux_audios.cpu().float(),
+ recon_audios.cpu().float(),
+ audio_lengths,
+ )
+ ):
+ if idx > 4:
+ break
+
+ mel_len = audio_len // self.gt_mel_transform.hop_length
+
+ image_mels = plot_mel(
+ [
+ gt_mel[:, :mel_len],
+ gen_aux_mel[:, :mel_len],
+ ],
+ [
+ "Ground-Truth",
+ "Auxiliary",
+ ],
+ )
+
+ if isinstance(self.logger, WandbLogger):
+ self.logger.experiment.log(
+ {
+ "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
+ "wavs": [
+ wandb.Audio(
+ audio[0, :audio_len],
+ sample_rate=self.sampling_rate,
+ caption="gt",
+ ),
+ wandb.Audio(
+ gen_aux_audio[0, :audio_len],
+ sample_rate=self.sampling_rate,
+ caption="aux",
+ ),
+ wandb.Audio(
+ recon_audio[0, :audio_len],
+ sample_rate=self.sampling_rate,
+ caption="recon",
+ ),
+ ],
+ },
+ )
+
+ if isinstance(self.logger, TensorBoardLogger):
+ self.logger.experiment.add_figure(
+ f"sample-{idx}/mels",
+ image_mels,
+ global_step=self.global_step,
+ )
+ self.logger.experiment.add_audio(
+ f"sample-{idx}/wavs/gt",
+ audio[0, :audio_len],
+ self.global_step,
+ sample_rate=self.sampling_rate,
+ )
+ self.logger.experiment.add_audio(
+ f"sample-{idx}/wavs/gen",
+ gen_aux_audio[0, :audio_len],
+ self.global_step,
+ sample_rate=self.sampling_rate,
+ )
+ self.logger.experiment.add_audio(
+ f"sample-{idx}/wavs/recon",
+ recon_audio[0, :audio_len],
+ self.global_step,
+ sample_rate=self.sampling_rate,
+ )
+
+ plt.close(image_mels)
+
+ def encode(self, audios, audio_lengths):
+ audios = audios.float()
+
+ mels = self.encode_mel_transform(audios)
+ mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ mels = mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.encoder(mels) * mel_masks_float_conv
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+ return self.quantizer.encode(encoded_features), feature_lengths
+
+ def decode(self, indices, feature_lengths, return_audios=False):
+ factor = math.prod(self.quantizer.downsample_factor)
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
+ z = (
+ z
+ + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
+ :, :, None
+ ]
+ )
+
+ gen_mel = (
+ self.decoder(
+ torch.randn_like(z) * mel_masks_float_conv,
+ condition=z,
+ )
+ * mel_masks_float_conv
+ )
+
+ if return_audios:
+ return self.vocoder(gen_mel)
+
+ return gen_mel
diff --git a/fish_speech/models/vqgan/modules/discriminator.py b/fish_speech/models/vqgan/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..69c7df41033f2cde22583468731f56b49eb594b7
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/discriminator.py
@@ -0,0 +1,44 @@
+import torch
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class Discriminator(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ blocks = []
+ convs = [
+ (1, 64, (3, 9), 1, (1, 4)),
+ (64, 128, (3, 9), (1, 2), (1, 4)),
+ (128, 256, (3, 9), (1, 2), (1, 4)),
+ (256, 512, (3, 9), (1, 2), (1, 4)),
+ (512, 1024, (3, 3), 1, (1, 1)),
+ (1024, 1, (3, 3), 1, (1, 1)),
+ ]
+
+ for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
+ convs
+ ):
+ blocks.append(
+ weight_norm(
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
+ )
+ )
+
+ if idx != len(convs) - 1:
+ blocks.append(nn.SiLU(inplace=True))
+
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ return self.blocks(x[:, None])[:, 0]
+
+
+if __name__ == "__main__":
+ model = Discriminator()
+ print(sum(p.numel() for p in model.parameters()) / 1_000_000)
+ x = torch.randn(1, 128, 1024)
+ y = model(x)
+ print(y.shape)
+ print(y)
diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b34958f309588b9c0911a367441042d8f8b47b6
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/firefly.py
@@ -0,0 +1,625 @@
+# A inference only version of the FireflyGAN model
+
+import math
+from functools import partial
+from math import prod
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+from torch.utils.checkpoint import checkpoint
+
+from fish_speech.models.vqgan.utils import sequence_mask
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.silu(x)
+ xt = c1(xt)
+ xt = F.silu(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_parametrizations(self):
+ for conv in self.convs1:
+ remove_parametrizations(conv, tensor_name="weight")
+ for conv in self.convs2:
+ remove_parametrizations(conv, tensor_name="weight")
+
+
+class ParralelBlock(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ kernel_sizes: tuple[int] = (3, 7, 11),
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ ):
+ super().__init__()
+
+ assert len(kernel_sizes) == len(dilation_sizes)
+
+ self.blocks = nn.ModuleList()
+ for k, d in zip(kernel_sizes, dilation_sizes):
+ self.blocks.append(ResBlock1(channels, k, d))
+
+ def forward(self, x):
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
+
+ def remove_parametrizations(self):
+ for block in self.blocks:
+ block.remove_parametrizations()
+
+
+class HiFiGANGenerator(nn.Module):
+ def __init__(
+ self,
+ *,
+ hop_length: int = 512,
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 128,
+ upsample_initial_channel: int = 512,
+ use_template: bool = True,
+ pre_conv_kernel_size: int = 7,
+ post_conv_kernel_size: int = 7,
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
+ ):
+ super().__init__()
+
+ assert (
+ prod(upsample_rates) == hop_length
+ ), f"hop_length must be {prod(upsample_rates)}"
+
+ self.conv_pre = weight_norm(
+ nn.Conv1d(
+ num_mels,
+ upsample_initial_channel,
+ pre_conv_kernel_size,
+ 1,
+ padding=get_padding(pre_conv_kernel_size),
+ )
+ )
+
+ self.num_upsamples = len(upsample_rates)
+ self.num_kernels = len(resblock_kernel_sizes)
+
+ self.noise_convs = nn.ModuleList()
+ self.use_template = use_template
+ self.ups = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
+ self.ups.append(
+ weight_norm(
+ nn.ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ if not use_template:
+ continue
+
+ if i + 1 < len(upsample_rates):
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
+ self.noise_convs.append(
+ Conv1d(
+ 1,
+ c_cur,
+ kernel_size=stride_f0 * 2,
+ stride=stride_f0,
+ padding=stride_f0 // 2,
+ )
+ )
+ else:
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ self.resblocks.append(
+ ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+ )
+
+ self.activation_post = post_activation()
+ self.conv_post = weight_norm(
+ nn.Conv1d(
+ ch,
+ 1,
+ post_conv_kernel_size,
+ 1,
+ padding=get_padding(post_conv_kernel_size),
+ )
+ )
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x, template=None):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ x = F.silu(x, inplace=True)
+ x = self.ups[i](x)
+
+ if self.use_template:
+ x = x + self.noise_convs[i](template)
+
+ if self.training and self.checkpointing:
+ x = checkpoint(
+ self.resblocks[i],
+ x,
+ use_reentrant=False,
+ )
+ else:
+ x = self.resblocks[i](x)
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_parametrizations(self):
+ for up in self.ups:
+ remove_parametrizations(up, tensor_name="weight")
+ for block in self.resblocks:
+ block.remove_parametrizations()
+ remove_parametrizations(self.conv_pre, tensor_name="weight")
+ remove_parametrizations(self.conv_post, tensor_name="weight")
+
+
+# DropPath copied from timm library
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """ # noqa: E501
+
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """ # noqa: E501
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None] * x + self.bias[:, None]
+ return x
+
+
+# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ drop_path: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+
+ self.dwconv = nn.Conv1d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+ x = self.drop_path(x)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+class ConvNeXtEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 3,
+ depths: list[int] = [3, 3, 9, 3],
+ dims: list[int] = [96, 192, 384, 768],
+ drop_path_rate: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ kernel_size: int = 7,
+ ):
+ super().__init__()
+ assert len(depths) == len(dims)
+
+ self.downsample_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ nn.Conv1d(
+ input_channels,
+ dims[0],
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ padding_mode="zeros",
+ ),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.downsample_layers.append(stem)
+
+ for i in range(len(depths) - 1):
+ mid_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+ )
+ self.downsample_layers.append(mid_layer)
+
+ self.stages = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+
+ cur = 0
+ for i in range(len(depths)):
+ stage = nn.Sequential(
+ *[
+ ConvNeXtBlock(
+ dim=dims[i],
+ drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value,
+ kernel_size=kernel_size,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ for i in range(len(self.downsample_layers)):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+
+ return self.norm(x)
+
+
+class FireflyArchitecture(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ head: nn.Module,
+ quantizer: nn.Module,
+ spec_transform: nn.Module,
+ ):
+ super().__init__()
+
+ self.backbone = backbone
+ self.head = head
+ self.quantizer = quantizer
+ self.spec_transform = spec_transform
+
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
+ if self.spec_transform is not None:
+ x = self.spec_transform(x)
+
+ x = self.backbone(x)
+ if mask is not None:
+ x = x * mask
+
+ if self.quantizer is not None:
+ vq_result = self.quantizer(x)
+ x = vq_result.z
+
+ if mask is not None:
+ x = x * mask
+
+ x = self.head(x, template=template)
+
+ if x.ndim == 2:
+ x = x[:, None, :]
+
+ if self.vq is not None:
+ return x, vq_result
+
+ return x
+
+ def encode(self, audios, audio_lengths):
+ audios = audios.float()
+
+ mels = self.spec_transform(audios)
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+ mels = mels * mel_masks_float_conv
+
+ # Encode
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+ return self.quantizer.encode(encoded_features), feature_lengths
+
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
+ factor = math.prod(self.quantizer.downsample_factor)
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+ mel_masks_float_conv = mel_masks[:, None, :].float()
+
+ audio_masks = sequence_mask(
+ feature_lengths * factor * self.spec_transform.hop_length,
+ indices.shape[2] * factor * self.spec_transform.hop_length,
+ )
+ audio_masks_float_conv = audio_masks[:, None, :].float()
+
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
+ x = self.head(z) * audio_masks_float_conv
+
+ return x
+
+ def remove_parametrizations(self):
+ if hasattr(self.backbone, "remove_parametrizations"):
+ self.backbone.remove_parametrizations()
+
+ if hasattr(self.head, "remove_parametrizations"):
+ self.head.remove_parametrizations()
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+
+class FireflyBase(nn.Module):
+ def __init__(self, ckpt_path: str = None, pretrained: bool = True):
+ super().__init__()
+
+ self.backbone = ConvNeXtEncoder(
+ input_channels=128,
+ depths=[3, 3, 9, 3],
+ dims=[128, 256, 384, 512],
+ drop_path_rate=0.2,
+ kernel_size=7,
+ )
+
+ self.head = HiFiGANGenerator(
+ hop_length=512,
+ upsample_rates=[8, 8, 2, 2, 2],
+ upsample_kernel_sizes=[16, 16, 4, 4, 4],
+ resblock_kernel_sizes=[3, 7, 11],
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ num_mels=512,
+ upsample_initial_channel=512,
+ use_template=False,
+ pre_conv_kernel_size=13,
+ post_conv_kernel_size=13,
+ )
+
+ if ckpt_path is not None:
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+ elif pretrained:
+ state_dict = torch.hub.load_state_dict_from_url(
+ "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
+ map_location="cpu",
+ model_dir="checkpoints",
+ )
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator." in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ self.load_state_dict(state_dict, strict=True)
+ self.head.remove_parametrizations()
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.backbone(x)
+ x = self.head(x)
+ if x.ndim == 2:
+ x = x[:, None, :]
+ return x
+
+
+if __name__ == "__main__":
+ model = FireflyBase()
+ model.eval()
+ x = torch.randn(1, 128, 128)
+ with torch.no_grad():
+ y = model(x)
+ print(y.shape)
diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c837d6aee576d192adcb6c38ec0f1e666b84b6d7
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/fsq.py
@@ -0,0 +1,139 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from vector_quantize_pytorch import GroupedResidualFSQ
+
+from .firefly import ConvNeXtBlock
+
+
+@dataclass
+class FSQResult:
+ z: torch.Tensor
+ codes: torch.Tensor
+ latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 1,
+ n_groups: int = 1,
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
+ downsample_factor: tuple[int] = (2, 2),
+ downsample_dims: tuple[int] | None = None,
+ ):
+ super().__init__()
+
+ if downsample_dims is None:
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+ all_dims = (input_dim,) + tuple(downsample_dims)
+
+ self.residual_fsq = GroupedResidualFSQ(
+ dim=all_dims[-1],
+ levels=levels,
+ num_quantizers=n_codebooks,
+ groups=n_groups,
+ )
+
+ self.downsample_factor = downsample_factor
+ self.downsample_dims = downsample_dims
+
+ self.downsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.Conv1d(
+ all_dims[idx],
+ all_dims[idx + 1],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
+ )
+ for idx, factor in enumerate(downsample_factor)
+ ]
+ )
+
+ self.upsample = nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.ConvTranspose1d(
+ all_dims[idx + 1],
+ all_dims[idx],
+ kernel_size=factor,
+ stride=factor,
+ ),
+ ConvNeXtBlock(dim=all_dims[idx]),
+ )
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
+ ]
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, z) -> FSQResult:
+ original_shape = z.shape
+ z = self.downsample(z)
+ quantized, indices = self.residual_fsq(z.mT)
+ result = FSQResult(
+ z=quantized.mT,
+ codes=indices.mT,
+ latents=z,
+ )
+ result.z = self.upsample(result.z)
+
+ # Pad or crop z to match original shape
+ diff = original_shape[-1] - result.z.shape[-1]
+ left = diff // 2
+ right = diff - left
+
+ if diff > 0:
+ result.z = F.pad(result.z, (left, right))
+ elif diff < 0:
+ result.z = result.z[..., left:-right]
+
+ return result
+
+ def encode(self, z):
+ z = self.downsample(z)
+ _, indices = self.residual_fsq(z.mT)
+ indices = rearrange(indices, "g b l r -> b (g r) l")
+ return indices
+
+ def decode(self, indices: torch.Tensor):
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+ z_q = self.residual_fsq.get_output_from_indices(indices)
+ z_q = self.upsample(z_q.mT)
+ return z_q
+
+ # def from_latents(self, latents: torch.Tensor):
+ # z_q, z_p, codes = super().from_latents(latents)
+ # z_q = self.upsample(z_q)
+ # return z_q, z_p, codes
+
+
+if __name__ == "__main__":
+ rvq = DownsampleFiniteScalarQuantize(
+ n_codebooks=1,
+ downsample_factor=(2, 2),
+ )
+ x = torch.randn(16, 512, 80)
+
+ result = rvq(x)
+ print(rvq)
+ print(result.latents.shape, result.codes.shape, result.z.shape)
+
+ # y = rvq.from_codes(result.codes)
+ # print(y[0].shape)
+
+ # y = rvq.from_latents(result.latents)
+ # print(y[0].shape)
diff --git a/fish_speech/models/vqgan/modules/reference.py b/fish_speech/models/vqgan/modules/reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..034d5c5e3572bd3828649fc0f82a1856ccc6b9e1
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/reference.py
@@ -0,0 +1,113 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .wavenet import WaveNet
+
+
+class ReferenceEncoder(WaveNet):
+ def __init__(
+ self,
+ input_channels: Optional[int] = None,
+ output_channels: Optional[int] = None,
+ residual_channels: int = 512,
+ residual_layers: int = 20,
+ dilation_cycle: Optional[int] = 4,
+ num_heads: int = 8,
+ latent_len: int = 4,
+ ):
+ super().__init__(
+ input_channels=input_channels,
+ residual_channels=residual_channels,
+ residual_layers=residual_layers,
+ dilation_cycle=dilation_cycle,
+ )
+
+ self.head_dim = residual_channels // num_heads
+ self.num_heads = num_heads
+
+ self.latent_len = latent_len
+ self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
+
+ self.q = nn.Linear(residual_channels, residual_channels, bias=True)
+ self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
+ self.q_norm = nn.LayerNorm(self.head_dim)
+ self.k_norm = nn.LayerNorm(self.head_dim)
+ self.proj = nn.Linear(residual_channels, residual_channels)
+ self.proj_drop = nn.Dropout(0.1)
+
+ self.norm = nn.LayerNorm(residual_channels)
+ self.mlp = nn.Sequential(
+ nn.Linear(residual_channels, residual_channels * 4),
+ nn.SiLU(),
+ nn.Linear(residual_channels * 4, residual_channels),
+ )
+ self.output_projection_attn = nn.Linear(residual_channels, output_channels)
+
+ torch.nn.init.trunc_normal_(self.latent, std=0.02)
+ self.apply(self.init_weights)
+
+ def init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ torch.nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, attn_mask=None):
+ x = super().forward(x).mT
+ B, N, C = x.shape
+
+ # Calculate mask
+ if attn_mask is not None:
+ assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
+
+ attn_mask = attn_mask[:, None, None, :].expand(
+ B, self.num_heads, self.latent_len, N
+ )
+
+ q_latent = self.latent.expand(B, -1, -1)
+ q = (
+ self.q(q_latent)
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv = (
+ self.kv(x)
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ k, v = kv.unbind(0)
+
+ q, k = self.q_norm(q), self.k_norm(k)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+
+ x = x.transpose(1, 2).reshape(B, self.latent_len, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ x = x + self.mlp(self.norm(x))
+ x = self.output_projection_attn(x)
+ x = x.mean(1)
+
+ return x
+
+
+if __name__ == "__main__":
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+ model = ReferenceEncoder(
+ input_channels=128,
+ output_channels=64,
+ residual_channels=384,
+ residual_layers=20,
+ dilation_cycle=4,
+ num_heads=8,
+ )
+ x = torch.randn(4, 128, 64)
+ mask = torch.ones(4, 64, dtype=torch.bool)
+ y = model(x, mask)
+ print(y.shape)
+ loss = F.mse_loss(y, torch.randn(4, 64))
+ loss.backward()
diff --git a/fish_speech/models/vqgan/modules/wavenet.py b/fish_speech/models/vqgan/modules/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7cc011c3e071067ff36e1aba12c05cff81d94f6
--- /dev/null
+++ b/fish_speech/models/vqgan/modules/wavenet.py
@@ -0,0 +1,225 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Mish(nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+class DiffusionEmbedding(nn.Module):
+ """Diffusion Step Embedding"""
+
+ def __init__(self, d_denoiser):
+ super(DiffusionEmbedding, self).__init__()
+ self.dim = d_denoiser
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class LinearNorm(nn.Module):
+ """LinearNorm Projection"""
+
+ def __init__(self, in_features, out_features, bias=False):
+ super(LinearNorm, self).__init__()
+ self.linear = nn.Linear(in_features, out_features, bias)
+
+ nn.init.xavier_uniform_(self.linear.weight)
+ if bias:
+ nn.init.constant_(self.linear.bias, 0.0)
+
+ def forward(self, x):
+ x = self.linear(x)
+ return x
+
+
+class ConvNorm(nn.Module):
+ """1D Convolution"""
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=None,
+ dilation=1,
+ bias=True,
+ w_init_gain="linear",
+ ):
+ super(ConvNorm, self).__init__()
+
+ if padding is None:
+ assert kernel_size % 2 == 1
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+ nn.init.kaiming_normal_(self.conv.weight)
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+
+ return conv_signal
+
+
+class ResidualBlock(nn.Module):
+ """Residual Block"""
+
+ def __init__(
+ self,
+ residual_channels,
+ use_linear_bias=False,
+ dilation=1,
+ condition_channels=None,
+ ):
+ super(ResidualBlock, self).__init__()
+ self.conv_layer = ConvNorm(
+ residual_channels,
+ 2 * residual_channels,
+ kernel_size=3,
+ stride=1,
+ padding=dilation,
+ dilation=dilation,
+ )
+
+ if condition_channels is not None:
+ self.diffusion_projection = LinearNorm(
+ residual_channels, residual_channels, use_linear_bias
+ )
+ self.condition_projection = ConvNorm(
+ condition_channels, 2 * residual_channels, kernel_size=1
+ )
+
+ self.output_projection = ConvNorm(
+ residual_channels, 2 * residual_channels, kernel_size=1
+ )
+
+ def forward(self, x, condition=None, diffusion_step=None):
+ y = x
+
+ if diffusion_step is not None:
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+ y = y + diffusion_step
+
+ y = self.conv_layer(y)
+
+ if condition is not None:
+ condition = self.condition_projection(condition)
+ y = y + condition
+
+ gate, filter = torch.chunk(y, 2, dim=1)
+ y = torch.sigmoid(gate) * torch.tanh(filter)
+
+ y = self.output_projection(y)
+ residual, skip = torch.chunk(y, 2, dim=1)
+
+ return (x + residual) / math.sqrt(2.0), skip
+
+
+class WaveNet(nn.Module):
+ def __init__(
+ self,
+ input_channels: Optional[int] = None,
+ output_channels: Optional[int] = None,
+ residual_channels: int = 512,
+ residual_layers: int = 20,
+ dilation_cycle: Optional[int] = 4,
+ is_diffusion: bool = False,
+ condition_channels: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # Input projection
+ self.input_projection = None
+ if input_channels is not None and input_channels != residual_channels:
+ self.input_projection = ConvNorm(
+ input_channels, residual_channels, kernel_size=1
+ )
+
+ if input_channels is None:
+ input_channels = residual_channels
+
+ self.input_channels = input_channels
+
+ # Residual layers
+ self.residual_layers = nn.ModuleList(
+ [
+ ResidualBlock(
+ residual_channels=residual_channels,
+ use_linear_bias=False,
+ dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+ condition_channels=condition_channels,
+ )
+ for i in range(residual_layers)
+ ]
+ )
+
+ # Skip projection
+ self.skip_projection = ConvNorm(
+ residual_channels, residual_channels, kernel_size=1
+ )
+
+ # Output projection
+ self.output_projection = None
+ if output_channels is not None and output_channels != residual_channels:
+ self.output_projection = ConvNorm(
+ residual_channels, output_channels, kernel_size=1
+ )
+
+ if is_diffusion:
+ self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+ self.mlp = nn.Sequential(
+ LinearNorm(residual_channels, residual_channels * 4, False),
+ Mish(),
+ LinearNorm(residual_channels * 4, residual_channels, False),
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if getattr(m, "bias", None) is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, t=None, condition=None):
+ if self.input_projection is not None:
+ x = self.input_projection(x)
+ x = F.silu(x)
+
+ if t is not None:
+ t = self.diffusion_embedding(t)
+ t = self.mlp(t)
+
+ skip = []
+ for layer in self.residual_layers:
+ x, skip_connection = layer(x, condition, t)
+ skip.append(skip_connection)
+
+ x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+ x = self.skip_projection(x)
+
+ if self.output_projection is not None:
+ x = F.silu(x)
+ x = self.output_projection(x)
+
+ return x
diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac
--- /dev/null
+++ b/fish_speech/models/vqgan/utils.py
@@ -0,0 +1,94 @@
+import matplotlib
+import torch
+from matplotlib import pyplot as plt
+
+matplotlib.use("Agg")
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def plot_mel(data, titles=None):
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+ if titles is None:
+ titles = [None for i in range(len(data))]
+
+ plt.tight_layout()
+
+ for i in range(len(data)):
+ mel = data[i]
+
+ if isinstance(mel, torch.Tensor):
+ mel = mel.float().detach().cpu().numpy()
+
+ axes[i][0].imshow(mel, origin="lower")
+ axes[i][0].set_aspect(2.5, adjustable="box")
+ axes[i][0].set_ylim(0, mel.shape[0])
+ axes[i][0].set_title(titles[i], fontsize="medium")
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+ axes[i][0].set_anchor("W")
+
+ return fig
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
+ n_channels_int = n_channels[0]
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+
+ return acts
+
+
+def avg_with_mask(x, mask):
+ assert mask.dtype == torch.float, "Mask should be float"
+
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(1)
+
+ if mask.shape[1] == 1:
+ mask = mask.expand_as(x)
+
+ return (x * mask).sum() / mask.sum()
diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29
--- /dev/null
+++ b/fish_speech/scheduler.py
@@ -0,0 +1,40 @@
+import math
+
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ final_lr_ratio: float = 0.0,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ progress = float(current_step - num_warmup_steps) / float(
+ max(1, num_training_steps - num_warmup_steps)
+ )
+
+ return max(
+ final_lr_ratio,
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+ )
+
+
+def get_constant_schedule_with_warmup_lr_lambda(
+ current_step: int,
+ *,
+ num_warmup_steps: int | float,
+ num_training_steps: int | None = None,
+):
+ if 0 < num_warmup_steps < 1: # float mode
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+
+ return 1.0
diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce
--- /dev/null
+++ b/fish_speech/text/__init__.py
@@ -0,0 +1,4 @@
+from .clean import clean_text
+from .spliter import split_text
+
+__all__ = ["clean_text", "split_text"]
diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/.gitignore
@@ -0,0 +1,114 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# JetBrains PyCharm
+.idea
+
+# Customize
+references
+url.txt
+
+# Git
+.git
diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/README.md
@@ -0,0 +1,36 @@
+# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
+
+# Chn Text Norm
+
+this is a repository for chinese text normalization (no longer maintained).
+
+## Quick Start ##
+
+### Git Clone Repo ###
+
+git clone this repo to the root directory of your project which need to use it.
+
+ cd /path/to/proj
+ git clone https://github.com/Joee1995/chn-text-norm.git
+
+after that, your doc tree should be:
+```
+proj # root of your project
+|--- chn_text_norm # this chn-text-norm tool
+ |--- text.py
+ |--- ...
+|--- text_normalize.py # your text normalization code
+|--- ...
+```
+
+### How to Use ? ###
+
+ # text_normalize.py
+ from chn_text_norm.text import *
+
+ raw_text = 'your raw text'
+ text = Text(raw_text=raw_text).normalize()
+
+### How to add quantums ###
+
+打开test.py,然后你就知道怎么做了。
diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_class.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""基本类
+中文字符类
+中文数字/数位类
+中文数字类
+中文数位类
+中文数字系统类
+中文数学符号类
+*中文其他符号类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
+
+
+class ChineseChar(object):
+ """
+ 中文字符
+ 每个字符对应简体和繁体,
+ e.g. 简体 = '负', 繁体 = '負'
+ 转换时可转换为简体或繁体
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ 中文数字/数位字符
+ 每个字符除繁简体外还有一个额外的大写字符
+ e.g. '陆' 和 '陸'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return "10^{}".format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+ if small_unit:
+ return ChineseNumberUnit(
+ power=index + 1,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[1],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(
+ power=index + 8,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(
+ power=(index + 2) * 4,
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(
+ power=pow(2, index + 3),
+ simplified=value[0],
+ traditional=value[1],
+ big_s=value[0],
+ big_t=value[1],
+ )
+ else:
+ raise ValueError(
+ "Counting type should be in {0} ({1} provided).".format(
+ NUMBERING_TYPES, numbering_type
+ )
+ )
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ 中文数字字符
+ """
+
+ def __init__(
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
+ ):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ 中文数位字符
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ 中文数字系统
+ """
+
+ pass
+
+
+class MathSymbol(object):
+ """
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
+ positive = ['正', '正']
+ negative = ['负', '負']
+ point = ['点', '點']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# 其他符号
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_constant.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+"""基本常量
+中文数字/数位/符号字符常量
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+CHINESE_DIGIS = "零一二三四五六七八九"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
+BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
+
+ZERO_ALT = "〇"
+ONE_ALT = "幺"
+TWO_ALTS = ["两", "兩"]
+
+POSITIVE = ["正", "正"]
+NEGATIVE = ["负", "負"]
+POINT = ["点", "點"]
+# PLUS = [u'加', u'加']
+# SIL = [u'杠', u'槓']
+
+# 中文数字系统类型
+NUMBERING_TYPES = ["low", "mid", "high"]
diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/basic_util.py
@@ -0,0 +1,342 @@
+# -*- coding: utf-8 -*-
+"""基本方法
+创建中文数字系统 方法
+中文字符串 <=> 数字串 方法
+数字串 <=> 中文字符串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-02"
+
+from fish_speech.text.chn_text_norm.basic_class import *
+from fish_speech.text.chn_text_norm.basic_constant import *
+
+
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
+ 返回对应的数字系统
+ """
+
+ # chinese number units of '亿' and larger
+ all_larger_units = zip(
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ larger_units = [
+ CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
+ ]
+ # chinese number units of '十, 百, 千, 万'
+ all_smaller_units = zip(
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
+ )
+ smaller_units = [
+ CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
+ ]
+ # digis
+ chinese_digis = zip(
+ CHINESE_DIGIS,
+ CHINESE_DIGIS,
+ BIG_CHINESE_DIGIS_SIMPLIFIED,
+ BIG_CHINESE_DIGIS_TRADITIONAL,
+ )
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [
+ d.traditional,
+ d.simplified,
+ d.big_s,
+ d.big_t,
+ d.alt_s,
+ d.alt_t,
+ ]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ""
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], [
+ get_symbol(c, system) for c in dec_string
+ ]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ 一百八 to 一百八十
+ 一亿一千三百万 to 一亿 一千万 三百万
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(
+ integer_symbols[-2], CNU
+ ):
+ integer_symbols.append(
+ CNU(integer_symbols[-2].power - 1, None, None, None, None)
+ )
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if (
+ isinstance(result[-i - 1], CNU)
+ and result[-i - 1].power < current_unit.power
+ ):
+ result[-i - 1] = CNU(
+ result[-i - 1].power + current_unit.power,
+ None,
+ None,
+ None,
+ None,
+ )
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = "".join([str(d.value) for d in dec_part])
+ if dec_part:
+ return "{0}.{1}".format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(
+ number_string,
+ numbering_type=NUMBERING_TYPES[1],
+ big=False,
+ traditional=False,
+ alt_zero=False,
+ alt_one=False,
+ alt_two=True,
+ use_zeros=True,
+ use_units=True,
+):
+
+ def get_value(value_string, use_zeros=True):
+
+ striped_string = value_string.lstrip("0")
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(
+ u for u in reversed(system.units) if u.power < len(striped_string)
+ )
+ result_string = value_string[: -result_unit.power]
+ return (
+ get_value(result_string)
+ + [result_unit]
+ + get_value(striped_string[-result_unit.power :])
+ )
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split(".")
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError(
+ "invalid input num string with more than one dot: {}".format(number_string)
+ )
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(
+ 2,
+ system.digits[2].alt_s,
+ system.digits[2].alt_t,
+ system.digits[2].big_s,
+ system.digits[2].big_t,
+ )
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = (
+ result_symbols[i + 1] if i < len(result_symbols) - 1 else None
+ )
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(
+ previous_symbol, (CNU, type(None))
+ ):
+ if next_symbol.power != 1 and (
+ (previous_symbol is None) or (previous_symbol.power != 1)
+ ):
+ result_symbols[i] = liang
+
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = "big_"
+ if traditional:
+ attr_name += "t"
+ else:
+ attr_name += "s"
+ else:
+ if traditional:
+ attr_name = "traditional"
+ else:
+ attr_name = "simplified"
+
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s
+ )
+
+ if alt_one:
+ result = result.replace(
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s
+ )
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if (
+ len(result) >= 2
+ and result[1]
+ in [
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
+ ]
+ and result[0]
+ in [
+ CHINESE_DIGIS[1],
+ BIG_CHINESE_DIGIS_SIMPLIFIED[1],
+ BIG_CHINESE_DIGIS_TRADITIONAL[1],
+ ]
+ ):
+ result = result[1:]
+
+ return result
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ all_chinese_number_string = (
+ CHINESE_DIGIS
+ + BIG_CHINESE_DIGIS_SIMPLIFIED
+ + BIG_CHINESE_DIGIS_TRADITIONAL
+ + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
+ + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+ + ZERO_ALT
+ + ONE_ALT
+ + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
+ )
+
+ print("num:", chn2num("一万零四百零三点八零五"))
+ print("num:", chn2num("一亿六点三"))
+ print("num:", chn2num("一亿零六点三"))
+ print("num:", chn2num("两千零一亿六点三"))
+ # print('num:', chn2num('一零零八六'))
+ print("txt:", num2chn("10260.03", alt_zero=True))
+ print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
+ print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
+ print(
+ "txt:",
+ num2chn(
+ "059523810880",
+ alt_one=True,
+ alt_two=False,
+ use_lzeros=True,
+ use_rzeros=True,
+ use_units=False,
+ ),
+ )
+
+ print(all_chinese_number_string)
diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/cardinal.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""CARDINAL类 (包含小数DECIMAL类)
+纯数 <=> 中文字符串 方法
+中文字符串 <=> 纯数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Cardinal:
+ """
+ CARDINAL类
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Cardinal(cardinal="21357.230").cardinal2chntext())
diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py
new file mode 100644
index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/date.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+"""DATE类
+日期 <=> 中文字符串 方法
+中文字符串 <=> 日期 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-07"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.digit import Digit
+
+
+class Date:
+ """
+ DATE类
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('年', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + '年'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('月', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split("年", maxsplit=1)
+ year = Digit(digit=year).digit2chntext() + "年"
+ except ValueError:
+ other = date
+ year = ""
+ if other:
+ try:
+ month, day = other.strip().split("月", maxsplit=1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
+ except ValueError:
+ day = date
+ month = ""
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ""
+ day = ""
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Date(date="09年3月16日").date2chntext())
diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/digit.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""DIGIT类
+数字串 <=> 中文字符串 方法
+中文字符串 <=> 数字串 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Digit:
+ """
+ DIGIT类
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Digit(digit="2016").digit2chntext())
diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/fraction.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+"""FRACTION类
+分数 <=> 中文字符串 方法
+中文字符串 <=> 分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Fraction:
+ """
+ FRACTION类
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split("分之")
+ return chn2num(numerator) + "/" + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split("/")
+ return num2chn(denominator) + "分之" + num2chn(numerator)
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Fraction(fraction="2135/7230").fraction2chntext())
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/money.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+"""MONEY类
+金钱 <=> 中文字符串 方法
+中文字符串 <=> 金钱 方法
+"""
+import re
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-08"
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+
+
+class Money:
+ """
+ MONEY类
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
+ )
+ self.chntext = money
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试
+ print(Money(money="21.5万元").money2chntext())
+ print(Money(money="230块5毛").money2chntext())
diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py
new file mode 100644
index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/percentage.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+"""PERCENTAGE类
+百分数 <=> 中文字符串 方法
+中文字符串 <=> 百分数 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-06"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class Percentage:
+ """
+ PERCENTAGE类
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
+
+ def percentage2chntext(self):
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
+ print(Percentage(percentage="65.3%").percentage2chntext())
diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/telephone.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+"""TELEPHONE类
+电话号码 <=> 中文字符串 方法
+中文字符串 <=> 电话号码 方法
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+from fish_speech.text.chn_text_norm.basic_util import *
+
+
+class TelePhone:
+ """
+ TELEPHONE类
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+
+ if fixed:
+ sil_parts = self.telephone.split("-")
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ else:
+ sp_parts = self.telephone.strip("+").split()
+ self.raw_chntext = "".join(
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
+ )
+ self.chntext = self.raw_chntext.replace("", "")
+ return self.chntext
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31
--- /dev/null
+++ b/fish_speech/text/chn_text_norm/text.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+"""
+TEXT类
+"""
+
+__author__ = "Zhiyang Zhou "
+__data__ = "2019-05-03"
+
+import re
+
+from fish_speech.text.chn_text_norm.cardinal import Cardinal
+from fish_speech.text.chn_text_norm.date import Date
+from fish_speech.text.chn_text_norm.digit import Digit
+from fish_speech.text.chn_text_norm.fraction import Fraction
+from fish_speech.text.chn_text_norm.money import Money
+from fish_speech.text.chn_text_norm.percentage import Percentage
+from fish_speech.text.chn_text_norm.telephone import TelePhone
+
+CURRENCY_NAMES = (
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
+)
+CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
+COM_QUANTIFIERS = (
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
+)
+
+
+class Text:
+ """
+ Text类
+ """
+
+ def __init__(self, raw_text, norm_text=None):
+ self.raw_text = "^" + raw_text + "$"
+ self.norm_text = norm_text
+
+ def _particular(self):
+ text = self.norm_text
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
+ self.norm_text = text
+ return self.norm_text
+
+ def normalize(self):
+ text = self.raw_text
+
+ # 规范化日期
+ pattern = re.compile(
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # 规范化金钱
+ pattern = re.compile(
+ r"\D+((\d+(\.\d+)?)[多余几]?"
+ + CURRENCY_UNITS
+ + "(\d"
+ + CURRENCY_UNITS
+ + "?)?)"
+ )
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('money')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
+ )
+
+ # 规范化固话/手机号码
+ # 手机
+ # http://www.jihaoba.com/news/show/13680
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
+ # 联通:130、131、132、156、155、186、185、176
+ # 电信:133、153、189、180、181、177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
+ )
+ # 固话
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
+ 1,
+ )
+
+ # 规范化分数
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fraction')
+ for matcher in matchers:
+ text = text.replace(
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
+ )
+
+ # 规范化百分数
+ text = text.replace("%", "%")
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('percentage')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0],
+ Percentage(percentage=matcher[0]).percentage2chntext(),
+ 1,
+ )
+
+ # 规范化纯数+量词
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ # 规范化数字编号
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # 规范化纯数
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal')
+ for matcher in matchers:
+ text = text.replace(
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
+ )
+
+ self.norm_text = text
+ self._particular()
+
+ return self.norm_text.lstrip("^").rstrip("$")
+
+
+if __name__ == "__main__":
+
+ # 测试程序
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
+ print(Text(raw_text="分数:32477/76391。").normalize())
+ print(Text(raw_text="百分数:80.03%。").normalize())
+ print(Text(raw_text="编号:31520181154418。").normalize())
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d9dc9033dfb3b5b6265a0d1dbb3acafc9e7606
--- /dev/null
+++ b/fish_speech/text/clean.py
@@ -0,0 +1,69 @@
+import itertools
+import re
+
+LANGUAGE_UNICODE_RANGE_MAP = {
+ "ZH": [(0x4E00, 0x9FFF)],
+ "JP": [(0x4E00, 0x9FFF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0x31F0, 0x31FF)],
+ "EN": [(0x0000, 0x007F)],
+}
+
+SYMBOLS_MAPPING = {
+ ":": ",",
+ ";": ",",
+ ",": ",",
+ "。": ".",
+ "!": "!",
+ "?": "?",
+ "\n": ".",
+ "·": ",",
+ "、": ",",
+ "...": "…",
+ "“": "'",
+ "”": "'",
+ "‘": "'",
+ "’": "'",
+ "(": "'",
+ ")": "'",
+ "(": "'",
+ ")": "'",
+ "《": "'",
+ "》": "'",
+ "【": "'",
+ "】": "'",
+ "[": "'",
+ "]": "'",
+ "—": "-",
+ "~": "-",
+ "~": "-",
+ "・": "-",
+ "「": "'",
+ "」": "'",
+ ";": ",",
+ ":": ",",
+}
+
+REPLACE_SYMBOL_REGEX = re.compile(
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
+)
+ALL_KNOWN_UTF8_RANGE = list(
+ itertools.chain.from_iterable(LANGUAGE_UNICODE_RANGE_MAP.values())
+)
+REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
+ "[^"
+ + "".join(
+ f"{re.escape(chr(start))}-{re.escape(chr(end))}"
+ for start, end in ALL_KNOWN_UTF8_RANGE
+ )
+ + "]"
+)
+
+
+def clean_text(text):
+ # Clean the text
+ text = text.strip()
+
+ # Replace all chinese symbols with their english counterparts
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
+ text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
+
+ return text
diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5528cd3a63fe4e6b4f5167776bcfa62fe1f5127a
--- /dev/null
+++ b/fish_speech/text/spliter.py
@@ -0,0 +1,130 @@
+import re
+import string
+
+from fish_speech.text.clean import clean_text
+
+
+def utf_8_len(text):
+ return len(text.encode("utf-8"))
+
+
+def break_text(texts, length, splits: set):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if char in splits:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def break_text_by_length(texts, length):
+ for text in texts:
+ if utf_8_len(text) <= length:
+ yield text
+ continue
+
+ curr = ""
+ for char in text:
+ curr += char
+
+ if utf_8_len(curr) >= length:
+ yield curr
+ curr = ""
+
+ if curr:
+ yield curr
+
+
+def add_cleaned(curr, segments):
+ curr = curr.strip()
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
+ segments.append(curr)
+
+
+def protect_float(text):
+ # Turns 3.14 into <3_f_14> to prevent splitting
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
+
+
+def unprotect_float(text):
+ # Turns <3_f_14> into 3.14
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
+
+
+def split_text(text, length):
+ text = clean_text(text)
+
+ # Break the text into pieces with following rules:
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
+ # 2. If the text is longer than length, split at ","
+ # 3. If the text is still longer than length, split at " "
+ # 4. If the text is still longer than length, split at any character to length
+
+ texts = [text]
+ texts = map(protect_float, texts)
+ texts = break_text(texts, length, {".", "!", "?"})
+ texts = map(unprotect_float, texts)
+ texts = break_text(texts, length, {","})
+ texts = break_text(texts, length, {" "})
+ texts = list(break_text_by_length(texts, length))
+
+ # Then, merge the texts into segments with length <= length
+ segments = []
+ curr = ""
+
+ for text in texts:
+ if utf_8_len(curr) + utf_8_len(text) <= length:
+ curr += text
+ else:
+ add_cleaned(curr, segments)
+ curr = text
+
+ if curr:
+ add_cleaned(curr, segments)
+
+ return segments
+
+
+if __name__ == "__main__":
+ # Test the split_text function
+
+ text = "This is a test sentence. This is another test sentence. And a third one."
+
+ assert split_text(text, 50) == [
+ "This is a test sentence.",
+ "This is another test sentence. And a third one.",
+ ]
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
+ assert split_text(" ", 10) == []
+ assert split_text("a", 10) == ["a"]
+
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
+ assert split_text(text, 50) == [
+ "This is a test sentence with only commas,",
+ "and no dots, and no exclamation marks,",
+ "and no question marks, and no newlines.",
+ ]
+
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
+ # First half split at " ", second half split at ","
+ assert split_text(text, 50) == [
+ "This is a test sentence This is a test sentence",
+ "This is a test sentence. This is a test sentence,",
+ "This is a test sentence, This is a test sentence.",
+ ]
+
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
+ assert split_text(text, 50) == [
+ "这是一段很长的中文文本,",
+ "而且没有句号,也没有感叹号,",
+ "也没有问号,也没有换行符.",
+ ]
diff --git a/fish_speech/train.py b/fish_speech/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61b793c3af812b9e0d5add86b3be210cf27940e
--- /dev/null
+++ b/fish_speech/train.py
@@ -0,0 +1,139 @@
+import os
+import sys
+from typing import Optional
+
+import hydra
+import lightning as L
+import pyrootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from lightning.pytorch.strategies import DDPStrategy
+from omegaconf import DictConfig, OmegaConf
+
+os.environ.pop("SLURM_NTASKS", None)
+os.environ.pop("SLURM_JOB_NAME", None)
+os.environ.pop("SLURM_NTASKS_PER_NODE", None)
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ Returns:
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+ """ # noqa: E501
+
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=False)
+
+ if cfg.get("deterministic"):
+ torch.use_deterministic_algorithms(True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.trainer,
+ callbacks=callbacks,
+ logger=logger,
+ )
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ utils.log_hyperparameters(object_dict)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+
+ ckpt_path = cfg.get("ckpt_path")
+ auto_resume = False
+
+ resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+ if resume_ckpt_path is not None:
+ ckpt_path = resume_ckpt_path
+ auto_resume = True
+
+ if ckpt_path is not None:
+ log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+ # resume weights only is disabled for auto-resume
+ if cfg.get("resume_weights_only") and auto_resume is False:
+ log.info("Resuming weights only!")
+ ckpt = torch.load(ckpt_path, map_location=model.device)
+ if "state_dict" in ckpt:
+ ckpt = ckpt["state_dict"]
+ err = model.load_state_dict(ckpt, strict=False)
+ log.info(f"Error loading state dict: {err}")
+ ckpt_path = None
+
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = cfg.get("ckpt_path")
+
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(
+ version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig) -> Optional[float]:
+ # train the model
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e0106d7f479b6f3b4ba19117c010bd87f39c2b
--- /dev/null
+++ b/fish_speech/utils/__init__.py
@@ -0,0 +1,21 @@
+from .braceexpand import braceexpand
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, task_wrapper
+
+__all__ = [
+ "enforce_tags",
+ "extras",
+ "get_metric_value",
+ "RankedLogger",
+ "instantiate_callbacks",
+ "instantiate_loggers",
+ "log_hyperparameters",
+ "print_config_tree",
+ "task_wrapper",
+ "braceexpand",
+ "get_latest_checkpoint",
+]
diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974
--- /dev/null
+++ b/fish_speech/utils/braceexpand.py
@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+ pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+ """braceexpand(pattern) -> iterator over generated strings
+
+ Returns an iterator over the strings resulting from brace expansion
+ of pattern. This function implements Brace Expansion as described in
+ bash(1), with the following limitations:
+
+ * A pattern containing unbalanced braces will raise an
+ UnbalancedBracesError exception. In bash, unbalanced braces will either
+ be partly expanded or ignored.
+
+ * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+ include the characters '[]^_`' between 'Z' and 'a'.
+
+ When escape is True (the default), characters in pattern can be
+ prefixed with a backslash to cause them not to be interpreted as
+ special characters for brace expansion (such as '{', '}', ',').
+ To pass through a a literal backslash, double it ('\\\\').
+
+ When escape is False, backslashes in pattern have no special
+ meaning and will be preserved in the output.
+
+ Examples:
+
+ >>> from braceexpand import braceexpand
+
+ # Integer range
+ >>> list(braceexpand('item{1..3}'))
+ ['item1', 'item2', 'item3']
+
+ # Character range
+ >>> list(braceexpand('{a..c}'))
+ ['a', 'b', 'c']
+
+ # Sequence
+ >>> list(braceexpand('index.html{,.backup}'))
+ ['index.html', 'index.html.backup']
+
+ # Nested patterns
+ >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+ ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+ # Prefixing an integer with zero causes all numbers to be padded to
+ # the same width.
+ >>> list(braceexpand('{07..10}'))
+ ['07', '08', '09', '10']
+
+ # An optional increment can be specified for ranges.
+ >>> list(braceexpand('{a..g..2}'))
+ ['a', 'c', 'e', 'g']
+
+ # Ranges can go in both directions.
+ >>> list(braceexpand('{4..1}'))
+ ['4', '3', '2', '1']
+
+ # Numbers can be negative
+ >>> list(braceexpand('{2..-1}'))
+ ['2', '1', '0', '-1']
+
+ # Unbalanced braces raise an exception.
+ >>> list(braceexpand('{1{2,3}'))
+ Traceback (most recent call last):
+ ...
+ UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+ # By default, the backslash is the escape character.
+ >>> list(braceexpand(r'{1\\{2,3}'))
+ ['1{2', '3']
+
+ # Setting 'escape' to False disables backslash escaping.
+ >>> list(braceexpand(r'\\{1,2}', escape=False))
+ ['\\\\1', '\\\\2']
+
+ """
+ return (
+ escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+ )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'pattern:', pattern
+ while pos < len(pattern):
+ if escape and pattern[pos] == "\\":
+ pos += 2
+ continue
+ elif pattern[pos] == "{":
+ if bracketdepth == 0 and pos > start:
+ # print 'literal:', pattern[start:pos]
+ items.append([pattern[start:pos]])
+ start = pos
+ bracketdepth += 1
+ elif pattern[pos] == "}":
+ bracketdepth -= 1
+ if bracketdepth == 0:
+ # print 'expression:', pattern[start+1:pos]
+ expr = pattern[start + 1 : pos]
+ item = parse_expression(expr, escape)
+ if item is None: # not a range or sequence
+ items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+ else:
+ items.append(item)
+ start = pos + 1 # skip the closing brace
+ pos += 1
+
+ if bracketdepth != 0: # unbalanced braces
+ raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+ if start < pos:
+ items.append([pattern[start:]])
+
+ return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+ int_range_match = int_range_re.match(expr)
+ if int_range_match:
+ return make_int_range(*int_range_match.groups())
+
+ char_range_match = char_range_re.match(expr)
+ if char_range_match:
+ return make_char_range(*char_range_match.groups())
+
+ return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+ # sequence -> chain(*sequence_items)
+ start = 0
+ pos = 0
+ bracketdepth = 0
+ items: list[Iterable[str]] = []
+
+ # print 'sequence:', seq
+ while pos < len(seq):
+ if escape and seq[pos] == "\\":
+ pos += 2
+ continue
+ elif seq[pos] == "{":
+ bracketdepth += 1
+ elif seq[pos] == "}":
+ bracketdepth -= 1
+ elif seq[pos] == "," and bracketdepth == 0:
+ items.append(parse_pattern(seq[start:pos], escape))
+ start = pos + 1 # skip the comma
+ pos += 1
+
+ if bracketdepth != 0:
+ raise UnbalancedBracesError
+ if not items:
+ return None
+
+ # part after the last comma (may be the empty string)
+ items.append(parse_pattern(seq[start:], escape))
+ return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+ if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+ padding = max(len(left), len(right))
+ else:
+ padding = 0
+ step = (int(incr) or 1) if incr else 1
+ start = int(left)
+ end = int(right)
+ r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+ fmt = "%0{}d".format(padding)
+ return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+ step = (int(incr) or 1) if incr else 1
+ start = alphabet.index(left)
+ end = alphabet.index(right)
+ if start < end:
+ return alphabet[start : end + 1 : step]
+ else:
+ end = end or -len(alphabet)
+ return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+
+ failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+ if failed:
+ sys.exit(1)
diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..4047aa53f60f45a18e9361ab62e410cbc6575d52
--- /dev/null
+++ b/fish_speech/utils/file.py
@@ -0,0 +1,119 @@
+import os
+from glob import glob
+from pathlib import Path
+from typing import Union
+
+from loguru import logger
+from natsort import natsorted
+
+AUDIO_EXTENSIONS = {
+ ".mp3",
+ ".wav",
+ ".flac",
+ ".ogg",
+ ".m4a",
+ ".wma",
+ ".aac",
+ ".aiff",
+ ".aif",
+ ".aifc",
+}
+
+
+def list_files(
+ path: Union[Path, str],
+ extensions: set[str] = None,
+ recursive: bool = False,
+ sort: bool = True,
+) -> list[Path]:
+ """List files in a directory.
+
+ Args:
+ path (Path): Path to the directory.
+ extensions (set, optional): Extensions to filter. Defaults to None.
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
+ sort (bool, optional): Whether to sort the files. Defaults to True.
+
+ Returns:
+ list: List of files.
+ """
+
+ if isinstance(path, str):
+ path = Path(path)
+
+ if not path.exists():
+ raise FileNotFoundError(f"Directory {path} does not exist.")
+
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
+
+ if sort:
+ files = natsorted(files)
+
+ return files
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+ # Find the latest checkpoint
+ ckpt_dir = Path(path)
+
+ if ckpt_dir.exists() is False:
+ return None
+
+ ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+ if len(ckpts) == 0:
+ return None
+
+ return ckpts[-1]
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+ """
+ Load a Bert-VITS2 style filelist.
+ """
+
+ files = set()
+ results = []
+ count_duplicated, count_not_found = 0, 0
+
+ LANGUAGE_TO_LANGUAGES = {
+ "zh": ["zh", "en"],
+ "jp": ["jp", "en"],
+ "en": ["en"],
+ }
+
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ splits = line.strip().split("|", maxsplit=3)
+ if len(splits) != 4:
+ logger.warning(f"Invalid line: {line}")
+ continue
+
+ filename, speaker, language, text = splits
+ file = Path(filename)
+ language = language.strip().lower()
+
+ if language == "ja":
+ language = "jp"
+
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+ languages = LANGUAGE_TO_LANGUAGES[language]
+
+ if file in files:
+ logger.warning(f"Duplicated file: {file}")
+ count_duplicated += 1
+ continue
+
+ if not file.exists():
+ logger.warning(f"File not found: {file}")
+ count_not_found += 1
+ continue
+
+ results.append((file, speaker, languages, text))
+
+ if count_duplicated > 0:
+ logger.warning(f"Total duplicated files: {count_duplicated}")
+
+ if count_not_found > 0:
+ logger.warning(f"Total files not found: {count_not_found}")
+
+ return results
diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e
--- /dev/null
+++ b/fish_speech/utils/instantiators.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config."""
+
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config."""
+
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc
--- /dev/null
+++ b/fish_speech/utils/logger.py
@@ -0,0 +1,55 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = True,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+ ) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError(
+ "The `rank_zero_only.rank` needs to be set before use"
+ )
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429
--- /dev/null
+++ b/fish_speech/utils/logging_utils.py
@@ -0,0 +1,48 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+ """Controls which config parts are saved by lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+ """
+
+ hparams = {}
+
+ cfg = object_dict["cfg"]
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f
--- /dev/null
+++ b/fish_speech/utils/rich_utils.py
@@ -0,0 +1,100 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints content of DictConfig using Rich library and its tree structure.
+
+ Args:
+ cfg (DictConfig): Configuration composed by Hydra.
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
+ """ # noqa: E501
+
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ (
+ queue.append(field)
+ if field in cfg
+ else log.warning(
+ f"Field '{field}' not found in config. "
+ + f"Skipping '{field}' config printing..."
+ )
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
+
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841
--- /dev/null
+++ b/fish_speech/utils/spectrogram.py
@@ -0,0 +1,122 @@
+import torch
+import torchaudio.functional as F
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or float(sample_rate // 2)
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+
+ fb = F.melscale_fbanks(
+ n_freqs=self.n_fft // 2 + 1,
+ f_min=self.f_min,
+ f_max=self.f_max,
+ n_mels=self.n_mels,
+ sample_rate=self.sample_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ )
+ self.register_buffer(
+ "fb",
+ fb,
+ persistent=False,
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
+ def forward(
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+ ) -> Tensor:
+ if sample_rate is not None and sample_rate != self.sample_rate:
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
+ linear = self.spectrogram(x)
+ x = self.apply_mel_scale(linear)
+ x = self.compress(x)
+
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c546bfa1eddd2ac6bf484cce1ec06da1d33fb121
--- /dev/null
+++ b/fish_speech/utils/utils.py
@@ -0,0 +1,114 @@
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+ """
+
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+ ...
+
+ return metric_dict, object_dict
+ ```
+ """ # noqa: E501
+
+ def wrap(cfg: DictConfig):
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or
+ # cause out-of-memory errors so when using hparam search
+ # plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.run_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+ """Safely retrieves value of the metric logged in LightningModule."""
+
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7
--- /dev/null
+++ b/fish_speech/webui/css/style.css
@@ -0,0 +1,161 @@
+:root {
+ --my-200: #80eeee;
+ --my-50: #ecfdf5;
+ --water-width: 300px;
+ --water-heigh: 300px;
+}
+
+
+/* general styled components */
+.tools {
+ align-items: center;
+ justify-content: center;
+}
+
+.gradio-button {
+ max-width: 2.2em;
+ min-width: 2.2em !important;
+ height: 2.4em;
+ align-self: end;
+ line-height: 1em;
+ border-radius: 0.5em;
+
+}
+
+.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
+ box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
+}
+
+/* replace original footer with ours */
+a{
+ font-weight: bold;
+ cursor: pointer;
+ color: #030C14 !important;
+}
+
+footer {
+ display: none !important;
+}
+
+#footer{
+ text-align: center;
+}
+
+#footer div{
+ display: inline-block;
+}
+
+#footer .versions{
+ font-size: 85%;
+ opacity: 0.85;
+}
+
+/*@keyframes moveBackground {*/
+/* 0% {*/
+/* background-position: 0 0;*/
+/* }*/
+/* 100% {*/
+/* background-position: -100px 100px;*/
+/* }*/
+/*}*/
+@keyframes moveJellyBackground {
+ 0% {
+ background-position: 0% 50%;
+ }
+ 50% {
+ background-position: 100% 50%;
+ }
+ 100% {
+ background-position: 0% 50%;
+ }
+}
+
+.gradio-container {
+ position: absolute;
+ z-index: 10;
+}
+
+
+.quan {
+ position: absolute;
+ bottom: 0;
+ width: var(--water-width);
+ height: var(--water-heigh);
+ border-radius: 0;
+ /*border: 3px solid rgb(246, 247, 248);*/
+ /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
+ z-index: 0;
+
+}
+
+.quan:last-child {
+ margin-right: 0;
+}
+
+.shui {
+ position: absolute;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgb(23, 106, 201);
+ border-radius: 0;
+ overflow: hidden;
+ z-index: 0;
+}
+
+.shui::after {
+
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 40%;
+ background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
+ animation: shi 5s linear infinite;
+}
+
+@keyframes shi {
+ 0% {
+ transform: translate(-50%, -65%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -65%) rotate(360deg);
+ }
+}
+
+.shui::before {
+ content: '';
+ position: absolute;
+ top: 20%;
+ left: 50%;
+ width: 150%;
+ height: 150%;
+ border-radius: 42%;
+ background-color: rgb(240, 228, 228, 0.2);
+ animation: xu 7s linear infinite;
+}
+
+@keyframes xu {
+ 0% {
+ transform: translate(-50%, -60%) rotate(0deg);
+ }
+ 100% {
+ transform: translate(-50%, -60%) rotate(360deg);
+ }
+}
+
+fieldset.data_src div.wrap label {
+ background: #f8bffee0 !important;
+}
+
+.scrollable-component {
+ max-height: 100px;
+ overflow-y: auto;
+}
+
+#file_accordion {
+ max-height: 220px !important;
+}
diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html
new file mode 100644
index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615
--- /dev/null
+++ b/fish_speech/webui/html/footer.html
@@ -0,0 +1,11 @@
+
+
+
+{versions}
+
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js
new file mode 100644
index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868
--- /dev/null
+++ b/fish_speech/webui/js/animate.js
@@ -0,0 +1,69 @@
+
+function createGradioAnimation() {
+ const params = new URLSearchParams(window.location.search);
+ if (!params.has('__theme')) {
+ params.set('__theme', 'light');
+ window.location.search = params.toString();
+ }
+
+ var gradioApp = document.querySelector('gradio-app');
+ if (gradioApp) {
+
+ document.documentElement.style.setProperty('--my-200', '#80eeee');
+ document.documentElement.style.setProperty('--my-50', '#ecfdf5');
+
+ // gradioApp.style.position = 'relative';
+ // gradioApp.style.backgroundSize = '200% 200%';
+ // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
+ // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
+ // gradioApp.style.display = 'flex';
+ // gradioApp.style.justifyContent = 'flex-start';
+ // gradioApp.style.flexWrap = 'nowrap';
+ // gradioApp.style.overflowX = 'auto';
+
+ // for (let i = 0; i < 6; i++) {
+ // var quan = document.createElement('div');
+ // quan.className = 'quan';
+ // gradioApp.insertBefore(quan, gradioApp.firstChild);
+ // quan.id = 'quan' + i.toString();
+ // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
+ // var quanContainer = document.querySelector('.quan');
+ // if (quanContainer) {
+ // var shui = document.createElement('div');
+ // shui.className = 'shui';
+ // quanContainer.insertBefore(shui, quanContainer.firstChild)
+ // }
+ // }
+ }
+
+ var container = document.createElement('div');
+ container.id = 'gradio-animation';
+ container.style.fontSize = '2em';
+ container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
+ container.style.fontWeight = 'bold';
+ container.style.textAlign = 'center';
+ container.style.marginBottom = '20px';
+
+ var text = 'Welcome to Fish-Speech!';
+ for (var i = 0; i < text.length; i++) {
+ (function(i){
+ setTimeout(function(){
+ var letter = document.createElement('span');
+ letter.style.opacity = '0';
+ letter.style.transition = 'opacity 0.5s';
+ letter.innerText = text[i];
+
+ container.appendChild(letter);
+
+ setTimeout(function() {
+ letter.style.opacity = '1';
+ }, 50);
+ }, i * 200);
+ })(i);
+ }
+
+ var gradioContainer = document.querySelector('.gradio-container');
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
+
+ return 'Animation created';
+}
diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f57b595a20177800dbedd71faef573ee8398418
--- /dev/null
+++ b/fish_speech/webui/launch_utils.py
@@ -0,0 +1,120 @@
+import importlib.util
+import os
+import subprocess
+import sys
+from functools import lru_cache
+from pathlib import Path
+from typing import Iterable
+
+import gradio as gr
+from gradio.themes.base import Base
+from gradio.themes.utils import colors, fonts, sizes
+
+GIT = (
+ (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
+ if sys.platform == "win32"
+ else "git"
+)
+GIT = str(GIT)
+
+
+def is_module_installed(module_name: str) -> bool:
+ spec = importlib.util.find_spec(module_name)
+ return spec is not None
+
+
+@lru_cache()
+def commit_hash():
+ try:
+ return subprocess.check_output(
+ [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
+ ).strip()
+ except Exception:
+ return ""
+
+
+def versions_html():
+ import torch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = commit_hash()
+ hash = commit.strip("'").split(" ")[0]
+
+ return f"""
+version: {hash}
+ •
+python: {python_version}
+ •
+torch: {getattr(torch, '__long_version__',torch.__version__)}
+ •
+gradio: {gr.__version__}
+ •
+author: fishaudio
+"""
+
+
+def version_check(commit):
+ try:
+ import requests
+
+ commits = requests.get(
+ "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
+ ).json()
+ if commit != "" and commits["commit"]["sha"] != commit:
+ print("--------------------------------------------------------")
+ print("| You are not up to date with the most recent release. |")
+ print("| Consider running `git pull` to update. |")
+ print("--------------------------------------------------------")
+ elif commits["commit"]["sha"] == commit:
+ print("You are up to date with the most recent release.")
+ else:
+ print("Not a git clone, can't perform version check.")
+ except Exception as e:
+ print("version check failed", e)
+
+
+class Seafoam(Base):
+ def __init__(
+ self,
+ *,
+ primary_hue: colors.Color | str = colors.emerald,
+ secondary_hue: colors.Color | str = colors.blue,
+ neutral_hue: colors.Color | str = colors.blue,
+ spacing_size: sizes.Size | str = sizes.spacing_md,
+ radius_size: sizes.Size | str = sizes.radius_md,
+ text_size: sizes.Size | str = sizes.text_lg,
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("Quicksand"),
+ "ui-sans-serif",
+ "sans-serif",
+ ),
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
+ fonts.GoogleFont("IBM Plex Mono"),
+ "ui-monospace",
+ "monospace",
+ ),
+ ):
+ super().__init__(
+ primary_hue=primary_hue,
+ secondary_hue=secondary_hue,
+ neutral_hue=neutral_hue,
+ spacing_size=spacing_size,
+ radius_size=radius_size,
+ text_size=text_size,
+ font=font,
+ font_mono=font_mono,
+ )
+ super().set(
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
+ button_primary_text_color="white",
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
+ slider_color="*secondary_300",
+ slider_color_dark="*secondary_600",
+ block_title_text_weight="600",
+ block_border_width="3px",
+ block_shadow="*shadow_drop_lg",
+ button_shadow="*shadow_drop_lg",
+ button_small_padding="0px",
+ button_large_padding="3px",
+ )
diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py
new file mode 100644
index 0000000000000000000000000000000000000000..d66e72f714181b7c4ab46cdcaaf0d64a209cbfdb
--- /dev/null
+++ b/fish_speech/webui/manage.py
@@ -0,0 +1,1230 @@
+from __future__ import annotations
+
+import datetime
+import html
+import json
+import os
+import platform
+import shutil
+import signal
+import subprocess
+import sys
+from pathlib import Path
+
+import gradio as gr
+import psutil
+import yaml
+from loguru import logger
+from tqdm import tqdm
+
+PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
+sys.path.insert(0, "")
+print(sys.path)
+cur_work_dir = Path(os.getcwd()).resolve()
+print("You are in ", str(cur_work_dir))
+
+from fish_speech.i18n import i18n
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
+
+config_path = cur_work_dir / "fish_speech" / "configs"
+vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
+llama_yml_path = config_path / "text2semantic_finetune.yaml"
+
+env = os.environ.copy()
+env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
+
+seafoam = Seafoam()
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(error)}
+
+ """
+
+
+def build_html_ok_message(msg):
+ return f"""
+
+ {html.escape(msg)}
+
+ """
+
+
+def build_html_href(link, desc, msg):
+ return f"""
+
+ {html.escape(msg)}
+ {desc}
+
+ """
+
+
+def load_data_in_raw(path):
+ with open(path, "r", encoding="utf-8") as file:
+ data = file.read()
+ return str(data)
+
+
+def kill_proc_tree(pid, including_parent=True):
+ try:
+ parent = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ # Process already terminated
+ return
+
+ children = parent.children(recursive=True)
+ for child in children:
+ try:
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+ if including_parent:
+ try:
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
+ except OSError:
+ pass
+
+
+system = platform.system()
+p_label = None
+p_infer = None
+p_tensorboard = None
+
+
+def kill_process(pid):
+ if system == "Windows":
+ cmd = "taskkill /t /f /pid %s" % pid
+ # os.system(cmd)
+ subprocess.run(cmd)
+ else:
+ kill_proc_tree(pid)
+
+
+def change_label(if_label):
+ global p_label
+ if if_label == True and p_label is None:
+ url = "http://localhost:3000"
+ remote_url = "https://text-labeler.pages.dev/"
+ try:
+ p_label = subprocess.Popen(
+ [
+ (
+ "asr-label-linux-x64"
+ if sys.platform == "linux"
+ else "asr-label-win-x64.exe"
+ )
+ ]
+ )
+ except FileNotFoundError:
+ logger.warning("asr-label execution not found!")
+
+ yield build_html_href(
+ link=remote_url,
+ desc=i18n("Optional online ver"),
+ msg=i18n("Opened labeler in browser"),
+ )
+
+ elif if_label == False and p_label is not None:
+ kill_process(p_label.pid)
+ p_label = None
+ yield build_html_ok_message("Nothing")
+
+
+def clean_infer_cache():
+ import tempfile
+
+ temp_dir = Path(tempfile.gettempdir())
+ gradio_dir = str(temp_dir / "gradio")
+ try:
+ shutil.rmtree(gradio_dir)
+ logger.info(f"Deleted cached audios: {gradio_dir}")
+ except PermissionError:
+ logger.info(f"Permission denied: Unable to delete {gradio_dir}")
+ except FileNotFoundError:
+ logger.info(f"{gradio_dir} was not found")
+ except Exception as e:
+ logger.info(f"An error occurred: {e}")
+
+
+def change_infer(
+ if_infer,
+ host,
+ port,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+):
+ global p_infer
+ if if_infer == True and p_infer == None:
+ env = os.environ.copy()
+
+ env["GRADIO_SERVER_NAME"] = host
+ env["GRADIO_SERVER_PORT"] = port
+ # 启动第二个进程
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Inferring interface is launched at {}").format(url)
+ )
+
+ clean_infer_cache()
+
+ p_infer = subprocess.Popen(
+ [
+ PYTHON,
+ "tools/webui.py",
+ "--decoder-checkpoint-path",
+ infer_decoder_model,
+ "--decoder-config-name",
+ infer_decoder_config,
+ "--llama-checkpoint-path",
+ infer_llama_model,
+ ]
+ + (["--compile"] if infer_compile == "Yes" else []),
+ env=env,
+ )
+
+ elif if_infer == False and p_infer is not None:
+ kill_process(p_infer.pid)
+ p_infer = None
+ yield build_html_error_message(i18n("Infer interface is closed"))
+
+
+js = load_data_in_raw("fish_speech/webui/js/animate.js")
+css = load_data_in_raw("fish_speech/webui/css/style.css")
+
+data_pre_output = (cur_work_dir / "data").resolve()
+default_model_output = (cur_work_dir / "results").resolve()
+default_filelist = data_pre_output / "detect.list"
+data_pre_output.mkdir(parents=True, exist_ok=True)
+
+items = []
+dict_items = {}
+
+
+def load_yaml_data_in_fact(yml_path):
+ with open(yml_path, "r", encoding="utf-8") as file:
+ yml = yaml.safe_load(file)
+ return yml
+
+
+def write_yaml_data_in_fact(yml, yml_path):
+ with open(yml_path, "w", encoding="utf-8") as file:
+ yaml.safe_dump(yml, file, allow_unicode=True)
+ return yml
+
+
+def generate_tree(directory, depth=0, max_depth=None, prefix=""):
+ if max_depth is not None and depth > max_depth:
+ return ""
+
+ tree_str = ""
+ files = []
+ directories = []
+ for item in os.listdir(directory):
+ if os.path.isdir(os.path.join(directory, item)):
+ directories.append(item)
+ else:
+ files.append(item)
+
+ entries = directories + files
+ for i, entry in enumerate(entries):
+ connector = "├── " if i < len(entries) - 1 else "└── "
+ tree_str += f"{prefix}{connector}{entry} "
+ if i < len(directories):
+ extension = "│ " if i < len(entries) - 1 else " "
+ tree_str += generate_tree(
+ os.path.join(directory, entry),
+ depth + 1,
+ max_depth,
+ prefix=prefix + extension,
+ )
+ return tree_str
+
+
+def new_explorer(data_path, max_depth):
+ return gr.Markdown(
+ elem_classes=["scrollable-component"],
+ value=generate_tree(data_path, max_depth=max_depth),
+ )
+
+
+def add_item(
+ folder: str,
+ method: str,
+ label_lang: str,
+ if_initial_prompt: bool,
+ initial_prompt: str | None,
+):
+ folder = folder.strip(" ").strip('"')
+
+ folder_path = Path(folder)
+
+ if folder and folder not in items and data_pre_output not in folder_path.parents:
+ if folder_path.is_dir():
+ items.append(folder)
+ dict_items[folder] = dict(
+ type="folder",
+ method=method,
+ label_lang=label_lang,
+ initial_prompt=initial_prompt if if_initial_prompt else None,
+ )
+ elif folder:
+ err = folder
+ return gr.Checkboxgroup(choices=items), build_html_error_message(
+ i18n("Invalid path: {}").format(err)
+ )
+
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info("After Adding: " + formatted_data)
+ gr.Info(formatted_data)
+ return gr.Checkboxgroup(choices=items), build_html_ok_message(
+ i18n("Added path successfully!")
+ )
+
+
+def remove_items(selected_items):
+ global items, dict_items
+ to_remove = [item for item in items if item in selected_items]
+ for item in to_remove:
+ del dict_items[item]
+ items = [item for item in items if item in dict_items.keys()]
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+ logger.info(formatted_data)
+ gr.Warning("After Removing: " + formatted_data)
+ return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
+ i18n("Removed path successfully!")
+ )
+
+
+def show_selected(options):
+ selected_options = ", ".join(options)
+
+ if options:
+ return i18n("Selected: {}").format(selected_options)
+ else:
+ return i18n("No selected options")
+
+
+from pydub import AudioSegment
+
+
+def convert_to_mono_in_place(audio_path: Path):
+ audio = AudioSegment.from_file(audio_path)
+ if audio.channels > 1:
+ mono_audio = audio.set_channels(1)
+ mono_audio.export(audio_path, format=audio_path.suffix[1:])
+ logger.info(f"Convert {audio_path} successfully")
+
+
+def list_copy(list_file_path, method):
+ wav_root = data_pre_output
+ lst = []
+ with list_file_path.open("r", encoding="utf-8") as file:
+ for line in tqdm(file, desc="Processing audio/transcript"):
+ wav_path, speaker_name, language, text = line.strip().split("|")
+ original_wav_path = Path(wav_path)
+ target_wav_path = (
+ wav_root / original_wav_path.parent.name / original_wav_path.name
+ )
+ lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
+ if target_wav_path.is_file():
+ continue
+ target_wav_path.parent.mkdir(parents=True, exist_ok=True)
+ if method == i18n("Copy"):
+ shutil.copy(original_wav_path, target_wav_path)
+ else:
+ shutil.move(original_wav_path, target_wav_path.parent)
+ convert_to_mono_in_place(target_wav_path)
+ original_lab_path = original_wav_path.with_suffix(".lab")
+ target_lab_path = (
+ wav_root
+ / original_wav_path.parent.name
+ / original_wav_path.with_suffix(".lab").name
+ )
+ if target_lab_path.is_file():
+ continue
+ if method == i18n("Copy"):
+ shutil.copy(original_lab_path, target_lab_path)
+ else:
+ shutil.move(original_lab_path, target_lab_path.parent)
+
+ if method == i18n("Move"):
+ with list_file_path.open("w", encoding="utf-8") as file:
+ file.writelines("\n".join(lst))
+
+ del lst
+ return build_html_ok_message(i18n("Use filelist"))
+
+
+def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
+ global dict_items
+ data_path = Path(data_path)
+ gr.Warning("Pre-processing begins...")
+ for item, content in dict_items.items():
+ item_path = Path(item)
+ tar_path = data_path / item_path.name
+
+ if content["type"] == "folder" and item_path.is_dir():
+ if content["method"] == i18n("Copy"):
+ os.makedirs(tar_path, exist_ok=True)
+ shutil.copytree(
+ src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
+ )
+ elif not tar_path.is_dir():
+ shutil.move(src=str(item_path), dst=str(tar_path))
+
+ for suf in ["wav", "flac", "mp3"]:
+ for audio_path in tar_path.glob(f"**/*.{suf}"):
+ convert_to_mono_in_place(audio_path)
+
+ cur_lang = content["label_lang"]
+ initial_prompt = content["initial_prompt"]
+
+ transcribe_cmd = [
+ PYTHON,
+ "tools/whisper_asr.py",
+ "--model-size",
+ label_model,
+ "--device",
+ label_device,
+ "--audio-dir",
+ tar_path,
+ "--save-dir",
+ tar_path,
+ "--language",
+ cur_lang,
+ ]
+
+ if initial_prompt is not None:
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
+
+ if cur_lang != "IGNORE":
+ try:
+ gr.Warning("Begin To Transcribe")
+ subprocess.run(
+ transcribe_cmd,
+ env=env,
+ )
+ except Exception:
+ print("Transcription error occurred")
+
+ elif content["type"] == "file" and item_path.is_file():
+ list_copy(item_path, content["method"])
+
+ return build_html_ok_message(i18n("Move files successfully")), new_explorer(
+ data_path, max_depth=max_depth
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+def train_process(
+ data_path: str,
+ option: str,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr,
+ llama_maxsteps,
+ llama_data_num_workers,
+ llama_data_batch_size,
+ llama_data_max_length,
+ llama_precision,
+ llama_check_interval,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+):
+
+ backend = "nccl" if sys.platform == "linux" else "gloo"
+
+ new_project = generate_folder_name()
+ print("New Project Name: ", new_project)
+
+ if option == "VQGAN":
+ msg = "Skipped VQGAN Training."
+ gr.Warning(msg)
+ logger.info(msg)
+
+ if option == "LLAMA":
+ msg = "LLAMA Training begins..."
+ gr.Warning(msg)
+ logger.info(msg)
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/vqgan/extract_vq.py",
+ str(data_pre_output),
+ "--num-workers",
+ "1",
+ "--batch-size",
+ "16",
+ "--config-name",
+ "firefly_gan_vq",
+ "--checkpoint-path",
+ "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+ ]
+ )
+
+ subprocess.run(
+ [
+ PYTHON,
+ "tools/llama/build_dataset.py",
+ "--input",
+ str(data_pre_output),
+ "--text-extension",
+ ".lab",
+ "--num-workers",
+ "16",
+ ]
+ )
+ ckpt_path = "checkpoints/fish-speech-1.2-sft/model.pth"
+ lora_prefix = "lora_" if llama_use_lora else ""
+ llama_name = lora_prefix + "text2semantic_" + new_project
+ latest = next(
+ iter(
+ sorted(
+ [
+ str(p.relative_to("results"))
+ for p in Path("results").glob(lora_prefix + "text2sem*/")
+ ],
+ reverse=True,
+ )
+ ),
+ llama_name,
+ )
+ project = (
+ llama_name
+ if llama_ckpt == i18n("new")
+ else (
+ latest
+ if llama_ckpt == i18n("latest")
+ else Path(llama_ckpt).relative_to("results")
+ )
+ )
+ logger.info(project)
+ train_cmd = [
+ PYTHON,
+ "fish_speech/train.py",
+ "--config-name",
+ "text2semantic_finetune",
+ f"project={project}",
+ f"trainer.strategy.process_group_backend={backend}",
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+ f"model.optimizer.lr={llama_lr}",
+ f"trainer.max_steps={llama_maxsteps}",
+ f"data.num_workers={llama_data_num_workers}",
+ f"data.batch_size={llama_data_batch_size}",
+ f"max_length={llama_data_max_length}",
+ f"trainer.precision={llama_precision}",
+ f"trainer.val_check_interval={llama_check_interval}",
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
+ f"train_dataset.interactive_prob={llama_use_speaker}",
+ ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
+ logger.info(train_cmd)
+ subprocess.run(train_cmd)
+
+ return build_html_ok_message(i18n("Training stopped"))
+
+
+def tensorboard_process(
+ if_tensorboard: bool,
+ tensorboard_dir: str,
+ host: str,
+ port: str,
+):
+ global p_tensorboard
+ if if_tensorboard == True and p_tensorboard == None:
+ url = f"http://{host}:{port}"
+ yield build_html_ok_message(
+ i18n("Tensorboard interface is launched at {}").format(url)
+ )
+ prefix = ["tensorboard"]
+ if Path("fishenv").exists():
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
+
+ p_tensorboard = subprocess.Popen(
+ prefix
+ + [
+ "--logdir",
+ tensorboard_dir,
+ "--host",
+ host,
+ "--port",
+ port,
+ "--reload_interval",
+ "120",
+ ]
+ )
+ elif if_tensorboard == False and p_tensorboard != None:
+ kill_process(p_tensorboard.pid)
+ p_tensorboard = None
+ yield build_html_error_message(i18n("Tensorboard interface is closed"))
+
+
+def fresh_tb_dir():
+ return gr.Dropdown(
+ choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
+ )
+
+
+def list_decoder_models():
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
+ if not paths:
+ logger.warning("No decoder model found")
+ return paths
+
+
+def list_llama_models():
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
+ choices = sorted(choices, reverse=True)
+ if not choices:
+ logger.warning("No LLaMA model found")
+ return choices
+
+
+def list_lora_llama_models():
+ choices = sorted(
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
+ )
+ if not choices:
+ logger.warning("No LoRA LLaMA model found")
+ return choices
+
+
+def fresh_decoder_model():
+ return gr.Dropdown(choices=list_decoder_models())
+
+
+def fresh_llama_ckpt(llama_use_lora):
+ return gr.Dropdown(
+ choices=[i18n("latest"), i18n("new")]
+ + (
+ [str(p) for p in Path("results").glob("text2sem*/")]
+ if not llama_use_lora
+ else [str(p) for p in Path("results").glob("lora_*/")]
+ )
+ )
+
+
+def fresh_llama_model():
+ return gr.Dropdown(choices=list_llama_models())
+
+
+def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
+ if (
+ lora_weight is None
+ or not Path(lora_weight).exists()
+ or not Path(llama_weight).exists()
+ ):
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+ gr.Warning("Merging begins...")
+ merge_cmd = [
+ PYTHON,
+ "tools/llama/merge_lora.py",
+ "--lora-config",
+ "r_8_alpha_16",
+ "--lora-weight",
+ lora_weight,
+ "--output",
+ llama_lora_output + "_" + generate_folder_name(),
+ ]
+ logger.info(merge_cmd)
+ subprocess.run(merge_cmd)
+ return build_html_ok_message(i18n("Merge successfully"))
+
+
+def llama_quantify(llama_weight, quantify_mode):
+ if llama_weight is None or not Path(llama_weight).exists():
+ return build_html_error_message(
+ i18n(
+ "Path error, please check the model file exists in the corresponding path"
+ )
+ )
+
+ gr.Warning("Quantifying begins...")
+
+ now = generate_folder_name()
+ quantify_cmd = [
+ PYTHON,
+ "tools/llama/quantize.py",
+ "--checkpoint-path",
+ llama_weight,
+ "--mode",
+ quantify_mode,
+ "--timestamp",
+ now,
+ ]
+ logger.info(quantify_cmd)
+ subprocess.run(quantify_cmd)
+ if quantify_mode == "int8":
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
+ )
+ else:
+ quantize_path = str(
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
+ )
+ return build_html_ok_message(
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
+ )
+
+
+init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
+init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
+
+with gr.Blocks(
+ head="",
+ js=js,
+ theme=seafoam,
+ analytics_enabled=False,
+ title="Fish Speech",
+) as demo:
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
+ with gr.Row():
+ textbox = gr.Textbox(
+ label="\U0000270F "
+ + i18n("Input Audio & Source Path for Transcription"),
+ info=i18n("Speaker is identified by the folder name"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ output_radio = gr.Radio(
+ label="\U0001F4C1 "
+ + i18n("Select source file processing method"),
+ choices=[i18n("Copy"), i18n("Move")],
+ value=i18n("Copy"),
+ interactive=True,
+ )
+ with gr.Column():
+ error = gr.HTML(label=i18n("Error Message"))
+ if_label = gr.Checkbox(
+ label=i18n("Open Labeler WebUI"), scale=0, show_label=True
+ )
+
+ with gr.Row():
+ label_device = gr.Dropdown(
+ label=i18n("Labeling Device"),
+ info=i18n(
+ "It is recommended to use CUDA, if you have low configuration, use CPU"
+ ),
+ choices=["cpu", "cuda"],
+ value="cuda",
+ interactive=True,
+ )
+ label_model = gr.Dropdown(
+ label=i18n("Whisper Model"),
+ info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
+ choices=["large-v3", "medium"],
+ value="large-v3",
+ interactive=True,
+ )
+ label_radio = gr.Dropdown(
+ label=i18n("Optional Label Language"),
+ info=i18n(
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
+ ),
+ choices=[
+ (i18n("Chinese"), "zh"),
+ (i18n("English"), "en"),
+ (i18n("Japanese"), "ja"),
+ (i18n("Disabled"), "IGNORE"),
+ (i18n("auto"), "auto"),
+ ],
+ value="IGNORE",
+ interactive=True,
+ )
+
+ with gr.Row():
+ if_initial_prompt = gr.Checkbox(
+ value=False,
+ label=i18n("Enable Initial Prompt"),
+ min_width=120,
+ scale=0,
+ )
+ initial_prompt = gr.Textbox(
+ label=i18n("Initial Prompt"),
+ info=i18n(
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
+ ),
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
+ interactive=False,
+ )
+
+ with gr.Row():
+ add_button = gr.Button(
+ "\U000027A1 " + i18n("Add to Processing Area"),
+ variant="primary",
+ )
+ remove_button = gr.Button(
+ "\U000026D4 " + i18n("Remove Selected Data")
+ )
+
+ with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
+ with gr.Row():
+ model_type_radio = gr.Radio(
+ label=i18n(
+ "Select the model to be trained (Depending on the Tab page you are on)"
+ ),
+ interactive=False,
+ choices=["VQGAN", "LLAMA"],
+ value="VQGAN",
+ )
+ with gr.Row():
+ with gr.Tabs():
+ with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
+ gr.HTML("You don't need to train this model!")
+
+ with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
+ with gr.Row(equal_height=False):
+ llama_use_lora = gr.Checkbox(
+ label=i18n("Use LoRA"),
+ info=i18n(
+ "Use LoRA can save GPU memory, but may reduce the quality of the model"
+ ),
+ value=True,
+ interactive=False,
+ )
+ llama_ckpt = gr.Dropdown(
+ label=i18n("Select LLAMA ckpt"),
+ choices=[i18n("latest"), i18n("new")]
+ + [
+ str(p)
+ for p in Path("results").glob("text2sem*/")
+ ]
+ + [str(p) for p in Path("results").glob("lora*/")],
+ value=i18n("latest"),
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lr_slider = gr.Slider(
+ label=i18n("Initial Learning Rate"),
+ interactive=True,
+ minimum=1e-5,
+ maximum=1e-4,
+ step=1e-5,
+ value=init_llama_yml["model"]["optimizer"]["lr"],
+ )
+ llama_maxsteps_slider = gr.Slider(
+ label=i18n("Maximum Training Steps"),
+ interactive=True,
+ minimum=50,
+ maximum=10000,
+ step=50,
+ value=init_llama_yml["trainer"]["max_steps"],
+ )
+ with gr.Row(equal_height=False):
+ llama_base_config = gr.Dropdown(
+ label=i18n("Model Size"),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ )
+ llama_data_num_workers_slider = gr.Slider(
+ label=i18n("Number of Workers"),
+ minimum=1,
+ maximum=16,
+ step=1,
+ value=(
+ init_llama_yml["data"]["num_workers"]
+ if sys.platform == "linux"
+ else 1
+ ),
+ )
+ with gr.Row(equal_height=False):
+ llama_data_batch_size_slider = gr.Slider(
+ label=i18n("Batch Size"),
+ interactive=True,
+ minimum=1,
+ maximum=32,
+ step=1,
+ value=init_llama_yml["data"]["batch_size"],
+ )
+ llama_data_max_length_slider = gr.Slider(
+ label=i18n("Maximum Length per Sample"),
+ interactive=True,
+ minimum=1024,
+ maximum=4096,
+ step=128,
+ value=init_llama_yml["max_length"],
+ )
+ with gr.Row(equal_height=False):
+ llama_precision_dropdown = gr.Dropdown(
+ label=i18n("Precision"),
+ info=i18n(
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
+ ),
+ interactive=True,
+ choices=["32", "bf16-true", "16-mixed"],
+ value="bf16-true",
+ )
+ llama_check_interval_slider = gr.Slider(
+ label=i18n("Save model every n steps"),
+ interactive=True,
+ minimum=50,
+ maximum=1000,
+ step=50,
+ value=init_llama_yml["trainer"][
+ "val_check_interval"
+ ],
+ )
+ with gr.Row(equal_height=False):
+ llama_grad_batches = gr.Slider(
+ label=i18n("Accumulate Gradient Batches"),
+ interactive=True,
+ minimum=1,
+ maximum=20,
+ step=1,
+ value=init_llama_yml["trainer"][
+ "accumulate_grad_batches"
+ ],
+ )
+ llama_use_speaker = gr.Slider(
+ label=i18n(
+ "Probability of applying Speaker Condition"
+ ),
+ interactive=True,
+ minimum=0.1,
+ maximum=1.0,
+ step=0.05,
+ value=init_llama_yml["train_dataset"][
+ "interactive_prob"
+ ],
+ )
+
+ with gr.Tab(label=i18n("Merge LoRA"), id=4):
+ with gr.Row(equal_height=False):
+ llama_weight = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "checkpoints/fish-speech-1.2-sft/model.pth",
+ ],
+ value="checkpoints/fish-speech-1.2-sft/model.pth",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ lora_weight = gr.Dropdown(
+ label=i18n("LoRA Model to be merged"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ str(p)
+ for p in Path("results").glob("lora*/**/*.ckpt")
+ ],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ lora_llama_config = gr.Dropdown(
+ label=i18n("LLAMA Model Config"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=[
+ "text2semantic_finetune",
+ ],
+ value="text2semantic_finetune",
+ allow_custom_value=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_output = gr.Dropdown(
+ label=i18n("Output Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/merged",
+ choices=["checkpoints/merged"],
+ allow_custom_value=True,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_lora_merge_btn = gr.Button(
+ value=i18n("Merge"), variant="primary"
+ )
+
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
+ with gr.Row(equal_height=False):
+ llama_weight_to_quantify = gr.Dropdown(
+ label=i18n("Base LLAMA Model"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_llama_models(),
+ value="checkpoints/fish-speech-1.2-sft",
+ allow_custom_value=True,
+ interactive=True,
+ )
+ quantify_mode = gr.Dropdown(
+ label=i18n("Post-quantification Precision"),
+ info=i18n(
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
+ ),
+ choices=["int8", "int4"],
+ value="int8",
+ allow_custom_value=False,
+ interactive=True,
+ )
+ with gr.Row(equal_height=False):
+ llama_quantify_btn = gr.Button(
+ value=i18n("Quantify"), variant="primary"
+ )
+
+ with gr.Tab(label="Tensorboard", id=6):
+ with gr.Row(equal_height=False):
+ tb_host = gr.Textbox(
+ label=i18n("Tensorboard Host"), value="127.0.0.1"
+ )
+ tb_port = gr.Textbox(
+ label=i18n("Tensorboard Port"), value="11451"
+ )
+ with gr.Row(equal_height=False):
+ tb_dir = gr.Dropdown(
+ label=i18n("Tensorboard Log Path"),
+ allow_custom_value=True,
+ choices=[
+ str(p)
+ for p in Path("results").glob("**/tensorboard/")
+ ],
+ )
+ with gr.Row(equal_height=False):
+ if_tb = gr.Checkbox(
+ label=i18n("Open Tensorboard"),
+ )
+
+ with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
+ with gr.Column():
+ with gr.Row():
+ with gr.Accordion(
+ label="\U0001F5A5 "
+ + i18n("Inference Server Configuration"),
+ open=False,
+ ):
+ with gr.Row():
+ infer_host_textbox = gr.Textbox(
+ label=i18n("WebUI Host"), value="127.0.0.1"
+ )
+ infer_port_textbox = gr.Textbox(
+ label=i18n("WebUI Port"), value="7862"
+ )
+ with gr.Row():
+ infer_decoder_model = gr.Dropdown(
+ label=i18n("Decoder Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ choices=list_decoder_models(),
+ value="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+ allow_custom_value=True,
+ )
+ infer_decoder_config = gr.Dropdown(
+ label=i18n("Decoder Model Config"),
+ info=i18n("Changing with the Model Path"),
+ value="firefly_gan_vq",
+ choices=[
+ "firefly_gan_vq",
+ ],
+ allow_custom_value=True,
+ )
+ with gr.Row():
+ infer_llama_model = gr.Dropdown(
+ label=i18n("LLAMA Model Path"),
+ info=i18n(
+ "Type the path or select from the dropdown"
+ ),
+ value="checkpoints/fish-speech-1.2-sft",
+ choices=list_llama_models(),
+ allow_custom_value=True,
+ )
+
+ with gr.Row():
+ infer_compile = gr.Radio(
+ label=i18n("Compile Model"),
+ info=i18n(
+ "Compile the model can significantly reduce the inference time, but will increase cold start time"
+ ),
+ choices=["Yes", "No"],
+ value=(
+ "Yes" if (sys.platform == "linux") else "No"
+ ),
+ interactive=is_module_installed("triton"),
+ )
+
+ with gr.Row():
+ infer_checkbox = gr.Checkbox(
+ label=i18n("Open Inference Server")
+ )
+ infer_error = gr.HTML(label=i18n("Inference Server Error"))
+
+ with gr.Column():
+ train_error = gr.HTML(label=i18n("Training Error"))
+ checkbox_group = gr.CheckboxGroup(
+ label="\U0001F4CA " + i18n("Data Source"),
+ info=i18n(
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
+ ),
+ elem_classes=["data_src"],
+ )
+ train_box = gr.Textbox(
+ label=i18n("Data Preprocessing Path"),
+ value=str(data_pre_output),
+ interactive=False,
+ )
+ model_box = gr.Textbox(
+ label="\U0001F4BE " + i18n("Model Output Path"),
+ value=str(default_model_output),
+ interactive=False,
+ )
+
+ with gr.Accordion(
+ i18n(
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
+ ),
+ elem_classes=["scrollable-component"],
+ elem_id="file_accordion",
+ ):
+ tree_slider = gr.Slider(
+ minimum=0,
+ maximum=3,
+ value=0,
+ step=1,
+ show_label=False,
+ container=False,
+ )
+ file_markdown = new_explorer(str(data_pre_output), 0)
+ with gr.Row(equal_height=False):
+ admit_btn = gr.Button(
+ "\U00002705 " + i18n("File Preprocessing"),
+ variant="primary",
+ )
+ fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
+ help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
+ train_btn = gr.Button(i18n("Start Training"), variant="primary")
+
+ footer = load_data_in_raw("fish_speech/webui/html/footer.html")
+ footer = footer.format(
+ versions=versions_html(),
+ api_docs="https://speech.fish.audio/inference/#http-api",
+ )
+ gr.HTML(footer, elem_id="footer")
+ vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
+ llama_page.select(lambda: "LLAMA", None, model_type_radio)
+ add_button.click(
+ fn=add_item,
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
+ outputs=[checkbox_group, error],
+ )
+ remove_button.click(
+ fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
+ )
+ checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
+ help_button.click(
+ fn=None,
+ js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
+ 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
+ )
+ if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
+ if_initial_prompt.change(
+ fn=lambda x: gr.Textbox(value="", interactive=x),
+ inputs=[if_initial_prompt],
+ outputs=[initial_prompt],
+ )
+ train_btn.click(
+ fn=train_process,
+ inputs=[
+ train_box,
+ model_type_radio,
+ # llama config
+ llama_ckpt,
+ llama_base_config,
+ llama_lr_slider,
+ llama_maxsteps_slider,
+ llama_data_num_workers_slider,
+ llama_data_batch_size_slider,
+ llama_data_max_length_slider,
+ llama_precision_dropdown,
+ llama_check_interval_slider,
+ llama_grad_batches,
+ llama_use_speaker,
+ llama_use_lora,
+ ],
+ outputs=[train_error],
+ )
+ if_tb.change(
+ fn=tensorboard_process,
+ inputs=[if_tb, tb_dir, tb_host, tb_port],
+ outputs=[train_error],
+ )
+ tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
+ infer_decoder_model.change(
+ fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
+ )
+ infer_llama_model.change(
+ fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
+ )
+ llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
+ admit_btn.click(
+ fn=check_files,
+ inputs=[train_box, tree_slider, label_model, label_device],
+ outputs=[error, file_markdown],
+ )
+ fresh_btn.click(
+ fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
+ )
+ llama_use_lora.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ llama_ckpt.change(
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
+ )
+ lora_weight.change(
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
+ inputs=[],
+ outputs=[lora_weight],
+ )
+ llama_lora_merge_btn.click(
+ fn=llama_lora_merge,
+ inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
+ outputs=[train_error],
+ )
+ llama_quantify_btn.click(
+ fn=llama_quantify,
+ inputs=[llama_weight_to_quantify, quantify_mode],
+ outputs=[train_error],
+ )
+ infer_checkbox.change(
+ fn=change_infer,
+ inputs=[
+ infer_checkbox,
+ infer_host_textbox,
+ infer_port_textbox,
+ infer_decoder_model,
+ infer_decoder_config,
+ infer_llama_model,
+ infer_compile,
+ ],
+ outputs=[infer_error],
+ )
+
+demo.launch(inbrowser=True)
diff --git a/inference.ipynb b/inference.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..fa396ead4b59709878a9adce396d5bd06e1b7841
--- /dev/null
+++ b/inference.ipynb
@@ -0,0 +1,210 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fish Speech"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### For Windows User / win用户"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "bat"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!chcp 65001"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### For Linux User / Linux 用户"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import locale\n",
+ "locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Prepare Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For Chinese users, you probably want to use mirror to accelerate downloading\n",
+ "# !set HF_ENDPOINT=https://hf-mirror.com\n",
+ "# !export HF_ENDPOINT=https://hf-mirror.com \n",
+ "\n",
+ "!huggingface-cli download fishaudio/fish-speech-1.2-sft --local-dir checkpoints/fish-speech-1.2-sft/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## WebUI Inference\n",
+ "\n",
+ "> You can use --compile to fuse CUDA kernels for faster inference (10x)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!python tools/webui.py \\\n",
+ " --llama-checkpoint-path checkpoints/fish-speech-1.2-sft \\\n",
+ " --decoder-checkpoint-path checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth \\\n",
+ " # --compile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Break-down CLI Inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. Encode reference audio: / 从语音生成 prompt: \n",
+ "\n",
+ "You should get a `fake.npy` file.\n",
+ "\n",
+ "你应该能得到一个 `fake.npy` 文件."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "## Enter the path to the audio file here\n",
+ "src_audio = r\"D:\\PythonProject\\\\vo_hutao_draw_appear.wav\"\n",
+ "\n",
+ "!python tools/vqgan/inference.py \\\n",
+ " -i {src_audio} \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth\"\n",
+ "\n",
+ "from IPython.display import Audio, display\n",
+ "audio = Audio(filename=\"fake.wav\")\n",
+ "display(audio)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. Generate semantic tokens from text: / 从文本生成语义 token:\n",
+ "\n",
+ "> This command will create a codes_N file in the working directory, where N is an integer starting from 0.\n",
+ "\n",
+ "> You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~300 tokens/second).\n",
+ "\n",
+ "> 该命令会在工作目录下创建 codes_N 文件, 其中 N 是从 0 开始的整数.\n",
+ "\n",
+ "> 您可以使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 tokens/秒 -> ~300 tokens/秒)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/llama/generate.py \\\n",
+ " --text \"hello world\" \\\n",
+ " --prompt-text \"The text corresponding to reference audio\" \\\n",
+ " --prompt-tokens \"fake.npy\" \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.2-sft\" \\\n",
+ " --num-samples 2\n",
+ " # --compile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. Generate speecj from semantic tokens: / 从语义 token 生成人声:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!python tools/vqgan/inference.py \\\n",
+ " -i \"codes_0.npy\" \\\n",
+ " --checkpoint-path \"checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth\"\n",
+ "\n",
+ "from IPython.display import Audio, display\n",
+ "audio = Audio(filename=\"fake.wav\")\n",
+ "display(audio)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/install_env.bat b/install_env.bat
new file mode 100644
index 0000000000000000000000000000000000000000..f59257991b6309aafea3439a82e2f3d4768077c5
--- /dev/null
+++ b/install_env.bat
@@ -0,0 +1,284 @@
+@echo off
+chcp 65001
+
+set USE_MIRROR=true
+set INSTALL_TYPE=preview
+echo "USE_MIRROR: %USE_MIRROR%"
+echo "INSTALL_TYPE: %INSTALL_TYPE%"
+setlocal enabledelayedexpansion
+
+cd /D "%~dp0"
+
+set PATH="%PATH%";%SystemRoot%\system32
+
+echo %PATH%
+
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+
+set TMP=%CD%\fishenv
+set TEMP=%CD%\fishenv
+
+(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
+
+set INSTALL_DIR=%cd%\fishenv
+set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
+set INSTALL_ENV_DIR=%cd%\fishenv\env
+set PIP_CMD=%cd%\fishenv\env\python -m pip
+set PYTHON_CMD=%cd%\fishenv\env\python
+set API_FLAG_PATH=%~dp0API_FLAGS.txt
+set MINICONDA_DOWNLOAD_URL=https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py310_23.3.1-0-Windows-x86_64.exe
+set MINICONDA_CHECKSUM=307194e1f12bbeb52b083634e89cc67db4f7980bd542254b43d3309eaf7cb358
+set conda_exists=F
+
+call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1
+if "%ERRORLEVEL%" EQU "0" set conda_exists=T
+
+if "%conda_exists%" == "F" (
+ echo.
+ echo Downloading Miniconda...
+ mkdir "%INSTALL_DIR%" 2>nul
+ call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe"
+ if errorlevel 1 (
+ echo.
+ echo Failed to download miniconda.
+ goto end
+ )
+ for /f %%a in ('
+ certutil -hashfile "%INSTALL_DIR%\miniconda_installer.exe" sha256
+ ^| find /i /v " "
+ ^| find /i "%MINICONDA_CHECKSUM%"
+ ') do (
+ set "hash=%%a"
+ )
+ if not defined hash (
+ echo.
+ echo Miniconda hash mismatched!
+ del "%INSTALL_DIR%\miniconda_installer.exe"
+ goto end
+ ) else (
+ echo.
+ echo Miniconda hash matched successfully.
+ )
+ echo Downloaded "%CONDA_ROOT_PREFIX%"
+ start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX%
+
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" --version
+ if errorlevel 1 (
+ echo.
+ echo Cannot install Miniconda.
+ goto end
+ ) else (
+ echo.
+ echo Miniconda Install success.
+ )
+
+ del "%INSTALL_DIR%\miniconda_installer.exe"
+)
+
+
+if not exist "%INSTALL_ENV_DIR%" (
+ echo.
+ echo Creating Conda Environment...
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ python=3.10
+
+ if errorlevel 1 (
+ echo.
+ echo Failed to Create Environment.
+ goto end
+ )
+)
+
+if not exist "%INSTALL_ENV_DIR%\python.exe" (
+ echo.
+ echo Conda Env does not exist.
+ goto end
+)
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=
+set PYTHONHOME=
+set "CUDA_PATH=%INSTALL_ENV_DIR%"
+set "CUDA_HOME=%CUDA_PATH%"
+
+call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
+
+if errorlevel 1 (
+ echo.
+ echo Failed to activate Env.
+ goto end
+) else (
+ echo.
+ echo successfully create env.
+)
+
+
+set "packages=torch torchvision torchaudio openai-whisper fish-speech"
+
+if "!INSTALL_TYPE!" == "preview" (
+ set "packages=!packages! triton_windows"
+)
+
+set "HF_ENDPOINT=https://huggingface.co"
+set "no_proxy="
+if "!USE_MIRROR!" == "true" (
+ set "HF_ENDPOINT=https://hf-mirror.com"
+ set "no_proxy=localhost, 127.0.0.1, 0.0.0.0"
+)
+echo "HF_ENDPOINT: !HF_ENDPOINT!"
+echo "NO_PROXY: !no_proxy!"
+
+set "install_packages="
+for %%p in (%packages%) do (
+ %PIP_CMD% show %%p >nul 2>&1
+ if errorlevel 1 (
+ set "install_packages=!install_packages! %%p"
+ )
+)
+
+if not "!install_packages!"=="" (
+ echo.
+ echo Installing: !install_packages!
+ for %%p in (!install_packages!) do (
+ if "!INSTALL_TYPE!"=="preview" (
+ if "%%p"=="torch" (
+ set "WHEEL_FILE=torch-2.4.0.dev20240427+cu121-cp310-cp310-win_amd64.whl"
+ set "URL=!HF_ENDPOINT!/datasets/SpicyqSama007/windows_compile/resolve/main/torch-2.4.0.dev20240427_cu121-cp310-cp310-win_amd64.whl?download=true"
+ set "CHKSUM=b091308f4cb74e63d0323afd67c92f2279d9e488d8cbf467bcc7b939bcd74e0b"
+ :TORCH_DOWNLOAD
+ echo "%CD%\!WHEEL_FILE!"
+ if not exist "%CD%\!WHEEL_FILE!" (
+ call curl -Lk "!URL!" --output "!WHEEL_FILE!"
+ )
+ for /f "delims=" %%I in ('certutil -hashfile "!WHEEL_FILE!" SHA256 ^| find /i "!CHKSUM!"') do (
+ set "FILE_VALID=true"
+ )
+ if not defined FILE_VALID (
+ echo File checksum does not match, re-downloading...
+ del "!WHEEL_FILE!"
+ goto TORCH_DOWNLOAD
+ )
+ echo "OK for !WHEEL_FILE!"
+ %PIP_CMD% install "%CD%\!WHEEL_FILE!" --no-warn-script-location
+ del "!WHEEL_FILE!"
+ ) else if "%%p"=="torchvision" (
+ set "WHEEL_FILE=torchvision-0.19.0.dev20240428+cu121-cp310-cp310-win_amd64.whl"
+ set "URL=!HF_ENDPOINT!/datasets/SpicyqSama007/windows_compile/resolve/main/torchvision-0.19.0.dev20240428_cu121-cp310-cp310-win_amd64.whl?download=true"
+ set "CHKSUM=7e46d0a89534013f001563d15e80f9eb431089571720c51f2cc595feeb01d785"
+ :TORCHVISION_DOWNLOAD
+ if not exist "!WHEEL_FILE!" (
+ call curl -Lk "!URL!" --output "!WHEEL_FILE!"
+ )
+ for /f "delims=" %%I in ('certutil -hashfile "!WHEEL_FILE!" SHA256 ^| find /i "!CHKSUM!"') do (
+ set "FILE_VALID=true"
+ )
+ if not defined FILE_VALID (
+ echo File checksum does not match, re-downloading...
+ del "!WHEEL_FILE!"
+ goto TORCHVISION_DOWNLOAD
+ )
+ echo "OK for !WHEEL_FILE!"
+ %PIP_CMD% install "%CD%\!WHEEL_FILE!" --no-warn-script-location
+ del "!WHEEL_FILE!"
+ ) else if "%%p"=="torchaudio" (
+ set "WHEEL_FILE=torchaudio-2.2.0.dev20240427+cu121-cp310-cp310-win_amd64.whl"
+ set "URL=!HF_ENDPOINT!/datasets/SpicyqSama007/windows_compile/resolve/main/torchaudio-2.2.0.dev20240427_cu121-cp310-cp310-win_amd64.whl?download=true"
+ set "CHKSUM=abafb4bc82cbc6f58f18e1b95191bc1884c28e404781082db2eb540b4fae8a5d"
+ :TORCHAUDIO_DOWNLOAD
+ if not exist "!WHEEL_FILE!" (
+ call curl -Lk "!URL!" --output "!WHEEL_FILE!"
+ )
+ for /f "delims=" %%I in ('certutil -hashfile "!WHEEL_FILE!" SHA256 ^| find /i "!CHKSUM!"') do (
+ set "FILE_VALID=true"
+ )
+ if not defined FILE_VALID (
+ echo File checksum does not match, re-downloading...
+ del "!WHEEL_FILE!"
+ goto TORCHAUDIO_DOWNLOAD
+ )
+ echo "OK for !WHEEL_FILE!"
+ %PIP_CMD% install "%CD%\!WHEEL_FILE!" --no-warn-script-location
+ del "!WHEEL_FILE!"
+ ) else if "%%p"=="openai-whisper" (
+ %PIP_CMD% install openai-whisper --no-warn-script-location
+ ) else if "%%p"=="fish-speech" (
+ %PIP_CMD% install -e .
+ ) else if "%%p"=="triton_windows" (
+ set "WHEEL_FILE=triton_windows-0.1.0-py3-none-any.whl"
+ set "URL=!HF_ENDPOINT!/datasets/SpicyqSama007/windows_compile/resolve/main/triton_windows-0.1.0-py3-none-any.whl?download=true"
+ set "CHKSUM=2cc998638180f37cf5025ab65e48c7f629aa5a369176cfa32177d2bd9aa26a0a"
+ :TRITON_DOWNLOAD
+ if not exist "!WHEEL_FILE!" (
+ call curl -Lk "!URL!" --output "!WHEEL_FILE!"
+ )
+ for /f "delims=" %%I in ('certutil -hashfile "!WHEEL_FILE!" SHA256 ^| find /i "!CHKSUM!"') do (
+ set "FILE_VALID=true"
+ )
+ if not defined FILE_VALID (
+ echo File checksum does not match, re-downloading...
+ del "!WHEEL_FILE!"
+ goto TRITON_DOWNLOAD
+ )
+ echo "OK for !WHEEL_FILE!"
+ %PIP_CMD% install "%CD%\!WHEEL_FILE!" --no-warn-script-location
+ del "!WHEEL_FILE!"
+ )
+
+ )
+ )
+)
+
+set "install_packages="
+for %%p in (%packages%) do (
+ %PIP_CMD% show %%p >nul 2>&1
+ if errorlevel 1 (
+ set "install_packages=!install_packages! %%p"
+ )
+)
+
+if not "!install_packages!"=="" (
+ echo.
+ echo Installing: !install_packages!
+
+ for %%p in (!install_packages!) do (
+ if "!USE_MIRROR!"=="true" (
+ if "%%p"=="torch" (
+ %PIP_CMD% install torch --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu121 --no-warn-script-location
+ ) else if "%%p"=="torchvision" (
+ %PIP_CMD% install torchvision --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu121 --no-warn-script-location
+ ) else if "%%p"=="torchaudio" (
+ %PIP_CMD% install torchaudio --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu121 --no-warn-script-location
+ ) else if "%%p"=="openai-whisper" (
+ %PIP_CMD% install -i https://pypi.tuna.tsinghua.edu.cn/simple openai-whisper --no-warn-script-location
+ ) else if "%%p"=="fish-speech" (
+ %PIP_CMD% install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
+ )
+ )
+
+ if "!USE_MIRROR!"=="false" (
+ if "%%p"=="torch" (
+ %PIP_CMD% install torch --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
+ ) else if "%%p"=="torchvision" (
+ %PIP_CMD% install torchvision --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
+ ) else if "%%p"=="torchaudio" (
+ %PIP_CMD% install torchaudio --index-url https://download.pytorch.org/whl/cu121 --no-warn-script-location
+ ) else if "%%p"=="openai-whisper" (
+ %PIP_CMD% install openai-whisper --no-warn-script-location
+ ) else if "%%p"=="fish-speech" (
+ %PIP_CMD% install -e .
+ )
+ )
+
+ )
+)
+echo Environment Check: Success.
+
+endlocal
+:end
+pause
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a4553c4f9357b79ff3f0bca21f5ca79b6c955951
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,104 @@
+site_name: Fish Speech
+site_description: Targeting SOTA TTS solutions.
+site_url: https://speech.fish.audio
+
+# Repository
+repo_name: fishaudio/fish-speech
+repo_url: https://github.com/fishaudio/fish-speech
+edit_uri: blob/main/docs
+
+# Copyright
+copyright: Copyright © 2023-2024 by Fish Audio
+
+theme:
+ name: material
+ language: en
+ features:
+ - content.action.edit
+ - content.action.view
+ - navigation.tracking
+ - navigation.footer
+ # - navigation.tabs
+ - search
+ - search.suggest
+ - search.highlight
+ - search.share
+ - content.code.copy
+ icon:
+ logo: fontawesome/solid/fish
+
+ palette:
+ # Palette toggle for automatic mode
+ - media: "(prefers-color-scheme)"
+ toggle:
+ icon: material/brightness-auto
+ name: Switch to light mode
+
+ # Palette toggle for light mode
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+ primary: black
+ font:
+ code: Roboto Mono
+
+ # Palette toggle for dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ toggle:
+ icon: material/brightness-4
+ name: Switch to light mode
+ primary: black
+ font:
+ code: Roboto Mono
+
+# Plugins
+plugins:
+ - search:
+ separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
+ lang:
+ - zh
+ - en
+ - i18n:
+ docs_structure: folder
+ languages:
+ - locale: en
+ name: English
+ build: true
+ - locale: zh
+ default: true
+ name: 简体中文
+ build: true
+ - locale: ja
+ name: 日本語
+ build: true
+
+markdown_extensions:
+ - pymdownx.highlight:
+ anchor_linenums: true
+ line_spans: __span
+ pygments_lang_class: true
+ - pymdownx.inlinehilite
+ - pymdownx.snippets
+ - pymdownx.superfences
+ - admonition
+ - pymdownx.details
+ - pymdownx.superfences
+ - attr_list
+ - md_in_html
+ - pymdownx.superfences
+
+extra_css:
+ - stylesheets/extra.css
+
+extra:
+ social:
+ - icon: fontawesome/brands/discord
+ link: https://discord.gg/Es5qTB9BcN
+ - icon: fontawesome/brands/docker
+ link: https://hub.docker.com/r/lengyue233/fish-speech
+ - icon: fontawesome/brands/qq
+ link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093
+ homepage: https://speech.fish.audio
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..28e6f8df29004dab13238d2f649ddb7e6bedd0aa
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,56 @@
+[project]
+name = "fish-speech"
+version = "0.1.0"
+authors = [
+ {name = "Lengyue", email = "lengyue@lengyue.me"},
+]
+description = "Fish Speech"
+readme = "README.md"
+requires-python = ">=3.10"
+keywords = ["TTS", "Speech"]
+license = {text = "BSD-3-Clause"}
+classifiers = [
+ "Programming Language :: Python :: 3",
+]
+dependencies = [
+ "numpy<=1.26.4",
+ "transformers>=4.35.2",
+ "datasets==2.18.0",
+ "lightning>=2.1.0",
+ "hydra-core>=1.3.2",
+ "tensorboard>=2.14.1",
+ "natsort>=8.4.0",
+ "einops>=0.7.0",
+ "librosa>=0.10.1",
+ "rich>=13.5.3",
+ "gradio>=4.0.0",
+ "wandb>=0.15.11",
+ "grpcio>=1.58.0",
+ "kui>=1.6.0",
+ "uvicorn>=0.30.0",
+ "loguru>=0.6.0",
+ "loralib>=0.1.2",
+ "natsort>=8.4.0",
+ "pyrootutils>=1.0.4",
+ "vector_quantize_pytorch>=1.14.24",
+ "resampy>=0.4.3",
+ "einx[torch]==0.2.2",
+ "zstandard>=0.22.0",
+ "pydub",
+ "faster_whisper",
+ "modelscope==1.16.1",
+ "funasr==1.1.2"
+]
+
+[project.optional-dependencies]
+asr = [
+ "openai-whisper",
+ "modelscope"
+]
+
+[build-system]
+requires = ["setuptools", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+packages = ["fish_speech"]
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad1493530f7f6d8fa476dbe0b76e6239fce2d7e7
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,6 @@
+{
+ "exclude": [
+ "data",
+ "filelists"
+ ]
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5f6903eeff730ab9ed2a2fc4949c3a3601113586
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,27 @@
+numpy<=1.26.4
+transformers>=4.35.2
+datasets==2.18.0
+lightning>=2.1.0
+hydra-core>=1.3.2
+tensorboard>=2.14.1
+natsort>=8.4.0
+einops>=0.7.0
+librosa>=0.10.1
+rich>=13.5.3
+gradio
+wandb>=0.15.11
+grpcio>=1.58.0
+kui>=1.6.0
+uvicorn>=0.30.0
+loguru>=0.6.0
+loralib>=0.1.2
+natsort>=8.4.0
+pyrootutils>=1.0.4
+vector_quantize_pytorch>=1.14.24
+resampy>=0.4.3
+einx[torch]==0.2.2
+zstandard>=0.22.0
+pydub
+faster_whisper
+modelscope==1.16.1
+funasr==1.1.2
diff --git a/run_cmd.bat b/run_cmd.bat
new file mode 100644
index 0000000000000000000000000000000000000000..05fda82d1bbe8c4f898c7d989018bf7bc71e7f3e
--- /dev/null
+++ b/run_cmd.bat
@@ -0,0 +1,50 @@
+@echo off
+chcp 65001
+
+set no_proxy="127.0.0.1, 0.0.0.0, localhost"
+setlocal enabledelayedexpansion
+
+cd /D "%~dp0"
+
+set PATH="%PATH%";%SystemRoot%\system32
+
+
+echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && (
+ echo.
+ echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && (
+ goto end
+ )
+)
+
+
+set TMP=%CD%\fishenv
+set TEMP=%CD%\fishenv
+
+
+(call conda deactivate && call conda deactivate && call conda deactivate) 2>nul
+
+
+set CONDA_ROOT_PREFIX=%cd%\fishenv\conda
+set INSTALL_ENV_DIR=%cd%\fishenv\env
+
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=
+set PYTHONHOME=
+
+
+call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%"
+
+if errorlevel 1 (
+ echo.
+ echo Environment activation failed.
+ goto end
+) else (
+ echo.
+ echo Environment activation succeeded.
+)
+
+cmd /k "%*"
+
+:end
+pause
diff --git a/start.bat b/start.bat
new file mode 100644
index 0000000000000000000000000000000000000000..f3b58a6a1af914b1aaf5e51409110d488aef46e9
--- /dev/null
+++ b/start.bat
@@ -0,0 +1,85 @@
+@echo off
+chcp 65001
+
+set USE_MIRROR=true
+set PYTHONPATH=%~dp0
+set PYTHON_CMD=%cd%\fishenv\env\python
+set API_FLAG_PATH=%~dp0API_FLAGS.txt
+set KMP_DUPLICATE_LIB_OK=TRUE
+
+setlocal enabledelayedexpansion
+
+set "HF_ENDPOINT=https://huggingface.co"
+set "no_proxy="
+if "%USE_MIRROR%" == "true" (
+ set "HF_ENDPOINT=https://hf-mirror.com"
+ set "no_proxy=localhost, 127.0.0.1, 0.0.0.0"
+)
+echo "HF_ENDPOINT: !HF_ENDPOINT!"
+echo "NO_PROXY: !no_proxy!"
+%PYTHON_CMD% .\tools\download_models.py
+
+set "API_FLAGS="
+set "flags="
+
+if exist "%API_FLAG_PATH%" (
+ for /f "usebackq tokens=*" %%a in ("%API_FLAG_PATH%") do (
+ set "line=%%a"
+ if not "!line:~0,1!"=="#" (
+ set "line=!line: =!"
+ set "line=!line:\=!"
+ set "line=!line:= !"
+ if not "!line!"=="" (
+ set "API_FLAGS=!API_FLAGS!!line! "
+ )
+ )
+ )
+)
+
+
+if not "!API_FLAGS!"=="" set "API_FLAGS=!API_FLAGS:~0,-1!"
+
+set "flags="
+
+echo !API_FLAGS! | findstr /C:"--api" >nul 2>&1
+if !errorlevel! equ 0 (
+ echo.
+ echo Start HTTP API...
+ set "mode=api"
+ goto process_flags
+)
+
+echo !API_FLAGS! | findstr /C:"--infer" >nul 2>&1
+if !errorlevel! equ 0 (
+ echo.
+ echo Start WebUI Inference...
+ set "mode=infer"
+ goto process_flags
+)
+
+
+:process_flags
+for %%p in (!API_FLAGS!) do (
+ if not "%%p"=="--!mode!" (
+ set "flags=!flags! %%p"
+ )
+)
+
+if not "!flags!"=="" set "flags=!flags:~1!"
+
+echo Debug: flags = !flags!
+
+if "!mode!"=="api" (
+ %PYTHON_CMD% -m tools.api !flags!
+) else if "!mode!"=="infer" (
+ %PYTHON_CMD% -m tools.webui !flags!
+)
+
+echo.
+echo Next launch the page...
+%PYTHON_CMD% fish_speech\webui\manage.py
+
+
+:end
+endlocal
+pause
diff --git a/tools/api.py b/tools/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a6379204bca380c3ef84edeec264c6e505a05a8
--- /dev/null
+++ b/tools/api.py
@@ -0,0 +1,482 @@
+import base64
+import io
+import json
+import queue
+import random
+import traceback
+import wave
+from argparse import ArgumentParser
+from http import HTTPStatus
+from pathlib import Path
+from typing import Annotated, Literal, Optional
+
+import librosa
+import numpy as np
+import pyrootutils
+import soundfile as sf
+import torch
+from kui.asgi import (
+ Body,
+ HTTPException,
+ HttpView,
+ JSONResponse,
+ Kui,
+ OpenAPI,
+ StreamResponse,
+)
+from kui.asgi.routing import MultimethodRoutes
+from loguru import logger
+from pydantic import BaseModel, Field
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+# from fish_speech.models.vqgan.lit_module import VQGAN
+from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+# Define utils for web server
+async def http_execption_handler(exc: HTTPException):
+ return JSONResponse(
+ dict(
+ statusCode=exc.status_code,
+ message=exc.content,
+ error=HTTPStatus(exc.status_code).phrase,
+ ),
+ exc.status_code,
+ exc.headers,
+ )
+
+
+async def other_exception_handler(exc: "Exception"):
+ traceback.print_exc()
+
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
+ return JSONResponse(
+ dict(statusCode=status, message=str(exc), error=status.phrase),
+ status,
+ )
+
+
+def load_audio(reference_audio, sr):
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
+ try:
+ audio_data = base64.b64decode(reference_audio)
+ reference_audio = io.BytesIO(audio_data)
+ except base64.binascii.Error:
+ raise ValueError("Invalid path or base64 string")
+
+ audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
+ return audio
+
+
+def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
+ if enable_reference_audio and reference_audio is not None:
+ # Load audios, and prepare basic info here
+ reference_audio_content = load_audio(
+ reference_audio, decoder_model.spec_transform.sample_rate
+ )
+
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
+ None, None, :
+ ]
+ audio_lengths = torch.tensor(
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
+ )
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ if isinstance(decoder_model, FireflyArchitecture):
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
+ else:
+ prompt_tokens = None
+ logger.info("No reference audio provided")
+
+ return prompt_tokens
+
+
+def decode_vq_tokens(
+ *,
+ decoder_model,
+ codes,
+):
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
+ logger.info(f"VQ features: {codes.shape}")
+
+ if isinstance(decoder_model, FireflyArchitecture):
+ # VQGAN Inference
+ return decoder_model.decode(
+ indices=codes[None],
+ feature_lengths=feature_lengths,
+ ).squeeze()
+
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
+
+
+routes = MultimethodRoutes(base_class=HttpView)
+
+
+def get_random_paths(base_path, data, speaker, emotion):
+ if base_path and data and speaker and emotion and (Path(base_path).exists()):
+ if speaker in data and emotion in data[speaker]:
+ files = data[speaker][emotion]
+ lab_files = [f for f in files if f.endswith(".lab")]
+ wav_files = [f for f in files if f.endswith(".wav")]
+
+ if lab_files and wav_files:
+ selected_lab = random.choice(lab_files)
+ selected_wav = random.choice(wav_files)
+
+ lab_path = Path(base_path) / speaker / emotion / selected_lab
+ wav_path = Path(base_path) / speaker / emotion / selected_wav
+ if lab_path.exists() and wav_path.exists():
+ return lab_path, wav_path
+
+ return None, None
+
+
+def load_json(json_file):
+ if not json_file:
+ logger.info("Not using a json file")
+ return None
+ try:
+ with open(json_file, "r", encoding="utf-8") as file:
+ data = json.load(file)
+ except FileNotFoundError:
+ logger.warning(f"ref json not found: {json_file}")
+ data = None
+ except Exception as e:
+ logger.warning(f"Loading json failed: {e}")
+ data = None
+ return data
+
+
+class InvokeRequest(BaseModel):
+ text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
+ reference_text: Optional[str] = None
+ reference_audio: Optional[str] = None
+ max_new_tokens: int = 1024
+ chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+ emotion: Optional[str] = None
+ format: Literal["wav", "mp3", "flac"] = "wav"
+ streaming: bool = False
+ ref_json: Optional[str] = "ref_data.json"
+ ref_base: Optional[str] = "ref_data"
+ speaker: Optional[str] = None
+
+
+def get_content_type(audio_format):
+ if audio_format == "wav":
+ return "audio/wav"
+ elif audio_format == "flac":
+ return "audio/flac"
+ elif audio_format == "mp3":
+ return "audio/mpeg"
+ else:
+ return "application/octet-stream"
+
+
+@torch.inference_mode()
+def inference(req: InvokeRequest):
+ # Parse reference audio aka prompt
+ prompt_tokens = None
+
+ ref_data = load_json(req.ref_json)
+ ref_base = req.ref_base
+
+ lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
+
+ if lab_path and wav_path:
+ with open(lab_path, "r", encoding="utf-8") as lab_file:
+ ref_text = lab_file.read()
+ req.reference_audio = wav_path
+ req.reference_text = ref_text
+ logger.info("ref_path: " + str(wav_path))
+ logger.info("ref_text: " + ref_text)
+
+ # Parse reference audio aka prompt
+ prompt_tokens = encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=req.reference_audio,
+ enable_reference_audio=req.reference_audio is not None,
+ )
+ logger.info(f"ref_text: {req.reference_text}")
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=req.max_new_tokens,
+ text=req.text,
+ top_p=req.top_p,
+ repetition_penalty=req.repetition_penalty,
+ temperature=req.temperature,
+ compile=args.compile,
+ iterative_prompt=req.chunk_length > 0,
+ chunk_length=req.chunk_length,
+ max_length=2048,
+ prompt_tokens=prompt_tokens,
+ prompt_text=req.reference_text,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if req.streaming:
+ yield wav_chunk_header()
+
+ segments = []
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ raise result.response
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with torch.autocast(
+ device_type=decoder_model.device.type, dtype=args.precision
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+
+ if req.streaming:
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
+ else:
+ segments.append(fake_audios)
+
+ if req.streaming:
+ return
+
+ if len(segments) == 0:
+ raise HTTPException(
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ content="No audio generated, please check the input text.",
+ )
+
+ fake_audios = np.concatenate(segments, axis=0)
+ yield fake_audios
+
+
+def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
+ if not use_auto_rerank:
+ # 如果不使用 auto_rerank,直接调用原始的 inference 函数
+ return inference(req)
+
+ zh_model, en_model = load_model()
+ max_attempts = 5
+ best_wer = float("inf")
+ best_audio = None
+
+ for attempt in range(max_attempts):
+ # 调用原始的 inference 函数
+ audio_generator = inference(req)
+ fake_audios = next(audio_generator)
+
+ asr_result = batch_asr(
+ zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
+ )[0]
+ wer = calculate_wer(req.text, asr_result["text"])
+
+ if wer <= 0.1 and not asr_result["huge_gap"]:
+ return fake_audios
+
+ if wer < best_wer:
+ best_wer = wer
+ best_audio = fake_audios
+
+ if attempt == max_attempts - 1:
+ break
+
+ return best_audio
+
+
+async def inference_async(req: InvokeRequest):
+ for chunk in inference(req):
+ yield chunk
+
+
+async def buffer_to_async_generator(buffer):
+ yield buffer
+
+
+@routes.http.post("/v1/invoke")
+async def api_invoke_model(
+ req: Annotated[InvokeRequest, Body(exclusive=True)],
+):
+ """
+ Invoke model and generate audio
+ """
+
+ if args.max_text_length > 0 and len(req.text) > args.max_text_length:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content=f"Text is too long, max length is {args.max_text_length}",
+ )
+
+ if req.streaming and req.format != "wav":
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST,
+ content="Streaming only supports WAV format",
+ )
+
+ if req.streaming:
+ return StreamResponse(
+ iterable=inference_async(req),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+ else:
+ fake_audios = next(inference(req))
+ buffer = io.BytesIO()
+ sf.write(
+ buffer,
+ fake_audios,
+ decoder_model.spec_transform.sample_rate,
+ format=req.format,
+ )
+
+ return StreamResponse(
+ iterable=buffer_to_async_generator(buffer.getvalue()),
+ headers={
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
+ },
+ content_type=get_content_type(req.format),
+ )
+
+
+@routes.http.post("/v1/health")
+async def api_health():
+ """
+ Health check
+ """
+
+ return JSONResponse({"status": "ok"})
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=str,
+ default="checkpoints/fish-speech-1.2-sft",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=str,
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-text-length", type=int, default=0)
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
+ parser.add_argument("--workers", type=int, default=1)
+ parser.add_argument("--use-auto-rerank", type=bool, default=True)
+
+ return parser.parse_args()
+
+
+# Define Kui app
+openapi = OpenAPI(
+ {
+ "title": "Fish Speech API",
+ },
+).routes
+
+app = Kui(
+ routes=routes + openapi[1:], # Remove the default route
+ exception_handlers={
+ HTTPException: http_execption_handler,
+ Exception: other_exception_handler,
+ },
+ cors_config={},
+)
+
+
+if __name__ == "__main__":
+ import threading
+
+ import uvicorn
+
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("VQ-GAN model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference(
+ InvokeRequest(
+ text="Hello world.",
+ reference_text=None,
+ reference_audio=None,
+ max_new_tokens=0,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ emotion=None,
+ format="wav",
+ ref_base=None,
+ ref_json=None,
+ )
+ )
+ )
+
+ logger.info(f"Warming up done, starting server at http://{args.listen}")
+ host, port = args.listen.split(":")
+ uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
diff --git a/tools/auto_rerank.py b/tools/auto_rerank.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e478ff108faeb39587a1c43abbea6d59f5378bd
--- /dev/null
+++ b/tools/auto_rerank.py
@@ -0,0 +1,126 @@
+import time
+from threading import Lock
+
+import numpy as np
+import torch
+import torchaudio
+from funasr import AutoModel
+from funasr.models.seaco_paraformer.model import SeacoParaformer
+
+# Monkey patching to disable hotwords
+SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
+
+
+def load_model(*, device="cuda"):
+ zh_model = AutoModel(
+ model="paraformer-zh",
+ device=device,
+ disable_pbar=True,
+ )
+ en_model = AutoModel(
+ model="paraformer-en",
+ device=device,
+ disable_pbar=True,
+ )
+
+ return zh_model, en_model
+
+
+@torch.no_grad()
+def batch_asr_internal(model, audios, sr):
+ resampled_audios = []
+ for audio in audios:
+ # 将 NumPy 数组转换为 PyTorch 张量
+ if isinstance(audio, np.ndarray):
+ audio = torch.from_numpy(audio).float()
+
+ # 确保音频是一维的
+ if audio.dim() > 1:
+ audio = audio.squeeze()
+
+ audio = torchaudio.functional.resample(audio, sr, 16000)
+ assert audio.dim() == 1
+ resampled_audios.append(audio)
+
+ res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
+
+ results = []
+ for r, audio in zip(res, audios):
+ text = r["text"]
+ duration = len(audio) / sr * 1000
+ huge_gap = False
+
+ if "timestamp" in r and len(r["timestamp"]) > 2:
+ for timestamp_a, timestamp_b in zip(
+ r["timestamp"][:-1], r["timestamp"][1:]
+ ):
+ # If there is a gap of more than 5 seconds, we consider it as a huge gap
+ if timestamp_b[0] - timestamp_a[1] > 5000:
+ huge_gap = True
+ break
+
+ # Doesn't make sense to have a huge gap at the end
+ if duration - r["timestamp"][-1][1] > 3000:
+ huge_gap = True
+
+ results.append(
+ {
+ "text": text,
+ "duration": duration,
+ "huge_gap": huge_gap,
+ }
+ )
+
+ return results
+
+
+global_lock = Lock()
+
+
+def batch_asr(model, audios, sr):
+ return batch_asr_internal(model, audios, sr)
+
+
+def is_chinese(text):
+ return True
+
+
+def calculate_wer(text1, text2):
+ words1 = text1.split()
+ words2 = text2.split()
+
+ # 计算编辑距离
+ m, n = len(words1), len(words2)
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
+
+ for i in range(m + 1):
+ dp[i][0] = i
+ for j in range(n + 1):
+ dp[0][j] = j
+
+ for i in range(1, m + 1):
+ for j in range(1, n + 1):
+ if words1[i - 1] == words2[j - 1]:
+ dp[i][j] = dp[i - 1][j - 1]
+ else:
+ dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
+
+ # 计算WER
+ edits = dp[m][n]
+ wer = edits / len(words1)
+
+ return wer
+
+
+if __name__ == "__main__":
+ zh_model, en_model = load_model()
+ audios = [
+ torchaudio.load("lengyue.wav")[0][0],
+ torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
+ ]
+ print(batch_asr(zh_model, audios, 44100))
+
+ start_time = time.time()
+ for _ in range(10):
+ batch_asr(zh_model, audios, 44100)
+ print("Time taken:", time.time() - start_time)
diff --git a/tools/download_models.py b/tools/download_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae4a4df47e5b72c449c6e7c798c8f40afdf5867
--- /dev/null
+++ b/tools/download_models.py
@@ -0,0 +1,63 @@
+import os
+
+from huggingface_hub import hf_hub_download
+
+
+# Download
+def check_and_download_files(repo_id, file_list, local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+ for file in file_list:
+ file_path = os.path.join(local_dir, file)
+ if not os.path.exists(file_path):
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=file,
+ resume_download=True,
+ local_dir=local_dir,
+ local_dir_use_symlinks=False,
+ )
+ else:
+ print(f"{file} 已存在,跳过下载。")
+
+
+# 1st
+repo_id_1 = "fishaudio/fish-speech-1.2-sft"
+local_dir_1 = "./checkpoints/fish-speech-1.2-sft"
+files_1 = [
+ "model.pth",
+ "README.md",
+ "special_tokens_map.json",
+ "tokenizer_config.json",
+ "tokenizer.json",
+ "config.json",
+ "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+]
+
+# 2nd
+repo_id_2 = "SpicyqSama007/fish-speech-packed"
+local_dir_2 = ".cache/whisper"
+files_2 = [
+ "medium.pt",
+ "small.pt",
+]
+
+# 3rd
+repo_id_3 = "fishaudio/fish-speech-1"
+local_dir_3 = "./"
+files_3 = [
+ "ffmpeg.exe",
+ "ffprobe.exe",
+]
+
+# 4th
+repo_id_4 = "SpicyqSama007/fish-speech-packed"
+local_dir_4 = "./"
+files_4 = [
+ "asr-label-win-x64.exe",
+]
+
+check_and_download_files(repo_id_1, files_1, local_dir_1)
+check_and_download_files(repo_id_2, files_2, local_dir_2)
+check_and_download_files(repo_id_3, files_3, local_dir_3)
+check_and_download_files(repo_id_4, files_4, local_dir_4)
diff --git a/tools/extract_model.py b/tools/extract_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..97fe62507b7282890319d8dc1eaa3cbca0e1f60a
--- /dev/null
+++ b/tools/extract_model.py
@@ -0,0 +1,21 @@
+import click
+import torch
+from loguru import logger
+
+
+@click.command()
+@click.argument("model_path")
+@click.argument("output_path")
+def main(model_path, output_path):
+ if model_path == output_path:
+ logger.error("Model path and output path are the same")
+ return
+
+ logger.info(f"Loading model from {model_path}")
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
+ torch.save(state_dict, output_path)
+ logger.info(f"Model saved to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/gen_ref.py b/tools/gen_ref.py
new file mode 100644
index 0000000000000000000000000000000000000000..a771903b02c4ae8ce22e08e3db56bb4d0c8b3b9c
--- /dev/null
+++ b/tools/gen_ref.py
@@ -0,0 +1,36 @@
+import json
+from pathlib import Path
+
+
+def scan_folder(base_path):
+ wav_lab_pairs = {}
+
+ base = Path(base_path)
+ for suf in ["wav", "lab"]:
+ for f in base.rglob(f"*.{suf}"):
+ relative_path = f.relative_to(base)
+ parts = relative_path.parts
+ print(parts)
+ if len(parts) >= 3:
+ character = parts[0]
+ emotion = parts[1]
+
+ if character not in wav_lab_pairs:
+ wav_lab_pairs[character] = {}
+ if emotion not in wav_lab_pairs[character]:
+ wav_lab_pairs[character][emotion] = []
+ wav_lab_pairs[character][emotion].append(str(f.name))
+
+ return wav_lab_pairs
+
+
+def save_to_json(data, output_file):
+ with open(output_file, "w", encoding="utf-8") as file:
+ json.dump(data, file, ensure_ascii=False, indent=2)
+
+
+base_path = "ref_data"
+out_ref_file = "ref_data.json"
+
+wav_lab_pairs = scan_folder(base_path)
+save_to_json(wav_lab_pairs, out_ref_file)
diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e2219956adc419aba91cde5d9097fad4288315
--- /dev/null
+++ b/tools/llama/build_dataset.py
@@ -0,0 +1,169 @@
+import itertools
+import os
+import re
+from collections import defaultdict
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
+from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
+from fish_speech.utils.file import load_filelist
+
+# To avoid CPU overload
+os.environ["MKL_NUM_THREADS"] = "1"
+os.environ["OMP_NUM_THREADS"] = "1"
+
+
+def task_generator_folder(root: Path, text_extension: str):
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+ files = sorted(files)
+
+ grouped_files = defaultdict(list)
+ for file in tqdm(files, desc=f"Grouping {root}"):
+ p = str(file.parent)
+ speaker = file.parent.name
+
+ try:
+ if isinstance(text_extension, str):
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
+ else:
+ texts = [
+ file.with_suffix(ext).read_text(encoding="utf-8")
+ for ext in text_extension
+ ]
+ except Exception as e:
+ logger.error(f"Failed to read text {file}: {e}")
+ continue
+
+ grouped_files[p].append((speaker, file, texts))
+
+ logger.info(
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
+ )
+
+ for i in grouped_files.values():
+ subset = [(f, t) for _, f, t in i]
+ yield i[0][0], subset, "folder"
+
+
+def task_generator_filelist(filelist):
+ grouped_files = defaultdict(list)
+ for filename, speaker, _, text in load_filelist(filelist):
+ grouped_files[speaker].append((Path(filename), [text]))
+
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+ for speaker, values in grouped_files.items():
+ yield speaker, values, "filelist"
+
+
+def run_task(task):
+ name, subset, source = task
+
+ # Parse the files
+ sentences = []
+ for file, texts in subset:
+ np_file = file.with_suffix(".npy")
+ if np_file.exists() is False:
+ logger.warning(f"Can't find {np_file}")
+ continue
+
+ new_texts = []
+
+ for text in texts:
+ # Simple cleaning: replace { xxx } and < xxx > with space
+ text = re.sub(r"\{.*?\}", " ", text)
+ text = re.sub(r"<.*?>", " ", text)
+ text = re.sub(r"\s+", " ", text)
+ new_texts.append(text)
+
+ try:
+ semantics = np.load(np_file)
+ except Exception as e:
+ logger.error(f"Failed to parse {file}: {e}")
+ continue
+
+ if isinstance(semantics, np.ndarray):
+ semantics = semantics.tolist()
+
+ sentences.append(
+ Sentence(
+ texts=new_texts,
+ semantics=[Semantics(values=s) for s in semantics],
+ )
+ )
+
+ # Pack the sentences
+ return pack_pb_stream(
+ TextData(
+ source=source,
+ name=name,
+ sentences=sentences,
+ )
+ )
+
+
+@click.command()
+@click.option(
+ "--input",
+ type=click.Path(path_type=Path),
+ required=True,
+ help="A folder containing the dataset or a filelist",
+ multiple=True,
+)
+@click.option(
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
+)
+@click.option("--num-workers", type=int, default=16)
+@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
+@click.option(
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
+)
+def main(input, output, num_workers, text_extension, shard_size):
+ generator_fns = []
+
+ for f in input:
+ assert f.exists(), f"{f} not found"
+
+ if f.is_dir():
+ generator_fn = task_generator_folder(f, text_extension)
+ else:
+ generator_fn = task_generator_filelist(f)
+
+ generator_fns.append(generator_fn)
+
+ generator_fn = itertools.chain(*generator_fns)
+ output.mkdir(parents=True, exist_ok=True)
+
+ dataset_fp = None
+ tar_idx = 0
+ written_size = 0
+
+ with Pool(num_workers) as p:
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
+ if dataset_fp is None:
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
+
+ dataset_fp.write(result)
+ written_size += len(result)
+
+ if written_size > shard_size * 1024 * 1024:
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
+ dataset_fp.close()
+ dataset_fp = None
+ written_size = 0
+ tar_idx += 1
+
+ if dataset_fp is not None:
+ dataset_fp.close()
+
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d70940487388185381246d8210a49a58e55743
--- /dev/null
+++ b/tools/llama/eval_in_context.py
@@ -0,0 +1,171 @@
+import pyrootutils
+import torch
+import torch.nn.functional as F
+from matplotlib import pyplot as plt
+from transformers import AutoTokenizer
+
+# register eval resolver and root
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+from torch.utils.data import DataLoader
+
+from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
+from tools.llama.generate import load_model
+
+
+def smooth(
+ scalars: list[float], weight: float
+) -> list[float]: # Weight between 0 and 1
+ last = scalars[0] # First value in the plot (first timestep)
+ smoothed = list()
+ for point in scalars:
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
+ smoothed.append(smoothed_val) # Save it
+ last = smoothed_val # Anchor the last smoothed value
+
+ return smoothed
+
+
+@torch.inference_mode()
+def analyze_one_model(loader, config, weight, max_length):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model = load_model(
+ config,
+ weight,
+ device,
+ torch.bfloat16,
+ max_length,
+ compile=False,
+ )[0]
+
+ current_step = 0
+ model.eval()
+
+ semantic_loss_sum = torch.zeros(
+ max_length,
+ dtype=torch.float32,
+ device=device,
+ )
+ counter = torch.zeros(
+ max_length,
+ dtype=torch.long,
+ device=device,
+ )
+
+ for batch in loader:
+ batch = {k: v.to(device) for k, v in batch.items()}
+
+ labels = batch["labels"]
+ outputs = model(
+ inp=batch["inputs"],
+ key_padding_mask=batch["attention_masks"],
+ )
+
+ token_logits = outputs.token_logits
+ codebook_logits = outputs.codebook_logits
+
+ # Generate labels
+ base_loss = F.cross_entropy(
+ token_logits.reshape(-1, token_logits.size(-1)),
+ labels[:, 0].reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
+ semantic_loss = F.cross_entropy(
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
+ codebook_labels.reshape(-1),
+ ignore_index=-100,
+ reduction="none",
+ )
+
+ base_loss = base_loss.reshape(labels[:, 0].shape)
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
+
+ semantic_loss_frame = semantic_loss.mean(-1)
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
+
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
+ semantic_loss_sum[~pad] += loss_sample[~pad]
+ counter[~pad] += 1
+
+ current_step += 1
+ if current_step == 10:
+ break
+
+ semantic_loss = semantic_loss.cpu()
+ counter = counter.cpu()
+ xs, ys = [], []
+
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
+ if count > 0:
+ xs.append(i)
+ ys.append((loss / count).item()) # for better loss visualization
+
+ smoothed_ys = smooth(ys, 0.95)
+
+ # Unload model
+ del model
+ torch.cuda.empty_cache()
+
+ return xs, ys, smoothed_ys
+
+
+def main():
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+ max_length = 4096
+
+ ds = AutoAugTextDataset(
+ ["data/protos/sft/云天河"],
+ tokenizer=tokenizer,
+ use_speaker=False,
+ interactive_prob=1.0,
+ max_length=max_length,
+ )
+
+ loader = DataLoader(
+ ds,
+ batch_size=8,
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
+ num_workers=0,
+ shuffle=False,
+ )
+
+ plt.figure(figsize=(10, 5), dpi=200)
+
+ plt.xlabel("Frame")
+ plt.ylabel("Loss")
+ plt.yscale("log")
+ plt.title("Semantic Loss")
+ plt.grid(which="both", axis="both")
+ plt.xlim(0, max_length)
+
+ tests = [
+ (
+ "pertrain-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
+ ),
+ (
+ "sft-medium",
+ "dual_ar_2_codebook_medium",
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
+ ),
+ (
+ "sft-large",
+ "dual_ar_2_codebook_large",
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
+ ),
+ ]
+
+ for name, config, weight in tests:
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
+ plt.plot(xs, smoothed_ys, label=name)
+
+ plt.legend()
+ plt.savefig("semantic_loss.png")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/generate.py b/tools/llama/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaf195efacc59bf3c73e2e0dd1d687d6a7a5c963
--- /dev/null
+++ b/tools/llama/generate.py
@@ -0,0 +1,695 @@
+import os
+import queue
+import threading
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal, Optional, Tuple, Union
+
+import click
+import hydra
+import numpy as np
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from fish_speech.text import clean_text, split_text
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+ # Experimental feature to reduce compilation times, will be on by default in future
+ torch._inductor.config.fx_graph_cache = True
+
+
+from fish_speech.models.text2semantic.llama import (
+ BaseTransformer,
+ DualARTransformer,
+ NaiveTransformer,
+)
+
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ temperature: torch.Tensor = 1.0,
+ top_p: torch.Tensor = 1.0,
+ repetition_penalty: torch.Tensor = 1.0,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=0, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
+
+
+def sample(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs(
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+
+def decode_one_token_ar(
+ model: DualARTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=(
+ previous_tokens[0] if previous_tokens is not None else None
+ ), # Disable repetition penalty for the token codebook
+ **sampling_kwargs,
+ )[0]
+ ]
+ x = x.hidden_states
+
+ # Cleanup the cache
+ for layer in model.fast_layers:
+ layer.attention.kv_cache.k_cache.fill_(0)
+ layer.attention.kv_cache.v_cache.fill_(0)
+
+ for codebook_idx in range(model.config.num_codebooks):
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
+ logits = model.forward_generate_fast(x, input_pos)
+ a = sample(
+ logits,
+ previous_tokens=(
+ previous_tokens[codebook_idx + 1]
+ if previous_tokens is not None
+ else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ x = model.fast_embeddings(a)
+ codebooks.append(a)
+
+ return torch.stack(codebooks, dim=0)
+
+
+def decode_one_token_naive(
+ model: NaiveTransformer,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ x = model.forward_generate(x, input_pos)
+
+ codebooks = [
+ sample(
+ x.token_logits,
+ previous_tokens=None, # Disable repetition penalty for the token codebook
+ **sampling_kwargs,
+ )[0]
+ ]
+
+ for i in range(model.config.num_codebooks):
+ codebooks.append(
+ sample(
+ x.codebook_logits[:, :, i],
+ previous_tokens=(
+ previous_tokens[i + 1] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs,
+ )[0]
+ )
+
+ return torch.stack(codebooks, dim=0)
+
+
+def decode_n_tokens(
+ model: NaiveTransformer,
+ cur_token: torch.Tensor,
+ input_pos: torch.Tensor,
+ num_new_tokens: int,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+):
+ previous_tokens = torch.zeros(
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
+ dtype=torch.int,
+ device=cur_token.device,
+ )
+
+ for i in tqdm(range(num_new_tokens)):
+ # We need to get windowed repeat penalty
+ win_size = 16
+ if i < win_size:
+ window = previous_tokens[:, :win_size]
+ else:
+ window = previous_tokens[:, i - win_size : i]
+
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ ): # Actually better for Inductor to codegen attention here
+ next_token = decode_one_token(
+ model=model,
+ x=cur_token,
+ input_pos=input_pos,
+ previous_tokens=window,
+ **sampling_kwargs,
+ )
+
+ input_pos += 1
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+ previous_tokens[:, i : i + 1] = next_token.view(
+ model.config.num_codebooks + 1, -1
+ )
+
+ if cur_token[0, 0, -1] == im_end_id:
+ break
+
+ return previous_tokens[:, : i + 1]
+
+
+@torch.no_grad()
+@torch.inference_mode()
+def generate(
+ *,
+ model: NaiveTransformer,
+ prompt: torch.Tensor,
+ max_new_tokens: int,
+ im_end_id: int = 4,
+ decode_one_token=decode_one_token_naive,
+ **sampling_kwargs,
+) -> torch.Tensor:
+ """
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+ """
+
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ T = prompt.size(1)
+
+ if max_new_tokens:
+ if T + max_new_tokens > model.config.max_seq_len:
+ max_new_tokens = model.config.max_seq_len - T
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+ T_new = T + max_new_tokens
+ else:
+ T_new = model.config.max_seq_len
+ max_new_tokens = T_new - T
+
+ device, dtype = prompt.device, prompt.dtype
+ with torch.device(device):
+ model.setup_caches(
+ max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
+ )
+
+ codebook_dim = 1 + model.config.num_codebooks
+ # create an empty tensor of the expected final shape and fill in the current tokens
+ empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
+ empty[:, :T] = prompt
+ seq = empty
+ input_pos = torch.arange(0, T, device=device)
+
+ # Use non-accelerated version for now, to avoid compilation overhead
+ prefill_decode = (
+ decode_one_token_naive
+ if isinstance(model, NaiveTransformer)
+ else decode_one_token_ar
+ )
+
+ next_token = prefill_decode(
+ model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
+ )
+ seq[:, T : T + 1] = next_token
+
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
+ x = decode_n_tokens(
+ model,
+ next_token.view(1, codebook_dim, -1),
+ input_pos,
+ max_new_tokens - 1,
+ im_end_id=im_end_id,
+ decode_one_token=decode_one_token,
+ **sampling_kwargs,
+ )
+ # x = torch.cat(generated_tokens, dim=1)
+ seq = seq[:, : T + 1 + x.size(1)]
+ seq[:, T + 1 :] = x
+
+ return seq
+
+
+def encode_tokens(
+ tokenizer,
+ string,
+ device="cuda",
+ prompt_tokens=None,
+ num_codebooks=4,
+):
+ string = clean_text(string)
+ string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
+
+ new_tokens = tokenizer.encode(
+ string,
+ add_special_tokens=False,
+ max_length=10**6,
+ truncation=False,
+ )
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
+
+ # Codebooks
+ zeros = (
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
+ * CODEBOOK_PAD_TOKEN_ID
+ )
+ prompt = torch.cat((tokens, zeros), dim=0)
+
+ if prompt_tokens is None:
+ return prompt
+
+ # Get prompt tokens
+ if prompt_tokens.ndim == 3:
+ assert (
+ prompt_tokens.shape[0] == 1
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
+ prompt_tokens = prompt_tokens[0]
+
+ assert prompt_tokens.ndim == 2
+ data = prompt_tokens + 1
+
+ if prompt_tokens.shape[0] > num_codebooks:
+ logger.warning(
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
+ )
+ data = data[:num_codebooks]
+
+ # Add pad token for each codebook
+ data = torch.cat(
+ (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
+ dim=1,
+ )
+
+ # Since 1.0, we use <|semantic|>
+ s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+ main_token_ids = (
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
+ )
+ main_token_ids[0, -1] = end_token_id
+
+ data = torch.cat((main_token_ids, data), dim=0)
+ prompt = torch.cat((prompt, data), dim=1)
+
+ return prompt
+
+
+def load_model(checkpoint_path, device, precision, compile=False):
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
+ checkpoint_path, load_weights=True
+ )
+
+ model = model.to(device=device, dtype=precision)
+ logger.info(f"Restored model from checkpoint")
+
+ if isinstance(model, DualARTransformer):
+ decode_one_token = decode_one_token_ar
+ logger.info("Using DualARTransformer")
+ else:
+ decode_one_token = decode_one_token_naive
+ logger.info("Using NaiveTransformer")
+
+ if compile:
+ logger.info("Compiling function...")
+ decode_one_token = torch.compile(
+ decode_one_token, mode="reduce-overhead", fullgraph=True
+ )
+
+ return model.eval(), decode_one_token
+
+
+@dataclass
+class GenerateResponse:
+ action: Literal["sample", "next"]
+ codes: Optional[torch.Tensor] = None
+ text: Optional[str] = None
+
+
+def generate_long(
+ *,
+ model,
+ device: str | torch.device,
+ decode_one_token: callable,
+ text: str,
+ num_samples: int = 1,
+ max_new_tokens: int = 0,
+ top_p: int = 0.7,
+ repetition_penalty: float = 1.5,
+ temperature: float = 0.7,
+ compile: bool = False,
+ iterative_prompt: bool = True,
+ max_length: int = 2048,
+ chunk_length: int = 150,
+ prompt_text: Optional[str | list[str]] = None,
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
+):
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
+
+ use_prompt = prompt_text is not None and prompt_tokens is not None
+ if use_prompt and isinstance(prompt_text, str):
+ prompt_text = [prompt_text]
+ prompt_tokens = [prompt_tokens]
+
+ assert use_prompt is False or len(prompt_text) == len(
+ prompt_tokens
+ ), "Prompt text and tokens must have the same length"
+
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ tokenizer = model.tokenizer
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+
+ encoded = []
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
+ encoded_prompts = []
+
+ if use_prompt:
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
+ encoded_prompts.append(
+ encode_tokens(
+ tokenizer,
+ string=t,
+ device=device,
+ prompt_tokens=c,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+
+ for idx, text in enumerate(texts):
+ encoded.append(
+ encode_tokens(
+ tokenizer,
+ string=text,
+ device=device,
+ num_codebooks=model.config.num_codebooks,
+ )
+ )
+ logger.info(f"Encoded text: {text}")
+
+ # Move temperature, top_p, repetition_penalty to device
+ # This is important so that changing params doesn't trigger recompile
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
+ repetition_penalty = torch.tensor(
+ repetition_penalty, device=device, dtype=torch.float
+ )
+
+ for sample_idx in range(num_samples):
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ global_encoded = []
+ seg_idx = 0
+
+ while seg_idx < len(encoded):
+ logger.info(
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
+ )
+
+ seg = encoded[seg_idx]
+ global_encoded.append(seg)
+
+ lengths = reversed([seg.size(1) for seg in global_encoded])
+
+ # Pick last 2000 tokens
+ count = 0
+ for i, length in enumerate(lengths):
+ count += length
+ if count + length > max_length - 1024 - sum(
+ t.shape[1] for t in encoded_prompts
+ ):
+ break
+
+ if i != 0 and i % 2 == 0:
+ i -= 1
+
+ # Rotate the list, always make sure first segment is included to avoid drift
+ if i < len(global_encoded) - 2:
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
+ else:
+ partial_encoded = global_encoded
+
+ if use_prompt:
+ partial_encoded = encoded_prompts + partial_encoded
+
+ cat_encoded = torch.cat(partial_encoded, dim=1)
+ prompt_length = cat_encoded.size(1)
+
+ t0 = time.perf_counter()
+ y = generate(
+ model=model,
+ prompt=cat_encoded,
+ max_new_tokens=max_new_tokens,
+ im_end_id=im_end_id,
+ decode_one_token=decode_one_token,
+ temperature=temperature,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ )
+
+ if sample_idx == 0 and seg_idx == 0 and compile:
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ t = time.perf_counter() - t0
+
+ tokens_generated = y.size(1) - prompt_length
+ tokens_sec = tokens_generated / t
+ logger.info(
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+ )
+ logger.info(
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+ )
+
+ if torch.cuda.is_available():
+ logger.info(
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
+ )
+
+ # Put the generated tokens
+ # since there is and tokens, we remove last 2 tokens
+ codes = y[1:, prompt_length:-1].clone()
+ codes = codes - 1
+ assert (codes >= 0).all(), f"Negative code found"
+
+ decoded = y[:, prompt_length:-1].clone()
+ # But for global encoding, we should keep the token
+
+ global_encoded.append(decoded)
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
+ seg_idx += 1
+
+ # This indicates the end of the current sample
+ yield GenerateResponse(action="next")
+
+
+@dataclass
+class WrappedGenerateResponse:
+ status: Literal["success", "error"]
+ response: Optional[GenerateResponse | Exception] = None
+
+
+@dataclass
+class GenerateRequest:
+ request: dict
+ response_queue: queue.Queue
+
+
+def launch_thread_safe_queue(
+ checkpoint_path,
+ device,
+ precision,
+ compile: bool = False,
+):
+ input_queue = queue.Queue()
+ init_event = threading.Event()
+
+ def worker():
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+ init_event.set()
+
+ while True:
+ item: GenerateRequest | None = input_queue.get()
+ if item is None:
+ break
+
+ kwargs = item.request
+ response_queue = item.response_queue
+
+ try:
+ for chunk in generate_long(
+ model=model, decode_one_token=decode_one_token, **kwargs
+ ):
+ response_queue.put(
+ WrappedGenerateResponse(status="success", response=chunk)
+ )
+ except Exception as e:
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
+
+ threading.Thread(target=worker, daemon=True).start()
+ init_event.wait()
+
+ return input_queue
+
+
+@click.command()
+@click.option(
+ "--text",
+ type=str,
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+)
+@click.option("--prompt-text", type=str, default=None, multiple=True)
+@click.option(
+ "--prompt-tokens",
+ type=click.Path(path_type=Path, exists=True),
+ default=None,
+ multiple=True,
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--top-p", type=float, default=0.7)
+@click.option("--repetition-penalty", type=float, default=1.2)
+@click.option("--temperature", type=float, default=0.7)
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.2-sft",
+)
+@click.option("--device", type=str, default="cuda")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
+@click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
+@click.option("--chunk-length", type=int, default=100)
+def main(
+ text: str,
+ prompt_text: Optional[list[str]],
+ prompt_tokens: Optional[list[Path]],
+ num_samples: int,
+ max_new_tokens: int,
+ top_p: int,
+ repetition_penalty: float,
+ temperature: float,
+ checkpoint_path: Path,
+ device: str,
+ compile: bool,
+ seed: int,
+ half: bool,
+ iterative_prompt: bool,
+ chunk_length: int,
+) -> None:
+
+ precision = torch.half if half else torch.bfloat16
+
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+ raise ValueError(
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
+ )
+
+ logger.info("Loading model ...")
+ t0 = time.time()
+ model, decode_one_token = load_model(
+ checkpoint_path, device, precision, compile=compile
+ )
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+ if prompt_tokens is not None:
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
+
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+ generator = generate_long(
+ model=model,
+ device=device,
+ decode_one_token=decode_one_token,
+ text=text,
+ num_samples=num_samples,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=compile,
+ iterative_prompt=iterative_prompt,
+ chunk_length=chunk_length,
+ prompt_text=prompt_text,
+ prompt_tokens=prompt_tokens,
+ )
+
+ idx = 0
+ codes = []
+
+ for response in generator:
+ if response.action == "sample":
+ codes.append(response.codes)
+ logger.info(f"Sampled text: {response.text}")
+ elif response.action == "next":
+ if codes:
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
+ logger.info(f"Saved codes to codes_{idx}.npy")
+ logger.info(f"Next sample")
+ codes = []
+ idx += 1
+ else:
+ logger.error(f"Error: {response}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..f12eece8d2dee9c52d756b1df63503eaca9411b7
--- /dev/null
+++ b/tools/llama/merge_lora.py
@@ -0,0 +1,95 @@
+import shutil
+from copy import deepcopy
+from pathlib import Path
+
+import click
+import hydra
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+
+from fish_speech.models.text2semantic.llama import BaseTransformer
+from fish_speech.models.text2semantic.lora import get_merged_state_dict
+
+
+@click.command()
+@click.option("--lora-config", type=str, default="r_8_alpha_16")
+@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
+@click.option("--lora-weight", type=str, required=True)
+@click.option("--output", type=str, required=True)
+def merge(lora_config, base_weight, lora_weight, output):
+ output = Path(output)
+ logger.info(
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
+ )
+
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
+ cfg = compose(config_name=lora_config)
+
+ lora_config = instantiate(cfg)
+ logger.info(f"Loaded lora model with config {lora_config}")
+
+ llama_model = BaseTransformer.from_pretrained(
+ path=base_weight,
+ load_weights=True,
+ lora_config=lora_config,
+ )
+ logger.info(f"Loaded llama model")
+
+ llama_state_dict = llama_model.state_dict()
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
+ llama_state_dict_copy = deepcopy(llama_state_dict)
+ lora_state_dict = torch.load(lora_weight, map_location="cpu")
+
+ if "state_dict" in llama_state_dict:
+ llama_state_dict = llama_state_dict["state_dict"]
+
+ if "state_dict" in lora_state_dict:
+ lora_state_dict = lora_state_dict["state_dict"]
+
+ # remove prefix model.
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
+ llama_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in llama_state_dict.items()
+ if k.startswith("model.")
+ }
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
+ lora_state_dict = {
+ k.replace("model.", ""): v
+ for k, v in lora_state_dict.items()
+ if k.startswith("model.")
+ }
+
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
+
+ merged_state_dict = llama_state_dict | lora_state_dict
+ llama_model.load_state_dict(merged_state_dict, strict=True)
+ logger.info(f"Merged model loaded")
+
+ # Trigger eval mode to merge lora
+ llama_model.eval()
+ llama_model.save_pretrained(output, drop_lora=True)
+ logger.info(f"Saved merged model to {output}, validating")
+
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
+ original_keys = set(llama_state_dict_copy.keys())
+ merged_keys = set(new_state_dict.keys())
+
+ assert original_keys == merged_keys, "Keys should be same"
+
+ for key in original_keys:
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
+ if diff_l1 != 0:
+ break
+ else:
+ logger.error("Merged model is same as the original model")
+ exit(1)
+
+ logger.info("Merged model is different from the original model, check passed")
+
+
+if __name__ == "__main__":
+ merge()
diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..aae32fcce7ffd9f865e6f3c4b7b281f23345c82a
--- /dev/null
+++ b/tools/llama/quantize.py
@@ -0,0 +1,497 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+import datetime
+import shutil
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import time
+from pathlib import Path
+
+import click
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fish_speech.models.text2semantic.llama import find_multiple
+from tools.llama.generate import load_model
+
+##### Quantization Primitives ######
+
+
+def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+ # assumes symmetric quantization
+ # assumes axis == 0
+ # assumes dense memory format
+ # TODO(future): relax ^ as needed
+
+ # default setup for affine quantization of activations
+ eps = torch.finfo(torch.float32).eps
+
+ # get min and max
+ min_val, max_val = torch.aminmax(x, dim=1)
+
+ # calculate scales and zero_points based on min and max
+ # reference: https://fburl.com/code/srbiybme
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+ device = min_val_neg.device
+
+ # reference: https://fburl.com/code/4wll53rk
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
+ # ensure scales is the same dtype as the original tensor
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+ # quantize based on qmin/qmax/scales/zp
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
+ x_div = x / scales.unsqueeze(-1)
+ x_round = torch.round(x_div)
+ x_zp = x_round + zero_points.unsqueeze(-1)
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+ return quant, scales, zero_points
+
+
+def get_group_qparams(w, n_bit=4, groupsize=128):
+ # needed for GPTQ with padding
+ if groupsize > w.shape[-1]:
+ groupsize = w.shape[-1]
+ assert groupsize > 1
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ max_val = to_quant.amax(dim=1, keepdim=True)
+ min_val = to_quant.amin(dim=1, keepdim=True)
+ max_int = 2**n_bit - 1
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
+ zeros = min_val + scales * (2 ** (n_bit - 1))
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
+ torch.bfloat16
+ ).reshape(w.shape[0], -1)
+
+
+def pack_scales_and_zeros(scales, zeros):
+ assert scales.shape == zeros.shape
+ assert scales.dtype == torch.bfloat16
+ assert zeros.dtype == torch.bfloat16
+ return (
+ torch.cat(
+ [
+ scales.reshape(scales.size(0), scales.size(1), 1),
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
+ ],
+ 2,
+ )
+ .transpose(0, 1)
+ .contiguous()
+ )
+
+
+def unpack_scales_and_zeros(scales_and_zeros):
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
+ assert scales_and_zeros.dtype == torch.float
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
+ assert groupsize > 1
+ # needed for GPTQ single column quantize
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w.shape[-1]
+
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+ min_val = zeros - scales * (2 ** (n_bit - 1))
+ max_int = 2**n_bit - 1
+ min_int = 0
+ w_int32 = (
+ to_quant.sub(min_val)
+ .div(scales)
+ .round()
+ .clamp_(min_int, max_int)
+ .to(torch.int32)
+ .reshape_as(w)
+ )
+
+ return w_int32
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128):
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
+ return w_int32, scales_and_zeros
+
+
+def group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit=4, groupsize=128
+):
+ assert groupsize > 1
+ # needed for GPTQ single column dequantize
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w_int32.shape[-1]
+ assert w_int32.shape[-1] % groupsize == 0
+ assert w_int32.dim() == 2
+
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+
+ w_dq = (
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
+ )
+ return w_dq
+
+
+def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
+ return group_dequantize_tensor_from_qparams(
+ w_int32, scales, zeros, n_bit, groupsize
+ )
+
+
+class QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ def create_quantized_state_dict(self) -> "StateDict":
+ pass
+
+ def convert_for_runtime(self) -> "nn.Module":
+ pass
+
+
+##### Weight-only int8 per-channel quantized code ######
+
+
+def replace_linear_weight_only_int8_per_channel(module):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
+ )
+ else:
+ replace_linear_weight_only_int8_per_channel(child)
+
+
+class WeightOnlyInt8QuantHandler:
+ def __init__(self, mod):
+ self.mod = mod
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
+ mod.weight.float(), -128, 127, torch.int8
+ )
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_weight_only_int8_per_channel(self.mod)
+ return self.mod
+
+
+class WeightOnlyInt8Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.register_buffer(
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
+ )
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
+
+
+##### weight only int4 per channel groupwise quantized code ######
+
+
+def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
+ weight_int32, scales_and_zeros = group_quantize_tensor(
+ weight_bf16, n_bit=4, groupsize=groupsize
+ )
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+ weight_int32, inner_k_tiles
+ )
+ return weight_int4pack, scales_and_zeros
+
+
+def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
+ origin_x_size = x.size()
+ x = x.reshape(-1, origin_x_size[-1])
+ c = torch.ops.aten._weight_int4pack_mm(
+ x, weight_int4pack, groupsize, scales_and_zeros
+ )
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
+
+def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
+
+
+def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=False,
+ ),
+ )
+ elif padding:
+ setattr(
+ module,
+ name,
+ WeightOnlyInt4Linear(
+ child.in_features,
+ child.out_features,
+ bias=False,
+ groupsize=groupsize,
+ inner_k_tiles=inner_k_tiles,
+ padding=True,
+ ),
+ )
+ else:
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
+
+
+class WeightOnlyInt4QuantHandler:
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
+ self.mod = mod
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+ self.padding = padding
+ assert groupsize in [32, 64, 128, 256]
+ assert inner_k_tiles in [2, 4, 8]
+
+ @torch.no_grad()
+ def create_quantized_state_dict(self):
+ cur_state_dict = self.mod.state_dict()
+ for fqn, mod in self.mod.named_modules():
+ if isinstance(mod, torch.nn.Linear):
+ assert not mod.bias
+ out_features = mod.out_features
+ in_features = mod.in_features
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
+
+ weight = mod.weight.data
+ if not _check_linear_int4_k(
+ in_features, self.groupsize, self.inner_k_tiles
+ ):
+ if self.padding:
+ import torch.nn.functional as F
+
+ print(
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
+ )
+ padded_in_features = find_multiple(in_features, 1024)
+ weight = F.pad(
+ weight, pad=(0, padded_in_features - in_features)
+ )
+ else:
+ print(
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
+ )
+ continue
+ (
+ weight_int4pack,
+ scales_and_zeros,
+ ) = prepare_int4_weight_and_scales_and_zeros(
+ weight.to(torch.bfloat16).to("cuda"),
+ self.groupsize,
+ self.inner_k_tiles,
+ )
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
+
+ return cur_state_dict
+
+ def convert_for_runtime(self):
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
+ return self.mod
+
+
+class WeightOnlyInt4Linear(torch.nn.Module):
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: torch.Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias=True,
+ device=None,
+ dtype=None,
+ groupsize: int = 128,
+ inner_k_tiles: int = 8,
+ padding: bool = True,
+ ) -> None:
+ super().__init__()
+ self.padding = padding
+ if padding:
+ self.origin_in_features = in_features
+ in_features = find_multiple(in_features, 1024)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ assert not bias, "require bias=False"
+ self.groupsize = groupsize
+ self.inner_k_tiles = inner_k_tiles
+
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ assert (
+ in_features % (inner_k_tiles * 16) == 0
+ ), "require in_features % (innerKTiles * 16) == 0"
+ self.register_buffer(
+ "weight",
+ torch.empty(
+ (
+ out_features // 8,
+ in_features // (inner_k_tiles * 16),
+ 32,
+ inner_k_tiles // 2,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty(
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
+ ),
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ input = input.to(torch.bfloat16)
+ if self.padding:
+ import torch.nn.functional as F
+
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
+ return linear_forward_int4(
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
+ )
+
+
+def generate_folder_name():
+ now = datetime.datetime.now()
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
+ return folder_name
+
+
+@click.command()
+@click.option(
+ "--checkpoint-path",
+ type=click.Path(path_type=Path, exists=True),
+ default="checkpoints/fish-speech-1.2-sft",
+)
+@click.option(
+ "--mode", type=str, default="int8", help="type of quantization to perform"
+)
+@click.option(
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
+)
+@click.option("--timestamp", type=str, default="None", help="When to do quantization")
+def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
+
+ device = "cpu"
+ precision = torch.bfloat16
+
+ print("Loading model ...")
+ t0 = time.time()
+
+ model, _ = load_model(
+ checkpoint_path=checkpoint_path,
+ device=device,
+ precision=precision,
+ compile=False,
+ )
+ vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+ now = timestamp if timestamp != "None" else generate_folder_name()
+
+ if mode == "int8":
+ print(
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
+ )
+ quant_handler = WeightOnlyInt8QuantHandler(model)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ elif mode == "int4":
+ print(
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
+ )
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+ dir_name = checkpoint_path
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+ if (dst_name / vq_model).exists():
+ (dst_name / vq_model).unlink()
+ quantize_path = dst_name / "model.pth"
+
+ else:
+ raise ValueError(
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
+ )
+
+ print(f"Writing quantized weights to {quantize_path}")
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
+ torch.save(quantized_state_dict, quantize_path)
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
+
+
+if __name__ == "__main__":
+ quantize()
diff --git a/tools/llama/rebuild_tokenizer.py b/tools/llama/rebuild_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea64fa6788833000c8dc41e3d570dd5b250fb14b
--- /dev/null
+++ b/tools/llama/rebuild_tokenizer.py
@@ -0,0 +1,57 @@
+from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+
+# Initialize a tokenizer
+tokenizer = Tokenizer(models.BPE())
+
+# Customize pre-tokenization and decoding
+tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
+tokenizer.decoder = decoders.ByteLevel()
+tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+# Don't train the tokenizer
+trainer = trainers.BpeTrainer(
+ vocab_size=0,
+ min_frequency=2,
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
+ special_tokens=[
+ "<|begin_of_sequence|>",
+ "<|end_of_sequence|>",
+ "<|im_start|>",
+ "<|im_sep|>", # system, user, assistant, etc.
+ "<|im_end|>",
+ "<|semantic|>", # audio features
+ "<|pad|>",
+ ],
+)
+
+# <|im_start|>user<|im_sep|>...<|im_end|>
+# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
+tokenizer.train_from_iterator([], trainer=trainer)
+
+print(len(tokenizer.get_vocab()))
+x = tokenizer.encode(
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
+).ids
+print(x, len(x))
+print(tokenizer.decode(x, skip_special_tokens=True))
+
+
+tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ pad_token="<|pad|>",
+ bos_token="<|begin_of_sequence|>",
+ eos_token="<|end_of_sequence|>",
+)
+
+# Try tokenizing a new sequence
+sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
+encoded = tokenizer(sequence).input_ids
+
+print("Test encoding....")
+print(f"\tSentence: {sequence}")
+print(f"\tEncoded: {encoded}")
+print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
+print(f"\tDecoded: {tokenizer.decode(encoded)}")
+
+tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
diff --git a/tools/merge_asr_files.py b/tools/merge_asr_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86d29a7a220aafc92cf8cf5ea9689f027b2287c
--- /dev/null
+++ b/tools/merge_asr_files.py
@@ -0,0 +1,55 @@
+import os
+from pathlib import Path
+
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+def merge_and_delete_files(save_dir, original_files):
+ save_path = Path(save_dir)
+ audio_slice_files = list_files(
+ path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
+ )
+ audio_files = {}
+ label_files = {}
+ for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
+ rel_path = Path(file_path).relative_to(save_path)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+ if file_path.suffix == ".wav":
+ prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
+ if prefix == rel_path.parent / file_path.stem:
+ continue
+ audio = AudioSegment.from_wav(file_path)
+ if prefix in audio_files.keys():
+ audio_files[prefix] = audio_files[prefix] + audio
+ else:
+ audio_files[prefix] = audio
+
+ elif file_path.suffix == ".lab":
+ prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
+ if prefix == rel_path.parent / file_path.stem:
+ continue
+ with open(file_path, "r", encoding="utf-8") as f:
+ label = f.read()
+ if prefix in label_files.keys():
+ label_files[prefix] = label_files[prefix] + ", " + label
+ else:
+ label_files[prefix] = label
+
+ for prefix, audio in audio_files.items():
+ output_audio_path = save_path / f"{prefix}.wav"
+ audio.export(output_audio_path, format="wav")
+
+ for prefix, label in label_files.items():
+ output_label_path = save_path / f"{prefix}.lab"
+ with open(output_label_path, "w", encoding="utf-8") as f:
+ f.write(label)
+
+ for file_path in original_files:
+ os.remove(file_path)
+
+
+if __name__ == "__main__":
+ merge_and_delete_files("/made/by/spicysama/laziman", [__file__])
diff --git a/tools/post_api.py b/tools/post_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccf887825fabdd20754ace3f4ab94ac0c1212665
--- /dev/null
+++ b/tools/post_api.py
@@ -0,0 +1,148 @@
+import argparse
+import base64
+import json
+from pathlib import Path
+
+import pyaudio
+import requests
+
+
+def wav_to_base64(file_path):
+ if not file_path or not Path(file_path).exists():
+ return None
+ with open(file_path, "rb") as wav_file:
+ wav_content = wav_file.read()
+ base64_encoded = base64.b64encode(wav_content)
+ return base64_encoded.decode("utf-8")
+
+
+def read_ref_text(ref_text):
+ path = Path(ref_text)
+ if path.exists() and path.is_file():
+ with path.open("r", encoding="utf-8") as file:
+ return file.read()
+ return ref_text
+
+
+def play_audio(audio_content, format, channels, rate):
+ p = pyaudio.PyAudio()
+ stream = p.open(format=format, channels=channels, rate=rate, output=True)
+ stream.write(audio_content)
+ stream.stop_stream()
+ stream.close()
+ p.terminate()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Send a WAV file and text to a server and receive synthesized audio."
+ )
+
+ parser.add_argument(
+ "--url",
+ "-u",
+ type=str,
+ default="http://127.0.0.1:8080/v1/invoke",
+ help="URL of the server",
+ )
+ parser.add_argument(
+ "--text", "-t", type=str, required=True, help="Text to be synthesized"
+ )
+ parser.add_argument(
+ "--reference_audio",
+ "-ra",
+ type=str,
+ default=None,
+ help="Path to the WAV file",
+ )
+ parser.add_argument(
+ "--reference_text",
+ "-rt",
+ type=str,
+ default=None,
+ help="Reference text for voice synthesis",
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=1024,
+ help="Maximum new tokens to generate",
+ )
+ parser.add_argument(
+ "--chunk_length", type=int, default=100, help="Chunk length for synthesis"
+ )
+ parser.add_argument(
+ "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
+ )
+ parser.add_argument(
+ "--repetition_penalty",
+ type=float,
+ default=1.2,
+ help="Repetition penalty for synthesis",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
+ )
+ parser.add_argument(
+ "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
+ )
+ parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
+ parser.add_argument("--format", type=str, default="wav", help="Audio format")
+ parser.add_argument(
+ "--streaming", type=bool, default=False, help="Enable streaming response"
+ )
+ parser.add_argument(
+ "--channels", type=int, default=1, help="Number of audio channels"
+ )
+ parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
+
+ args = parser.parse_args()
+
+ base64_audio = wav_to_base64(args.reference_audio)
+
+ ref_text = args.reference_text
+ if ref_text:
+ ref_text = read_ref_text(ref_text)
+
+ data = {
+ "text": args.text,
+ "reference_text": ref_text,
+ "reference_audio": base64_audio,
+ "max_new_tokens": args.max_new_tokens,
+ "chunk_length": args.chunk_length,
+ "top_p": args.top_p,
+ "repetition_penalty": args.repetition_penalty,
+ "temperature": args.temperature,
+ "speaker": args.speaker,
+ "emotion": args.emotion,
+ "format": args.format,
+ "streaming": args.streaming,
+ }
+
+ response = requests.post(args.url, json=data, stream=args.streaming)
+
+ audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
+
+ if response.status_code == 200:
+ if args.streaming:
+ p = pyaudio.PyAudio()
+ stream = p.open(
+ format=audio_format, channels=args.channels, rate=args.rate, output=True
+ )
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk:
+ stream.write(chunk)
+ stream.stop_stream()
+ stream.close()
+ p.terminate()
+ else:
+ audio_content = response.content
+
+ with open("generated_audio.wav", "wb") as audio_file:
+ audio_file.write(audio_content)
+
+ play_audio(audio_content, audio_format, args.channels, args.rate)
+ print("Audio has been saved to 'generated_audio.wav'.")
+ else:
+ print(f"Request failed with status code {response.status_code}")
+ print(response.json())
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fb55d9a034fc963f4243a25704df097924159e7
--- /dev/null
+++ b/tools/smart_pad.py
@@ -0,0 +1,47 @@
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+ loudness = librosa.feature.rms(
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+ )[0]
+ for i in range(len(loudness) - 1, 0, -1):
+ if loudness[i] > threshold:
+ break
+
+ silent_time = (len(loudness) - i) * 512 / sample_rate
+
+ if silent_time <= 0.3:
+ random_time = random.uniform(0.3, 0.7)
+ waveform = F.pad(
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+ )
+
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+ with Pool(num_workers) as p:
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..977afdf3260994ef31d2189a5973a2628b26c0c5
--- /dev/null
+++ b/tools/vqgan/create_train_split.py
@@ -0,0 +1,83 @@
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ if min_duration is None and max_duration is None:
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
+ else:
+ filtered_files = []
+ for file in tqdm(files):
+ try:
+ audio = AudioSegment.from_file(str(file))
+ duration = len(audio) / 1000.0
+
+ if min_duration is not None and duration < min_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+ )
+ continue
+
+ if max_duration is not None and duration > max_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+ )
+ continue
+
+ filtered_files.append(str(file.relative_to(root)))
+ except Exception as e:
+ logger.info(f"Error processing {file}: {e}")
+
+ logger.info(
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+ )
+
+ Random(42).shuffle(filtered_files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(filtered_files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(filtered_files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc84f7e6e0ccfa5480de68b61fb9f3d16c7ae99f
--- /dev/null
+++ b/tools/vqgan/extract_vq.py
@@ -0,0 +1,227 @@
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name} :{function} :{line} | "
+ "{extra[rank]} - {message} "
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "firefly_gan_vq",
+ checkpoint_path: str = "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+ device: str | torch.device = "cuda",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend="sox" if sys.platform == "linux" else "soundfile"
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(
+ wav.cuda(), sr, model.spec_transform.sample_rate
+ )[0]
+ total_time += len(wav) / model.spec_transform.sample_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ if torch.cuda.is_available():
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+ else:
+ # Set to empty string to avoid using GPU
+ visible_devices = [""]
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..68b73b53db6e252b2b33ec07842b28886c1ea8ff
--- /dev/null
+++ b/tools/vqgan/inference.py
@@ -0,0 +1,120 @@
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ result = model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model: {result}")
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+
+ # Load audio
+ audio, sr = torchaudio.load(str(input_path))
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(
+ audio, sr, model.spec_transform.sample_rate
+ )
+
+ audios = audio[None].to(device)
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+ indices = model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
+ fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/webui.py b/tools/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2af42043f19b3911a024447d0da1332ce944762
--- /dev/null
+++ b/tools/webui.py
@@ -0,0 +1,558 @@
+import gc
+import html
+import io
+import os
+import queue
+import wave
+from argparse import ArgumentParser
+from functools import partial
+from pathlib import Path
+
+import gradio as gr
+import numpy as np
+import pyrootutils
+import torch
+from loguru import logger
+from transformers import AutoTokenizer
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+
+from fish_speech.i18n import i18n
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from tools.api import decode_vq_tokens, encode_reference
+from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).")}
+
+{i18n("Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+SPACE_IMPORTED = False
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+@torch.inference_mode()
+def inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+):
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+ return (
+ None,
+ None,
+ i18n("Text is too long, please keep it under {} characters.").format(
+ args.max_gradio_length
+ ),
+ )
+
+ # Parse reference audio aka prompt
+ prompt_tokens = encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=reference_audio,
+ enable_reference_audio=enable_reference_audio,
+ )
+
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=max_new_tokens,
+ text=text,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=args.compile,
+ iterative_prompt=chunk_length > 0,
+ chunk_length=chunk_length,
+ max_length=2048,
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
+ prompt_text=reference_text if enable_reference_audio else None,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if streaming:
+ yield wav_chunk_header(), None, None
+
+ segments = []
+
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ yield None, None, build_html_error_message(result.response)
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with torch.autocast(
+ device_type=(
+ "cpu"
+ if decoder_model.device.type == "mps"
+ else decoder_model.device.type
+ ),
+ dtype=args.precision,
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+ segments.append(fake_audios)
+
+ if streaming:
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
+
+ if len(segments) == 0:
+ return (
+ None,
+ None,
+ build_html_error_message(
+ i18n("No audio generated, please check the input text.")
+ ),
+ )
+
+ # No matter streaming or not, we need to return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def inference_with_auto_rerank(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+ use_auto_rerank=True,
+):
+ if not use_auto_rerank:
+ return inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming,
+ )
+
+ zh_model, en_model = load_model()
+ max_attempts = 2
+ best_wer = float("inf")
+ best_audio = None
+ best_sample_rate = None
+
+ for attempt in range(max_attempts):
+ audio_generator = inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+ )
+
+ # 获取音频数据
+ for _ in audio_generator:
+ pass
+ _, (sample_rate, audio), message = _
+
+ if audio is None:
+ return None, None, message
+
+ asr_result = batch_asr(
+ zh_model if is_chinese(text) else en_model, [audio], sample_rate
+ )[0]
+ wer = calculate_wer(text, asr_result["text"])
+
+ if wer <= 0.3 and not asr_result["huge_gap"]:
+ return None, (sample_rate, audio), None
+
+ if wer < best_wer:
+ best_wer = wer
+ best_audio = audio
+ best_sample_rate = sample_rate
+
+ if attempt == max_attempts - 1:
+ break
+
+ return None, (best_sample_rate, best_audio), None
+
+
+inference_stream = partial(inference_with_auto_rerank, streaming=True)
+
+n_audios = 4
+
+global_audio_list = []
+global_error_list = []
+
+
+def inference_wrapper(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ batch_infer_num,
+):
+ audios = []
+ errors = []
+
+ for _ in range(batch_infer_num):
+ result = inference_with_auto_rerank(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ )
+
+ _, audio_data, error_message = result
+
+ audios.append(
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
+ )
+ errors.append(
+ gr.HTML(value=error_message if error_message else None, visible=True),
+ )
+
+ for _ in range(batch_infer_num, n_audios):
+ audios.append(
+ gr.Audio(value=None, visible=False),
+ )
+ errors.append(
+ gr.HTML(value=None, visible=False),
+ )
+
+ return None, *audios, *errors
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+def normalize_text(user_input, use_normalization):
+ if use_normalization:
+ return ChnNormedText(raw_text=user_input).normalize()
+ else:
+ return user_input
+
+
+def build_app():
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+ refined_text = gr.Textbox(
+ label=i18n("Realtime Transform Text"),
+ placeholder=i18n(
+ "Normalization Result Preview (Currently Only Chinese)"
+ ),
+ lines=5,
+ interactive=False,
+ )
+
+ with gr.Row():
+ if_refine_text = gr.Checkbox(
+ label=i18n("Text Normalization"),
+ value=True,
+ scale=0,
+ min_width=150,
+ )
+
+ with gr.Row():
+ with gr.Tab(label=i18n("Advanced Config")):
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=0,
+ maximum=500,
+ value=100,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n("Maximum tokens per batch, 0 means no limit"),
+ minimum=0,
+ maximum=2048,
+ value=1024, # 0 means no limit
+ step=8,
+ )
+
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.5,
+ value=1.2,
+ step=0.01,
+ )
+
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+
+ enable_reference_audio = gr.Checkbox(
+ label=i18n("Enable Reference Audio"),
+ )
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ placeholder=i18n("Reference Text"),
+ lines=1,
+ value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ )
+ with gr.Tab(label=i18n("Batch Inference")):
+ batch_infer_num = gr.Slider(
+ label="Batch infer nums",
+ minimum=1,
+ maximum=n_audios,
+ step=1,
+ value=1,
+ )
+
+ with gr.Column(scale=3):
+ for _ in range(n_audios):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True if _ == 0 else False,
+ )
+ global_error_list.append(error)
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True if _ == 0 else False,
+ )
+ global_audio_list.append(audio)
+
+ with gr.Row():
+ stream_audio = gr.Audio(
+ label=i18n("Streaming Audio"),
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ show_download_button=True,
+ )
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
+ )
+ generate_stream = gr.Button(
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
+ variant="primary",
+ )
+
+ text.input(
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
+ )
+
+ # # Submit
+ generate.click(
+ inference_wrapper,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ batch_infer_num,
+ ],
+ [stream_audio, *global_audio_list, *global_error_list],
+ concurrency_limit=1,
+ )
+
+ generate_stream.click(
+ inference_stream,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ ],
+ [stream_audio, global_audio_list[0], global_error_list[0]],
+ concurrency_limit=10,
+ )
+ return app
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.2-sft",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference(
+ text="Hello, world!",
+ enable_reference_audio=False,
+ reference_audio=None,
+ reference_text="",
+ max_new_tokens=0,
+ chunk_length=100,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ app = build_app()
+ app.launch(show_api=True)
diff --git a/tools/webui_colab.py b/tools/webui_colab.py
new file mode 100644
index 0000000000000000000000000000000000000000..ded517d50170b1ed59d311c83323174dd7a835f9
--- /dev/null
+++ b/tools/webui_colab.py
@@ -0,0 +1,558 @@
+import gc
+import html
+import io
+import os
+import queue
+import wave
+from argparse import ArgumentParser
+from functools import partial
+from pathlib import Path
+
+import gradio as gr
+import numpy as np
+import pyrootutils
+import torch
+from loguru import logger
+from transformers import AutoTokenizer
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+
+from fish_speech.i18n import i18n
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from tools.api import decode_vq_tokens, encode_reference
+from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).")}
+
+{i18n("Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+SPACE_IMPORTED = False
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+@torch.inference_mode()
+def inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+):
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+ return (
+ None,
+ None,
+ i18n("Text is too long, please keep it under {} characters.").format(
+ args.max_gradio_length
+ ),
+ )
+
+ # Parse reference audio aka prompt
+ prompt_tokens = encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=reference_audio,
+ enable_reference_audio=enable_reference_audio,
+ )
+
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=max_new_tokens,
+ text=text,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=args.compile,
+ iterative_prompt=chunk_length > 0,
+ chunk_length=chunk_length,
+ max_length=2048,
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
+ prompt_text=reference_text if enable_reference_audio else None,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if streaming:
+ yield wav_chunk_header(), None, None
+
+ segments = []
+
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ yield None, None, build_html_error_message(result.response)
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with torch.autocast(
+ device_type=(
+ "cpu"
+ if decoder_model.device.type == "mps"
+ else decoder_model.device.type
+ ),
+ dtype=args.precision,
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+ segments.append(fake_audios)
+
+ if streaming:
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
+
+ if len(segments) == 0:
+ return (
+ None,
+ None,
+ build_html_error_message(
+ i18n("No audio generated, please check the input text.")
+ ),
+ )
+
+ # No matter streaming or not, we need to return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def inference_with_auto_rerank(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+ use_auto_rerank=True,
+):
+ if not use_auto_rerank:
+ return inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming,
+ )
+
+ zh_model, en_model = load_model()
+ max_attempts = 2
+ best_wer = float("inf")
+ best_audio = None
+ best_sample_rate = None
+
+ for attempt in range(max_attempts):
+ audio_generator = inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+ )
+
+ # 获取音频数据
+ for _ in audio_generator:
+ pass
+ _, (sample_rate, audio), message = _
+
+ if audio is None:
+ return None, None, message
+
+ asr_result = batch_asr(
+ zh_model if is_chinese(text) else en_model, [audio], sample_rate
+ )[0]
+ wer = calculate_wer(text, asr_result["text"])
+
+ if wer <= 0.3 and not asr_result["huge_gap"]:
+ return None, (sample_rate, audio), None
+
+ if wer < best_wer:
+ best_wer = wer
+ best_audio = audio
+ best_sample_rate = sample_rate
+
+ if attempt == max_attempts - 1:
+ break
+
+ return None, (best_sample_rate, best_audio), None
+
+
+inference_stream = partial(inference_with_auto_rerank, streaming=True)
+
+n_audios = 4
+
+global_audio_list = []
+global_error_list = []
+
+
+def inference_wrapper(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ batch_infer_num,
+):
+ audios = []
+ errors = []
+
+ for _ in range(batch_infer_num):
+ result = inference_with_auto_rerank(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ )
+
+ _, audio_data, error_message = result
+
+ audios.append(
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
+ )
+ errors.append(
+ gr.HTML(value=error_message if error_message else None, visible=True),
+ )
+
+ for _ in range(batch_infer_num, n_audios):
+ audios.append(
+ gr.Audio(value=None, visible=False),
+ )
+ errors.append(
+ gr.HTML(value=None, visible=False),
+ )
+
+ return None, *audios, *errors
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+def normalize_text(user_input, use_normalization):
+ if use_normalization:
+ return ChnNormedText(raw_text=user_input).normalize()
+ else:
+ return user_input
+
+
+def build_app():
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+ refined_text = gr.Textbox(
+ label=i18n("Realtime Transform Text"),
+ placeholder=i18n(
+ "Normalization Result Preview (Currently Only Chinese)"
+ ),
+ lines=5,
+ interactive=False,
+ )
+
+ with gr.Row():
+ if_refine_text = gr.Checkbox(
+ label=i18n("Text Normalization"),
+ value=True,
+ scale=0,
+ min_width=150,
+ )
+
+ with gr.Row():
+ with gr.Tab(label=i18n("Advanced Config")):
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=0,
+ maximum=500,
+ value=100,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n("Maximum tokens per batch, 0 means no limit"),
+ minimum=0,
+ maximum=2048,
+ value=1024, # 0 means no limit
+ step=8,
+ )
+
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.5,
+ value=1.2,
+ step=0.01,
+ )
+
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+
+ enable_reference_audio = gr.Checkbox(
+ label=i18n("Enable Reference Audio"),
+ )
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ placeholder=i18n("Reference Text"),
+ lines=1,
+ value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ )
+ with gr.Tab(label=i18n("Batch Inference")):
+ batch_infer_num = gr.Slider(
+ label="Batch infer nums",
+ minimum=1,
+ maximum=n_audios,
+ step=1,
+ value=1,
+ )
+
+ with gr.Column(scale=3):
+ for _ in range(n_audios):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True if _ == 0 else False,
+ )
+ global_error_list.append(error)
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True if _ == 0 else False,
+ )
+ global_audio_list.append(audio)
+
+ with gr.Row():
+ stream_audio = gr.Audio(
+ label=i18n("Streaming Audio"),
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ show_download_button=True,
+ )
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
+ )
+ generate_stream = gr.Button(
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
+ variant="primary",
+ )
+
+ text.input(
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
+ )
+
+ # # Submit
+ generate.click(
+ inference_wrapper,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ batch_infer_num,
+ ],
+ [stream_audio, *global_audio_list, *global_error_list],
+ concurrency_limit=1,
+ )
+
+ generate_stream.click(
+ inference_stream,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ ],
+ [stream_audio, global_audio_list[0], global_error_list[0]],
+ concurrency_limit=10,
+ )
+ return app
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.2-sft",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference(
+ text="Hello, world!",
+ enable_reference_audio=False,
+ reference_audio=None,
+ reference_text="",
+ max_new_tokens=0,
+ chunk_length=100,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ app = build_app()
+ app.launch(share=True, show_error=True)
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..6070e73554e63bffaa98ef831697220324df5493
--- /dev/null
+++ b/tools/whisper_asr.py
@@ -0,0 +1,191 @@
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import re
+from pathlib import Path
+
+import click
+import soundfile as sf
+from faster_whisper import WhisperModel
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
+@click.option(
+ "--compute-type",
+ default="float16",
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
+)
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=44100,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+ model_size,
+ compute_type,
+ audio_dir,
+ save_dir,
+ sample_rate,
+ device,
+ language,
+ initial_prompt,
+):
+ logger.info("Loading / Downloading Faster Whisper model...")
+
+ model = WhisperModel(
+ model_size,
+ device=device,
+ compute_type=compute_type,
+ download_root="faster_whisper",
+ )
+
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ numbered_suffix_pattern = re.compile(r"-\d{3}$")
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ # Skip files that already have a .lab file or a -{3-digit number} suffix
+ numbered_suffix = numbered_suffix_pattern.search(file_stem)
+ lab_file = file_path.with_suffix(".lab")
+
+ if numbered_suffix and lab_file.exists():
+ continue
+
+ if not numbered_suffix and lab_file.with_stem(lab_file.stem + "-001").exists():
+ if file_path.exists():
+ file_path.unlink()
+ continue
+
+ audio = AudioSegment.from_file(file_path)
+
+ segments, info = model.transcribe(
+ file_path,
+ beam_size=5,
+ language=None if language == "auto" else language,
+ initial_prompt=initial_prompt,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for segment in segments:
+ id, start, end, text = (
+ segment.id,
+ segment.start,
+ segment.end,
+ segment.text,
+ )
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
+ start_ms = int(start * 1000)
+ end_ms = int(end * 1000) + 200 # add 0.2s avoid truncating
+ segment_audio = audio[start_ms:end_ms]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}-{id:03d}{file_suffix}"
+ )
+ segment_audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = (
+ save_path / rel_path.parent / f"{file_stem}-{id:03d}.lab"
+ )
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(segment.text)
+
+ file_path.unlink()
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+
+ audio = AudioSegment.from_wav(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
+ )
+
+ model_size = "large-v3"
+
+ model = WhisperModel(
+ model_size,
+ device="cuda",
+ compute_type="float16",
+ download_root="faster_whisper",
+ )
+
+ segments, info = model.transcribe(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
+ beam_size=5,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for i, segment in enumerate(segments):
+ print(
+ "Segment %03d [%.2fs -> %.2fs] %s"
+ % (i, segment.start, segment.end, segment.text)
+ )
+ start_ms = int(segment.start * 1000)
+ end_ms = int(segment.end * 1000)
+ segment_audio = audio[start_ms:end_ms]
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
+ print(f"Exported segment_{i:03d}.wav")
+
+ print("All segments have been exported.")