eswardivi commited on
Commit
5fe16b1
·
1 Parent(s): cf0d18e

Ported the space to use API and local version still exists but in different branch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +38 -148
  2. bckp.py +176 -0
  3. conver.py +156 -0
  4. logo.png +0 -0
  5. melo/__init__.py +0 -0
  6. melo/api.py +0 -121
  7. melo/attentions.py +0 -459
  8. melo/commons.py +0 -160
  9. melo/download_utils.py +0 -47
  10. melo/mel_processing.py +0 -174
  11. melo/models.py +0 -1038
  12. melo/modules.py +0 -598
  13. melo/split_utils.py +0 -131
  14. melo/text/__init__.py +0 -35
  15. melo/text/chinese.py +0 -199
  16. melo/text/chinese_bert.py +0 -107
  17. melo/text/chinese_mix.py +0 -253
  18. melo/text/cleaner.py +0 -36
  19. melo/text/cleaner_multiling.py +0 -110
  20. melo/text/cmudict.rep +0 -0
  21. melo/text/cmudict_cache.pickle +0 -3
  22. melo/text/english.py +0 -284
  23. melo/text/english_bert.py +0 -39
  24. melo/text/english_utils/__init__.py +0 -0
  25. melo/text/english_utils/abbreviations.py +0 -35
  26. melo/text/english_utils/number_norm.py +0 -97
  27. melo/text/english_utils/time_norm.py +0 -47
  28. melo/text/es_phonemizer/__init__.py +0 -0
  29. melo/text/es_phonemizer/base.py +0 -140
  30. melo/text/es_phonemizer/cleaner.py +0 -109
  31. melo/text/es_phonemizer/es_symbols.json +0 -79
  32. melo/text/es_phonemizer/es_symbols.txt +0 -1
  33. melo/text/es_phonemizer/es_symbols_v2.json +0 -83
  34. melo/text/es_phonemizer/es_to_ipa.py +0 -12
  35. melo/text/es_phonemizer/gruut_wrapper.py +0 -253
  36. melo/text/es_phonemizer/punctuation.py +0 -174
  37. melo/text/es_phonemizer/spanish_symbols.txt +0 -1
  38. melo/text/es_phonemizer/test.ipynb +0 -124
  39. melo/text/fr_phonemizer/__init__.py +0 -0
  40. melo/text/fr_phonemizer/base.py +0 -140
  41. melo/text/fr_phonemizer/cleaner.py +0 -122
  42. melo/text/fr_phonemizer/en_symbols.json +0 -78
  43. melo/text/fr_phonemizer/fr_symbols.json +0 -89
  44. melo/text/fr_phonemizer/fr_to_ipa.py +0 -30
  45. melo/text/fr_phonemizer/french_abbreviations.py +0 -48
  46. melo/text/fr_phonemizer/french_symbols.txt +0 -1
  47. melo/text/fr_phonemizer/gruut_wrapper.py +0 -258
  48. melo/text/fr_phonemizer/punctuation.py +0 -172
  49. melo/text/french.py +0 -94
  50. melo/text/french_bert.py +0 -39
app.py CHANGED
@@ -1,150 +1,30 @@
1
  import gradio as gr
2
- import spaces
3
- import os, torch, io
4
- import json
5
- import re
6
- os.system("python -m unidic download")
7
- import httpx
8
- # print("Make sure you've downloaded unidic (python -m unidic download) for this WebUI to work.")
9
- from melo.api import TTS
10
- import tempfile
11
- import wave
12
- from pydub import AudioSegment
13
- from transformers import (
14
- AutoModelForCausalLM,
15
- AutoTokenizer,
16
- TextIteratorStreamer,
17
- BitsAndBytesConfig,
18
- )
19
- from threading import Thread
20
 
21
- from gradio_client import Client
 
 
22
 
23
- # client = Client("eswardivi/AIO_Chat")
24
- quantization_config = BitsAndBytesConfig(
25
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
26
- )
27
 
28
- model = AutoModelForCausalLM.from_pretrained(
29
- "NousResearch/Hermes-2-Pro-Llama-3-8B", quantization_config=quantization_config
30
- )
31
- tok = AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B",revision='8ab73a6800796d84448bc936db9bac5ad9f984ae')
32
- terminators = [
33
- tok.eos_token_id,
34
- tok.convert_tokens_to_ids("<|eot_id|>")
35
- ]
36
- def validate_url(url):
37
  try:
38
- response = httpx.get(url, timeout=60.0)
39
- response.raise_for_status()
40
- return response.text
41
- except httpx.RequestError as e:
42
- return f"An error occurred while requesting {url}: {str(e)}"
43
- except httpx.HTTPStatusError as e:
44
- return f"Error response {e.response.status_code} while requesting {url}"
 
 
 
45
  except Exception as e:
46
- return f"An unexpected error occurred: {str(e)}"
47
-
48
- def fetch_text(url):
49
- print("Entered Webpage Extraction")
50
- prefix_url = "https://r.jina.ai/"
51
- full_url = prefix_url + url
52
- print(full_url)
53
- print("Exited Webpage Extraction")
54
- return validate_url(full_url)
55
-
56
- @spaces.GPU(duration=100)
57
- def synthesize(article_url,progress_audio=gr.Progress()):
58
- if not article_url.startswith("http://") and not article_url.startswith("https://"):
59
- return "URL must start with 'http://' or 'https://'",None
60
-
61
- text = fetch_text(article_url)
62
- if "Error" in text:
63
- return text, None
64
-
65
- device = "cuda" if torch.cuda.is_available() else "cpu"
66
- template = """
67
- {
68
- "conversation": [
69
- {"speaker": "", "text": ""},
70
- {"speaker": "", "text": ""}
71
- ]
72
- }
73
- """
74
- chat = []
75
- chat.append(
76
- {
77
- "role": "user",
78
- "content": text + """\n Convert the provided text into a short, informative podcast conversation between two experts. The tone should be professional and engaging. Please adhere to the following format and return only JSON:
79
- {
80
- "conversation": [
81
- {"speaker": "", "text": ""},
82
- {"speaker": "", "text": ""}
83
- ]
84
- }
85
- """,
86
- }
87
- )
88
-
89
-
90
- messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
91
- model_inputs = tok([messages], return_tensors="pt").to(device)
92
- streamer = TextIteratorStreamer(
93
- tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
94
- )
95
- generate_kwargs = dict(
96
- model_inputs,
97
- streamer=streamer,
98
- max_new_tokens=1024,
99
- do_sample=True,
100
- temperature=0.9,
101
- eos_token_id=terminators,
102
- )
103
- print("Entered Generation")
104
- t = Thread(target=model.generate, kwargs=generate_kwargs)
105
- t.start()
106
-
107
- partial_text = ""
108
- for new_text in streamer:
109
- partial_text += new_text
110
-
111
- # print("Calling API")
112
- # result = client.predict(
113
- # f"{text} \n Convert the text as Elaborate Conversation between two people as Podcast.\nfollowing this template and return only JSON \n {template}",
114
- # 0.9,
115
- # True,
116
- # 1024,
117
- # api_name="/chat"
118
- # )
119
- # print("API Call Completed")
120
- pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}"
121
- json_match = re.search(pattern, partial_text)
122
- print("Exited Generation")
123
- if json_match:
124
- conversation=json_match.group()
125
- else:
126
- conversation = template
127
- print(partial_text)
128
- print(conversation)
129
- speed = 1.0
130
- models = {"EN": TTS(language="EN", device=device)}
131
- speakers = ["EN-Default", "EN-US"]
132
- combined_audio = AudioSegment.empty()
133
-
134
- conversation_dict = json.loads(conversation)
135
- for i, turn in enumerate(conversation_dict["conversation"]):
136
- bio = io.BytesIO()
137
- text = turn["text"]
138
- speaker = speakers[i % 2]
139
- speaker_id = models["EN"].hps.data.spk2id[speaker]
140
- models["EN"].tts_to_file(text, speaker_id, bio, speed=1.0, pbar=progress_audio.tqdm, format="wav")
141
- bio.seek(0)
142
- audio_segment = AudioSegment.from_file(bio, format="wav")
143
- combined_audio += audio_segment
144
- final_audio_path = "final.mp3"
145
- combined_audio.export(final_audio_path, format="mp3")
146
- return conversation, final_audio_path
147
-
148
 
149
  with gr.Blocks(theme='gstaff/sketch') as demo:
150
  gr.Markdown("# Turn Any Article into a Podcast")
@@ -158,19 +38,29 @@ with gr.Blocks(theme='gstaff/sketch') as demo:
158
  gr.Markdown("""
159
  - View the code at [GitHub - NarrateIt](https://github.com/EswarDivi/NarrateIt).
160
  """)
 
161
  with gr.Group():
162
- text = gr.Textbox(label="Article Link")
163
  btn = gr.Button("Podcastify", variant="primary")
 
164
  with gr.Row():
165
- conv_display = gr.Textbox(label="Conversation", interactive=False)
166
- aud = gr.Audio(interactive=False)
167
- btn.click(synthesize, inputs=[text], outputs=[conv_display, aud])
 
 
 
 
 
 
 
 
 
168
  gr.Markdown("""
169
  Special thanks to:
170
 
171
  - [gstaff/sketch](https://huggingface.co/spaces/gstaff/sketch) for the Sketch Theme.
172
- - [mrfakename/MeloTTS](https://huggingface.co/spaces/mrfakename/MeloTTS) and [GitHub](https://github.com/myshell-ai/MeloTTS) for MeloTTS.
173
- - [Hermes-2-Pro-Llama-3-8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B) for Function Calling Support.
174
  - [Jina AI](https://jina.ai/reader/) for the web page parsing.
175
  """)
176
- demo.queue(api_open=True, default_concurrency_limit=10).launch(show_api=True,share=True)
 
 
1
  import gradio as gr
2
+ import os
3
+ import asyncio
4
+ from conver import ConversationConfig, URLToAudioConverter
5
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ load_dotenv()
8
+ def synthesize_sync(article_url):
9
+ return asyncio.run(synthesize(article_url))
10
 
11
+ async def synthesize(article_url):
12
+ if not article_url:
13
+ return "Please provide a valid URL.", None
 
14
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ config = ConversationConfig()
17
+ converter = URLToAudioConverter(config, llm_api_key=os.environ.get("TOGETHER_API_KEY"))
18
+
19
+ output_file, conversation = await converter.url_to_audio(
20
+ article_url,
21
+ "en-US-AvaMultilingualNeural",
22
+ "en-US-AndrewMultilingualNeural"
23
+ )
24
+
25
+ return conversation, output_file
26
  except Exception as e:
27
+ return f"Error: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  with gr.Blocks(theme='gstaff/sketch') as demo:
30
  gr.Markdown("# Turn Any Article into a Podcast")
 
38
  gr.Markdown("""
39
  - View the code at [GitHub - NarrateIt](https://github.com/EswarDivi/NarrateIt).
40
  """)
41
+
42
  with gr.Group():
43
+ text = gr.Textbox(label="Article Link", placeholder="Enter the article URL here...")
44
  btn = gr.Button("Podcastify", variant="primary")
45
+
46
  with gr.Row():
47
+ conv_display = gr.Textbox(label="Conversation", interactive=False, lines=10)
48
+ aud = gr.Audio(label="Generated Podcast", interactive=False)
49
+
50
+ gr.Examples(
51
+ examples=["https://huggingface.co/blog/gradio-mcp"],
52
+ inputs=text,
53
+ fn=synthesize_sync,
54
+ outputs=[conv_display, aud]
55
+ )
56
+
57
+ btn.click(synthesize_sync, inputs=[text], outputs=[conv_display, aud])
58
+
59
  gr.Markdown("""
60
  Special thanks to:
61
 
62
  - [gstaff/sketch](https://huggingface.co/spaces/gstaff/sketch) for the Sketch Theme.
 
 
63
  - [Jina AI](https://jina.ai/reader/) for the web page parsing.
64
  """)
65
+
66
+ demo.queue(api_open=True, default_concurrency_limit=15).launch(show_api=True)
bckp.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os, torch, io
4
+ import json
5
+ import re
6
+ os.system("python -m unidic download")
7
+ import httpx
8
+ # print("Make sure you've downloaded unidic (python -m unidic download) for this WebUI to work.")
9
+ from melo.api import TTS
10
+ import tempfile
11
+ import wave
12
+ from pydub import AudioSegment
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ TextIteratorStreamer,
17
+ BitsAndBytesConfig,
18
+ )
19
+ from threading import Thread
20
+
21
+ from gradio_client import Client
22
+
23
+ # client = Client("eswardivi/AIO_Chat")
24
+ quantization_config = BitsAndBytesConfig(
25
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
26
+ )
27
+
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ "NousResearch/Hermes-2-Pro-Llama-3-8B", quantization_config=quantization_config
30
+ )
31
+ tok = AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B",revision='8ab73a6800796d84448bc936db9bac5ad9f984ae')
32
+ terminators = [
33
+ tok.eos_token_id,
34
+ tok.convert_tokens_to_ids("<|eot_id|>")
35
+ ]
36
+ def validate_url(url):
37
+ try:
38
+ response = httpx.get(url, timeout=60.0)
39
+ response.raise_for_status()
40
+ return response.text
41
+ except httpx.RequestError as e:
42
+ return f"An error occurred while requesting {url}: {str(e)}"
43
+ except httpx.HTTPStatusError as e:
44
+ return f"Error response {e.response.status_code} while requesting {url}"
45
+ except Exception as e:
46
+ return f"An unexpected error occurred: {str(e)}"
47
+
48
+ def fetch_text(url):
49
+ print("Entered Webpage Extraction")
50
+ prefix_url = "https://r.jina.ai/"
51
+ full_url = prefix_url + url
52
+ print(full_url)
53
+ print("Exited Webpage Extraction")
54
+ return validate_url(full_url)
55
+
56
+ @spaces.GPU(duration=100)
57
+ def synthesize(article_url,progress_audio=gr.Progress()):
58
+ if not article_url.startswith("http://") and not article_url.startswith("https://"):
59
+ return "URL must start with 'http://' or 'https://'",None
60
+
61
+ text = fetch_text(article_url)
62
+ if "Error" in text:
63
+ return text, None
64
+
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ template = """
67
+ {
68
+ "conversation": [
69
+ {"speaker": "", "text": ""},
70
+ {"speaker": "", "text": ""}
71
+ ]
72
+ }
73
+ """
74
+ chat = []
75
+ chat.append(
76
+ {
77
+ "role": "user",
78
+ "content": text + """\n Convert the provided text into a short, informative podcast conversation between two experts. The tone should be professional and engaging. Please adhere to the following format and return only JSON:
79
+ {
80
+ "conversation": [
81
+ {"speaker": "", "text": ""},
82
+ {"speaker": "", "text": ""}
83
+ ]
84
+ }
85
+ """,
86
+ }
87
+ )
88
+
89
+
90
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
91
+ model_inputs = tok([messages], return_tensors="pt").to(device)
92
+ streamer = TextIteratorStreamer(
93
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
94
+ )
95
+ generate_kwargs = dict(
96
+ model_inputs,
97
+ streamer=streamer,
98
+ max_new_tokens=1024,
99
+ do_sample=True,
100
+ temperature=0.9,
101
+ eos_token_id=terminators,
102
+ )
103
+ print("Entered Generation")
104
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
105
+ t.start()
106
+
107
+ partial_text = ""
108
+ for new_text in streamer:
109
+ partial_text += new_text
110
+
111
+ # print("Calling API")
112
+ # result = client.predict(
113
+ # f"{text} \n Convert the text as Elaborate Conversation between two people as Podcast.\nfollowing this template and return only JSON \n {template}",
114
+ # 0.9,
115
+ # True,
116
+ # 1024,
117
+ # api_name="/chat"
118
+ # )
119
+ # print("API Call Completed")
120
+ pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}"
121
+ json_match = re.search(pattern, partial_text)
122
+ print("Exited Generation")
123
+ if json_match:
124
+ conversation=json_match.group()
125
+ else:
126
+ conversation = template
127
+ print(partial_text)
128
+ print(conversation)
129
+ speed = 1.0
130
+ models = {"EN": TTS(language="EN", device=device)}
131
+ speakers = ["EN-Default", "EN-US"]
132
+ combined_audio = AudioSegment.empty()
133
+
134
+ conversation_dict = json.loads(conversation)
135
+ for i, turn in enumerate(conversation_dict["conversation"]):
136
+ bio = io.BytesIO()
137
+ text = turn["text"]
138
+ speaker = speakers[i % 2]
139
+ speaker_id = models["EN"].hps.data.spk2id[speaker]
140
+ models["EN"].tts_to_file(text, speaker_id, bio, speed=1.0, pbar=progress_audio.tqdm, format="wav")
141
+ bio.seek(0)
142
+ audio_segment = AudioSegment.from_file(bio, format="wav")
143
+ combined_audio += audio_segment
144
+ final_audio_path = "final.mp3"
145
+ combined_audio.export(final_audio_path, format="mp3")
146
+ return conversation, final_audio_path
147
+
148
+
149
+ with gr.Blocks(theme='gstaff/sketch') as demo:
150
+ gr.Markdown("# Turn Any Article into a Podcast")
151
+ gr.Markdown("## Easily convert articles from URLs into listenable audio podcasts.")
152
+ gr.Markdown("### Instructions")
153
+ gr.Markdown("""
154
+ - **Step 1:** Paste the URL of the article you want to convert into the textbox.
155
+ - **Step 2:** Click on "Podcastify" to generate the podcast.
156
+ - **Step 3:** Listen to the podcast or view the conversation.
157
+ """)
158
+ gr.Markdown("""
159
+ - View the code at [GitHub - NarrateIt](https://github.com/EswarDivi/NarrateIt).
160
+ """)
161
+ with gr.Group():
162
+ text = gr.Textbox(label="Article Link")
163
+ btn = gr.Button("Podcastify", variant="primary")
164
+ with gr.Row():
165
+ conv_display = gr.Textbox(label="Conversation", interactive=False)
166
+ aud = gr.Audio(interactive=False)
167
+ btn.click(synthesize, inputs=[text], outputs=[conv_display, aud])
168
+ gr.Markdown("""
169
+ Special thanks to:
170
+
171
+ - [gstaff/sketch](https://huggingface.co/spaces/gstaff/sketch) for the Sketch Theme.
172
+ - [mrfakename/MeloTTS](https://huggingface.co/spaces/mrfakename/MeloTTS) and [GitHub](https://github.com/myshell-ai/MeloTTS) for MeloTTS.
173
+ - [Hermes-2-Pro-Llama-3-8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B) for Function Calling Support.
174
+ - [Jina AI](https://jina.ai/reader/) for the web page parsing.
175
+ """)
176
+ demo.queue(api_open=True, default_concurrency_limit=10).launch(show_api=True,share=True)
conver.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple, Dict
3
+ import os
4
+ import re
5
+ import httpx
6
+ import json
7
+ from openai import OpenAI
8
+ import edge_tts
9
+ import tempfile
10
+ import wave
11
+ from pydub import AudioSegment
12
+ import base64
13
+ from pathlib import Path
14
+
15
+
16
+ @dataclass
17
+ class ConversationConfig:
18
+ max_words: int = 3000
19
+ prefix_url: str = "https://r.jina.ai/"
20
+ model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
21
+
22
+
23
+ class URLToAudioConverter:
24
+ def __init__(self, config: ConversationConfig, llm_api_key: str):
25
+ self.config = config
26
+ self.llm_client = OpenAI(api_key=llm_api_key, base_url="https://api.together.xyz/v1")
27
+ self.llm_out = None
28
+
29
+ def fetch_text(self, url: str) -> str:
30
+ if not url:
31
+ raise ValueError("URL cannot be empty")
32
+
33
+ full_url = f"{self.config.prefix_url}{url}"
34
+ try:
35
+ response = httpx.get(full_url, timeout=60.0)
36
+ response.raise_for_status()
37
+ return response.text
38
+ except httpx.HTTPError as e:
39
+ raise RuntimeError(f"Failed to fetch URL: {e}")
40
+
41
+ def extract_conversation(self, text: str) -> Dict:
42
+ if not text:
43
+ raise ValueError("Input text cannot be empty")
44
+
45
+ try:
46
+ chat_completion = self.llm_client.chat.completions.create(
47
+ messages=[{"role": "user", "content": self._build_prompt(text)}],
48
+ model=self.config.model_name,
49
+ )
50
+
51
+ pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}"
52
+ json_match = re.search(pattern, chat_completion.choices[0].message.content)
53
+
54
+ if not json_match:
55
+ raise ValueError("No valid JSON found in response")
56
+
57
+ return json.loads(json_match.group())
58
+ except Exception as e:
59
+ raise RuntimeError(f"Failed to extract conversation: {e}")
60
+
61
+ def _build_prompt(self, text: str) -> str:
62
+ template = """
63
+ {
64
+ "conversation": [
65
+ {"speaker": "", "text": ""},
66
+ {"speaker": "", "text": ""}
67
+ ]
68
+ }
69
+ """
70
+ return (
71
+ f"{text}\nConvert the provided text into a short informative and crisp "
72
+ f"podcast conversation between two experts. The tone should be "
73
+ f"professional and engaging. Please adhere to the following "
74
+ f"format and return the conversation in JSON:\n{template}"
75
+ )
76
+
77
+ async def text_to_speech(self, conversation_json: Dict, voice_1: str, voice_2: str) -> Tuple[List[str], str]:
78
+ output_dir = Path(self._create_output_directory())
79
+ filenames = []
80
+
81
+ try:
82
+ for i, turn in enumerate(conversation_json["conversation"]):
83
+ filename = output_dir / f"output_{i}.wav"
84
+ voice = voice_1 if i % 2 == 0 else voice_2
85
+
86
+ tmp_path, error = await self._generate_audio(turn["text"], voice)
87
+ if error:
88
+ raise RuntimeError(f"Text-to-speech failed: {error}")
89
+
90
+ os.rename(tmp_path, filename)
91
+ filenames.append(str(filename))
92
+
93
+ return filenames, str(output_dir)
94
+ except Exception as e:
95
+ raise RuntimeError(f"Failed to convert text to speech: {e}")
96
+
97
+ async def _generate_audio(self, text: str, voice: str, rate: int = 0, pitch: int = 0) -> Tuple[str, str]:
98
+ if not text.strip():
99
+ return None, "Text cannot be empty"
100
+ if not voice:
101
+ return None, "Voice cannot be empty"
102
+
103
+ voice_short_name = voice.split(" - ")[0]
104
+ rate_str = f"{rate:+d}%"
105
+ pitch_str = f"{pitch:+d}Hz"
106
+ communicate = edge_tts.Communicate(text, voice_short_name, rate=rate_str, pitch=pitch_str)
107
+
108
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
109
+ tmp_path = tmp_file.name
110
+ await communicate.save(tmp_path)
111
+
112
+ return tmp_path, None
113
+
114
+ def _create_output_directory(self) -> str:
115
+ random_bytes = os.urandom(8)
116
+ folder_name = base64.urlsafe_b64encode(random_bytes).decode("utf-8")
117
+ os.makedirs(folder_name, exist_ok=True)
118
+ return folder_name
119
+
120
+ def combine_audio_files(self, filenames: List[str], output_file: str) -> None:
121
+ if not filenames:
122
+ raise ValueError("No input files provided")
123
+
124
+ try:
125
+ audio_segments = []
126
+
127
+ for filename in filenames:
128
+ audio_segment = AudioSegment.from_mp3(filename)
129
+ audio_segments.append(audio_segment)
130
+
131
+ combined = sum(audio_segments)
132
+
133
+ combined.export(output_file, format="wav")
134
+
135
+ for filename in filenames:
136
+ os.remove(filename)
137
+
138
+ except Exception as e:
139
+ raise RuntimeError(f"Failed to combine audio files: {e}")
140
+
141
+ async def url_to_audio(self, url: str, voice_1: str, voice_2: str) -> str:
142
+ text = self.fetch_text(url)
143
+
144
+ words = text.split()
145
+ if len(words) > self.config.max_words:
146
+ text = " ".join(words[: self.config.max_words])
147
+
148
+ conversation_json = self.extract_conversation(text)
149
+ self.llm_out = conversation_json
150
+ audio_files, folder_name = await self.text_to_speech(
151
+ conversation_json, voice_1, voice_2
152
+ )
153
+
154
+ final_output = os.path.join(folder_name, "combined_output.wav")
155
+ self.combine_audio_files(audio_files, final_output)
156
+ return final_output
logo.png DELETED
Binary file (146 kB)
 
melo/__init__.py DELETED
File without changes
melo/api.py DELETED
@@ -1,121 +0,0 @@
1
- import os
2
- import re
3
- import json
4
- import torch
5
- import librosa
6
- import soundfile
7
- import torchaudio
8
- import numpy as np
9
- import torch.nn as nn
10
- from tqdm import tqdm
11
- from . import utils
12
- from . import commons
13
- from .models import SynthesizerTrn
14
- from .split_utils import split_sentence
15
- from .mel_processing import spectrogram_torch, spectrogram_torch_conv
16
- from .download_utils import load_or_download_config, load_or_download_model
17
-
18
- class TTS(nn.Module):
19
- def __init__(self,
20
- language,
21
- device='cuda:0'):
22
- super().__init__()
23
- if 'cuda' in device:
24
- assert torch.cuda.is_available()
25
-
26
- # config_path =
27
- hps = load_or_download_config(language)
28
-
29
- num_languages = hps.num_languages
30
- num_tones = hps.num_tones
31
- symbols = hps.symbols
32
-
33
- model = SynthesizerTrn(
34
- len(symbols),
35
- hps.data.filter_length // 2 + 1,
36
- hps.train.segment_size // hps.data.hop_length,
37
- n_speakers=hps.data.n_speakers,
38
- num_tones=num_tones,
39
- num_languages=num_languages,
40
- **hps.model,
41
- ).to(device)
42
-
43
- model.eval()
44
- self.model = model
45
- self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
46
- self.hps = hps
47
- self.device = device
48
-
49
- # load state_dict
50
- checkpoint_dict = load_or_download_model(language, device)
51
- self.model.load_state_dict(checkpoint_dict['model'], strict=True)
52
-
53
- language = language.split('_')[0]
54
- self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
55
-
56
- @staticmethod
57
- def audio_numpy_concat(segment_data_list, sr, speed=1.):
58
- audio_segments = []
59
- for segment_data in segment_data_list:
60
- audio_segments += segment_data.reshape(-1).tolist()
61
- audio_segments += [0] * int((sr * 0.05) / speed)
62
- audio_segments = np.array(audio_segments).astype(np.float32)
63
- return audio_segments
64
-
65
- @staticmethod
66
- def split_sentences_into_pieces(text, language):
67
- texts = split_sentence(text, language_str=language)
68
- # print(" > Text splitted to sentences.")
69
- # print('\n'.join(texts))
70
- # print(" > ===========================")
71
- return texts
72
-
73
- def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None):
74
- language = self.language
75
- texts = self.split_sentences_into_pieces(text, language)
76
- audio_list = []
77
- tx = texts
78
- if pbar:
79
- tx = pbar(texts)
80
- else:
81
- if position:
82
- tx = tqdm(texts, position=position)
83
- else:
84
- tx = tqdm(texts)
85
- for t in tx:
86
- if language in ['EN', 'ZH_MIX_EN']:
87
- t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
88
- device = self.device
89
- bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
90
- with torch.no_grad():
91
- x_tst = phones.to(device).unsqueeze(0)
92
- tones = tones.to(device).unsqueeze(0)
93
- lang_ids = lang_ids.to(device).unsqueeze(0)
94
- bert = bert.to(device).unsqueeze(0)
95
- ja_bert = ja_bert.to(device).unsqueeze(0)
96
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
97
- del phones
98
- speakers = torch.LongTensor([speaker_id]).to(device)
99
- audio = self.model.infer(
100
- x_tst,
101
- x_tst_lengths,
102
- speakers,
103
- tones,
104
- lang_ids,
105
- bert,
106
- ja_bert,
107
- sdp_ratio=sdp_ratio,
108
- noise_scale=noise_scale,
109
- noise_scale_w=noise_scale_w,
110
- length_scale=1. / speed,
111
- )[0][0, 0].data.cpu().float().numpy()
112
- del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
113
- #
114
- audio_list.append(audio)
115
- torch.cuda.empty_cache()
116
- audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
117
-
118
- if output_path is None:
119
- return audio
120
- else:
121
- soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/attentions.py DELETED
@@ -1,459 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from . import commons
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class LayerNorm(nn.Module):
13
- def __init__(self, channels, eps=1e-5):
14
- super().__init__()
15
- self.channels = channels
16
- self.eps = eps
17
-
18
- self.gamma = nn.Parameter(torch.ones(channels))
19
- self.beta = nn.Parameter(torch.zeros(channels))
20
-
21
- def forward(self, x):
22
- x = x.transpose(1, -1)
23
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
- return x.transpose(1, -1)
25
-
26
-
27
- @torch.jit.script
28
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
- n_channels_int = n_channels[0]
30
- in_act = input_a + input_b
31
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
- acts = t_act * s_act
34
- return acts
35
-
36
-
37
- class Encoder(nn.Module):
38
- def __init__(
39
- self,
40
- hidden_channels,
41
- filter_channels,
42
- n_heads,
43
- n_layers,
44
- kernel_size=1,
45
- p_dropout=0.0,
46
- window_size=4,
47
- isflow=True,
48
- **kwargs
49
- ):
50
- super().__init__()
51
- self.hidden_channels = hidden_channels
52
- self.filter_channels = filter_channels
53
- self.n_heads = n_heads
54
- self.n_layers = n_layers
55
- self.kernel_size = kernel_size
56
- self.p_dropout = p_dropout
57
- self.window_size = window_size
58
-
59
- self.cond_layer_idx = self.n_layers
60
- if "gin_channels" in kwargs:
61
- self.gin_channels = kwargs["gin_channels"]
62
- if self.gin_channels != 0:
63
- self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
64
- self.cond_layer_idx = (
65
- kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
66
- )
67
- assert (
68
- self.cond_layer_idx < self.n_layers
69
- ), "cond_layer_idx should be less than n_layers"
70
- self.drop = nn.Dropout(p_dropout)
71
- self.attn_layers = nn.ModuleList()
72
- self.norm_layers_1 = nn.ModuleList()
73
- self.ffn_layers = nn.ModuleList()
74
- self.norm_layers_2 = nn.ModuleList()
75
-
76
- for i in range(self.n_layers):
77
- self.attn_layers.append(
78
- MultiHeadAttention(
79
- hidden_channels,
80
- hidden_channels,
81
- n_heads,
82
- p_dropout=p_dropout,
83
- window_size=window_size,
84
- )
85
- )
86
- self.norm_layers_1.append(LayerNorm(hidden_channels))
87
- self.ffn_layers.append(
88
- FFN(
89
- hidden_channels,
90
- hidden_channels,
91
- filter_channels,
92
- kernel_size,
93
- p_dropout=p_dropout,
94
- )
95
- )
96
- self.norm_layers_2.append(LayerNorm(hidden_channels))
97
-
98
- def forward(self, x, x_mask, g=None):
99
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
100
- x = x * x_mask
101
- for i in range(self.n_layers):
102
- if i == self.cond_layer_idx and g is not None:
103
- g = self.spk_emb_linear(g.transpose(1, 2))
104
- g = g.transpose(1, 2)
105
- x = x + g
106
- x = x * x_mask
107
- y = self.attn_layers[i](x, x, attn_mask)
108
- y = self.drop(y)
109
- x = self.norm_layers_1[i](x + y)
110
-
111
- y = self.ffn_layers[i](x, x_mask)
112
- y = self.drop(y)
113
- x = self.norm_layers_2[i](x + y)
114
- x = x * x_mask
115
- return x
116
-
117
-
118
- class Decoder(nn.Module):
119
- def __init__(
120
- self,
121
- hidden_channels,
122
- filter_channels,
123
- n_heads,
124
- n_layers,
125
- kernel_size=1,
126
- p_dropout=0.0,
127
- proximal_bias=False,
128
- proximal_init=True,
129
- **kwargs
130
- ):
131
- super().__init__()
132
- self.hidden_channels = hidden_channels
133
- self.filter_channels = filter_channels
134
- self.n_heads = n_heads
135
- self.n_layers = n_layers
136
- self.kernel_size = kernel_size
137
- self.p_dropout = p_dropout
138
- self.proximal_bias = proximal_bias
139
- self.proximal_init = proximal_init
140
-
141
- self.drop = nn.Dropout(p_dropout)
142
- self.self_attn_layers = nn.ModuleList()
143
- self.norm_layers_0 = nn.ModuleList()
144
- self.encdec_attn_layers = nn.ModuleList()
145
- self.norm_layers_1 = nn.ModuleList()
146
- self.ffn_layers = nn.ModuleList()
147
- self.norm_layers_2 = nn.ModuleList()
148
- for i in range(self.n_layers):
149
- self.self_attn_layers.append(
150
- MultiHeadAttention(
151
- hidden_channels,
152
- hidden_channels,
153
- n_heads,
154
- p_dropout=p_dropout,
155
- proximal_bias=proximal_bias,
156
- proximal_init=proximal_init,
157
- )
158
- )
159
- self.norm_layers_0.append(LayerNorm(hidden_channels))
160
- self.encdec_attn_layers.append(
161
- MultiHeadAttention(
162
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
163
- )
164
- )
165
- self.norm_layers_1.append(LayerNorm(hidden_channels))
166
- self.ffn_layers.append(
167
- FFN(
168
- hidden_channels,
169
- hidden_channels,
170
- filter_channels,
171
- kernel_size,
172
- p_dropout=p_dropout,
173
- causal=True,
174
- )
175
- )
176
- self.norm_layers_2.append(LayerNorm(hidden_channels))
177
-
178
- def forward(self, x, x_mask, h, h_mask):
179
- """
180
- x: decoder input
181
- h: encoder output
182
- """
183
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
184
- device=x.device, dtype=x.dtype
185
- )
186
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
187
- x = x * x_mask
188
- for i in range(self.n_layers):
189
- y = self.self_attn_layers[i](x, x, self_attn_mask)
190
- y = self.drop(y)
191
- x = self.norm_layers_0[i](x + y)
192
-
193
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
194
- y = self.drop(y)
195
- x = self.norm_layers_1[i](x + y)
196
-
197
- y = self.ffn_layers[i](x, x_mask)
198
- y = self.drop(y)
199
- x = self.norm_layers_2[i](x + y)
200
- x = x * x_mask
201
- return x
202
-
203
-
204
- class MultiHeadAttention(nn.Module):
205
- def __init__(
206
- self,
207
- channels,
208
- out_channels,
209
- n_heads,
210
- p_dropout=0.0,
211
- window_size=None,
212
- heads_share=True,
213
- block_length=None,
214
- proximal_bias=False,
215
- proximal_init=False,
216
- ):
217
- super().__init__()
218
- assert channels % n_heads == 0
219
-
220
- self.channels = channels
221
- self.out_channels = out_channels
222
- self.n_heads = n_heads
223
- self.p_dropout = p_dropout
224
- self.window_size = window_size
225
- self.heads_share = heads_share
226
- self.block_length = block_length
227
- self.proximal_bias = proximal_bias
228
- self.proximal_init = proximal_init
229
- self.attn = None
230
-
231
- self.k_channels = channels // n_heads
232
- self.conv_q = nn.Conv1d(channels, channels, 1)
233
- self.conv_k = nn.Conv1d(channels, channels, 1)
234
- self.conv_v = nn.Conv1d(channels, channels, 1)
235
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
236
- self.drop = nn.Dropout(p_dropout)
237
-
238
- if window_size is not None:
239
- n_heads_rel = 1 if heads_share else n_heads
240
- rel_stddev = self.k_channels**-0.5
241
- self.emb_rel_k = nn.Parameter(
242
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
243
- * rel_stddev
244
- )
245
- self.emb_rel_v = nn.Parameter(
246
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
247
- * rel_stddev
248
- )
249
-
250
- nn.init.xavier_uniform_(self.conv_q.weight)
251
- nn.init.xavier_uniform_(self.conv_k.weight)
252
- nn.init.xavier_uniform_(self.conv_v.weight)
253
- if proximal_init:
254
- with torch.no_grad():
255
- self.conv_k.weight.copy_(self.conv_q.weight)
256
- self.conv_k.bias.copy_(self.conv_q.bias)
257
-
258
- def forward(self, x, c, attn_mask=None):
259
- q = self.conv_q(x)
260
- k = self.conv_k(c)
261
- v = self.conv_v(c)
262
-
263
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
264
-
265
- x = self.conv_o(x)
266
- return x
267
-
268
- def attention(self, query, key, value, mask=None):
269
- # reshape [b, d, t] -> [b, n_h, t, d_k]
270
- b, d, t_s, t_t = (*key.size(), query.size(2))
271
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
272
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
273
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
274
-
275
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
276
- if self.window_size is not None:
277
- assert (
278
- t_s == t_t
279
- ), "Relative attention is only available for self-attention."
280
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
281
- rel_logits = self._matmul_with_relative_keys(
282
- query / math.sqrt(self.k_channels), key_relative_embeddings
283
- )
284
- scores_local = self._relative_position_to_absolute_position(rel_logits)
285
- scores = scores + scores_local
286
- if self.proximal_bias:
287
- assert t_s == t_t, "Proximal bias is only available for self-attention."
288
- scores = scores + self._attention_bias_proximal(t_s).to(
289
- device=scores.device, dtype=scores.dtype
290
- )
291
- if mask is not None:
292
- scores = scores.masked_fill(mask == 0, -1e4)
293
- if self.block_length is not None:
294
- assert (
295
- t_s == t_t
296
- ), "Local attention is only available for self-attention."
297
- block_mask = (
298
- torch.ones_like(scores)
299
- .triu(-self.block_length)
300
- .tril(self.block_length)
301
- )
302
- scores = scores.masked_fill(block_mask == 0, -1e4)
303
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
304
- p_attn = self.drop(p_attn)
305
- output = torch.matmul(p_attn, value)
306
- if self.window_size is not None:
307
- relative_weights = self._absolute_position_to_relative_position(p_attn)
308
- value_relative_embeddings = self._get_relative_embeddings(
309
- self.emb_rel_v, t_s
310
- )
311
- output = output + self._matmul_with_relative_values(
312
- relative_weights, value_relative_embeddings
313
- )
314
- output = (
315
- output.transpose(2, 3).contiguous().view(b, d, t_t)
316
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
317
- return output, p_attn
318
-
319
- def _matmul_with_relative_values(self, x, y):
320
- """
321
- x: [b, h, l, m]
322
- y: [h or 1, m, d]
323
- ret: [b, h, l, d]
324
- """
325
- ret = torch.matmul(x, y.unsqueeze(0))
326
- return ret
327
-
328
- def _matmul_with_relative_keys(self, x, y):
329
- """
330
- x: [b, h, l, d]
331
- y: [h or 1, m, d]
332
- ret: [b, h, l, m]
333
- """
334
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
335
- return ret
336
-
337
- def _get_relative_embeddings(self, relative_embeddings, length):
338
- 2 * self.window_size + 1
339
- # Pad first before slice to avoid using cond ops.
340
- pad_length = max(length - (self.window_size + 1), 0)
341
- slice_start_position = max((self.window_size + 1) - length, 0)
342
- slice_end_position = slice_start_position + 2 * length - 1
343
- if pad_length > 0:
344
- padded_relative_embeddings = F.pad(
345
- relative_embeddings,
346
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
347
- )
348
- else:
349
- padded_relative_embeddings = relative_embeddings
350
- used_relative_embeddings = padded_relative_embeddings[
351
- :, slice_start_position:slice_end_position
352
- ]
353
- return used_relative_embeddings
354
-
355
- def _relative_position_to_absolute_position(self, x):
356
- """
357
- x: [b, h, l, 2*l-1]
358
- ret: [b, h, l, l]
359
- """
360
- batch, heads, length, _ = x.size()
361
- # Concat columns of pad to shift from relative to absolute indexing.
362
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
363
-
364
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
365
- x_flat = x.view([batch, heads, length * 2 * length])
366
- x_flat = F.pad(
367
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
368
- )
369
-
370
- # Reshape and slice out the padded elements.
371
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
372
- :, :, :length, length - 1 :
373
- ]
374
- return x_final
375
-
376
- def _absolute_position_to_relative_position(self, x):
377
- """
378
- x: [b, h, l, l]
379
- ret: [b, h, l, 2*l-1]
380
- """
381
- batch, heads, length, _ = x.size()
382
- # pad along column
383
- x = F.pad(
384
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
385
- )
386
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
387
- # add 0's in the beginning that will skew the elements after reshape
388
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
389
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
390
- return x_final
391
-
392
- def _attention_bias_proximal(self, length):
393
- """Bias for self-attention to encourage attention to close positions.
394
- Args:
395
- length: an integer scalar.
396
- Returns:
397
- a Tensor with shape [1, 1, length, length]
398
- """
399
- r = torch.arange(length, dtype=torch.float32)
400
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
401
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
402
-
403
-
404
- class FFN(nn.Module):
405
- def __init__(
406
- self,
407
- in_channels,
408
- out_channels,
409
- filter_channels,
410
- kernel_size,
411
- p_dropout=0.0,
412
- activation=None,
413
- causal=False,
414
- ):
415
- super().__init__()
416
- self.in_channels = in_channels
417
- self.out_channels = out_channels
418
- self.filter_channels = filter_channels
419
- self.kernel_size = kernel_size
420
- self.p_dropout = p_dropout
421
- self.activation = activation
422
- self.causal = causal
423
-
424
- if causal:
425
- self.padding = self._causal_padding
426
- else:
427
- self.padding = self._same_padding
428
-
429
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
430
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
431
- self.drop = nn.Dropout(p_dropout)
432
-
433
- def forward(self, x, x_mask):
434
- x = self.conv_1(self.padding(x * x_mask))
435
- if self.activation == "gelu":
436
- x = x * torch.sigmoid(1.702 * x)
437
- else:
438
- x = torch.relu(x)
439
- x = self.drop(x)
440
- x = self.conv_2(self.padding(x * x_mask))
441
- return x * x_mask
442
-
443
- def _causal_padding(self, x):
444
- if self.kernel_size == 1:
445
- return x
446
- pad_l = self.kernel_size - 1
447
- pad_r = 0
448
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
449
- x = F.pad(x, commons.convert_pad_shape(padding))
450
- return x
451
-
452
- def _same_padding(self, x):
453
- if self.kernel_size == 1:
454
- return x
455
- pad_l = (self.kernel_size - 1) // 2
456
- pad_r = self.kernel_size // 2
457
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
458
- x = F.pad(x, commons.convert_pad_shape(padding))
459
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/commons.py DELETED
@@ -1,160 +0,0 @@
1
- import math
2
- import torch
3
- from torch.nn import functional as F
4
-
5
-
6
- def init_weights(m, mean=0.0, std=0.01):
7
- classname = m.__class__.__name__
8
- if classname.find("Conv") != -1:
9
- m.weight.data.normal_(mean, std)
10
-
11
-
12
- def get_padding(kernel_size, dilation=1):
13
- return int((kernel_size * dilation - dilation) / 2)
14
-
15
-
16
- def convert_pad_shape(pad_shape):
17
- layer = pad_shape[::-1]
18
- pad_shape = [item for sublist in layer for item in sublist]
19
- return pad_shape
20
-
21
-
22
- def intersperse(lst, item):
23
- result = [item] * (len(lst) * 2 + 1)
24
- result[1::2] = lst
25
- return result
26
-
27
-
28
- def kl_divergence(m_p, logs_p, m_q, logs_q):
29
- """KL(P||Q)"""
30
- kl = (logs_q - logs_p) - 0.5
31
- kl += (
32
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
- )
34
- return kl
35
-
36
-
37
- def rand_gumbel(shape):
38
- """Sample from the Gumbel distribution, protect from overflows."""
39
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
- return -torch.log(-torch.log(uniform_samples))
41
-
42
-
43
- def rand_gumbel_like(x):
44
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
- return g
46
-
47
-
48
- def slice_segments(x, ids_str, segment_size=4):
49
- ret = torch.zeros_like(x[:, :, :segment_size])
50
- for i in range(x.size(0)):
51
- idx_str = ids_str[i]
52
- idx_end = idx_str + segment_size
53
- ret[i] = x[i, :, idx_str:idx_end]
54
- return ret
55
-
56
-
57
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
- b, d, t = x.size()
59
- if x_lengths is None:
60
- x_lengths = t
61
- ids_str_max = x_lengths - segment_size + 1
62
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
- ret = slice_segments(x, ids_str, segment_size)
64
- return ret, ids_str
65
-
66
-
67
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
- position = torch.arange(length, dtype=torch.float)
69
- num_timescales = channels // 2
70
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
- num_timescales - 1
72
- )
73
- inv_timescales = min_timescale * torch.exp(
74
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
- )
76
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
- signal = F.pad(signal, [0, 0, 0, channels % 2])
79
- signal = signal.view(1, channels, length)
80
- return signal
81
-
82
-
83
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
- b, channels, length = x.size()
85
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
- return x + signal.to(dtype=x.dtype, device=x.device)
87
-
88
-
89
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
- b, channels, length = x.size()
91
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
-
94
-
95
- def subsequent_mask(length):
96
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
- return mask
98
-
99
-
100
- @torch.jit.script
101
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
- n_channels_int = n_channels[0]
103
- in_act = input_a + input_b
104
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
- acts = t_act * s_act
107
- return acts
108
-
109
-
110
- def convert_pad_shape(pad_shape):
111
- layer = pad_shape[::-1]
112
- pad_shape = [item for sublist in layer for item in sublist]
113
- return pad_shape
114
-
115
-
116
- def shift_1d(x):
117
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
- return x
119
-
120
-
121
- def sequence_mask(length, max_length=None):
122
- if max_length is None:
123
- max_length = length.max()
124
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
- return x.unsqueeze(0) < length.unsqueeze(1)
126
-
127
-
128
- def generate_path(duration, mask):
129
- """
130
- duration: [b, 1, t_x]
131
- mask: [b, 1, t_y, t_x]
132
- """
133
-
134
- b, _, t_y, t_x = mask.shape
135
- cum_duration = torch.cumsum(duration, -1)
136
-
137
- cum_duration_flat = cum_duration.view(b * t_x)
138
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
- path = path.view(b, t_x, t_y)
140
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141
- path = path.unsqueeze(1).transpose(2, 3) * mask
142
- return path
143
-
144
-
145
- def clip_grad_value_(parameters, clip_value, norm_type=2):
146
- if isinstance(parameters, torch.Tensor):
147
- parameters = [parameters]
148
- parameters = list(filter(lambda p: p.grad is not None, parameters))
149
- norm_type = float(norm_type)
150
- if clip_value is not None:
151
- clip_value = float(clip_value)
152
-
153
- total_norm = 0
154
- for p in parameters:
155
- param_norm = p.grad.data.norm(norm_type)
156
- total_norm += param_norm.item() ** norm_type
157
- if clip_value is not None:
158
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
159
- total_norm = total_norm ** (1.0 / norm_type)
160
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/download_utils.py DELETED
@@ -1,47 +0,0 @@
1
- import torch
2
- import os
3
- from . import utils
4
-
5
- DOWNLOAD_CKPT_URLS = {
6
- 'EN': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN/checkpoint.pth',
7
- 'EN_V2': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN_V2/checkpoint.pth',
8
- 'FR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/FR/checkpoint.pth',
9
- 'JP': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/JP/checkpoint.pth',
10
- 'ES': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ES/checkpoint.pth',
11
- 'ZH': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ZH/checkpoint.pth',
12
- 'KR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/KR/checkpoint.pth',
13
- }
14
-
15
- DOWNLOAD_CONFIG_URLS = {
16
- 'EN': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN/config.json',
17
- 'EN_V2': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/EN_V2/config.json',
18
- 'FR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/FR/config.json',
19
- 'JP': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/JP/config.json',
20
- 'ES': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ES/config.json',
21
- 'ZH': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/ZH/config.json',
22
- 'KR': 'https://myshell-public-repo-host.s3.amazonaws.com/openvoice/basespeakers/KR/config.json',
23
- }
24
-
25
- def load_or_download_config(locale):
26
- language = locale.split('-')[0].upper()
27
- assert language in DOWNLOAD_CONFIG_URLS
28
- config_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/config.json')
29
- try:
30
- return utils.get_hparams_from_file(config_path)
31
- except:
32
- # download
33
- os.makedirs(os.path.dirname(config_path), exist_ok=True)
34
- os.system(f'wget {DOWNLOAD_CONFIG_URLS[language]} -O {config_path}')
35
- return utils.get_hparams_from_file(config_path)
36
-
37
- def load_or_download_model(locale, device):
38
- language = locale.split('-')[0].upper()
39
- assert language in DOWNLOAD_CKPT_URLS
40
- ckpt_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/checkpoint.pth')
41
- try:
42
- return torch.load(ckpt_path, map_location=device)
43
- except:
44
- # download
45
- os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
46
- os.system(f'wget {DOWNLOAD_CKPT_URLS[language]} -O {ckpt_path}')
47
- return torch.load(ckpt_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/mel_processing.py DELETED
@@ -1,174 +0,0 @@
1
- import torch
2
- import torch.utils.data
3
- import librosa
4
- from librosa.filters import mel as librosa_mel_fn
5
-
6
- MAX_WAV_VALUE = 32768.0
7
-
8
-
9
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
10
- """
11
- PARAMS
12
- ------
13
- C: compression factor
14
- """
15
- return torch.log(torch.clamp(x, min=clip_val) * C)
16
-
17
-
18
- def dynamic_range_decompression_torch(x, C=1):
19
- """
20
- PARAMS
21
- ------
22
- C: compression factor used to compress
23
- """
24
- return torch.exp(x) / C
25
-
26
-
27
- def spectral_normalize_torch(magnitudes):
28
- output = dynamic_range_compression_torch(magnitudes)
29
- return output
30
-
31
-
32
- def spectral_de_normalize_torch(magnitudes):
33
- output = dynamic_range_decompression_torch(magnitudes)
34
- return output
35
-
36
-
37
- mel_basis = {}
38
- hann_window = {}
39
-
40
-
41
- def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
42
- if torch.min(y) < -1.1:
43
- print("min value is ", torch.min(y))
44
- if torch.max(y) > 1.1:
45
- print("max value is ", torch.max(y))
46
-
47
- global hann_window
48
- dtype_device = str(y.dtype) + "_" + str(y.device)
49
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
50
- if wnsize_dtype_device not in hann_window:
51
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
52
- dtype=y.dtype, device=y.device
53
- )
54
-
55
- y = torch.nn.functional.pad(
56
- y.unsqueeze(1),
57
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
58
- mode="reflect",
59
- )
60
- y = y.squeeze(1)
61
-
62
- spec = torch.stft(
63
- y,
64
- n_fft,
65
- hop_length=hop_size,
66
- win_length=win_size,
67
- window=hann_window[wnsize_dtype_device],
68
- center=center,
69
- pad_mode="reflect",
70
- normalized=False,
71
- onesided=True,
72
- return_complex=False,
73
- )
74
-
75
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
76
- return spec
77
-
78
-
79
- def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
80
- global hann_window
81
- dtype_device = str(y.dtype) + '_' + str(y.device)
82
- wnsize_dtype_device = str(win_size) + '_' + dtype_device
83
- if wnsize_dtype_device not in hann_window:
84
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
85
-
86
- y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
87
-
88
- # ******************** original ************************#
89
- # y = y.squeeze(1)
90
- # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
91
- # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
92
-
93
- # ******************** ConvSTFT ************************#
94
- freq_cutoff = n_fft // 2 + 1
95
- fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
96
- forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
97
- forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
98
-
99
- import torch.nn.functional as F
100
-
101
- # if center:
102
- # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
103
- assert center is False
104
-
105
- forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
106
- spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
107
-
108
-
109
- # ******************** Verification ************************#
110
- spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
111
- center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
112
- assert torch.allclose(spec1, spec2, atol=1e-4)
113
-
114
- spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
115
- return spec
116
-
117
-
118
- def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
119
- global mel_basis
120
- dtype_device = str(spec.dtype) + "_" + str(spec.device)
121
- fmax_dtype_device = str(fmax) + "_" + dtype_device
122
- if fmax_dtype_device not in mel_basis:
123
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
124
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
125
- dtype=spec.dtype, device=spec.device
126
- )
127
- spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
128
- spec = spectral_normalize_torch(spec)
129
- return spec
130
-
131
-
132
- def mel_spectrogram_torch(
133
- y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
134
- ):
135
- global mel_basis, hann_window
136
- dtype_device = str(y.dtype) + "_" + str(y.device)
137
- fmax_dtype_device = str(fmax) + "_" + dtype_device
138
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
139
- if fmax_dtype_device not in mel_basis:
140
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
141
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
142
- dtype=y.dtype, device=y.device
143
- )
144
- if wnsize_dtype_device not in hann_window:
145
- hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
146
- dtype=y.dtype, device=y.device
147
- )
148
-
149
- y = torch.nn.functional.pad(
150
- y.unsqueeze(1),
151
- (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
152
- mode="reflect",
153
- )
154
- y = y.squeeze(1)
155
-
156
- spec = torch.stft(
157
- y,
158
- n_fft,
159
- hop_length=hop_size,
160
- win_length=win_size,
161
- window=hann_window[wnsize_dtype_device],
162
- center=center,
163
- pad_mode="reflect",
164
- normalized=False,
165
- onesided=True,
166
- return_complex=False,
167
- )
168
-
169
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
170
-
171
- spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
172
- spec = spectral_normalize_torch(spec)
173
-
174
- return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/models.py DELETED
@@ -1,1038 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from . import commons
7
- from . import modules
8
- from . import attentions
9
-
10
- from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
-
13
- from .commons import init_weights, get_padding
14
-
15
-
16
- class DurationDiscriminator(nn.Module): # vits2
17
- def __init__(
18
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
19
- ):
20
- super().__init__()
21
- self.in_channels = in_channels
22
- self.filter_channels = filter_channels
23
- self.kernel_size = kernel_size
24
- self.p_dropout = p_dropout
25
- self.gin_channels = gin_channels
26
-
27
- self.drop = nn.Dropout(p_dropout)
28
- self.conv_1 = nn.Conv1d(
29
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
30
- )
31
- self.norm_1 = modules.LayerNorm(filter_channels)
32
- self.conv_2 = nn.Conv1d(
33
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
- )
35
- self.norm_2 = modules.LayerNorm(filter_channels)
36
- self.dur_proj = nn.Conv1d(1, filter_channels, 1)
37
-
38
- self.pre_out_conv_1 = nn.Conv1d(
39
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
40
- )
41
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
42
- self.pre_out_conv_2 = nn.Conv1d(
43
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
- )
45
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
46
-
47
- if gin_channels != 0:
48
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
-
50
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
51
-
52
- def forward_probability(self, x, x_mask, dur, g=None):
53
- dur = self.dur_proj(dur)
54
- x = torch.cat([x, dur], dim=1)
55
- x = self.pre_out_conv_1(x * x_mask)
56
- x = torch.relu(x)
57
- x = self.pre_out_norm_1(x)
58
- x = self.drop(x)
59
- x = self.pre_out_conv_2(x * x_mask)
60
- x = torch.relu(x)
61
- x = self.pre_out_norm_2(x)
62
- x = self.drop(x)
63
- x = x * x_mask
64
- x = x.transpose(1, 2)
65
- output_prob = self.output_layer(x)
66
- return output_prob
67
-
68
- def forward(self, x, x_mask, dur_r, dur_hat, g=None):
69
- x = torch.detach(x)
70
- if g is not None:
71
- g = torch.detach(g)
72
- x = x + self.cond(g)
73
- x = self.conv_1(x * x_mask)
74
- x = torch.relu(x)
75
- x = self.norm_1(x)
76
- x = self.drop(x)
77
- x = self.conv_2(x * x_mask)
78
- x = torch.relu(x)
79
- x = self.norm_2(x)
80
- x = self.drop(x)
81
-
82
- output_probs = []
83
- for dur in [dur_r, dur_hat]:
84
- output_prob = self.forward_probability(x, x_mask, dur, g)
85
- output_probs.append(output_prob)
86
-
87
- return output_probs
88
-
89
-
90
- class TransformerCouplingBlock(nn.Module):
91
- def __init__(
92
- self,
93
- channels,
94
- hidden_channels,
95
- filter_channels,
96
- n_heads,
97
- n_layers,
98
- kernel_size,
99
- p_dropout,
100
- n_flows=4,
101
- gin_channels=0,
102
- share_parameter=False,
103
- ):
104
- super().__init__()
105
- self.channels = channels
106
- self.hidden_channels = hidden_channels
107
- self.kernel_size = kernel_size
108
- self.n_layers = n_layers
109
- self.n_flows = n_flows
110
- self.gin_channels = gin_channels
111
-
112
- self.flows = nn.ModuleList()
113
-
114
- self.wn = (
115
- attentions.FFT(
116
- hidden_channels,
117
- filter_channels,
118
- n_heads,
119
- n_layers,
120
- kernel_size,
121
- p_dropout,
122
- isflow=True,
123
- gin_channels=self.gin_channels,
124
- )
125
- if share_parameter
126
- else None
127
- )
128
-
129
- for i in range(n_flows):
130
- self.flows.append(
131
- modules.TransformerCouplingLayer(
132
- channels,
133
- hidden_channels,
134
- kernel_size,
135
- n_layers,
136
- n_heads,
137
- p_dropout,
138
- filter_channels,
139
- mean_only=True,
140
- wn_sharing_parameter=self.wn,
141
- gin_channels=self.gin_channels,
142
- )
143
- )
144
- self.flows.append(modules.Flip())
145
-
146
- def forward(self, x, x_mask, g=None, reverse=False):
147
- if not reverse:
148
- for flow in self.flows:
149
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
- else:
151
- for flow in reversed(self.flows):
152
- x = flow(x, x_mask, g=g, reverse=reverse)
153
- return x
154
-
155
-
156
- class StochasticDurationPredictor(nn.Module):
157
- def __init__(
158
- self,
159
- in_channels,
160
- filter_channels,
161
- kernel_size,
162
- p_dropout,
163
- n_flows=4,
164
- gin_channels=0,
165
- ):
166
- super().__init__()
167
- filter_channels = in_channels # it needs to be removed from future version.
168
- self.in_channels = in_channels
169
- self.filter_channels = filter_channels
170
- self.kernel_size = kernel_size
171
- self.p_dropout = p_dropout
172
- self.n_flows = n_flows
173
- self.gin_channels = gin_channels
174
-
175
- self.log_flow = modules.Log()
176
- self.flows = nn.ModuleList()
177
- self.flows.append(modules.ElementwiseAffine(2))
178
- for i in range(n_flows):
179
- self.flows.append(
180
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
181
- )
182
- self.flows.append(modules.Flip())
183
-
184
- self.post_pre = nn.Conv1d(1, filter_channels, 1)
185
- self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
186
- self.post_convs = modules.DDSConv(
187
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
188
- )
189
- self.post_flows = nn.ModuleList()
190
- self.post_flows.append(modules.ElementwiseAffine(2))
191
- for i in range(4):
192
- self.post_flows.append(
193
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
194
- )
195
- self.post_flows.append(modules.Flip())
196
-
197
- self.pre = nn.Conv1d(in_channels, filter_channels, 1)
198
- self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
199
- self.convs = modules.DDSConv(
200
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
201
- )
202
- if gin_channels != 0:
203
- self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
204
-
205
- def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
206
- x = torch.detach(x)
207
- x = self.pre(x)
208
- if g is not None:
209
- g = torch.detach(g)
210
- x = x + self.cond(g)
211
- x = self.convs(x, x_mask)
212
- x = self.proj(x) * x_mask
213
-
214
- if not reverse:
215
- flows = self.flows
216
- assert w is not None
217
-
218
- logdet_tot_q = 0
219
- h_w = self.post_pre(w)
220
- h_w = self.post_convs(h_w, x_mask)
221
- h_w = self.post_proj(h_w) * x_mask
222
- e_q = (
223
- torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
224
- * x_mask
225
- )
226
- z_q = e_q
227
- for flow in self.post_flows:
228
- z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
229
- logdet_tot_q += logdet_q
230
- z_u, z1 = torch.split(z_q, [1, 1], 1)
231
- u = torch.sigmoid(z_u) * x_mask
232
- z0 = (w - u) * x_mask
233
- logdet_tot_q += torch.sum(
234
- (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
235
- )
236
- logq = (
237
- torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
238
- - logdet_tot_q
239
- )
240
-
241
- logdet_tot = 0
242
- z0, logdet = self.log_flow(z0, x_mask)
243
- logdet_tot += logdet
244
- z = torch.cat([z0, z1], 1)
245
- for flow in flows:
246
- z, logdet = flow(z, x_mask, g=x, reverse=reverse)
247
- logdet_tot = logdet_tot + logdet
248
- nll = (
249
- torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
250
- - logdet_tot
251
- )
252
- return nll + logq # [b]
253
- else:
254
- flows = list(reversed(self.flows))
255
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
256
- z = (
257
- torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
258
- * noise_scale
259
- )
260
- for flow in flows:
261
- z = flow(z, x_mask, g=x, reverse=reverse)
262
- z0, z1 = torch.split(z, [1, 1], 1)
263
- logw = z0
264
- return logw
265
-
266
-
267
- class DurationPredictor(nn.Module):
268
- def __init__(
269
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
270
- ):
271
- super().__init__()
272
-
273
- self.in_channels = in_channels
274
- self.filter_channels = filter_channels
275
- self.kernel_size = kernel_size
276
- self.p_dropout = p_dropout
277
- self.gin_channels = gin_channels
278
-
279
- self.drop = nn.Dropout(p_dropout)
280
- self.conv_1 = nn.Conv1d(
281
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
282
- )
283
- self.norm_1 = modules.LayerNorm(filter_channels)
284
- self.conv_2 = nn.Conv1d(
285
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
286
- )
287
- self.norm_2 = modules.LayerNorm(filter_channels)
288
- self.proj = nn.Conv1d(filter_channels, 1, 1)
289
-
290
- if gin_channels != 0:
291
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
292
-
293
- def forward(self, x, x_mask, g=None):
294
- x = torch.detach(x)
295
- if g is not None:
296
- g = torch.detach(g)
297
- x = x + self.cond(g)
298
- x = self.conv_1(x * x_mask)
299
- x = torch.relu(x)
300
- x = self.norm_1(x)
301
- x = self.drop(x)
302
- x = self.conv_2(x * x_mask)
303
- x = torch.relu(x)
304
- x = self.norm_2(x)
305
- x = self.drop(x)
306
- x = self.proj(x * x_mask)
307
- return x * x_mask
308
-
309
-
310
- class TextEncoder(nn.Module):
311
- def __init__(
312
- self,
313
- n_vocab,
314
- out_channels,
315
- hidden_channels,
316
- filter_channels,
317
- n_heads,
318
- n_layers,
319
- kernel_size,
320
- p_dropout,
321
- gin_channels=0,
322
- num_languages=None,
323
- num_tones=None,
324
- ):
325
- super().__init__()
326
- if num_languages is None:
327
- from text import num_languages
328
- if num_tones is None:
329
- from text import num_tones
330
- self.n_vocab = n_vocab
331
- self.out_channels = out_channels
332
- self.hidden_channels = hidden_channels
333
- self.filter_channels = filter_channels
334
- self.n_heads = n_heads
335
- self.n_layers = n_layers
336
- self.kernel_size = kernel_size
337
- self.p_dropout = p_dropout
338
- self.gin_channels = gin_channels
339
- self.emb = nn.Embedding(n_vocab, hidden_channels)
340
- nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
341
- self.tone_emb = nn.Embedding(num_tones, hidden_channels)
342
- nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
343
- self.language_emb = nn.Embedding(num_languages, hidden_channels)
344
- nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
345
- self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
346
- self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
347
-
348
- self.encoder = attentions.Encoder(
349
- hidden_channels,
350
- filter_channels,
351
- n_heads,
352
- n_layers,
353
- kernel_size,
354
- p_dropout,
355
- gin_channels=self.gin_channels,
356
- )
357
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
358
-
359
- def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
360
- bert_emb = self.bert_proj(bert).transpose(1, 2)
361
- ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
362
- x = (
363
- self.emb(x)
364
- + self.tone_emb(tone)
365
- + self.language_emb(language)
366
- + bert_emb
367
- + ja_bert_emb
368
- ) * math.sqrt(
369
- self.hidden_channels
370
- ) # [b, t, h]
371
- x = torch.transpose(x, 1, -1) # [b, h, t]
372
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
373
- x.dtype
374
- )
375
-
376
- x = self.encoder(x * x_mask, x_mask, g=g)
377
- stats = self.proj(x) * x_mask
378
-
379
- m, logs = torch.split(stats, self.out_channels, dim=1)
380
- return x, m, logs, x_mask
381
-
382
-
383
- class ResidualCouplingBlock(nn.Module):
384
- def __init__(
385
- self,
386
- channels,
387
- hidden_channels,
388
- kernel_size,
389
- dilation_rate,
390
- n_layers,
391
- n_flows=4,
392
- gin_channels=0,
393
- ):
394
- super().__init__()
395
- self.channels = channels
396
- self.hidden_channels = hidden_channels
397
- self.kernel_size = kernel_size
398
- self.dilation_rate = dilation_rate
399
- self.n_layers = n_layers
400
- self.n_flows = n_flows
401
- self.gin_channels = gin_channels
402
-
403
- self.flows = nn.ModuleList()
404
- for i in range(n_flows):
405
- self.flows.append(
406
- modules.ResidualCouplingLayer(
407
- channels,
408
- hidden_channels,
409
- kernel_size,
410
- dilation_rate,
411
- n_layers,
412
- gin_channels=gin_channels,
413
- mean_only=True,
414
- )
415
- )
416
- self.flows.append(modules.Flip())
417
-
418
- def forward(self, x, x_mask, g=None, reverse=False):
419
- if not reverse:
420
- for flow in self.flows:
421
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
422
- else:
423
- for flow in reversed(self.flows):
424
- x = flow(x, x_mask, g=g, reverse=reverse)
425
- return x
426
-
427
-
428
- class PosteriorEncoder(nn.Module):
429
- def __init__(
430
- self,
431
- in_channels,
432
- out_channels,
433
- hidden_channels,
434
- kernel_size,
435
- dilation_rate,
436
- n_layers,
437
- gin_channels=0,
438
- ):
439
- super().__init__()
440
- self.in_channels = in_channels
441
- self.out_channels = out_channels
442
- self.hidden_channels = hidden_channels
443
- self.kernel_size = kernel_size
444
- self.dilation_rate = dilation_rate
445
- self.n_layers = n_layers
446
- self.gin_channels = gin_channels
447
-
448
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
449
- self.enc = modules.WN(
450
- hidden_channels,
451
- kernel_size,
452
- dilation_rate,
453
- n_layers,
454
- gin_channels=gin_channels,
455
- )
456
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
457
-
458
- def forward(self, x, x_lengths, g=None, tau=1.0):
459
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
460
- x.dtype
461
- )
462
- x = self.pre(x) * x_mask
463
- x = self.enc(x, x_mask, g=g)
464
- stats = self.proj(x) * x_mask
465
- m, logs = torch.split(stats, self.out_channels, dim=1)
466
- z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
467
- return z, m, logs, x_mask
468
-
469
-
470
- class Generator(torch.nn.Module):
471
- def __init__(
472
- self,
473
- initial_channel,
474
- resblock,
475
- resblock_kernel_sizes,
476
- resblock_dilation_sizes,
477
- upsample_rates,
478
- upsample_initial_channel,
479
- upsample_kernel_sizes,
480
- gin_channels=0,
481
- ):
482
- super(Generator, self).__init__()
483
- self.num_kernels = len(resblock_kernel_sizes)
484
- self.num_upsamples = len(upsample_rates)
485
- self.conv_pre = Conv1d(
486
- initial_channel, upsample_initial_channel, 7, 1, padding=3
487
- )
488
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
489
-
490
- self.ups = nn.ModuleList()
491
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
492
- self.ups.append(
493
- weight_norm(
494
- ConvTranspose1d(
495
- upsample_initial_channel // (2**i),
496
- upsample_initial_channel // (2 ** (i + 1)),
497
- k,
498
- u,
499
- padding=(k - u) // 2,
500
- )
501
- )
502
- )
503
-
504
- self.resblocks = nn.ModuleList()
505
- for i in range(len(self.ups)):
506
- ch = upsample_initial_channel // (2 ** (i + 1))
507
- for j, (k, d) in enumerate(
508
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
509
- ):
510
- self.resblocks.append(resblock(ch, k, d))
511
-
512
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
513
- self.ups.apply(init_weights)
514
-
515
- if gin_channels != 0:
516
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
517
-
518
- def forward(self, x, g=None):
519
- x = self.conv_pre(x)
520
- if g is not None:
521
- x = x + self.cond(g)
522
-
523
- for i in range(self.num_upsamples):
524
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
525
- x = self.ups[i](x)
526
- xs = None
527
- for j in range(self.num_kernels):
528
- if xs is None:
529
- xs = self.resblocks[i * self.num_kernels + j](x)
530
- else:
531
- xs += self.resblocks[i * self.num_kernels + j](x)
532
- x = xs / self.num_kernels
533
- x = F.leaky_relu(x)
534
- x = self.conv_post(x)
535
- x = torch.tanh(x)
536
-
537
- return x
538
-
539
- def remove_weight_norm(self):
540
- print("Removing weight norm...")
541
- for layer in self.ups:
542
- remove_weight_norm(layer)
543
- for layer in self.resblocks:
544
- layer.remove_weight_norm()
545
-
546
-
547
- class DiscriminatorP(torch.nn.Module):
548
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
549
- super(DiscriminatorP, self).__init__()
550
- self.period = period
551
- self.use_spectral_norm = use_spectral_norm
552
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
553
- self.convs = nn.ModuleList(
554
- [
555
- norm_f(
556
- Conv2d(
557
- 1,
558
- 32,
559
- (kernel_size, 1),
560
- (stride, 1),
561
- padding=(get_padding(kernel_size, 1), 0),
562
- )
563
- ),
564
- norm_f(
565
- Conv2d(
566
- 32,
567
- 128,
568
- (kernel_size, 1),
569
- (stride, 1),
570
- padding=(get_padding(kernel_size, 1), 0),
571
- )
572
- ),
573
- norm_f(
574
- Conv2d(
575
- 128,
576
- 512,
577
- (kernel_size, 1),
578
- (stride, 1),
579
- padding=(get_padding(kernel_size, 1), 0),
580
- )
581
- ),
582
- norm_f(
583
- Conv2d(
584
- 512,
585
- 1024,
586
- (kernel_size, 1),
587
- (stride, 1),
588
- padding=(get_padding(kernel_size, 1), 0),
589
- )
590
- ),
591
- norm_f(
592
- Conv2d(
593
- 1024,
594
- 1024,
595
- (kernel_size, 1),
596
- 1,
597
- padding=(get_padding(kernel_size, 1), 0),
598
- )
599
- ),
600
- ]
601
- )
602
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
603
-
604
- def forward(self, x):
605
- fmap = []
606
-
607
- # 1d to 2d
608
- b, c, t = x.shape
609
- if t % self.period != 0: # pad first
610
- n_pad = self.period - (t % self.period)
611
- x = F.pad(x, (0, n_pad), "reflect")
612
- t = t + n_pad
613
- x = x.view(b, c, t // self.period, self.period)
614
-
615
- for layer in self.convs:
616
- x = layer(x)
617
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
618
- fmap.append(x)
619
- x = self.conv_post(x)
620
- fmap.append(x)
621
- x = torch.flatten(x, 1, -1)
622
-
623
- return x, fmap
624
-
625
-
626
- class DiscriminatorS(torch.nn.Module):
627
- def __init__(self, use_spectral_norm=False):
628
- super(DiscriminatorS, self).__init__()
629
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
630
- self.convs = nn.ModuleList(
631
- [
632
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
633
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
634
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
635
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
636
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
637
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
638
- ]
639
- )
640
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
641
-
642
- def forward(self, x):
643
- fmap = []
644
-
645
- for layer in self.convs:
646
- x = layer(x)
647
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
648
- fmap.append(x)
649
- x = self.conv_post(x)
650
- fmap.append(x)
651
- x = torch.flatten(x, 1, -1)
652
-
653
- return x, fmap
654
-
655
-
656
- class MultiPeriodDiscriminator(torch.nn.Module):
657
- def __init__(self, use_spectral_norm=False):
658
- super(MultiPeriodDiscriminator, self).__init__()
659
- periods = [2, 3, 5, 7, 11]
660
-
661
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
662
- discs = discs + [
663
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
664
- ]
665
- self.discriminators = nn.ModuleList(discs)
666
-
667
- def forward(self, y, y_hat):
668
- y_d_rs = []
669
- y_d_gs = []
670
- fmap_rs = []
671
- fmap_gs = []
672
- for i, d in enumerate(self.discriminators):
673
- y_d_r, fmap_r = d(y)
674
- y_d_g, fmap_g = d(y_hat)
675
- y_d_rs.append(y_d_r)
676
- y_d_gs.append(y_d_g)
677
- fmap_rs.append(fmap_r)
678
- fmap_gs.append(fmap_g)
679
-
680
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
681
-
682
-
683
- class ReferenceEncoder(nn.Module):
684
- """
685
- inputs --- [N, Ty/r, n_mels*r] mels
686
- outputs --- [N, ref_enc_gru_size]
687
- """
688
-
689
- def __init__(self, spec_channels, gin_channels=0, layernorm=False):
690
- super().__init__()
691
- self.spec_channels = spec_channels
692
- ref_enc_filters = [32, 32, 64, 64, 128, 128]
693
- K = len(ref_enc_filters)
694
- filters = [1] + ref_enc_filters
695
- convs = [
696
- weight_norm(
697
- nn.Conv2d(
698
- in_channels=filters[i],
699
- out_channels=filters[i + 1],
700
- kernel_size=(3, 3),
701
- stride=(2, 2),
702
- padding=(1, 1),
703
- )
704
- )
705
- for i in range(K)
706
- ]
707
- self.convs = nn.ModuleList(convs)
708
- # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
709
-
710
- out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
711
- self.gru = nn.GRU(
712
- input_size=ref_enc_filters[-1] * out_channels,
713
- hidden_size=256 // 2,
714
- batch_first=True,
715
- )
716
- self.proj = nn.Linear(128, gin_channels)
717
- if layernorm:
718
- self.layernorm = nn.LayerNorm(self.spec_channels)
719
- print('[Ref Enc]: using layer norm')
720
- else:
721
- self.layernorm = None
722
-
723
- def forward(self, inputs, mask=None):
724
- N = inputs.size(0)
725
-
726
- out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
727
- if self.layernorm is not None:
728
- out = self.layernorm(out)
729
-
730
- for conv in self.convs:
731
- out = conv(out)
732
- # out = wn(out)
733
- out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
734
-
735
- out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
736
- T = out.size(1)
737
- N = out.size(0)
738
- out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
739
-
740
- self.gru.flatten_parameters()
741
- memory, out = self.gru(out) # out --- [1, N, 128]
742
-
743
- return self.proj(out.squeeze(0))
744
-
745
- def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
746
- for i in range(n_convs):
747
- L = (L - kernel_size + 2 * pad) // stride + 1
748
- return L
749
-
750
-
751
- class SynthesizerTrn(nn.Module):
752
- """
753
- Synthesizer for Training
754
- """
755
-
756
- def __init__(
757
- self,
758
- n_vocab,
759
- spec_channels,
760
- segment_size,
761
- inter_channels,
762
- hidden_channels,
763
- filter_channels,
764
- n_heads,
765
- n_layers,
766
- kernel_size,
767
- p_dropout,
768
- resblock,
769
- resblock_kernel_sizes,
770
- resblock_dilation_sizes,
771
- upsample_rates,
772
- upsample_initial_channel,
773
- upsample_kernel_sizes,
774
- n_speakers=256,
775
- gin_channels=256,
776
- use_sdp=True,
777
- n_flow_layer=4,
778
- n_layers_trans_flow=6,
779
- flow_share_parameter=False,
780
- use_transformer_flow=True,
781
- use_vc=False,
782
- num_languages=None,
783
- num_tones=None,
784
- norm_refenc=False,
785
- use_se=False,
786
- **kwargs
787
- ):
788
- super().__init__()
789
- self.n_vocab = n_vocab
790
- self.spec_channels = spec_channels
791
- self.inter_channels = inter_channels
792
- self.hidden_channels = hidden_channels
793
- self.filter_channels = filter_channels
794
- self.n_heads = n_heads
795
- self.n_layers = n_layers
796
- self.kernel_size = kernel_size
797
- self.p_dropout = p_dropout
798
- self.resblock = resblock
799
- self.resblock_kernel_sizes = resblock_kernel_sizes
800
- self.resblock_dilation_sizes = resblock_dilation_sizes
801
- self.upsample_rates = upsample_rates
802
- self.upsample_initial_channel = upsample_initial_channel
803
- self.upsample_kernel_sizes = upsample_kernel_sizes
804
- self.segment_size = segment_size
805
- self.n_speakers = n_speakers
806
- self.gin_channels = gin_channels
807
- self.n_layers_trans_flow = n_layers_trans_flow
808
- self.use_spk_conditioned_encoder = kwargs.get(
809
- "use_spk_conditioned_encoder", True
810
- )
811
- self.use_sdp = use_sdp
812
- self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
813
- self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
814
- self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
815
- self.current_mas_noise_scale = self.mas_noise_scale_initial
816
- if self.use_spk_conditioned_encoder and gin_channels > 0:
817
- self.enc_gin_channels = gin_channels
818
- else:
819
- self.enc_gin_channels = 0
820
- self.enc_p = TextEncoder(
821
- n_vocab,
822
- inter_channels,
823
- hidden_channels,
824
- filter_channels,
825
- n_heads,
826
- n_layers,
827
- kernel_size,
828
- p_dropout,
829
- gin_channels=self.enc_gin_channels,
830
- num_languages=num_languages,
831
- num_tones=num_tones,
832
- )
833
- self.dec = Generator(
834
- inter_channels,
835
- resblock,
836
- resblock_kernel_sizes,
837
- resblock_dilation_sizes,
838
- upsample_rates,
839
- upsample_initial_channel,
840
- upsample_kernel_sizes,
841
- gin_channels=gin_channels,
842
- )
843
- self.enc_q = PosteriorEncoder(
844
- spec_channels,
845
- inter_channels,
846
- hidden_channels,
847
- 5,
848
- 1,
849
- 16,
850
- gin_channels=gin_channels,
851
- )
852
- if use_transformer_flow:
853
- self.flow = TransformerCouplingBlock(
854
- inter_channels,
855
- hidden_channels,
856
- filter_channels,
857
- n_heads,
858
- n_layers_trans_flow,
859
- 5,
860
- p_dropout,
861
- n_flow_layer,
862
- gin_channels=gin_channels,
863
- share_parameter=flow_share_parameter,
864
- )
865
- else:
866
- self.flow = ResidualCouplingBlock(
867
- inter_channels,
868
- hidden_channels,
869
- 5,
870
- 1,
871
- n_flow_layer,
872
- gin_channels=gin_channels,
873
- )
874
- self.sdp = StochasticDurationPredictor(
875
- hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
876
- )
877
- self.dp = DurationPredictor(
878
- hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
879
- )
880
-
881
- if n_speakers > 1:
882
- if use_se:
883
- emb_dim = 512
884
- self.emb_g = nn.Linear(emb_dim, gin_channels)
885
- else:
886
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
887
- else:
888
- self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
889
- self.use_vc = use_vc
890
- self.use_se = use_se
891
-
892
- def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
893
- if self.n_speakers > 0:
894
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
895
- else:
896
- g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
897
- if self.use_vc:
898
- g_p = None
899
- else:
900
- g_p = g
901
- x, m_p, logs_p, x_mask = self.enc_p(
902
- x, x_lengths, tone, language, bert, ja_bert, g=g_p
903
- )
904
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
905
- z_p = self.flow(z, y_mask, g=g)
906
-
907
- with torch.no_grad():
908
- # negative cross-entropy
909
- s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
910
- neg_cent1 = torch.sum(
911
- -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
912
- ) # [b, 1, t_s]
913
- neg_cent2 = torch.matmul(
914
- -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
915
- ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
916
- neg_cent3 = torch.matmul(
917
- z_p.transpose(1, 2), (m_p * s_p_sq_r)
918
- ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
919
- neg_cent4 = torch.sum(
920
- -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
921
- ) # [b, 1, t_s]
922
- neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
923
- if self.use_noise_scaled_mas:
924
- epsilon = (
925
- torch.std(neg_cent)
926
- * torch.randn_like(neg_cent)
927
- * self.current_mas_noise_scale
928
- )
929
- neg_cent = neg_cent + epsilon
930
-
931
- attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
932
- attn = (
933
- monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
934
- .unsqueeze(1)
935
- .detach()
936
- )
937
-
938
- w = attn.sum(2)
939
-
940
- l_length_sdp = self.sdp(x, x_mask, w, g=g)
941
- l_length_sdp = l_length_sdp / torch.sum(x_mask)
942
-
943
- logw_ = torch.log(w + 1e-6) * x_mask
944
- logw = self.dp(x, x_mask, g=g)
945
- l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
946
- x_mask
947
- ) # for averaging
948
-
949
- l_length = l_length_dp + l_length_sdp
950
-
951
- # expand prior
952
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
953
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
954
-
955
- z_slice, ids_slice = commons.rand_slice_segments(
956
- z, y_lengths, self.segment_size
957
- )
958
- o = self.dec(z_slice, g=g)
959
- return (
960
- o,
961
- l_length,
962
- attn,
963
- ids_slice,
964
- x_mask,
965
- y_mask,
966
- (z, z_p, m_p, logs_p, m_q, logs_q),
967
- (x, logw, logw_),
968
- )
969
-
970
- def infer(
971
- self,
972
- x,
973
- x_lengths,
974
- sid,
975
- tone,
976
- language,
977
- bert,
978
- ja_bert,
979
- noise_scale=0.667,
980
- length_scale=1,
981
- noise_scale_w=0.8,
982
- max_len=None,
983
- sdp_ratio=0,
984
- y=None,
985
- g=None,
986
- ):
987
- # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
988
- # g = self.gst(y)
989
- if g is None:
990
- if self.n_speakers > 0:
991
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
992
- else:
993
- g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
994
- if self.use_vc:
995
- g_p = None
996
- else:
997
- g_p = g
998
- x, m_p, logs_p, x_mask = self.enc_p(
999
- x, x_lengths, tone, language, bert, ja_bert, g=g_p
1000
- )
1001
- logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1002
- sdp_ratio
1003
- ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1004
- w = torch.exp(logw) * x_mask * length_scale
1005
-
1006
- w_ceil = torch.ceil(w)
1007
- y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1008
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1009
- x_mask.dtype
1010
- )
1011
- attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1012
- attn = commons.generate_path(w_ceil, attn_mask)
1013
-
1014
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1015
- 1, 2
1016
- ) # [b, t', t], [b, t, d] -> [b, d, t']
1017
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1018
- 1, 2
1019
- ) # [b, t', t], [b, t, d] -> [b, d, t']
1020
-
1021
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1022
- z = self.flow(z_p, y_mask, g=g, reverse=True)
1023
- o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1024
- # print('max/min of o:', o.max(), o.min())
1025
- return o, attn, y_mask, (z, z_p, m_p, logs_p)
1026
-
1027
- def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
1028
- if self.use_se:
1029
- sid_src = self.emb_g(sid_src).unsqueeze(-1)
1030
- sid_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
1031
-
1032
- g_src = sid_src
1033
- g_tgt = sid_tgt
1034
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
1035
- z_p = self.flow(z, y_mask, g=g_src)
1036
- z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1037
- o_hat = self.dec(z_hat * y_mask, g=g_tgt)
1038
- return o_hat, y_mask, (z, z_p, z_hat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/modules.py DELETED
@@ -1,598 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from torch.nn import Conv1d
7
- from torch.nn.utils import weight_norm, remove_weight_norm
8
-
9
- from . import commons
10
- from .commons import init_weights, get_padding
11
- from .transforms import piecewise_rational_quadratic_transform
12
- from .attentions import Encoder
13
-
14
- LRELU_SLOPE = 0.1
15
-
16
-
17
- class LayerNorm(nn.Module):
18
- def __init__(self, channels, eps=1e-5):
19
- super().__init__()
20
- self.channels = channels
21
- self.eps = eps
22
-
23
- self.gamma = nn.Parameter(torch.ones(channels))
24
- self.beta = nn.Parameter(torch.zeros(channels))
25
-
26
- def forward(self, x):
27
- x = x.transpose(1, -1)
28
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
- return x.transpose(1, -1)
30
-
31
-
32
- class ConvReluNorm(nn.Module):
33
- def __init__(
34
- self,
35
- in_channels,
36
- hidden_channels,
37
- out_channels,
38
- kernel_size,
39
- n_layers,
40
- p_dropout,
41
- ):
42
- super().__init__()
43
- self.in_channels = in_channels
44
- self.hidden_channels = hidden_channels
45
- self.out_channels = out_channels
46
- self.kernel_size = kernel_size
47
- self.n_layers = n_layers
48
- self.p_dropout = p_dropout
49
- assert n_layers > 1, "Number of layers should be larger than 0."
50
-
51
- self.conv_layers = nn.ModuleList()
52
- self.norm_layers = nn.ModuleList()
53
- self.conv_layers.append(
54
- nn.Conv1d(
55
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
- )
57
- )
58
- self.norm_layers.append(LayerNorm(hidden_channels))
59
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
- for _ in range(n_layers - 1):
61
- self.conv_layers.append(
62
- nn.Conv1d(
63
- hidden_channels,
64
- hidden_channels,
65
- kernel_size,
66
- padding=kernel_size // 2,
67
- )
68
- )
69
- self.norm_layers.append(LayerNorm(hidden_channels))
70
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
- self.proj.weight.data.zero_()
72
- self.proj.bias.data.zero_()
73
-
74
- def forward(self, x, x_mask):
75
- x_org = x
76
- for i in range(self.n_layers):
77
- x = self.conv_layers[i](x * x_mask)
78
- x = self.norm_layers[i](x)
79
- x = self.relu_drop(x)
80
- x = x_org + self.proj(x)
81
- return x * x_mask
82
-
83
-
84
- class DDSConv(nn.Module):
85
- """
86
- Dialted and Depth-Separable Convolution
87
- """
88
-
89
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
- super().__init__()
91
- self.channels = channels
92
- self.kernel_size = kernel_size
93
- self.n_layers = n_layers
94
- self.p_dropout = p_dropout
95
-
96
- self.drop = nn.Dropout(p_dropout)
97
- self.convs_sep = nn.ModuleList()
98
- self.convs_1x1 = nn.ModuleList()
99
- self.norms_1 = nn.ModuleList()
100
- self.norms_2 = nn.ModuleList()
101
- for i in range(n_layers):
102
- dilation = kernel_size**i
103
- padding = (kernel_size * dilation - dilation) // 2
104
- self.convs_sep.append(
105
- nn.Conv1d(
106
- channels,
107
- channels,
108
- kernel_size,
109
- groups=channels,
110
- dilation=dilation,
111
- padding=padding,
112
- )
113
- )
114
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
- self.norms_1.append(LayerNorm(channels))
116
- self.norms_2.append(LayerNorm(channels))
117
-
118
- def forward(self, x, x_mask, g=None):
119
- if g is not None:
120
- x = x + g
121
- for i in range(self.n_layers):
122
- y = self.convs_sep[i](x * x_mask)
123
- y = self.norms_1[i](y)
124
- y = F.gelu(y)
125
- y = self.convs_1x1[i](y)
126
- y = self.norms_2[i](y)
127
- y = F.gelu(y)
128
- y = self.drop(y)
129
- x = x + y
130
- return x * x_mask
131
-
132
-
133
- class WN(torch.nn.Module):
134
- def __init__(
135
- self,
136
- hidden_channels,
137
- kernel_size,
138
- dilation_rate,
139
- n_layers,
140
- gin_channels=0,
141
- p_dropout=0,
142
- ):
143
- super(WN, self).__init__()
144
- assert kernel_size % 2 == 1
145
- self.hidden_channels = hidden_channels
146
- self.kernel_size = (kernel_size,)
147
- self.dilation_rate = dilation_rate
148
- self.n_layers = n_layers
149
- self.gin_channels = gin_channels
150
- self.p_dropout = p_dropout
151
-
152
- self.in_layers = torch.nn.ModuleList()
153
- self.res_skip_layers = torch.nn.ModuleList()
154
- self.drop = nn.Dropout(p_dropout)
155
-
156
- if gin_channels != 0:
157
- cond_layer = torch.nn.Conv1d(
158
- gin_channels, 2 * hidden_channels * n_layers, 1
159
- )
160
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
-
162
- for i in range(n_layers):
163
- dilation = dilation_rate**i
164
- padding = int((kernel_size * dilation - dilation) / 2)
165
- in_layer = torch.nn.Conv1d(
166
- hidden_channels,
167
- 2 * hidden_channels,
168
- kernel_size,
169
- dilation=dilation,
170
- padding=padding,
171
- )
172
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
- self.in_layers.append(in_layer)
174
-
175
- # last one is not necessary
176
- if i < n_layers - 1:
177
- res_skip_channels = 2 * hidden_channels
178
- else:
179
- res_skip_channels = hidden_channels
180
-
181
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
- self.res_skip_layers.append(res_skip_layer)
184
-
185
- def forward(self, x, x_mask, g=None, **kwargs):
186
- output = torch.zeros_like(x)
187
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
-
189
- if g is not None:
190
- g = self.cond_layer(g)
191
-
192
- for i in range(self.n_layers):
193
- x_in = self.in_layers[i](x)
194
- if g is not None:
195
- cond_offset = i * 2 * self.hidden_channels
196
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
- else:
198
- g_l = torch.zeros_like(x_in)
199
-
200
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
- acts = self.drop(acts)
202
-
203
- res_skip_acts = self.res_skip_layers[i](acts)
204
- if i < self.n_layers - 1:
205
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
- x = (x + res_acts) * x_mask
207
- output = output + res_skip_acts[:, self.hidden_channels :, :]
208
- else:
209
- output = output + res_skip_acts
210
- return output * x_mask
211
-
212
- def remove_weight_norm(self):
213
- if self.gin_channels != 0:
214
- torch.nn.utils.remove_weight_norm(self.cond_layer)
215
- for l in self.in_layers:
216
- torch.nn.utils.remove_weight_norm(l)
217
- for l in self.res_skip_layers:
218
- torch.nn.utils.remove_weight_norm(l)
219
-
220
-
221
- class ResBlock1(torch.nn.Module):
222
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
- super(ResBlock1, self).__init__()
224
- self.convs1 = nn.ModuleList(
225
- [
226
- weight_norm(
227
- Conv1d(
228
- channels,
229
- channels,
230
- kernel_size,
231
- 1,
232
- dilation=dilation[0],
233
- padding=get_padding(kernel_size, dilation[0]),
234
- )
235
- ),
236
- weight_norm(
237
- Conv1d(
238
- channels,
239
- channels,
240
- kernel_size,
241
- 1,
242
- dilation=dilation[1],
243
- padding=get_padding(kernel_size, dilation[1]),
244
- )
245
- ),
246
- weight_norm(
247
- Conv1d(
248
- channels,
249
- channels,
250
- kernel_size,
251
- 1,
252
- dilation=dilation[2],
253
- padding=get_padding(kernel_size, dilation[2]),
254
- )
255
- ),
256
- ]
257
- )
258
- self.convs1.apply(init_weights)
259
-
260
- self.convs2 = nn.ModuleList(
261
- [
262
- weight_norm(
263
- Conv1d(
264
- channels,
265
- channels,
266
- kernel_size,
267
- 1,
268
- dilation=1,
269
- padding=get_padding(kernel_size, 1),
270
- )
271
- ),
272
- weight_norm(
273
- Conv1d(
274
- channels,
275
- channels,
276
- kernel_size,
277
- 1,
278
- dilation=1,
279
- padding=get_padding(kernel_size, 1),
280
- )
281
- ),
282
- weight_norm(
283
- Conv1d(
284
- channels,
285
- channels,
286
- kernel_size,
287
- 1,
288
- dilation=1,
289
- padding=get_padding(kernel_size, 1),
290
- )
291
- ),
292
- ]
293
- )
294
- self.convs2.apply(init_weights)
295
-
296
- def forward(self, x, x_mask=None):
297
- for c1, c2 in zip(self.convs1, self.convs2):
298
- xt = F.leaky_relu(x, LRELU_SLOPE)
299
- if x_mask is not None:
300
- xt = xt * x_mask
301
- xt = c1(xt)
302
- xt = F.leaky_relu(xt, LRELU_SLOPE)
303
- if x_mask is not None:
304
- xt = xt * x_mask
305
- xt = c2(xt)
306
- x = xt + x
307
- if x_mask is not None:
308
- x = x * x_mask
309
- return x
310
-
311
- def remove_weight_norm(self):
312
- for l in self.convs1:
313
- remove_weight_norm(l)
314
- for l in self.convs2:
315
- remove_weight_norm(l)
316
-
317
-
318
- class ResBlock2(torch.nn.Module):
319
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
- super(ResBlock2, self).__init__()
321
- self.convs = nn.ModuleList(
322
- [
323
- weight_norm(
324
- Conv1d(
325
- channels,
326
- channels,
327
- kernel_size,
328
- 1,
329
- dilation=dilation[0],
330
- padding=get_padding(kernel_size, dilation[0]),
331
- )
332
- ),
333
- weight_norm(
334
- Conv1d(
335
- channels,
336
- channels,
337
- kernel_size,
338
- 1,
339
- dilation=dilation[1],
340
- padding=get_padding(kernel_size, dilation[1]),
341
- )
342
- ),
343
- ]
344
- )
345
- self.convs.apply(init_weights)
346
-
347
- def forward(self, x, x_mask=None):
348
- for c in self.convs:
349
- xt = F.leaky_relu(x, LRELU_SLOPE)
350
- if x_mask is not None:
351
- xt = xt * x_mask
352
- xt = c(xt)
353
- x = xt + x
354
- if x_mask is not None:
355
- x = x * x_mask
356
- return x
357
-
358
- def remove_weight_norm(self):
359
- for l in self.convs:
360
- remove_weight_norm(l)
361
-
362
-
363
- class Log(nn.Module):
364
- def forward(self, x, x_mask, reverse=False, **kwargs):
365
- if not reverse:
366
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
- logdet = torch.sum(-y, [1, 2])
368
- return y, logdet
369
- else:
370
- x = torch.exp(x) * x_mask
371
- return x
372
-
373
-
374
- class Flip(nn.Module):
375
- def forward(self, x, *args, reverse=False, **kwargs):
376
- x = torch.flip(x, [1])
377
- if not reverse:
378
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
- return x, logdet
380
- else:
381
- return x
382
-
383
-
384
- class ElementwiseAffine(nn.Module):
385
- def __init__(self, channels):
386
- super().__init__()
387
- self.channels = channels
388
- self.m = nn.Parameter(torch.zeros(channels, 1))
389
- self.logs = nn.Parameter(torch.zeros(channels, 1))
390
-
391
- def forward(self, x, x_mask, reverse=False, **kwargs):
392
- if not reverse:
393
- y = self.m + torch.exp(self.logs) * x
394
- y = y * x_mask
395
- logdet = torch.sum(self.logs * x_mask, [1, 2])
396
- return y, logdet
397
- else:
398
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
- return x
400
-
401
-
402
- class ResidualCouplingLayer(nn.Module):
403
- def __init__(
404
- self,
405
- channels,
406
- hidden_channels,
407
- kernel_size,
408
- dilation_rate,
409
- n_layers,
410
- p_dropout=0,
411
- gin_channels=0,
412
- mean_only=False,
413
- ):
414
- assert channels % 2 == 0, "channels should be divisible by 2"
415
- super().__init__()
416
- self.channels = channels
417
- self.hidden_channels = hidden_channels
418
- self.kernel_size = kernel_size
419
- self.dilation_rate = dilation_rate
420
- self.n_layers = n_layers
421
- self.half_channels = channels // 2
422
- self.mean_only = mean_only
423
-
424
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
- self.enc = WN(
426
- hidden_channels,
427
- kernel_size,
428
- dilation_rate,
429
- n_layers,
430
- p_dropout=p_dropout,
431
- gin_channels=gin_channels,
432
- )
433
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
- self.post.weight.data.zero_()
435
- self.post.bias.data.zero_()
436
-
437
- def forward(self, x, x_mask, g=None, reverse=False):
438
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
- h = self.pre(x0) * x_mask
440
- h = self.enc(h, x_mask, g=g)
441
- stats = self.post(h) * x_mask
442
- if not self.mean_only:
443
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
- else:
445
- m = stats
446
- logs = torch.zeros_like(m)
447
-
448
- if not reverse:
449
- x1 = m + x1 * torch.exp(logs) * x_mask
450
- x = torch.cat([x0, x1], 1)
451
- logdet = torch.sum(logs, [1, 2])
452
- return x, logdet
453
- else:
454
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
- x = torch.cat([x0, x1], 1)
456
- return x
457
-
458
-
459
- class ConvFlow(nn.Module):
460
- def __init__(
461
- self,
462
- in_channels,
463
- filter_channels,
464
- kernel_size,
465
- n_layers,
466
- num_bins=10,
467
- tail_bound=5.0,
468
- ):
469
- super().__init__()
470
- self.in_channels = in_channels
471
- self.filter_channels = filter_channels
472
- self.kernel_size = kernel_size
473
- self.n_layers = n_layers
474
- self.num_bins = num_bins
475
- self.tail_bound = tail_bound
476
- self.half_channels = in_channels // 2
477
-
478
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
- self.proj = nn.Conv1d(
481
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
- )
483
- self.proj.weight.data.zero_()
484
- self.proj.bias.data.zero_()
485
-
486
- def forward(self, x, x_mask, g=None, reverse=False):
487
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
- h = self.pre(x0)
489
- h = self.convs(h, x_mask, g=g)
490
- h = self.proj(h) * x_mask
491
-
492
- b, c, t = x0.shape
493
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
-
495
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
- self.filter_channels
498
- )
499
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
-
501
- x1, logabsdet = piecewise_rational_quadratic_transform(
502
- x1,
503
- unnormalized_widths,
504
- unnormalized_heights,
505
- unnormalized_derivatives,
506
- inverse=reverse,
507
- tails="linear",
508
- tail_bound=self.tail_bound,
509
- )
510
-
511
- x = torch.cat([x0, x1], 1) * x_mask
512
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
- if not reverse:
514
- return x, logdet
515
- else:
516
- return x
517
-
518
-
519
- class TransformerCouplingLayer(nn.Module):
520
- def __init__(
521
- self,
522
- channels,
523
- hidden_channels,
524
- kernel_size,
525
- n_layers,
526
- n_heads,
527
- p_dropout=0,
528
- filter_channels=0,
529
- mean_only=False,
530
- wn_sharing_parameter=None,
531
- gin_channels=0,
532
- ):
533
- assert n_layers == 3, n_layers
534
- assert channels % 2 == 0, "channels should be divisible by 2"
535
- super().__init__()
536
- self.channels = channels
537
- self.hidden_channels = hidden_channels
538
- self.kernel_size = kernel_size
539
- self.n_layers = n_layers
540
- self.half_channels = channels // 2
541
- self.mean_only = mean_only
542
-
543
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
- self.enc = (
545
- Encoder(
546
- hidden_channels,
547
- filter_channels,
548
- n_heads,
549
- n_layers,
550
- kernel_size,
551
- p_dropout,
552
- isflow=True,
553
- gin_channels=gin_channels,
554
- )
555
- if wn_sharing_parameter is None
556
- else wn_sharing_parameter
557
- )
558
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
- self.post.weight.data.zero_()
560
- self.post.bias.data.zero_()
561
-
562
- def forward(self, x, x_mask, g=None, reverse=False):
563
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
- h = self.pre(x0) * x_mask
565
- h = self.enc(h, x_mask, g=g)
566
- stats = self.post(h) * x_mask
567
- if not self.mean_only:
568
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
- else:
570
- m = stats
571
- logs = torch.zeros_like(m)
572
-
573
- if not reverse:
574
- x1 = m + x1 * torch.exp(logs) * x_mask
575
- x = torch.cat([x0, x1], 1)
576
- logdet = torch.sum(logs, [1, 2])
577
- return x, logdet
578
- else:
579
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
- x = torch.cat([x0, x1], 1)
581
- return x
582
-
583
- x1, logabsdet = piecewise_rational_quadratic_transform(
584
- x1,
585
- unnormalized_widths,
586
- unnormalized_heights,
587
- unnormalized_derivatives,
588
- inverse=reverse,
589
- tails="linear",
590
- tail_bound=self.tail_bound,
591
- )
592
-
593
- x = torch.cat([x0, x1], 1) * x_mask
594
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
595
- if not reverse:
596
- return x, logdet
597
- else:
598
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/split_utils.py DELETED
@@ -1,131 +0,0 @@
1
- import re
2
- import os
3
- import glob
4
- import numpy as np
5
- import soundfile as sf
6
- import torchaudio
7
- from txtsplit import txtsplit
8
- def split_sentence(text, min_len=10, language_str='EN'):
9
- if language_str in ['EN', 'FR', 'ES', 'SP', 'DE', 'RU']:
10
- sentences = split_sentences_latin(text, min_len=min_len)
11
- else:
12
- sentences = split_sentences_zh(text, min_len=min_len)
13
- return sentences
14
-
15
- def split_sentences_latin(text, min_len=10):
16
- text = re.sub('[。!?;]', '.', text)
17
- text = re.sub('[,]', ',', text)
18
- text = re.sub('[“”]', '"', text)
19
- text = re.sub('[‘’]', "'", text)
20
- text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
21
- return [item.strip() for item in txtsplit(text, 512, 512) if item.strip()]
22
- # 将文本中的换行符、空格和制表符替换为空格
23
- # text = re.sub('[\n\t ]+', ' ', text)
24
- # # 在标点符号后添加一个空格
25
- # text = re.sub('([,.!?;])', r'\1 $#!', text)
26
- # # 分隔句子并去除前后空格
27
- # sentences = [s.strip() for s in text.split('$#!')]
28
- # if len(sentences[-1]) == 0: del sentences[-1]
29
-
30
- # new_sentences = []
31
- # new_sent = []
32
- # count_len = 0
33
- # for ind, sent in enumerate(sentences):
34
- # # print(sent)
35
- # new_sent.append(sent)
36
- # count_len += len(sent.split(" "))
37
- # if count_len > min_len or ind == len(sentences) - 1:
38
- # count_len = 0
39
- # new_sentences.append(' '.join(new_sent))
40
- # new_sent = []
41
- # return merge_short_sentences_en(new_sentences)
42
-
43
- def split_sentences_zh(text, min_len=10):
44
- text = re.sub('[。!?;]', '.', text)
45
- text = re.sub('[,]', ',', text)
46
- # 将文本中的换行符、空格和制表符替换为空格
47
- text = re.sub('[\n\t ]+', ' ', text)
48
- # 在标点符号后添加一个空格
49
- text = re.sub('([,.!?;])', r'\1 $#!', text)
50
- # 分隔句子并去除前后空格
51
- # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
52
- sentences = [s.strip() for s in text.split('$#!')]
53
- if len(sentences[-1]) == 0: del sentences[-1]
54
-
55
- new_sentences = []
56
- new_sent = []
57
- count_len = 0
58
- for ind, sent in enumerate(sentences):
59
- new_sent.append(sent)
60
- count_len += len(sent)
61
- if count_len > min_len or ind == len(sentences) - 1:
62
- count_len = 0
63
- new_sentences.append(' '.join(new_sent))
64
- new_sent = []
65
- return merge_short_sentences_zh(new_sentences)
66
-
67
- def merge_short_sentences_en(sens):
68
- """Avoid short sentences by merging them with the following sentence.
69
-
70
- Args:
71
- List[str]: list of input sentences.
72
-
73
- Returns:
74
- List[str]: list of output sentences.
75
- """
76
- sens_out = []
77
- for s in sens:
78
- # If the previous sentense is too short, merge them with
79
- # the current sentence.
80
- if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
81
- sens_out[-1] = sens_out[-1] + " " + s
82
- else:
83
- sens_out.append(s)
84
- try:
85
- if len(sens_out[-1].split(" ")) <= 2:
86
- sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
87
- sens_out.pop(-1)
88
- except:
89
- pass
90
- return sens_out
91
-
92
- def merge_short_sentences_zh(sens):
93
- # return sens
94
- """Avoid short sentences by merging them with the following sentence.
95
-
96
- Args:
97
- List[str]: list of input sentences.
98
-
99
- Returns:
100
- List[str]: list of output sentences.
101
- """
102
- sens_out = []
103
- for s in sens:
104
- # If the previous sentense is too short, merge them with
105
- # the current sentence.
106
- if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
107
- sens_out[-1] = sens_out[-1] + " " + s
108
- else:
109
- sens_out.append(s)
110
- try:
111
- if len(sens_out[-1]) <= 2:
112
- sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
113
- sens_out.pop(-1)
114
- except:
115
- pass
116
- return sens_out
117
-
118
-
119
- if __name__ == '__main__':
120
- zh_text = "好的,我来给你讲一个故事吧。从前有一个小姑娘,她叫做小红。小红非常喜欢在森林里玩耍,她经常会和她的小伙伴们一起去探险。有一天,小红和她的小伙伴们走到了森林深处,突然遇到了一只凶猛的野兽。小红的小伙伴们都吓得不敢动弹,但是小红并没有被吓倒,她勇敢地走向野兽,用她的智慧和勇气成功地制服了野兽,保护了她的小伙伴们。从那以后,小红变得更加勇敢和自信,成为了她小伙伴们心中的英雄。"
121
- en_text = "I didn’t know what to do. I said please kill her because it would be better than being kidnapped,” Ben, whose surname CNN is not using for security concerns, said on Wednesday. “It’s a nightmare. I said ‘please kill her, don’t take her there.’"
122
- sp_text = "¡Claro! ¿En qué tema te gustaría que te hable en español? Puedo proporcionarte información o conversar contigo sobre una amplia variedad de temas, desde cultura y comida hasta viajes y tecnología. ¿Tienes alguna preferencia en particular?"
123
- fr_text = "Bien sûr ! En quelle matière voudriez-vous que je vous parle en français ? Je peux vous fournir des informations ou discuter avec vous sur une grande variété de sujets, que ce soit la culture, la nourriture, les voyages ou la technologie. Avez-vous une préférence particulière ?"
124
- de_text = 'Es war das Wichtigste was wir sichern wollten da es keine Möglichkeit gab eine 20 Megatonnen- H- Bombe ab zu werfen von einem 30, C124.'
125
- ru_text = 'Но он был во многом, как-бы, всё равно что сын плантатора, так как являлся сыном человека, у которого было в собственности много чего.'
126
- print(split_sentence(zh_text, language_str='ZH'))
127
- print(split_sentence(en_text, language_str='EN'))
128
- print(split_sentence(sp_text, language_str='SP'))
129
- print(split_sentence(fr_text, language_str='FR'))
130
- print(split_sentence(de_text, language_str='DE'))
131
- print(split_sentence(ru_text, language_str='RU'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/__init__.py DELETED
@@ -1,35 +0,0 @@
1
- from .symbols import *
2
-
3
-
4
- _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
-
6
-
7
- def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
8
- """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
- Args:
10
- text: string to convert to a sequence
11
- Returns:
12
- List of integers corresponding to the symbols in the text
13
- """
14
- symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
15
- phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
16
- tone_start = language_tone_start_map[language]
17
- tones = [i + tone_start for i in tones]
18
- lang_id = language_id_map[language]
19
- lang_ids = [lang_id for i in phones]
20
- return phones, tones, lang_ids
21
-
22
-
23
- def get_bert(norm_text, word2ph, language, device):
24
- from .chinese_bert import get_bert_feature as zh_bert
25
- from .english_bert import get_bert_feature as en_bert
26
- from .japanese_bert import get_bert_feature as jp_bert
27
- from .chinese_mix import get_bert_feature as zh_mix_en_bert
28
- from .spanish_bert import get_bert_feature as sp_bert
29
- from .french_bert import get_bert_feature as fr_bert
30
- from .korean import get_bert_feature as kr_bert
31
-
32
- lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
33
- 'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
34
- bert = lang_bert_func_map[language](norm_text, word2ph, device)
35
- return bert
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/chinese.py DELETED
@@ -1,199 +0,0 @@
1
- import os
2
- import re
3
-
4
- import cn2an
5
- from pypinyin import lazy_pinyin, Style
6
-
7
- from .symbols import punctuation
8
- from .tone_sandhi import ToneSandhi
9
-
10
- current_file_path = os.path.dirname(__file__)
11
- pinyin_to_symbol_map = {
12
- line.split("\t")[0]: line.strip().split("\t")[1]
13
- for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
- }
15
-
16
- import jieba.posseg as psg
17
-
18
-
19
- rep_map = {
20
- ":": ",",
21
- ";": ",",
22
- ",": ",",
23
- "。": ".",
24
- "!": "!",
25
- "?": "?",
26
- "\n": ".",
27
- "·": ",",
28
- "、": ",",
29
- "...": "…",
30
- "$": ".",
31
- "“": "'",
32
- "”": "'",
33
- "‘": "'",
34
- "’": "'",
35
- "(": "'",
36
- ")": "'",
37
- "(": "'",
38
- ")": "'",
39
- "《": "'",
40
- "》": "'",
41
- "【": "'",
42
- "】": "'",
43
- "[": "'",
44
- "]": "'",
45
- "—": "-",
46
- "~": "-",
47
- "~": "-",
48
- "「": "'",
49
- "」": "'",
50
- }
51
-
52
- tone_modifier = ToneSandhi()
53
-
54
-
55
- def replace_punctuation(text):
56
- text = text.replace("嗯", "恩").replace("呣", "母")
57
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
58
-
59
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
60
-
61
- replaced_text = re.sub(
62
- r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
63
- )
64
-
65
- return replaced_text
66
-
67
-
68
- def g2p(text):
69
- pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
70
- sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
71
- phones, tones, word2ph = _g2p(sentences)
72
- assert sum(word2ph) == len(phones)
73
- assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
74
- phones = ["_"] + phones + ["_"]
75
- tones = [0] + tones + [0]
76
- word2ph = [1] + word2ph + [1]
77
- return phones, tones, word2ph
78
-
79
-
80
- def _get_initials_finals(word):
81
- initials = []
82
- finals = []
83
- orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
84
- orig_finals = lazy_pinyin(
85
- word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
86
- )
87
- for c, v in zip(orig_initials, orig_finals):
88
- initials.append(c)
89
- finals.append(v)
90
- return initials, finals
91
-
92
-
93
- def _g2p(segments):
94
- phones_list = []
95
- tones_list = []
96
- word2ph = []
97
- for seg in segments:
98
- # Replace all English words in the sentence
99
- seg = re.sub("[a-zA-Z]+", "", seg)
100
- seg_cut = psg.lcut(seg)
101
- initials = []
102
- finals = []
103
- seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
104
- for word, pos in seg_cut:
105
- if pos == "eng":
106
- import pdb; pdb.set_trace()
107
- continue
108
- sub_initials, sub_finals = _get_initials_finals(word)
109
- sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
- initials.append(sub_initials)
111
- finals.append(sub_finals)
112
-
113
- # assert len(sub_initials) == len(sub_finals) == len(word)
114
- initials = sum(initials, [])
115
- finals = sum(finals, [])
116
- #
117
- for c, v in zip(initials, finals):
118
- raw_pinyin = c + v
119
- # NOTE: post process for pypinyin outputs
120
- # we discriminate i, ii and iii
121
- if c == v:
122
- assert c in punctuation
123
- phone = [c]
124
- tone = "0"
125
- word2ph.append(1)
126
- else:
127
- v_without_tone = v[:-1]
128
- tone = v[-1]
129
-
130
- pinyin = c + v_without_tone
131
- assert tone in "12345"
132
-
133
- if c:
134
- # 多音节
135
- v_rep_map = {
136
- "uei": "ui",
137
- "iou": "iu",
138
- "uen": "un",
139
- }
140
- if v_without_tone in v_rep_map.keys():
141
- pinyin = c + v_rep_map[v_without_tone]
142
- else:
143
- # 单音节
144
- pinyin_rep_map = {
145
- "ing": "ying",
146
- "i": "yi",
147
- "in": "yin",
148
- "u": "wu",
149
- }
150
- if pinyin in pinyin_rep_map.keys():
151
- pinyin = pinyin_rep_map[pinyin]
152
- else:
153
- single_rep_map = {
154
- "v": "yu",
155
- "e": "e",
156
- "i": "y",
157
- "u": "w",
158
- }
159
- if pinyin[0] in single_rep_map.keys():
160
- pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
-
162
- assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
- phone = pinyin_to_symbol_map[pinyin].split(" ")
164
- word2ph.append(len(phone))
165
-
166
- phones_list += phone
167
- tones_list += [int(tone)] * len(phone)
168
- return phones_list, tones_list, word2ph
169
-
170
-
171
- def text_normalize(text):
172
- numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
- for number in numbers:
174
- text = text.replace(number, cn2an.an2cn(number), 1)
175
- text = replace_punctuation(text)
176
- return text
177
-
178
-
179
- def get_bert_feature(text, word2ph, device=None):
180
- from text import chinese_bert
181
-
182
- return chinese_bert.get_bert_feature(text, word2ph, device=device)
183
-
184
-
185
- if __name__ == "__main__":
186
- from text.chinese_bert import get_bert_feature
187
-
188
- text = "啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
- text = text_normalize(text)
190
- print(text)
191
- phones, tones, word2ph = g2p(text)
192
- bert = get_bert_feature(text, word2ph)
193
-
194
- print(phones, tones, word2ph, bert.shape)
195
-
196
-
197
- # # 示例用法
198
- # text = "这是一个示例文本:,你好!这是一个测试...."
199
- # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/chinese_bert.py DELETED
@@ -1,107 +0,0 @@
1
- import torch
2
- import sys
3
- from transformers import AutoTokenizer, AutoModelForMaskedLM
4
-
5
-
6
- # model_id = 'hfl/chinese-roberta-wwm-ext-large'
7
- local_path = "./bert/chinese-roberta-wwm-ext-large"
8
-
9
-
10
- tokenizers = {}
11
- models = {}
12
-
13
- def get_bert_feature(text, word2ph, device=None, model_id='hfl/chinese-roberta-wwm-ext-large'):
14
- if model_id not in models:
15
- models[model_id] = AutoModelForMaskedLM.from_pretrained(
16
- model_id
17
- ).to(device)
18
- tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
19
- model = models[model_id]
20
- tokenizer = tokenizers[model_id]
21
-
22
- if (
23
- sys.platform == "darwin"
24
- and torch.backends.mps.is_available()
25
- and device == "cpu"
26
- ):
27
- device = "mps"
28
- if not device:
29
- device = "cuda"
30
-
31
- with torch.no_grad():
32
- inputs = tokenizer(text, return_tensors="pt")
33
- for i in inputs:
34
- inputs[i] = inputs[i].to(device)
35
- res = model(**inputs, output_hidden_states=True)
36
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
37
- # import pdb; pdb.set_trace()
38
- # assert len(word2ph) == len(text) + 2
39
- word2phone = word2ph
40
- phone_level_feature = []
41
- for i in range(len(word2phone)):
42
- repeat_feature = res[i].repeat(word2phone[i], 1)
43
- phone_level_feature.append(repeat_feature)
44
-
45
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
46
- return phone_level_feature.T
47
-
48
-
49
- if __name__ == "__main__":
50
- import torch
51
-
52
- word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
53
- word2phone = [
54
- 1,
55
- 2,
56
- 1,
57
- 2,
58
- 2,
59
- 1,
60
- 2,
61
- 2,
62
- 1,
63
- 2,
64
- 2,
65
- 1,
66
- 2,
67
- 2,
68
- 2,
69
- 2,
70
- 2,
71
- 1,
72
- 1,
73
- 2,
74
- 2,
75
- 1,
76
- 2,
77
- 2,
78
- 2,
79
- 2,
80
- 1,
81
- 2,
82
- 2,
83
- 2,
84
- 2,
85
- 2,
86
- 1,
87
- 2,
88
- 2,
89
- 2,
90
- 2,
91
- 1,
92
- ]
93
-
94
- # 计算总帧数
95
- total_frames = sum(word2phone)
96
- print(word_level_feature.shape)
97
- print(word2phone)
98
- phone_level_feature = []
99
- for i in range(len(word2phone)):
100
- print(word_level_feature[i].shape)
101
-
102
- # 对每个词重复word2phone[i]次
103
- repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
104
- phone_level_feature.append(repeat_feature)
105
-
106
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
107
- print(phone_level_feature.shape) # torch.Size([36, 1024])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/chinese_mix.py DELETED
@@ -1,253 +0,0 @@
1
- import os
2
- import re
3
-
4
- import cn2an
5
- from pypinyin import lazy_pinyin, Style
6
-
7
- # from text.symbols import punctuation
8
- from .symbols import language_tone_start_map
9
- from .tone_sandhi import ToneSandhi
10
- from .english import g2p as g2p_en
11
- from transformers import AutoTokenizer
12
-
13
- punctuation = ["!", "?", "…", ",", ".", "'", "-"]
14
- current_file_path = os.path.dirname(__file__)
15
- pinyin_to_symbol_map = {
16
- line.split("\t")[0]: line.strip().split("\t")[1]
17
- for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
18
- }
19
-
20
- import jieba.posseg as psg
21
-
22
-
23
- rep_map = {
24
- ":": ",",
25
- ";": ",",
26
- ",": ",",
27
- "。": ".",
28
- "!": "!",
29
- "?": "?",
30
- "\n": ".",
31
- "·": ",",
32
- "、": ",",
33
- "...": "…",
34
- "$": ".",
35
- "“": "'",
36
- "”": "'",
37
- "‘": "'",
38
- "’": "'",
39
- "(": "'",
40
- ")": "'",
41
- "(": "'",
42
- ")": "'",
43
- "《": "'",
44
- "》": "'",
45
- "【": "'",
46
- "】": "'",
47
- "[": "'",
48
- "]": "'",
49
- "—": "-",
50
- "~": "-",
51
- "~": "-",
52
- "「": "'",
53
- "」": "'",
54
- }
55
-
56
- tone_modifier = ToneSandhi()
57
-
58
-
59
- def replace_punctuation(text):
60
- text = text.replace("嗯", "恩").replace("呣", "母")
61
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
62
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
63
- replaced_text = re.sub(r"[^\u4e00-\u9fa5_a-zA-Z\s" + "".join(punctuation) + r"]+", "", replaced_text)
64
- replaced_text = re.sub(r"[\s]+", " ", replaced_text)
65
-
66
- return replaced_text
67
-
68
-
69
- def g2p(text, impl='v2'):
70
- pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
- sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
- if impl == 'v1':
73
- _func = _g2p
74
- elif impl == 'v2':
75
- _func = _g2p_v2
76
- else:
77
- raise NotImplementedError()
78
- phones, tones, word2ph = _func(sentences)
79
- assert sum(word2ph) == len(phones)
80
- # assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
81
- phones = ["_"] + phones + ["_"]
82
- tones = [0] + tones + [0]
83
- word2ph = [1] + word2ph + [1]
84
- return phones, tones, word2ph
85
-
86
-
87
- def _get_initials_finals(word):
88
- initials = []
89
- finals = []
90
- orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
91
- orig_finals = lazy_pinyin(
92
- word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
93
- )
94
- for c, v in zip(orig_initials, orig_finals):
95
- initials.append(c)
96
- finals.append(v)
97
- return initials, finals
98
-
99
- model_id = 'bert-base-multilingual-uncased'
100
- tokenizer = AutoTokenizer.from_pretrained(model_id)
101
- def _g2p(segments):
102
- phones_list = []
103
- tones_list = []
104
- word2ph = []
105
- for seg in segments:
106
- # Replace all English words in the sentence
107
- # seg = re.sub("[a-zA-Z]+", "", seg)
108
- seg_cut = psg.lcut(seg)
109
- initials = []
110
- finals = []
111
- seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
112
- for word, pos in seg_cut:
113
- if pos == "eng":
114
- initials.append(['EN_WORD'])
115
- finals.append([word])
116
- else:
117
- sub_initials, sub_finals = _get_initials_finals(word)
118
- sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
119
- initials.append(sub_initials)
120
- finals.append(sub_finals)
121
-
122
- # assert len(sub_initials) == len(sub_finals) == len(word)
123
- initials = sum(initials, [])
124
- finals = sum(finals, [])
125
- #
126
- for c, v in zip(initials, finals):
127
- if c == 'EN_WORD':
128
- tokenized_en = tokenizer.tokenize(v)
129
- phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
130
- # apply offset to tones_en
131
- tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
132
- phones_list += phones_en
133
- tones_list += tones_en
134
- word2ph += word2ph_en
135
- else:
136
- raw_pinyin = c + v
137
- # NOTE: post process for pypinyin outputs
138
- # we discriminate i, ii and iii
139
- if c == v:
140
- assert c in punctuation
141
- phone = [c]
142
- tone = "0"
143
- word2ph.append(1)
144
- else:
145
- v_without_tone = v[:-1]
146
- tone = v[-1]
147
-
148
- pinyin = c + v_without_tone
149
- assert tone in "12345"
150
-
151
- if c:
152
- # 多音节
153
- v_rep_map = {
154
- "uei": "ui",
155
- "iou": "iu",
156
- "uen": "un",
157
- }
158
- if v_without_tone in v_rep_map.keys():
159
- pinyin = c + v_rep_map[v_without_tone]
160
- else:
161
- # 单音节
162
- pinyin_rep_map = {
163
- "ing": "ying",
164
- "i": "yi",
165
- "in": "yin",
166
- "u": "wu",
167
- }
168
- if pinyin in pinyin_rep_map.keys():
169
- pinyin = pinyin_rep_map[pinyin]
170
- else:
171
- single_rep_map = {
172
- "v": "yu",
173
- "e": "e",
174
- "i": "y",
175
- "u": "w",
176
- }
177
- if pinyin[0] in single_rep_map.keys():
178
- pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
179
-
180
- assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
181
- phone = pinyin_to_symbol_map[pinyin].split(" ")
182
- word2ph.append(len(phone))
183
-
184
- phones_list += phone
185
- tones_list += [int(tone)] * len(phone)
186
- return phones_list, tones_list, word2ph
187
-
188
-
189
- def text_normalize(text):
190
- numbers = re.findall(r"\d+(?:\.?\d+)?", text)
191
- for number in numbers:
192
- text = text.replace(number, cn2an.an2cn(number), 1)
193
- text = replace_punctuation(text)
194
- return text
195
-
196
-
197
- def get_bert_feature(text, word2ph, device):
198
- from . import chinese_bert
199
- return chinese_bert.get_bert_feature(text, word2ph, model_id='bert-base-multilingual-uncased', device=device)
200
-
201
- from .chinese import _g2p as _chinese_g2p
202
- def _g2p_v2(segments):
203
- spliter = '#$&^!@'
204
-
205
- phones_list = []
206
- tones_list = []
207
- word2ph = []
208
-
209
- for text in segments:
210
- assert spliter not in text
211
- # replace all english words
212
- text = re.sub('([a-zA-Z\s]+)', lambda x: f'{spliter}{x.group(1)}{spliter}', text)
213
- texts = text.split(spliter)
214
- texts = [t for t in texts if len(t) > 0]
215
-
216
-
217
- for text in texts:
218
- if re.match('[a-zA-Z\s]+', text):
219
- # english
220
- tokenized_en = tokenizer.tokenize(text)
221
- phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
222
- # apply offset to tones_en
223
- tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
224
- phones_list += phones_en
225
- tones_list += tones_en
226
- word2ph += word2ph_en
227
- else:
228
- phones_zh, tones_zh, word2ph_zh = _chinese_g2p([text])
229
- phones_list += phones_zh
230
- tones_list += tones_zh
231
- word2ph += word2ph_zh
232
- return phones_list, tones_list, word2ph
233
-
234
-
235
-
236
- if __name__ == "__main__":
237
- # from text.chinese_bert import get_bert_feature
238
-
239
- text = "NFT啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
240
- text = '我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。'
241
- text = '今天下午,我们准备去shopping mall购物,然后晚上去看一场movie。'
242
- text = '我们现在 also 能够 help 很多公司 use some machine learning 的 algorithms 啊!'
243
- text = text_normalize(text)
244
- print(text)
245
- phones, tones, word2ph = g2p(text, impl='v2')
246
- bert = get_bert_feature(text, word2ph, device='cuda:0')
247
- print(phones)
248
- import pdb; pdb.set_trace()
249
-
250
-
251
- # # 示例用法
252
- # text = "这是一个示例文本:,你好!这是一个测试...."
253
- # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/cleaner.py DELETED
@@ -1,36 +0,0 @@
1
- from . import chinese, japanese, english, chinese_mix, korean, french, spanish
2
- from . import cleaned_text_to_sequence
3
- import copy
4
-
5
- language_module_map = {"ZH": chinese, "JP": japanese, "EN": english, 'ZH_MIX_EN': chinese_mix, 'KR': korean,
6
- 'FR': french, 'SP': spanish, 'ES': spanish}
7
-
8
-
9
- def clean_text(text, language):
10
- language_module = language_module_map[language]
11
- norm_text = language_module.text_normalize(text)
12
- phones, tones, word2ph = language_module.g2p(norm_text)
13
- return norm_text, phones, tones, word2ph
14
-
15
-
16
- def clean_text_bert(text, language, device=None):
17
- language_module = language_module_map[language]
18
- norm_text = language_module.text_normalize(text)
19
- phones, tones, word2ph = language_module.g2p(norm_text)
20
-
21
- word2ph_bak = copy.deepcopy(word2ph)
22
- for i in range(len(word2ph)):
23
- word2ph[i] = word2ph[i] * 2
24
- word2ph[0] += 1
25
- bert = language_module.get_bert_feature(norm_text, word2ph, device=device)
26
-
27
- return norm_text, phones, tones, word2ph_bak, bert
28
-
29
-
30
- def text_to_sequence(text, language):
31
- norm_text, phones, tones, word2ph = clean_text(text, language)
32
- return cleaned_text_to_sequence(phones, tones, language)
33
-
34
-
35
- if __name__ == "__main__":
36
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/cleaner_multiling.py DELETED
@@ -1,110 +0,0 @@
1
- """Set of default text cleaners"""
2
- # TODO: pick the cleaner for languages dynamically
3
-
4
- import re
5
-
6
- # Regular expression matching whitespace:
7
- _whitespace_re = re.compile(r"\s+")
8
-
9
- rep_map = {
10
- ":": ",",
11
- ";": ",",
12
- ",": ",",
13
- "。": ".",
14
- "!": "!",
15
- "?": "?",
16
- "\n": ".",
17
- "·": ",",
18
- "、": ",",
19
- "...": ".",
20
- "…": ".",
21
- "$": ".",
22
- "“": "'",
23
- "”": "'",
24
- "‘": "'",
25
- "’": "'",
26
- "(": "'",
27
- ")": "'",
28
- "(": "'",
29
- ")": "'",
30
- "《": "'",
31
- "》": "'",
32
- "【": "'",
33
- "】": "'",
34
- "[": "'",
35
- "]": "'",
36
- "—": "",
37
- "~": "-",
38
- "~": "-",
39
- "「": "'",
40
- "」": "'",
41
- }
42
-
43
- def replace_punctuation(text):
44
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
45
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
46
- return replaced_text
47
-
48
- def lowercase(text):
49
- return text.lower()
50
-
51
-
52
- def collapse_whitespace(text):
53
- return re.sub(_whitespace_re, " ", text).strip()
54
-
55
- def remove_punctuation_at_begin(text):
56
- return re.sub(r'^[,.!?]+', '', text)
57
-
58
- def remove_aux_symbols(text):
59
- text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
60
- return text
61
-
62
-
63
- def replace_symbols(text, lang="en"):
64
- """Replace symbols based on the lenguage tag.
65
-
66
- Args:
67
- text:
68
- Input text.
69
- lang:
70
- Lenguage identifier. ex: "en", "fr", "pt", "ca".
71
-
72
- Returns:
73
- The modified text
74
- example:
75
- input args:
76
- text: "si l'avi cau, diguem-ho"
77
- lang: "ca"
78
- Output:
79
- text: "si lavi cau, diguemho"
80
- """
81
- text = text.replace(";", ",")
82
- text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
83
- text = text.replace(":", ",")
84
- if lang == "en":
85
- text = text.replace("&", " and ")
86
- elif lang == "fr":
87
- text = text.replace("&", " et ")
88
- elif lang == "pt":
89
- text = text.replace("&", " e ")
90
- elif lang == "ca":
91
- text = text.replace("&", " i ")
92
- text = text.replace("'", "")
93
- elif lang== "es":
94
- text=text.replace("&","y")
95
- text = text.replace("'", "")
96
- return text
97
-
98
- def unicleaners(text, cased=False, lang='en'):
99
- """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
100
- numbers, phonemizer already does that"""
101
- if not cased:
102
- text = lowercase(text)
103
- text = replace_punctuation(text)
104
- text = replace_symbols(text, lang=lang)
105
- text = remove_aux_symbols(text)
106
- text = remove_punctuation_at_begin(text)
107
- text = collapse_whitespace(text)
108
- text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
109
- return text
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/cmudict.rep DELETED
The diff for this file is too large to render. See raw diff
 
melo/text/cmudict_cache.pickle DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
- size 6212655
 
 
 
 
melo/text/english.py DELETED
@@ -1,284 +0,0 @@
1
- import pickle
2
- import os
3
- import re
4
- from g2p_en import G2p
5
-
6
- from . import symbols
7
-
8
- from .english_utils.abbreviations import expand_abbreviations
9
- from .english_utils.time_norm import expand_time_english
10
- from .english_utils.number_norm import normalize_numbers
11
- from .japanese import distribute_phone
12
-
13
- from transformers import AutoTokenizer
14
-
15
- current_file_path = os.path.dirname(__file__)
16
- CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
17
- CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
18
- _g2p = G2p()
19
-
20
- arpa = {
21
- "AH0",
22
- "S",
23
- "AH1",
24
- "EY2",
25
- "AE2",
26
- "EH0",
27
- "OW2",
28
- "UH0",
29
- "NG",
30
- "B",
31
- "G",
32
- "AY0",
33
- "M",
34
- "AA0",
35
- "F",
36
- "AO0",
37
- "ER2",
38
- "UH1",
39
- "IY1",
40
- "AH2",
41
- "DH",
42
- "IY0",
43
- "EY1",
44
- "IH0",
45
- "K",
46
- "N",
47
- "W",
48
- "IY2",
49
- "T",
50
- "AA1",
51
- "ER1",
52
- "EH2",
53
- "OY0",
54
- "UH2",
55
- "UW1",
56
- "Z",
57
- "AW2",
58
- "AW1",
59
- "V",
60
- "UW2",
61
- "AA2",
62
- "ER",
63
- "AW0",
64
- "UW0",
65
- "R",
66
- "OW1",
67
- "EH1",
68
- "ZH",
69
- "AE0",
70
- "IH2",
71
- "IH",
72
- "Y",
73
- "JH",
74
- "P",
75
- "AY1",
76
- "EY0",
77
- "OY2",
78
- "TH",
79
- "HH",
80
- "D",
81
- "ER0",
82
- "CH",
83
- "AO1",
84
- "AE1",
85
- "AO2",
86
- "OY1",
87
- "AY2",
88
- "IH1",
89
- "OW0",
90
- "L",
91
- "SH",
92
- }
93
-
94
-
95
- def post_replace_ph(ph):
96
- rep_map = {
97
- ":": ",",
98
- ";": ",",
99
- ",": ",",
100
- "。": ".",
101
- "!": "!",
102
- "?": "?",
103
- "\n": ".",
104
- "·": ",",
105
- "、": ",",
106
- "...": "…",
107
- "v": "V",
108
- }
109
- if ph in rep_map.keys():
110
- ph = rep_map[ph]
111
- if ph in symbols:
112
- return ph
113
- if ph not in symbols:
114
- ph = "UNK"
115
- return ph
116
-
117
-
118
- def read_dict():
119
- g2p_dict = {}
120
- start_line = 49
121
- with open(CMU_DICT_PATH) as f:
122
- line = f.readline()
123
- line_index = 1
124
- while line:
125
- if line_index >= start_line:
126
- line = line.strip()
127
- word_split = line.split(" ")
128
- word = word_split[0]
129
-
130
- syllable_split = word_split[1].split(" - ")
131
- g2p_dict[word] = []
132
- for syllable in syllable_split:
133
- phone_split = syllable.split(" ")
134
- g2p_dict[word].append(phone_split)
135
-
136
- line_index = line_index + 1
137
- line = f.readline()
138
-
139
- return g2p_dict
140
-
141
-
142
- def cache_dict(g2p_dict, file_path):
143
- with open(file_path, "wb") as pickle_file:
144
- pickle.dump(g2p_dict, pickle_file)
145
-
146
-
147
- def get_dict():
148
- if os.path.exists(CACHE_PATH):
149
- with open(CACHE_PATH, "rb") as pickle_file:
150
- g2p_dict = pickle.load(pickle_file)
151
- else:
152
- g2p_dict = read_dict()
153
- cache_dict(g2p_dict, CACHE_PATH)
154
-
155
- return g2p_dict
156
-
157
-
158
- eng_dict = get_dict()
159
-
160
-
161
- def refine_ph(phn):
162
- tone = 0
163
- if re.search(r"\d$", phn):
164
- tone = int(phn[-1]) + 1
165
- phn = phn[:-1]
166
- return phn.lower(), tone
167
-
168
-
169
- def refine_syllables(syllables):
170
- tones = []
171
- phonemes = []
172
- for phn_list in syllables:
173
- for i in range(len(phn_list)):
174
- phn = phn_list[i]
175
- phn, tone = refine_ph(phn)
176
- phonemes.append(phn)
177
- tones.append(tone)
178
- return phonemes, tones
179
-
180
-
181
- def text_normalize(text):
182
- text = text.lower()
183
- text = expand_time_english(text)
184
- text = normalize_numbers(text)
185
- text = expand_abbreviations(text)
186
- return text
187
-
188
- model_id = 'bert-base-uncased'
189
- tokenizer = AutoTokenizer.from_pretrained(model_id)
190
- def g2p_old(text):
191
- tokenized = tokenizer.tokenize(text)
192
- # import pdb; pdb.set_trace()
193
- phones = []
194
- tones = []
195
- words = re.split(r"([,;.\-\?\!\s+])", text)
196
- for w in words:
197
- if w.upper() in eng_dict:
198
- phns, tns = refine_syllables(eng_dict[w.upper()])
199
- phones += phns
200
- tones += tns
201
- else:
202
- phone_list = list(filter(lambda p: p != " ", _g2p(w)))
203
- for ph in phone_list:
204
- if ph in arpa:
205
- ph, tn = refine_ph(ph)
206
- phones.append(ph)
207
- tones.append(tn)
208
- else:
209
- phones.append(ph)
210
- tones.append(0)
211
- # todo: implement word2ph
212
- word2ph = [1 for i in phones]
213
-
214
- phones = [post_replace_ph(i) for i in phones]
215
- return phones, tones, word2ph
216
-
217
- def g2p(text, pad_start_end=True, tokenized=None):
218
- if tokenized is None:
219
- tokenized = tokenizer.tokenize(text)
220
- # import pdb; pdb.set_trace()
221
- phs = []
222
- ph_groups = []
223
- for t in tokenized:
224
- if not t.startswith("#"):
225
- ph_groups.append([t])
226
- else:
227
- ph_groups[-1].append(t.replace("#", ""))
228
-
229
- phones = []
230
- tones = []
231
- word2ph = []
232
- for group in ph_groups:
233
- w = "".join(group)
234
- phone_len = 0
235
- word_len = len(group)
236
- if w.upper() in eng_dict:
237
- phns, tns = refine_syllables(eng_dict[w.upper()])
238
- phones += phns
239
- tones += tns
240
- phone_len += len(phns)
241
- else:
242
- phone_list = list(filter(lambda p: p != " ", _g2p(w)))
243
- for ph in phone_list:
244
- if ph in arpa:
245
- ph, tn = refine_ph(ph)
246
- phones.append(ph)
247
- tones.append(tn)
248
- else:
249
- phones.append(ph)
250
- tones.append(0)
251
- phone_len += 1
252
- aaa = distribute_phone(phone_len, word_len)
253
- word2ph += aaa
254
- phones = [post_replace_ph(i) for i in phones]
255
-
256
- if pad_start_end:
257
- phones = ["_"] + phones + ["_"]
258
- tones = [0] + tones + [0]
259
- word2ph = [1] + word2ph + [1]
260
- return phones, tones, word2ph
261
-
262
- def get_bert_feature(text, word2ph, device=None):
263
- from text import english_bert
264
-
265
- return english_bert.get_bert_feature(text, word2ph, device=device)
266
-
267
- if __name__ == "__main__":
268
- # print(get_dict())
269
- # print(eng_word_to_phoneme("hello"))
270
- from text.english_bert import get_bert_feature
271
- text = "In this paper, we propose 1 DSPGAN, a N-F-T GAN-based universal vocoder."
272
- text = text_normalize(text)
273
- phones, tones, word2ph = g2p(text)
274
- import pdb; pdb.set_trace()
275
- bert = get_bert_feature(text, word2ph)
276
-
277
- print(phones, tones, word2ph, bert.shape)
278
-
279
- # all_phones = set()
280
- # for k, syllables in eng_dict.items():
281
- # for group in syllables:
282
- # for ph in group:
283
- # all_phones.add(ph)
284
- # print(all_phones)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/english_bert.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForMaskedLM
3
- import sys
4
-
5
- model_id = 'bert-base-uncased'
6
- tokenizer = AutoTokenizer.from_pretrained(model_id)
7
- model = None
8
-
9
- def get_bert_feature(text, word2ph, device=None):
10
- global model
11
- if (
12
- sys.platform == "darwin"
13
- and torch.backends.mps.is_available()
14
- and device == "cpu"
15
- ):
16
- device = "mps"
17
- if not device:
18
- device = "cuda"
19
- if model is None:
20
- model = AutoModelForMaskedLM.from_pretrained(model_id).to(
21
- device
22
- )
23
- with torch.no_grad():
24
- inputs = tokenizer(text, return_tensors="pt")
25
- for i in inputs:
26
- inputs[i] = inputs[i].to(device)
27
- res = model(**inputs, output_hidden_states=True)
28
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
29
-
30
- assert inputs["input_ids"].shape[-1] == len(word2ph)
31
- word2phone = word2ph
32
- phone_level_feature = []
33
- for i in range(len(word2phone)):
34
- repeat_feature = res[i].repeat(word2phone[i], 1)
35
- phone_level_feature.append(repeat_feature)
36
-
37
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
38
-
39
- return phone_level_feature.T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/english_utils/__init__.py DELETED
File without changes
melo/text/english_utils/abbreviations.py DELETED
@@ -1,35 +0,0 @@
1
- import re
2
-
3
- # List of (regular expression, replacement) pairs for abbreviations in english:
4
- abbreviations_en = [
5
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
6
- for x in [
7
- ("mrs", "misess"),
8
- ("mr", "mister"),
9
- ("dr", "doctor"),
10
- ("st", "saint"),
11
- ("co", "company"),
12
- ("jr", "junior"),
13
- ("maj", "major"),
14
- ("gen", "general"),
15
- ("drs", "doctors"),
16
- ("rev", "reverend"),
17
- ("lt", "lieutenant"),
18
- ("hon", "honorable"),
19
- ("sgt", "sergeant"),
20
- ("capt", "captain"),
21
- ("esq", "esquire"),
22
- ("ltd", "limited"),
23
- ("col", "colonel"),
24
- ("ft", "fort"),
25
- ]
26
- ]
27
-
28
- def expand_abbreviations(text, lang="en"):
29
- if lang == "en":
30
- _abbreviations = abbreviations_en
31
- else:
32
- raise NotImplementedError()
33
- for regex, replacement in _abbreviations:
34
- text = re.sub(regex, replacement, text)
35
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/english_utils/number_norm.py DELETED
@@ -1,97 +0,0 @@
1
- """ from https://github.com/keithito/tacotron """
2
-
3
- import re
4
- from typing import Dict
5
-
6
- import inflect
7
-
8
- _inflect = inflect.engine()
9
- _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
10
- _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
11
- _currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
12
- _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
- _number_re = re.compile(r"-?[0-9]+")
14
-
15
-
16
- def _remove_commas(m):
17
- return m.group(1).replace(",", "")
18
-
19
-
20
- def _expand_decimal_point(m):
21
- return m.group(1).replace(".", " point ")
22
-
23
-
24
- def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
25
- parts = value.replace(",", "").split(".")
26
- if len(parts) > 2:
27
- return f"{value} {inflection[2]}" # Unexpected format
28
- text = []
29
- integer = int(parts[0]) if parts[0] else 0
30
- if integer > 0:
31
- integer_unit = inflection.get(integer, inflection[2])
32
- text.append(f"{integer} {integer_unit}")
33
- fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
34
- if fraction > 0:
35
- fraction_unit = inflection.get(fraction / 100, inflection[0.02])
36
- text.append(f"{fraction} {fraction_unit}")
37
- if len(text) == 0:
38
- return f"zero {inflection[2]}"
39
- return " ".join(text)
40
-
41
-
42
- def _expand_currency(m: "re.Match") -> str:
43
- currencies = {
44
- "$": {
45
- 0.01: "cent",
46
- 0.02: "cents",
47
- 1: "dollar",
48
- 2: "dollars",
49
- },
50
- "€": {
51
- 0.01: "cent",
52
- 0.02: "cents",
53
- 1: "euro",
54
- 2: "euros",
55
- },
56
- "£": {
57
- 0.01: "penny",
58
- 0.02: "pence",
59
- 1: "pound sterling",
60
- 2: "pounds sterling",
61
- },
62
- "¥": {
63
- # TODO rin
64
- 0.02: "sen",
65
- 2: "yen",
66
- },
67
- }
68
- unit = m.group(1)
69
- currency = currencies[unit]
70
- value = m.group(2)
71
- return __expand_currency(value, currency)
72
-
73
-
74
- def _expand_ordinal(m):
75
- return _inflect.number_to_words(m.group(0))
76
-
77
-
78
- def _expand_number(m):
79
- num = int(m.group(0))
80
- if 1000 < num < 3000:
81
- if num == 2000:
82
- return "two thousand"
83
- if 2000 < num < 2010:
84
- return "two thousand " + _inflect.number_to_words(num % 100)
85
- if num % 100 == 0:
86
- return _inflect.number_to_words(num // 100) + " hundred"
87
- return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
88
- return _inflect.number_to_words(num, andword="")
89
-
90
-
91
- def normalize_numbers(text):
92
- text = re.sub(_comma_number_re, _remove_commas, text)
93
- text = re.sub(_currency_re, _expand_currency, text)
94
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
95
- text = re.sub(_ordinal_re, _expand_ordinal, text)
96
- text = re.sub(_number_re, _expand_number, text)
97
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/english_utils/time_norm.py DELETED
@@ -1,47 +0,0 @@
1
- import re
2
-
3
- import inflect
4
-
5
- _inflect = inflect.engine()
6
-
7
- _time_re = re.compile(
8
- r"""\b
9
- ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours
10
- :
11
- ([0-5][0-9]) # minutes
12
- \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
13
- \b""",
14
- re.IGNORECASE | re.X,
15
- )
16
-
17
-
18
- def _expand_num(n: int) -> str:
19
- return _inflect.number_to_words(n)
20
-
21
-
22
- def _expand_time_english(match: "re.Match") -> str:
23
- hour = int(match.group(1))
24
- past_noon = hour >= 12
25
- time = []
26
- if hour > 12:
27
- hour -= 12
28
- elif hour == 0:
29
- hour = 12
30
- past_noon = True
31
- time.append(_expand_num(hour))
32
-
33
- minute = int(match.group(6))
34
- if minute > 0:
35
- if minute < 10:
36
- time.append("oh")
37
- time.append(_expand_num(minute))
38
- am_pm = match.group(7)
39
- if am_pm is None:
40
- time.append("p m" if past_noon else "a m")
41
- else:
42
- time.extend(list(am_pm.replace(".", "")))
43
- return " ".join(time)
44
-
45
-
46
- def expand_time_english(text: str) -> str:
47
- return re.sub(_time_re, _expand_time_english, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/__init__.py DELETED
File without changes
melo/text/es_phonemizer/base.py DELETED
@@ -1,140 +0,0 @@
1
- import abc
2
- from typing import List, Tuple
3
-
4
- from .punctuation import Punctuation
5
-
6
-
7
- class BasePhonemizer(abc.ABC):
8
- """Base phonemizer class
9
-
10
- Phonemization follows the following steps:
11
- 1. Preprocessing:
12
- - remove empty lines
13
- - remove punctuation
14
- - keep track of punctuation marks
15
-
16
- 2. Phonemization:
17
- - convert text to phonemes
18
-
19
- 3. Postprocessing:
20
- - join phonemes
21
- - restore punctuation marks
22
-
23
- Args:
24
- language (str):
25
- Language used by the phonemizer.
26
-
27
- punctuations (List[str]):
28
- List of punctuation marks to be preserved.
29
-
30
- keep_puncs (bool):
31
- Whether to preserve punctuation marks or not.
32
- """
33
-
34
- def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
35
- # ensure the backend is installed on the system
36
- if not self.is_available():
37
- raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
38
-
39
- # ensure the backend support the requested language
40
- self._language = self._init_language(language)
41
-
42
- # setup punctuation processing
43
- self._keep_puncs = keep_puncs
44
- self._punctuator = Punctuation(punctuations)
45
-
46
- def _init_language(self, language):
47
- """Language initialization
48
-
49
- This method may be overloaded in child classes (see Segments backend)
50
-
51
- """
52
- if not self.is_supported_language(language):
53
- raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
54
- return language
55
-
56
- @property
57
- def language(self):
58
- """The language code configured to be used for phonemization"""
59
- return self._language
60
-
61
- @staticmethod
62
- @abc.abstractmethod
63
- def name():
64
- """The name of the backend"""
65
- ...
66
-
67
- @classmethod
68
- @abc.abstractmethod
69
- def is_available(cls):
70
- """Returns True if the backend is installed, False otherwise"""
71
- ...
72
-
73
- @classmethod
74
- @abc.abstractmethod
75
- def version(cls):
76
- """Return the backend version as a tuple (major, minor, patch)"""
77
- ...
78
-
79
- @staticmethod
80
- @abc.abstractmethod
81
- def supported_languages():
82
- """Return a dict of language codes -> name supported by the backend"""
83
- ...
84
-
85
- def is_supported_language(self, language):
86
- """Returns True if `language` is supported by the backend"""
87
- return language in self.supported_languages()
88
-
89
- @abc.abstractmethod
90
- def _phonemize(self, text, separator):
91
- """The main phonemization method"""
92
-
93
- def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
94
- """Preprocess the text before phonemization
95
-
96
- 1. remove spaces
97
- 2. remove punctuation
98
-
99
- Override this if you need a different behaviour
100
- """
101
- text = text.strip()
102
- if self._keep_puncs:
103
- # a tuple (text, punctuation marks)
104
- return self._punctuator.strip_to_restore(text)
105
- return [self._punctuator.strip(text)], []
106
-
107
- def _phonemize_postprocess(self, phonemized, punctuations) -> str:
108
- """Postprocess the raw phonemized output
109
-
110
- Override this if you need a different behaviour
111
- """
112
- if self._keep_puncs:
113
- return self._punctuator.restore(phonemized, punctuations)[0]
114
- return phonemized[0]
115
-
116
- def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
117
- """Returns the `text` phonemized for the given language
118
-
119
- Args:
120
- text (str):
121
- Text to be phonemized.
122
-
123
- separator (str):
124
- string separator used between phonemes. Default to '_'.
125
-
126
- Returns:
127
- (str): Phonemized text
128
- """
129
- text, punctuations = self._phonemize_preprocess(text)
130
- phonemized = []
131
- for t in text:
132
- p = self._phonemize(t, separator)
133
- phonemized.append(p)
134
- phonemized = self._phonemize_postprocess(phonemized, punctuations)
135
- return phonemized
136
-
137
- def print_logs(self, level: int = 0):
138
- indent = "\t" * level
139
- print(f"{indent}| > phoneme language: {self.language}")
140
- print(f"{indent}| > phoneme backend: {self.name()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/cleaner.py DELETED
@@ -1,109 +0,0 @@
1
- """Set of default text cleaners"""
2
- # TODO: pick the cleaner for languages dynamically
3
-
4
- import re
5
-
6
- # Regular expression matching whitespace:
7
- _whitespace_re = re.compile(r"\s+")
8
-
9
- rep_map = {
10
- ":": ",",
11
- ";": ",",
12
- ",": ",",
13
- "。": ".",
14
- "!": "!",
15
- "?": "?",
16
- "\n": ".",
17
- "·": ",",
18
- "、": ",",
19
- "...": ".",
20
- "…": ".",
21
- "$": ".",
22
- "“": "'",
23
- "”": "'",
24
- "‘": "'",
25
- "’": "'",
26
- "(": "'",
27
- ")": "'",
28
- "(": "'",
29
- ")": "'",
30
- "《": "'",
31
- "》": "'",
32
- "【": "'",
33
- "】": "'",
34
- "[": "'",
35
- "]": "'",
36
- "—": "",
37
- "~": "-",
38
- "~": "-",
39
- "「": "'",
40
- "」": "'",
41
- }
42
-
43
- def replace_punctuation(text):
44
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
45
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
46
- return replaced_text
47
-
48
- def lowercase(text):
49
- return text.lower()
50
-
51
-
52
- def collapse_whitespace(text):
53
- return re.sub(_whitespace_re, " ", text).strip()
54
-
55
- def remove_punctuation_at_begin(text):
56
- return re.sub(r'^[,.!?]+', '', text)
57
-
58
- def remove_aux_symbols(text):
59
- text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
60
- return text
61
-
62
-
63
- def replace_symbols(text, lang="en"):
64
- """Replace symbols based on the lenguage tag.
65
-
66
- Args:
67
- text:
68
- Input text.
69
- lang:
70
- Lenguage identifier. ex: "en", "fr", "pt", "ca".
71
-
72
- Returns:
73
- The modified text
74
- example:
75
- input args:
76
- text: "si l'avi cau, diguem-ho"
77
- lang: "ca"
78
- Output:
79
- text: "si lavi cau, diguemho"
80
- """
81
- text = text.replace(";", ",")
82
- text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
83
- text = text.replace(":", ",")
84
- if lang == "en":
85
- text = text.replace("&", " and ")
86
- elif lang == "fr":
87
- text = text.replace("&", " et ")
88
- elif lang == "pt":
89
- text = text.replace("&", " e ")
90
- elif lang == "ca":
91
- text = text.replace("&", " i ")
92
- text = text.replace("'", "")
93
- elif lang== "es":
94
- text=text.replace("&","y")
95
- text = text.replace("'", "")
96
- return text
97
-
98
- def spanish_cleaners(text):
99
- """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
100
- numbers, phonemizer already does that"""
101
- text = lowercase(text)
102
- text = replace_symbols(text, lang="es")
103
- text = replace_punctuation(text)
104
- text = remove_aux_symbols(text)
105
- text = remove_punctuation_at_begin(text)
106
- text = collapse_whitespace(text)
107
- text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
108
- return text
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/es_symbols.json DELETED
@@ -1,79 +0,0 @@
1
- {
2
- "symbols": [
3
- "_",
4
- ",",
5
- ".",
6
- "!",
7
- "?",
8
- "-",
9
- "~",
10
- "\u2026",
11
- "N",
12
- "Q",
13
- "a",
14
- "b",
15
- "d",
16
- "e",
17
- "f",
18
- "g",
19
- "h",
20
- "i",
21
- "j",
22
- "k",
23
- "l",
24
- "m",
25
- "n",
26
- "o",
27
- "p",
28
- "s",
29
- "t",
30
- "u",
31
- "v",
32
- "w",
33
- "x",
34
- "y",
35
- "z",
36
- "\u0251",
37
- "\u00e6",
38
- "\u0283",
39
- "\u0291",
40
- "\u00e7",
41
- "\u026f",
42
- "\u026a",
43
- "\u0254",
44
- "\u025b",
45
- "\u0279",
46
- "\u00f0",
47
- "\u0259",
48
- "\u026b",
49
- "\u0265",
50
- "\u0278",
51
- "\u028a",
52
- "\u027e",
53
- "\u0292",
54
- "\u03b8",
55
- "\u03b2",
56
- "\u014b",
57
- "\u0266",
58
- "\u207c",
59
- "\u02b0",
60
- "`",
61
- "^",
62
- "#",
63
- "*",
64
- "=",
65
- "\u02c8",
66
- "\u02cc",
67
- "\u2192",
68
- "\u2193",
69
- "\u2191",
70
- " ",
71
- "\u0263",
72
- "\u0261",
73
- "r",
74
- "\u0272",
75
- "\u029d",
76
- "\u028e",
77
- "\u02d0"
78
- ]
79
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/es_symbols.txt DELETED
@@ -1 +0,0 @@
1
- _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɡrɲʝɣʎː—¿¡
 
 
melo/text/es_phonemizer/es_symbols_v2.json DELETED
@@ -1,83 +0,0 @@
1
- {
2
- "symbols": [
3
- "_",
4
- ",",
5
- ".",
6
- "!",
7
- "?",
8
- "-",
9
- "~",
10
- "\u2026",
11
- "N",
12
- "Q",
13
- "a",
14
- "b",
15
- "d",
16
- "e",
17
- "f",
18
- "g",
19
- "h",
20
- "i",
21
- "j",
22
- "k",
23
- "l",
24
- "m",
25
- "n",
26
- "o",
27
- "p",
28
- "s",
29
- "t",
30
- "u",
31
- "v",
32
- "w",
33
- "x",
34
- "y",
35
- "z",
36
- "\u0251",
37
- "\u00e6",
38
- "\u0283",
39
- "\u0291",
40
- "\u00e7",
41
- "\u026f",
42
- "\u026a",
43
- "\u0254",
44
- "\u025b",
45
- "\u0279",
46
- "\u00f0",
47
- "\u0259",
48
- "\u026b",
49
- "\u0265",
50
- "\u0278",
51
- "\u028a",
52
- "\u027e",
53
- "\u0292",
54
- "\u03b8",
55
- "\u03b2",
56
- "\u014b",
57
- "\u0266",
58
- "\u207c",
59
- "\u02b0",
60
- "`",
61
- "^",
62
- "#",
63
- "*",
64
- "=",
65
- "\u02c8",
66
- "\u02cc",
67
- "\u2192",
68
- "\u2193",
69
- "\u2191",
70
- " ",
71
- "\u0261",
72
- "r",
73
- "\u0272",
74
- "\u029d",
75
- "\u0263",
76
- "\u028e",
77
- "\u02d0",
78
-
79
- "\u2014",
80
- "\u00bf",
81
- "\u00a1"
82
- ]
83
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/es_to_ipa.py DELETED
@@ -1,12 +0,0 @@
1
- from .cleaner import spanish_cleaners
2
- from .gruut_wrapper import Gruut
3
-
4
- def es2ipa(text):
5
- e = Gruut(language="es-es", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
6
- # text = spanish_cleaners(text)
7
- phonemes = e.phonemize(text, separator="")
8
- return phonemes
9
-
10
-
11
- if __name__ == '__main__':
12
- print(es2ipa('¿Y a quién echaría de menos, en el mundo si no fuese a vos?'))
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/gruut_wrapper.py DELETED
@@ -1,253 +0,0 @@
1
- import importlib
2
- from typing import List
3
-
4
- import gruut
5
- from gruut_ipa import IPA # pip install gruut_ipa
6
-
7
- from .base import BasePhonemizer
8
- from .punctuation import Punctuation
9
-
10
- # Table for str.translate to fix gruut/TTS phoneme mismatch
11
- GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
12
-
13
-
14
- class Gruut(BasePhonemizer):
15
- """Gruut wrapper for G2P
16
-
17
- Args:
18
- language (str):
19
- Valid language code for the used backend.
20
-
21
- punctuations (str):
22
- Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
23
-
24
- keep_puncs (bool):
25
- If true, keep the punctuations after phonemization. Defaults to True.
26
-
27
- use_espeak_phonemes (bool):
28
- If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
29
-
30
- keep_stress (bool):
31
- If true, keep the stress characters after phonemization. Defaults to False.
32
-
33
- Example:
34
-
35
- >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
36
- >>> phonemizer = Gruut('en-us')
37
- >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
38
- 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
39
- """
40
-
41
- def __init__(
42
- self,
43
- language: str,
44
- punctuations=Punctuation.default_puncs(),
45
- keep_puncs=True,
46
- use_espeak_phonemes=False,
47
- keep_stress=False,
48
- ):
49
- super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
50
- self.use_espeak_phonemes = use_espeak_phonemes
51
- self.keep_stress = keep_stress
52
-
53
- @staticmethod
54
- def name():
55
- return "gruut"
56
-
57
- def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
58
- """Convert input text to phonemes.
59
-
60
- Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
61
- that constitude a single sound.
62
-
63
- It doesn't affect 🐸TTS since it individually converts each character to token IDs.
64
-
65
- Examples::
66
- "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
67
-
68
- Args:
69
- text (str):
70
- Text to be converted to phonemes.
71
-
72
- tie (bool, optional) : When True use a '͡' character between
73
- consecutive characters of a single phoneme. Else separate phoneme
74
- with '_'. This option requires espeak>=1.49. Default to False.
75
- """
76
- ph_list = []
77
- for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
78
- for word in sentence:
79
- if word.is_break:
80
- # Use actual character for break phoneme (e.g., comma)
81
- if ph_list:
82
- # Join with previous word
83
- ph_list[-1].append(word.text)
84
- else:
85
- # First word is punctuation
86
- ph_list.append([word.text])
87
- elif word.phonemes:
88
- # Add phonemes for word
89
- word_phonemes = []
90
-
91
- for word_phoneme in word.phonemes:
92
- if not self.keep_stress:
93
- # Remove primary/secondary stress
94
- word_phoneme = IPA.without_stress(word_phoneme)
95
-
96
- word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
97
-
98
- if word_phoneme:
99
- # Flatten phonemes
100
- word_phonemes.extend(word_phoneme)
101
-
102
- if word_phonemes:
103
- ph_list.append(word_phonemes)
104
-
105
- ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
106
- ph = f"{separator} ".join(ph_words)
107
- return ph
108
-
109
- def _phonemize(self, text, separator):
110
- return self.phonemize_gruut(text, separator, tie=False)
111
-
112
- def is_supported_language(self, language):
113
- """Returns True if `language` is supported by the backend"""
114
- return gruut.is_language_supported(language)
115
-
116
- @staticmethod
117
- def supported_languages() -> List:
118
- """Get a dictionary of supported languages.
119
-
120
- Returns:
121
- List: List of language codes.
122
- """
123
- return list(gruut.get_supported_languages())
124
-
125
- def version(self):
126
- """Get the version of the used backend.
127
-
128
- Returns:
129
- str: Version of the used backend.
130
- """
131
- return gruut.__version__
132
-
133
- @classmethod
134
- def is_available(cls):
135
- """Return true if ESpeak is available else false"""
136
- return importlib.util.find_spec("gruut") is not None
137
-
138
-
139
- if __name__ == "__main__":
140
- from es_to_ipa import es2ipa
141
- import json
142
-
143
- e = Gruut(language="es-es", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
144
- symbols = [
145
- "_",
146
- ",",
147
- ".",
148
- "!",
149
- "?",
150
- "-",
151
- "~",
152
- "\u2026",
153
- "N",
154
- "Q",
155
- "a",
156
- "b",
157
- "d",
158
- "e",
159
- "f",
160
- "g",
161
- "h",
162
- "i",
163
- "j",
164
- "k",
165
- "l",
166
- "m",
167
- "n",
168
- "o",
169
- "p",
170
- "s",
171
- "t",
172
- "u",
173
- "v",
174
- "w",
175
- "x",
176
- "y",
177
- "z",
178
- "\u0251",
179
- "\u00e6",
180
- "\u0283",
181
- "\u0291",
182
- "\u00e7",
183
- "\u026f",
184
- "\u026a",
185
- "\u0254",
186
- "\u025b",
187
- "\u0279",
188
- "\u00f0",
189
- "\u0259",
190
- "\u026b",
191
- "\u0265",
192
- "\u0278",
193
- "\u028a",
194
- "\u027e",
195
- "\u0292",
196
- "\u03b8",
197
- "\u03b2",
198
- "\u014b",
199
- "\u0266",
200
- "\u207c",
201
- "\u02b0",
202
- "`",
203
- "^",
204
- "#",
205
- "*",
206
- "=",
207
- "\u02c8",
208
- "\u02cc",
209
- "\u2192",
210
- "\u2193",
211
- "\u2191",
212
- " ",
213
- ]
214
- with open('./text/es_phonemizer/spanish_text.txt', 'r') as f:
215
- lines = f.readlines()
216
-
217
-
218
- used_sym = []
219
- not_existed_sym = []
220
- phonemes = []
221
-
222
- for line in lines[:400]:
223
- text = line.split('|')[-1].strip()
224
- ipa = es2ipa(text)
225
- phonemes.append(ipa + '\n')
226
- for s in ipa:
227
- if s not in symbols:
228
- if s not in not_existed_sym:
229
- print(f'not_existed char: {s}')
230
- not_existed_sym.append(s)
231
- else:
232
- if s not in used_sym:
233
- # print(f'used char: {s}')
234
- used_sym.append(s)
235
-
236
- print(used_sym)
237
- print(not_existed_sym)
238
-
239
-
240
- with open('./text/es_phonemizer/es_symbols.txt', 'w') as g:
241
- g.writelines(symbols + not_existed_sym)
242
-
243
- with open('./text/es_phonemizer/example_ipa.txt', 'w') as g:
244
- g.writelines(phonemes)
245
-
246
- data = {'symbols': symbols + not_existed_sym}
247
- with open('./text/es_phonemizer/es_symbols_v2.json', 'w') as f:
248
- json.dump(data, f, indent=4)
249
-
250
-
251
-
252
-
253
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/punctuation.py DELETED
@@ -1,174 +0,0 @@
1
- import collections
2
- import re
3
- from enum import Enum
4
-
5
- import six
6
-
7
- _DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
8
-
9
- _PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
10
-
11
-
12
- class PuncPosition(Enum):
13
- """Enum for the punctuations positions"""
14
-
15
- BEGIN = 0
16
- END = 1
17
- MIDDLE = 2
18
- ALONE = 3
19
-
20
-
21
- class Punctuation:
22
- """Handle punctuations in text.
23
-
24
- Just strip punctuations from text or strip and restore them later.
25
-
26
- Args:
27
- puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
28
-
29
- Example:
30
- >>> punc = Punctuation()
31
- >>> punc.strip("This is. example !")
32
- 'This is example'
33
-
34
- >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
35
- >>> ' '.join(text_striped)
36
- 'This is example'
37
-
38
- >>> text_restored = punc.restore(text_striped, punc_map)
39
- >>> text_restored[0]
40
- 'This is. example !'
41
- """
42
-
43
- def __init__(self, puncs: str = _DEF_PUNCS):
44
- self.puncs = puncs
45
-
46
- @staticmethod
47
- def default_puncs():
48
- """Return default set of punctuations."""
49
- return _DEF_PUNCS
50
-
51
- @property
52
- def puncs(self):
53
- return self._puncs
54
-
55
- @puncs.setter
56
- def puncs(self, value):
57
- if not isinstance(value, six.string_types):
58
- raise ValueError("[!] Punctuations must be of type str.")
59
- self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
60
- self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
61
-
62
- def strip(self, text):
63
- """Remove all the punctuations by replacing with `space`.
64
-
65
- Args:
66
- text (str): The text to be processed.
67
-
68
- Example::
69
-
70
- "This is. example !" -> "This is example "
71
- """
72
- return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
73
-
74
- def strip_to_restore(self, text):
75
- """Remove punctuations from text to restore them later.
76
-
77
- Args:
78
- text (str): The text to be processed.
79
-
80
- Examples ::
81
-
82
- "This is. example !" -> [["This is", "example"], [".", "!"]]
83
-
84
- """
85
- text, puncs = self._strip_to_restore(text)
86
- return text, puncs
87
-
88
- def _strip_to_restore(self, text):
89
- """Auxiliary method for Punctuation.preserve()"""
90
- matches = list(re.finditer(self.puncs_regular_exp, text))
91
- if not matches:
92
- return [text], []
93
- # the text is only punctuations
94
- if len(matches) == 1 and matches[0].group() == text:
95
- return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
96
- # build a punctuation map to be used later to restore punctuations
97
- puncs = []
98
- for match in matches:
99
- position = PuncPosition.MIDDLE
100
- if match == matches[0] and text.startswith(match.group()):
101
- position = PuncPosition.BEGIN
102
- elif match == matches[-1] and text.endswith(match.group()):
103
- position = PuncPosition.END
104
- puncs.append(_PUNC_IDX(match.group(), position))
105
- # convert str text to a List[str], each item is separated by a punctuation
106
- splitted_text = []
107
- for idx, punc in enumerate(puncs):
108
- split = text.split(punc.punc)
109
- prefix, suffix = split[0], punc.punc.join(split[1:])
110
- splitted_text.append(prefix)
111
- # if the text does not end with a punctuation, add it to the last item
112
- if idx == len(puncs) - 1 and len(suffix) > 0:
113
- splitted_text.append(suffix)
114
- text = suffix
115
- while splitted_text[0] == '':
116
- splitted_text = splitted_text[1:]
117
- return splitted_text, puncs
118
-
119
- @classmethod
120
- def restore(cls, text, puncs):
121
- """Restore punctuation in a text.
122
-
123
- Args:
124
- text (str): The text to be processed.
125
- puncs (List[str]): The list of punctuations map to be used for restoring.
126
-
127
- Examples ::
128
-
129
- ['This is', 'example'], ['.', '!'] -> "This is. example!"
130
-
131
- """
132
- return cls._restore(text, puncs, 0)
133
-
134
- @classmethod
135
- def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
136
- """Auxiliary method for Punctuation.restore()"""
137
- if not puncs:
138
- return text
139
-
140
- # nothing have been phonemized, returns the puncs alone
141
- if not text:
142
- return ["".join(m.punc for m in puncs)]
143
-
144
- current = puncs[0]
145
-
146
- if current.position == PuncPosition.BEGIN:
147
- return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
148
-
149
- if current.position == PuncPosition.END:
150
- return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
151
-
152
- if current.position == PuncPosition.ALONE:
153
- return [current.mark] + cls._restore(text, puncs[1:], num + 1)
154
-
155
- # POSITION == MIDDLE
156
- if len(text) == 1: # pragma: nocover
157
- # a corner case where the final part of an intermediate
158
- # mark (I) has not been phonemized
159
- return cls._restore([text[0] + current.punc], puncs[1:], num)
160
-
161
- return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
162
-
163
-
164
- # if __name__ == "__main__":
165
- # punc = Punctuation()
166
- # text = "This is. This is, example!"
167
-
168
- # print(punc.strip(text))
169
-
170
- # split_text, puncs = punc.strip_to_restore(text)
171
- # print(split_text, " ---- ", puncs)
172
-
173
- # restored_text = punc.restore(split_text, puncs)
174
- # print(restored_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/es_phonemizer/spanish_symbols.txt DELETED
@@ -1 +0,0 @@
1
- dˌaβˈiðkopeɾfjl unθsbmtʃwɛxɪŋʊɣɡrɲʝʎː
 
 
melo/text/es_phonemizer/test.ipynb DELETED
@@ -1,124 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "ename": "ImportError",
10
- "evalue": "attempted relative import with no known parent package",
11
- "output_type": "error",
12
- "traceback": [
13
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
14
- "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
15
- "\u001b[1;32m/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb Cell 1\u001b[0m line \u001b[0;36m5\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\u001b[39m,\u001b[39m \u001b[39msys\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m'\u001b[39m\u001b[39m/home/xumin/workspace/MyShell-VC-Training/text/es_phonemizer/\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mes_to_ipa\u001b[39;00m \u001b[39mimport\u001b[39;00m es2ipa\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msplit_sentences_en\u001b[39m(text, min_len\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m):\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# 将文本中的换行符、空格和制表符替换为空格\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m text \u001b[39m=\u001b[39m re\u001b[39m.\u001b[39msub(\u001b[39m'\u001b[39m\u001b[39m[\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m ]+\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m, text)\n",
16
- "File \u001b[0;32m/data/workspace/Bert-VITS2/text/es_phonemizer/es_to_ipa.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mcleaner\u001b[39;00m \u001b[39mimport\u001b[39;00m spanish_cleaners\n\u001b[1;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mgruut_wrapper\u001b[39;00m \u001b[39mimport\u001b[39;00m Gruut\n\u001b[1;32m 4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mes2ipa\u001b[39m(text):\n",
17
- "\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
18
- ]
19
- }
20
- ],
21
- "source": [
22
- "import re\n",
23
- "import os\n",
24
- "import os, sys\n",
25
- "sys.path.append('/home/xumin/workspace/MyShell-VC-Training/text/es_phonemizer/')\n",
26
- "from es_to_ipa import es2ipa\n",
27
- "\n",
28
- "\n",
29
- "\n",
30
- "def split_sentences_en(text, min_len=10):\n",
31
- " # 将文本中的换行符、空格和制表符替换为空格\n",
32
- " text = re.sub('[\\n\\t ]+', ' ', text)\n",
33
- " # 在标点符号后添加一个空格\n",
34
- " text = re.sub('([¿—¡])', r'\\1 $#!', text)\n",
35
- " # 分隔句子并去除前后空格\n",
36
- " \n",
37
- " sentences = [s.strip() for s in text.split(' $#!')]\n",
38
- " if len(sentences[-1]) == 0: del sentences[-1]\n",
39
- "\n",
40
- " new_sentences = []\n",
41
- " new_sent = []\n",
42
- " for ind, sent in enumerate(sentences):\n",
43
- " if sent in ['¿', '—', '¡']:\n",
44
- " new_sent.append(sent)\n",
45
- " else:\n",
46
- " new_sent.append(es2ipa(sent))\n",
47
- " \n",
48
- " \n",
49
- " new_sentences = ''.join(new_sent)\n",
50
- "\n",
51
- " return new_sentences"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": 3,
57
- "metadata": {},
58
- "outputs": [
59
- {
60
- "data": {
61
- "text/plain": [
62
- "'—¿aβˈeis estˈaðo kasˈaða alɣˈuna bˈeθ?'"
63
- ]
64
- },
65
- "execution_count": 3,
66
- "metadata": {},
67
- "output_type": "execute_result"
68
- }
69
- ],
70
- "source": [
71
- "split_sentences_en('—¿Habéis estado casada alguna vez?')"
72
- ]
73
- },
74
- {
75
- "cell_type": "code",
76
- "execution_count": 4,
77
- "metadata": {},
78
- "outputs": [
79
- {
80
- "data": {
81
- "text/plain": [
82
- "'aβˈeis estˈaðo kasˈaða alɣˈuna bˈeθ?'"
83
- ]
84
- },
85
- "execution_count": 4,
86
- "metadata": {},
87
- "output_type": "execute_result"
88
- }
89
- ],
90
- "source": [
91
- "es2ipa('—¿Habéis estado casada alguna vez?')"
92
- ]
93
- },
94
- {
95
- "cell_type": "code",
96
- "execution_count": null,
97
- "metadata": {},
98
- "outputs": [],
99
- "source": []
100
- }
101
- ],
102
- "metadata": {
103
- "kernelspec": {
104
- "display_name": "base",
105
- "language": "python",
106
- "name": "python3"
107
- },
108
- "language_info": {
109
- "codemirror_mode": {
110
- "name": "ipython",
111
- "version": 3
112
- },
113
- "file_extension": ".py",
114
- "mimetype": "text/x-python",
115
- "name": "python",
116
- "nbconvert_exporter": "python",
117
- "pygments_lexer": "ipython3",
118
- "version": "3.8.18"
119
- },
120
- "orig_nbformat": 4
121
- },
122
- "nbformat": 4,
123
- "nbformat_minor": 2
124
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/__init__.py DELETED
File without changes
melo/text/fr_phonemizer/base.py DELETED
@@ -1,140 +0,0 @@
1
- import abc
2
- from typing import List, Tuple
3
-
4
- from .punctuation import Punctuation
5
-
6
-
7
- class BasePhonemizer(abc.ABC):
8
- """Base phonemizer class
9
-
10
- Phonemization follows the following steps:
11
- 1. Preprocessing:
12
- - remove empty lines
13
- - remove punctuation
14
- - keep track of punctuation marks
15
-
16
- 2. Phonemization:
17
- - convert text to phonemes
18
-
19
- 3. Postprocessing:
20
- - join phonemes
21
- - restore punctuation marks
22
-
23
- Args:
24
- language (str):
25
- Language used by the phonemizer.
26
-
27
- punctuations (List[str]):
28
- List of punctuation marks to be preserved.
29
-
30
- keep_puncs (bool):
31
- Whether to preserve punctuation marks or not.
32
- """
33
-
34
- def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
35
- # ensure the backend is installed on the system
36
- if not self.is_available():
37
- raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
38
-
39
- # ensure the backend support the requested language
40
- self._language = self._init_language(language)
41
-
42
- # setup punctuation processing
43
- self._keep_puncs = keep_puncs
44
- self._punctuator = Punctuation(punctuations)
45
-
46
- def _init_language(self, language):
47
- """Language initialization
48
-
49
- This method may be overloaded in child classes (see Segments backend)
50
-
51
- """
52
- if not self.is_supported_language(language):
53
- raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
54
- return language
55
-
56
- @property
57
- def language(self):
58
- """The language code configured to be used for phonemization"""
59
- return self._language
60
-
61
- @staticmethod
62
- @abc.abstractmethod
63
- def name():
64
- """The name of the backend"""
65
- ...
66
-
67
- @classmethod
68
- @abc.abstractmethod
69
- def is_available(cls):
70
- """Returns True if the backend is installed, False otherwise"""
71
- ...
72
-
73
- @classmethod
74
- @abc.abstractmethod
75
- def version(cls):
76
- """Return the backend version as a tuple (major, minor, patch)"""
77
- ...
78
-
79
- @staticmethod
80
- @abc.abstractmethod
81
- def supported_languages():
82
- """Return a dict of language codes -> name supported by the backend"""
83
- ...
84
-
85
- def is_supported_language(self, language):
86
- """Returns True if `language` is supported by the backend"""
87
- return language in self.supported_languages()
88
-
89
- @abc.abstractmethod
90
- def _phonemize(self, text, separator):
91
- """The main phonemization method"""
92
-
93
- def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
94
- """Preprocess the text before phonemization
95
-
96
- 1. remove spaces
97
- 2. remove punctuation
98
-
99
- Override this if you need a different behaviour
100
- """
101
- text = text.strip()
102
- if self._keep_puncs:
103
- # a tuple (text, punctuation marks)
104
- return self._punctuator.strip_to_restore(text)
105
- return [self._punctuator.strip(text)], []
106
-
107
- def _phonemize_postprocess(self, phonemized, punctuations) -> str:
108
- """Postprocess the raw phonemized output
109
-
110
- Override this if you need a different behaviour
111
- """
112
- if self._keep_puncs:
113
- return self._punctuator.restore(phonemized, punctuations)[0]
114
- return phonemized[0]
115
-
116
- def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
117
- """Returns the `text` phonemized for the given language
118
-
119
- Args:
120
- text (str):
121
- Text to be phonemized.
122
-
123
- separator (str):
124
- string separator used between phonemes. Default to '_'.
125
-
126
- Returns:
127
- (str): Phonemized text
128
- """
129
- text, punctuations = self._phonemize_preprocess(text)
130
- phonemized = []
131
- for t in text:
132
- p = self._phonemize(t, separator)
133
- phonemized.append(p)
134
- phonemized = self._phonemize_postprocess(phonemized, punctuations)
135
- return phonemized
136
-
137
- def print_logs(self, level: int = 0):
138
- indent = "\t" * level
139
- print(f"{indent}| > phoneme language: {self.language}")
140
- print(f"{indent}| > phoneme backend: {self.name()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/cleaner.py DELETED
@@ -1,122 +0,0 @@
1
- """Set of default text cleaners"""
2
- # TODO: pick the cleaner for languages dynamically
3
-
4
- import re
5
- from .french_abbreviations import abbreviations_fr
6
-
7
- # Regular expression matching whitespace:
8
- _whitespace_re = re.compile(r"\s+")
9
-
10
-
11
- rep_map = {
12
- ":": ",",
13
- ";": ",",
14
- ",": ",",
15
- "。": ".",
16
- "!": "!",
17
- "?": "?",
18
- "\n": ".",
19
- "·": ",",
20
- "、": ",",
21
- "...": ".",
22
- "…": ".",
23
- "$": ".",
24
- "“": "",
25
- "”": "",
26
- "‘": "",
27
- "’": "",
28
- "(": "",
29
- ")": "",
30
- "(": "",
31
- ")": "",
32
- "《": "",
33
- "》": "",
34
- "【": "",
35
- "】": "",
36
- "[": "",
37
- "]": "",
38
- "—": "",
39
- "~": "-",
40
- "~": "-",
41
- "「": "",
42
- "」": "",
43
- "¿" : "",
44
- "¡" : ""
45
- }
46
-
47
-
48
- def replace_punctuation(text):
49
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
50
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
51
- return replaced_text
52
-
53
- def expand_abbreviations(text, lang="fr"):
54
- if lang == "fr":
55
- _abbreviations = abbreviations_fr
56
- for regex, replacement in _abbreviations:
57
- text = re.sub(regex, replacement, text)
58
- return text
59
-
60
-
61
- def lowercase(text):
62
- return text.lower()
63
-
64
-
65
- def collapse_whitespace(text):
66
- return re.sub(_whitespace_re, " ", text).strip()
67
-
68
- def remove_punctuation_at_begin(text):
69
- return re.sub(r'^[,.!?]+', '', text)
70
-
71
- def remove_aux_symbols(text):
72
- text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
73
- return text
74
-
75
-
76
- def replace_symbols(text, lang="en"):
77
- """Replace symbols based on the lenguage tag.
78
-
79
- Args:
80
- text:
81
- Input text.
82
- lang:
83
- Lenguage identifier. ex: "en", "fr", "pt", "ca".
84
-
85
- Returns:
86
- The modified text
87
- example:
88
- input args:
89
- text: "si l'avi cau, diguem-ho"
90
- lang: "ca"
91
- Output:
92
- text: "si lavi cau, diguemho"
93
- """
94
- text = text.replace(";", ",")
95
- text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
96
- text = text.replace(":", ",")
97
- if lang == "en":
98
- text = text.replace("&", " and ")
99
- elif lang == "fr":
100
- text = text.replace("&", " et ")
101
- elif lang == "pt":
102
- text = text.replace("&", " e ")
103
- elif lang == "ca":
104
- text = text.replace("&", " i ")
105
- text = text.replace("'", "")
106
- elif lang== "es":
107
- text=text.replace("&","y")
108
- text = text.replace("'", "")
109
- return text
110
-
111
- def french_cleaners(text):
112
- """Pipeline for French text. There is no need to expand numbers, phonemizer already does that"""
113
- text = expand_abbreviations(text, lang="fr")
114
- # text = lowercase(text) # as we use the cased bert
115
- text = replace_punctuation(text)
116
- text = replace_symbols(text, lang="fr")
117
- text = remove_aux_symbols(text)
118
- text = remove_punctuation_at_begin(text)
119
- text = collapse_whitespace(text)
120
- text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
121
- return text
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/en_symbols.json DELETED
@@ -1,78 +0,0 @@
1
- {"symbols": [
2
- "_",
3
- ",",
4
- ".",
5
- "!",
6
- "?",
7
- "-",
8
- "~",
9
- "\u2026",
10
- "N",
11
- "Q",
12
- "a",
13
- "b",
14
- "d",
15
- "e",
16
- "f",
17
- "g",
18
- "h",
19
- "i",
20
- "j",
21
- "k",
22
- "l",
23
- "m",
24
- "n",
25
- "o",
26
- "p",
27
- "s",
28
- "t",
29
- "u",
30
- "v",
31
- "w",
32
- "x",
33
- "y",
34
- "z",
35
- "\u0251",
36
- "\u00e6",
37
- "\u0283",
38
- "\u0291",
39
- "\u00e7",
40
- "\u026f",
41
- "\u026a",
42
- "\u0254",
43
- "\u025b",
44
- "\u0279",
45
- "\u00f0",
46
- "\u0259",
47
- "\u026b",
48
- "\u0265",
49
- "\u0278",
50
- "\u028a",
51
- "\u027e",
52
- "\u0292",
53
- "\u03b8",
54
- "\u03b2",
55
- "\u014b",
56
- "\u0266",
57
- "\u207c",
58
- "\u02b0",
59
- "`",
60
- "^",
61
- "#",
62
- "*",
63
- "=",
64
- "\u02c8",
65
- "\u02cc",
66
- "\u2192",
67
- "\u2193",
68
- "\u2191",
69
- " ",
70
- "ɣ",
71
- "ɡ",
72
- "r",
73
- "ɲ",
74
- "ʝ",
75
- "ʎ",
76
- "ː"
77
- ]
78
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/fr_symbols.json DELETED
@@ -1,89 +0,0 @@
1
- {
2
- "symbols": [
3
- "_",
4
- ",",
5
- ".",
6
- "!",
7
- "?",
8
- "-",
9
- "~",
10
- "\u2026",
11
- "N",
12
- "Q",
13
- "a",
14
- "b",
15
- "d",
16
- "e",
17
- "f",
18
- "g",
19
- "h",
20
- "i",
21
- "j",
22
- "k",
23
- "l",
24
- "m",
25
- "n",
26
- "o",
27
- "p",
28
- "s",
29
- "t",
30
- "u",
31
- "v",
32
- "w",
33
- "x",
34
- "y",
35
- "z",
36
- "\u0251",
37
- "\u00e6",
38
- "\u0283",
39
- "\u0291",
40
- "\u00e7",
41
- "\u026f",
42
- "\u026a",
43
- "\u0254",
44
- "\u025b",
45
- "\u0279",
46
- "\u00f0",
47
- "\u0259",
48
- "\u026b",
49
- "\u0265",
50
- "\u0278",
51
- "\u028a",
52
- "\u027e",
53
- "\u0292",
54
- "\u03b8",
55
- "\u03b2",
56
- "\u014b",
57
- "\u0266",
58
- "\u207c",
59
- "\u02b0",
60
- "`",
61
- "^",
62
- "#",
63
- "*",
64
- "=",
65
- "\u02c8",
66
- "\u02cc",
67
- "\u2192",
68
- "\u2193",
69
- "\u2191",
70
- " ",
71
- "\u0263",
72
- "\u0261",
73
- "r",
74
- "\u0272",
75
- "\u029d",
76
- "\u028e",
77
- "\u02d0",
78
-
79
- "\u0303",
80
- "\u0153",
81
- "\u00f8",
82
- "\u0281",
83
- "\u0252",
84
- "\u028c",
85
- "\u2014",
86
- "\u025c",
87
- "\u0250"
88
- ]
89
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/fr_to_ipa.py DELETED
@@ -1,30 +0,0 @@
1
- from .cleaner import french_cleaners
2
- from .gruut_wrapper import Gruut
3
-
4
-
5
- def remove_consecutive_t(input_str):
6
- result = []
7
- count = 0
8
-
9
- for char in input_str:
10
- if char == 't':
11
- count += 1
12
- else:
13
- if count < 3:
14
- result.extend(['t'] * count)
15
- count = 0
16
- result.append(char)
17
-
18
- if count < 3:
19
- result.extend(['t'] * count)
20
-
21
- return ''.join(result)
22
-
23
- def fr2ipa(text):
24
- e = Gruut(language="fr-fr", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
25
- # text = french_cleaners(text)
26
- phonemes = e.phonemize(text, separator="")
27
- # print(phonemes)
28
- phonemes = remove_consecutive_t(phonemes)
29
- # print(phonemes)
30
- return phonemes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/french_abbreviations.py DELETED
@@ -1,48 +0,0 @@
1
- import re
2
-
3
- # List of (regular expression, replacement) pairs for abbreviations in french:
4
- abbreviations_fr = [
5
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
6
- for x in [
7
- ("M", "monsieur"),
8
- ("Mlle", "mademoiselle"),
9
- ("Mlles", "mesdemoiselles"),
10
- ("Mme", "Madame"),
11
- ("Mmes", "Mesdames"),
12
- ("N.B", "nota bene"),
13
- ("M", "monsieur"),
14
- ("p.c.q", "parce que"),
15
- ("Pr", "professeur"),
16
- ("qqch", "quelque chose"),
17
- ("rdv", "rendez-vous"),
18
- ("max", "maximum"),
19
- ("min", "minimum"),
20
- ("no", "numéro"),
21
- ("adr", "adresse"),
22
- ("dr", "docteur"),
23
- ("st", "saint"),
24
- ("co", "companie"),
25
- ("jr", "junior"),
26
- ("sgt", "sergent"),
27
- ("capt", "capitain"),
28
- ("col", "colonel"),
29
- ("av", "avenue"),
30
- ("av. J.-C", "avant Jésus-Christ"),
31
- ("apr. J.-C", "après Jésus-Christ"),
32
- ("art", "article"),
33
- ("boul", "boulevard"),
34
- ("c.-à-d", "c’est-à-dire"),
35
- ("etc", "et cetera"),
36
- ("ex", "exemple"),
37
- ("excl", "exclusivement"),
38
- ("boul", "boulevard"),
39
- ]
40
- ] + [
41
- (re.compile("\\b%s" % x[0]), x[1])
42
- for x in [
43
- ("Mlle", "mademoiselle"),
44
- ("Mlles", "mesdemoiselles"),
45
- ("Mme", "Madame"),
46
- ("Mmes", "Mesdames"),
47
- ]
48
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/french_symbols.txt DELETED
@@ -1 +0,0 @@
1
- _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɣɡrɲʝʎː̃œøʁɒʌ—ɜɐ
 
 
melo/text/fr_phonemizer/gruut_wrapper.py DELETED
@@ -1,258 +0,0 @@
1
- import importlib
2
- from typing import List
3
-
4
- import gruut
5
- from gruut_ipa import IPA # pip install gruut_ipa
6
-
7
- from .base import BasePhonemizer
8
- from .punctuation import Punctuation
9
-
10
- # Table for str.translate to fix gruut/TTS phoneme mismatch
11
- GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
12
-
13
-
14
- class Gruut(BasePhonemizer):
15
- """Gruut wrapper for G2P
16
-
17
- Args:
18
- language (str):
19
- Valid language code for the used backend.
20
-
21
- punctuations (str):
22
- Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
23
-
24
- keep_puncs (bool):
25
- If true, keep the punctuations after phonemization. Defaults to True.
26
-
27
- use_espeak_phonemes (bool):
28
- If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
29
-
30
- keep_stress (bool):
31
- If true, keep the stress characters after phonemization. Defaults to False.
32
-
33
- Example:
34
-
35
- >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
36
- >>> phonemizer = Gruut('en-us')
37
- >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
38
- 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
39
- """
40
-
41
- def __init__(
42
- self,
43
- language: str,
44
- punctuations=Punctuation.default_puncs(),
45
- keep_puncs=True,
46
- use_espeak_phonemes=False,
47
- keep_stress=False,
48
- ):
49
- super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
50
- self.use_espeak_phonemes = use_espeak_phonemes
51
- self.keep_stress = keep_stress
52
-
53
- @staticmethod
54
- def name():
55
- return "gruut"
56
-
57
- def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
58
- """Convert input text to phonemes.
59
-
60
- Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
61
- that constitude a single sound.
62
-
63
- It doesn't affect 🐸TTS since it individually converts each character to token IDs.
64
-
65
- Examples::
66
- "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
67
-
68
- Args:
69
- text (str):
70
- Text to be converted to phonemes.
71
-
72
- tie (bool, optional) : When True use a '͡' character between
73
- consecutive characters of a single phoneme. Else separate phoneme
74
- with '_'. This option requires espeak>=1.49. Default to False.
75
- """
76
- ph_list = []
77
- for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
78
- for word in sentence:
79
- if word.is_break:
80
- # Use actual character for break phoneme (e.g., comma)
81
- if ph_list:
82
- # Join with previous word
83
- ph_list[-1].append(word.text)
84
- else:
85
- # First word is punctuation
86
- ph_list.append([word.text])
87
- elif word.phonemes:
88
- # Add phonemes for word
89
- word_phonemes = []
90
-
91
- for word_phoneme in word.phonemes:
92
- if not self.keep_stress:
93
- # Remove primary/secondary stress
94
- word_phoneme = IPA.without_stress(word_phoneme)
95
-
96
- word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
97
-
98
- if word_phoneme:
99
- # Flatten phonemes
100
- word_phonemes.extend(word_phoneme)
101
-
102
- if word_phonemes:
103
- ph_list.append(word_phonemes)
104
-
105
- ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
106
- ph = f"{separator} ".join(ph_words)
107
- return ph
108
-
109
- def _phonemize(self, text, separator):
110
- return self.phonemize_gruut(text, separator, tie=False)
111
-
112
- def is_supported_language(self, language):
113
- """Returns True if `language` is supported by the backend"""
114
- return gruut.is_language_supported(language)
115
-
116
- @staticmethod
117
- def supported_languages() -> List:
118
- """Get a dictionary of supported languages.
119
-
120
- Returns:
121
- List: List of language codes.
122
- """
123
- return list(gruut.get_supported_languages())
124
-
125
- def version(self):
126
- """Get the version of the used backend.
127
-
128
- Returns:
129
- str: Version of the used backend.
130
- """
131
- return gruut.__version__
132
-
133
- @classmethod
134
- def is_available(cls):
135
- """Return true if ESpeak is available else false"""
136
- return importlib.util.find_spec("gruut") is not None
137
-
138
-
139
- if __name__ == "__main__":
140
- from cleaner import french_cleaners
141
- import json
142
-
143
- e = Gruut(language="fr-fr", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
144
- symbols = [ # en + sp
145
- "_",
146
- ",",
147
- ".",
148
- "!",
149
- "?",
150
- "-",
151
- "~",
152
- "\u2026",
153
- "N",
154
- "Q",
155
- "a",
156
- "b",
157
- "d",
158
- "e",
159
- "f",
160
- "g",
161
- "h",
162
- "i",
163
- "j",
164
- "k",
165
- "l",
166
- "m",
167
- "n",
168
- "o",
169
- "p",
170
- "s",
171
- "t",
172
- "u",
173
- "v",
174
- "w",
175
- "x",
176
- "y",
177
- "z",
178
- "\u0251",
179
- "\u00e6",
180
- "\u0283",
181
- "\u0291",
182
- "\u00e7",
183
- "\u026f",
184
- "\u026a",
185
- "\u0254",
186
- "\u025b",
187
- "\u0279",
188
- "\u00f0",
189
- "\u0259",
190
- "\u026b",
191
- "\u0265",
192
- "\u0278",
193
- "\u028a",
194
- "\u027e",
195
- "\u0292",
196
- "\u03b8",
197
- "\u03b2",
198
- "\u014b",
199
- "\u0266",
200
- "\u207c",
201
- "\u02b0",
202
- "`",
203
- "^",
204
- "#",
205
- "*",
206
- "=",
207
- "\u02c8",
208
- "\u02cc",
209
- "\u2192",
210
- "\u2193",
211
- "\u2191",
212
- " ",
213
- "ɣ",
214
- "ɡ",
215
- "r",
216
- "ɲ",
217
- "ʝ",
218
- "ʎ",
219
- "ː"
220
- ]
221
- with open('/home/xumin/workspace/VITS-Training-Multiling/230715_fr/metadata.txt', 'r') as f:
222
- lines = f.readlines()
223
-
224
-
225
- used_sym = []
226
- not_existed_sym = []
227
- phonemes = []
228
-
229
- for line in lines:
230
- text = line.split('|')[-1].strip()
231
- text = french_cleaners(text)
232
- ipa = e.phonemize(text, separator="")
233
- phonemes.append(ipa)
234
- for s in ipa:
235
- if s not in symbols:
236
- if s not in not_existed_sym:
237
- print(f'not_existed char: {s}')
238
- not_existed_sym.append(s)
239
- else:
240
- if s not in used_sym:
241
- # print(f'used char: {s}')
242
- used_sym.append(s)
243
-
244
- print(used_sym)
245
- print(not_existed_sym)
246
-
247
-
248
- with open('./text/fr_phonemizer/french_symbols.txt', 'w') as g:
249
- g.writelines(symbols + not_existed_sym)
250
-
251
- with open('./text/fr_phonemizer/example_ipa.txt', 'w') as g:
252
- g.writelines(phonemes)
253
-
254
- data = {'symbols': symbols + not_existed_sym}
255
-
256
- with open('./text/fr_phonemizer/fr_symbols.json', 'w') as f:
257
- json.dump(data, f, indent=4)
258
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/fr_phonemizer/punctuation.py DELETED
@@ -1,172 +0,0 @@
1
- import collections
2
- import re
3
- from enum import Enum
4
-
5
- import six
6
-
7
- _DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
8
-
9
- _PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
10
-
11
-
12
- class PuncPosition(Enum):
13
- """Enum for the punctuations positions"""
14
-
15
- BEGIN = 0
16
- END = 1
17
- MIDDLE = 2
18
- ALONE = 3
19
-
20
-
21
- class Punctuation:
22
- """Handle punctuations in text.
23
-
24
- Just strip punctuations from text or strip and restore them later.
25
-
26
- Args:
27
- puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
28
-
29
- Example:
30
- >>> punc = Punctuation()
31
- >>> punc.strip("This is. example !")
32
- 'This is example'
33
-
34
- >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
35
- >>> ' '.join(text_striped)
36
- 'This is example'
37
-
38
- >>> text_restored = punc.restore(text_striped, punc_map)
39
- >>> text_restored[0]
40
- 'This is. example !'
41
- """
42
-
43
- def __init__(self, puncs: str = _DEF_PUNCS):
44
- self.puncs = puncs
45
-
46
- @staticmethod
47
- def default_puncs():
48
- """Return default set of punctuations."""
49
- return _DEF_PUNCS
50
-
51
- @property
52
- def puncs(self):
53
- return self._puncs
54
-
55
- @puncs.setter
56
- def puncs(self, value):
57
- if not isinstance(value, six.string_types):
58
- raise ValueError("[!] Punctuations must be of type str.")
59
- self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
60
- self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
61
-
62
- def strip(self, text):
63
- """Remove all the punctuations by replacing with `space`.
64
-
65
- Args:
66
- text (str): The text to be processed.
67
-
68
- Example::
69
-
70
- "This is. example !" -> "This is example "
71
- """
72
- return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
73
-
74
- def strip_to_restore(self, text):
75
- """Remove punctuations from text to restore them later.
76
-
77
- Args:
78
- text (str): The text to be processed.
79
-
80
- Examples ::
81
-
82
- "This is. example !" -> [["This is", "example"], [".", "!"]]
83
-
84
- """
85
- text, puncs = self._strip_to_restore(text)
86
- return text, puncs
87
-
88
- def _strip_to_restore(self, text):
89
- """Auxiliary method for Punctuation.preserve()"""
90
- matches = list(re.finditer(self.puncs_regular_exp, text))
91
- if not matches:
92
- return [text], []
93
- # the text is only punctuations
94
- if len(matches) == 1 and matches[0].group() == text:
95
- return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
96
- # build a punctuation map to be used later to restore punctuations
97
- puncs = []
98
- for match in matches:
99
- position = PuncPosition.MIDDLE
100
- if match == matches[0] and text.startswith(match.group()):
101
- position = PuncPosition.BEGIN
102
- elif match == matches[-1] and text.endswith(match.group()):
103
- position = PuncPosition.END
104
- puncs.append(_PUNC_IDX(match.group(), position))
105
- # convert str text to a List[str], each item is separated by a punctuation
106
- splitted_text = []
107
- for idx, punc in enumerate(puncs):
108
- split = text.split(punc.punc)
109
- prefix, suffix = split[0], punc.punc.join(split[1:])
110
- splitted_text.append(prefix)
111
- # if the text does not end with a punctuation, add it to the last item
112
- if idx == len(puncs) - 1 and len(suffix) > 0:
113
- splitted_text.append(suffix)
114
- text = suffix
115
- return splitted_text, puncs
116
-
117
- @classmethod
118
- def restore(cls, text, puncs):
119
- """Restore punctuation in a text.
120
-
121
- Args:
122
- text (str): The text to be processed.
123
- puncs (List[str]): The list of punctuations map to be used for restoring.
124
-
125
- Examples ::
126
-
127
- ['This is', 'example'], ['.', '!'] -> "This is. example!"
128
-
129
- """
130
- return cls._restore(text, puncs, 0)
131
-
132
- @classmethod
133
- def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
134
- """Auxiliary method for Punctuation.restore()"""
135
- if not puncs:
136
- return text
137
-
138
- # nothing have been phonemized, returns the puncs alone
139
- if not text:
140
- return ["".join(m.punc for m in puncs)]
141
-
142
- current = puncs[0]
143
-
144
- if current.position == PuncPosition.BEGIN:
145
- return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
146
-
147
- if current.position == PuncPosition.END:
148
- return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
149
-
150
- if current.position == PuncPosition.ALONE:
151
- return [current.mark] + cls._restore(text, puncs[1:], num + 1)
152
-
153
- # POSITION == MIDDLE
154
- if len(text) == 1: # pragma: nocover
155
- # a corner case where the final part of an intermediate
156
- # mark (I) has not been phonemized
157
- return cls._restore([text[0] + current.punc], puncs[1:], num)
158
-
159
- return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
160
-
161
-
162
- # if __name__ == "__main__":
163
- # punc = Punctuation()
164
- # text = "This is. This is, example!"
165
-
166
- # print(punc.strip(text))
167
-
168
- # split_text, puncs = punc.strip_to_restore(text)
169
- # print(split_text, " ---- ", puncs)
170
-
171
- # restored_text = punc.restore(split_text, puncs)
172
- # print(restored_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/french.py DELETED
@@ -1,94 +0,0 @@
1
- import pickle
2
- import os
3
- import re
4
-
5
- from . import symbols
6
- from .fr_phonemizer import cleaner as fr_cleaner
7
- from .fr_phonemizer import fr_to_ipa
8
- from transformers import AutoTokenizer
9
-
10
-
11
- def distribute_phone(n_phone, n_word):
12
- phones_per_word = [0] * n_word
13
- for task in range(n_phone):
14
- min_tasks = min(phones_per_word)
15
- min_index = phones_per_word.index(min_tasks)
16
- phones_per_word[min_index] += 1
17
- return phones_per_word
18
-
19
- def text_normalize(text):
20
- text = fr_cleaner.french_cleaners(text)
21
- return text
22
-
23
- model_id = 'dbmdz/bert-base-french-europeana-cased'
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
-
26
- def g2p(text, pad_start_end=True, tokenized=None):
27
- if tokenized is None:
28
- tokenized = tokenizer.tokenize(text)
29
- # import pdb; pdb.set_trace()
30
- phs = []
31
- ph_groups = []
32
- for t in tokenized:
33
- if not t.startswith("#"):
34
- ph_groups.append([t])
35
- else:
36
- ph_groups[-1].append(t.replace("#", ""))
37
-
38
- phones = []
39
- tones = []
40
- word2ph = []
41
- # print(ph_groups)
42
- for group in ph_groups:
43
- w = "".join(group)
44
- phone_len = 0
45
- word_len = len(group)
46
- if w == '[UNK]':
47
- phone_list = ['UNK']
48
- else:
49
- phone_list = list(filter(lambda p: p != " ", fr_to_ipa.fr2ipa(w)))
50
-
51
- for ph in phone_list:
52
- phones.append(ph)
53
- tones.append(0)
54
- phone_len += 1
55
- aaa = distribute_phone(phone_len, word_len)
56
- word2ph += aaa
57
- # print(phone_list, aaa)
58
- # print('=' * 10)
59
-
60
- if pad_start_end:
61
- phones = ["_"] + phones + ["_"]
62
- tones = [0] + tones + [0]
63
- word2ph = [1] + word2ph + [1]
64
- return phones, tones, word2ph
65
-
66
- def get_bert_feature(text, word2ph, device=None):
67
- from text import french_bert
68
- return french_bert.get_bert_feature(text, word2ph, device=device)
69
-
70
- if __name__ == "__main__":
71
- ori_text = 'Ce service gratuit est“”"" 【disponible》 en chinois 【simplifié] et autres 123'
72
- # ori_text = "Ils essayaient vainement de faire comprendre à ma mère qu'avec les cent mille francs que m'avait laissé mon père,"
73
- # print(ori_text)
74
- text = text_normalize(ori_text)
75
- print(text)
76
- phoneme = fr_to_ipa.fr2ipa(text)
77
- print(phoneme)
78
-
79
-
80
- from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
81
- from text.cleaner_multiling import unicleaners
82
-
83
- def text_normalize(text):
84
- text = unicleaners(text, cased=True, lang='fr')
85
- return text
86
-
87
- # print(ori_text)
88
- text = text_normalize(ori_text)
89
- print(text)
90
- phonemizer = MultiPhonemizer({"fr-fr": "espeak"})
91
- # phonemizer.lang_to_phonemizer['fr'].keep_stress = True
92
- # phonemizer.lang_to_phonemizer['fr'].use_espeak_phonemes = True
93
- phoneme = phonemizer.phonemize(text, separator="", language='fr-fr')
94
- print(phoneme)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
melo/text/french_bert.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForMaskedLM
3
- import sys
4
-
5
- model_id = 'dbmdz/bert-base-french-europeana-cased'
6
- tokenizer = AutoTokenizer.from_pretrained(model_id)
7
- model = None
8
-
9
- def get_bert_feature(text, word2ph, device=None):
10
- global model
11
- if (
12
- sys.platform == "darwin"
13
- and torch.backends.mps.is_available()
14
- and device == "cpu"
15
- ):
16
- device = "mps"
17
- if not device:
18
- device = "cuda"
19
- if model is None:
20
- model = AutoModelForMaskedLM.from_pretrained(model_id).to(
21
- device
22
- )
23
- with torch.no_grad():
24
- inputs = tokenizer(text, return_tensors="pt")
25
- for i in inputs:
26
- inputs[i] = inputs[i].to(device)
27
- res = model(**inputs, output_hidden_states=True)
28
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
29
-
30
- assert inputs["input_ids"].shape[-1] == len(word2ph)
31
- word2phone = word2ph
32
- phone_level_feature = []
33
- for i in range(len(word2phone)):
34
- repeat_feature = res[i].repeat(word2phone[i], 1)
35
- phone_level_feature.append(repeat_feature)
36
-
37
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
38
-
39
- return phone_level_feature.T