Mahiruoshi commited on
Commit
e738dd9
·
verified ·
1 Parent(s): 7bac831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +721 -114
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from pathlib import Path
4
 
5
  import logging
 
6
  import re_matching
7
 
8
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -15,7 +16,8 @@ logging.basicConfig(
15
  )
16
 
17
  logger = logging.getLogger(__name__)
18
-
 
19
  import librosa
20
  import numpy as np
21
  import torch
@@ -23,8 +25,6 @@ import torch.nn as nn
23
  from torch.utils.data import Dataset
24
  from torch.utils.data import DataLoader, Dataset
25
  from tqdm import tqdm
26
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
27
-
28
 
29
  import gradio as gr
30
 
@@ -40,9 +40,30 @@ import utils
40
  from models import SynthesizerTrn
41
  from text.symbols import symbols
42
  import sys
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  net_g = None
45
- '''
46
  device = (
47
  "cuda:0"
48
  if torch.cuda.is_available()
@@ -52,8 +73,8 @@ device = (
52
  else "cpu"
53
  )
54
  )
55
- '''
56
- device = "cpu"
57
  BandList = {
58
  "PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
59
  "Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
@@ -70,6 +91,359 @@ BandList = {
70
  "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
71
  }
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def get_net_g(model_path: str, device: str, hps):
74
  net_g = SynthesizerTrn(
75
  len(symbols),
@@ -82,8 +456,8 @@ def get_net_g(model_path: str, device: str, hps):
82
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
83
  return net_g
84
 
85
- def get_text(text, language_str, hps, device):
86
- # 在此处实现当前版本的get_text
87
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
88
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
89
 
@@ -94,18 +468,24 @@ def get_text(text, language_str, hps, device):
94
  for i in range(len(word2ph)):
95
  word2ph[i] = word2ph[i] * 2
96
  word2ph[0] += 1
97
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
98
  del word2ph
99
  assert bert_ori.shape[-1] == len(phone), phone
100
 
101
  if language_str == "ZH":
102
  bert = bert_ori
103
- ja_bert = torch.zeros(1024, len(phone))
104
- en_bert = torch.zeros(1024, len(phone))
105
  elif language_str == "JP":
106
- bert = torch.zeros(1024, len(phone))
107
  ja_bert = bert_ori
108
- en_bert = torch.zeros(1024, len(phone))
 
 
 
 
109
  else:
110
  raise ValueError("language_str should be ZH, JP or EN")
111
 
@@ -125,19 +505,47 @@ def infer(
125
  noise_scale_w,
126
  length_scale,
127
  sid,
128
- reference_audio=None,
129
- emotion='Happy',
 
 
 
 
130
  ):
131
-
132
- language= 'JP' if is_japanese(text) else 'ZH'
133
- if isinstance(reference_audio, np.ndarray):
134
- emo = get_clap_audio_feature(reference_audio, device)
135
- else:
136
- emo = get_clap_text_feature(emotion, device)
137
- emo = torch.squeeze(emo, dim=1)
 
 
 
 
 
138
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
139
- text, language, hps, device
 
 
 
 
 
140
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  with torch.no_grad():
142
  x_tst = phones.to(device).unsqueeze(0)
143
  tones = tones.to(device).unsqueeze(0)
@@ -146,7 +554,7 @@ def infer(
146
  ja_bert = ja_bert.to(device).unsqueeze(0)
147
  en_bert = en_bert.to(device).unsqueeze(0)
148
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
149
- emo = emo.to(device).unsqueeze(0)
150
  del phones
151
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
152
  audio = (
@@ -159,7 +567,6 @@ def infer(
159
  bert,
160
  ja_bert,
161
  en_bert,
162
- emo,
163
  sdp_ratio=sdp_ratio,
164
  noise_scale=noise_scale,
165
  noise_scale_w=noise_scale_w,
@@ -169,109 +576,309 @@ def infer(
169
  .float()
170
  .numpy()
171
  )
172
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
173
  if torch.cuda.is_available():
174
  torch.cuda.empty_cache()
175
- return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
176
-
177
- def is_japanese(string):
178
- for ch in string:
179
- if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
180
- return True
181
- return False
182
 
183
  def loadmodel(model):
184
  _ = net_g.eval()
185
  _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
186
  return "success"
187
 
188
- if __name__ == "__main__":
189
- languages = [ "Auto", "ZH", "JP"]
190
- modelPaths = []
191
- for dirpath, dirnames, filenames in os.walk('Data/BangDreamV22/models/'):
192
- for filename in filenames:
193
- modelPaths.append(os.path.join(dirpath, filename))
194
- hps = utils.get_hparams_from_file('Data/BangDreamV22/configs/config.json')
195
- net_g = get_net_g(
196
- model_path=modelPaths[-1], device=device, hps=hps
197
- )
198
- speaker_ids = hps.data.spk2id
199
- speakers = list(speaker_ids.keys())
200
- with gr.Blocks() as app:
201
- for band in BandList:
202
- with gr.TabItem(band):
203
- for name in BandList[band]:
204
- with gr.TabItem(name):
205
- classifiedPaths = []
206
- for dirpath, dirnames, filenames in os.walk("Data/Bushiroad/classifedSample/"+name):
207
- for filename in filenames:
208
- classifiedPaths.append(os.path.join(dirpath, filename))
209
- with gr.Row():
210
- with gr.Column():
211
- with gr.Row():
212
- gr.Markdown(
213
- '<div align="center">'
214
- f'<img style="width:auto;height:400px;" src="https://mahiruoshi-bangdream-bert-vits2.hf.space/file/image/{name}.png">'
215
- '</div>'
216
- )
217
- length_scale = gr.Slider(
218
- minimum=0.1, maximum=2, value=1, step=0.01, label="语速调节"
219
- )
220
- emotion = gr.Textbox(
221
- label="Text prompt",
222
- placeholder="用文字描述生成风格。如:Happy",
223
- value="Happy",
224
- visible=True,
225
- )
226
- with gr.Accordion(label="参数设定", open=False):
227
- sdp_ratio = gr.Slider(
228
- minimum=0, maximum=1, value=0.2, step=0.01, label="SDP/DP混合比"
229
- )
230
- noise_scale = gr.Slider(
231
- minimum=0.1, maximum=2, value=0.6, step=0.01, label="感情调节"
232
- )
233
- noise_scale_w = gr.Slider(
234
- minimum=0.1, maximum=2, value=0.8, step=0.01, label="音素长度"
235
- )
236
- speaker = gr.Dropdown(
237
- choices=speakers, value=name, label="说话人"
238
- )
239
- with gr.Accordion(label="切换模型", open=False):
240
- modelstrs = gr.Dropdown(label = "模型", choices = modelPaths, value = modelPaths[0], type = "value")
241
- btnMod = gr.Button("载入模型")
242
- statusa = gr.TextArea()
243
- btnMod.click(loadmodel, inputs=[modelstrs], outputs = [statusa])
244
- with gr.Column():
245
- text = gr.TextArea(
246
- label="输入纯日语或者中文",
247
- placeholder="输入纯日语或者中文",
248
- value="为什么要演奏春日影!",
249
- )
250
- try:
251
- reference_audio = gr.Dropdown(label = "情感参考", choices = classifiedPaths, value = classifiedPaths[0], type = "value")
252
- except:
253
- reference_audio = gr.Audio(label="情感参考音频)", type="filepath")
254
- btn = gr.Button("点击生成", variant="primary")
255
- audio_output = gr.Audio(label="Output Audio")
256
- '''
257
- btntran = gr.Button("快速中翻日")
258
- translateResult = gr.TextArea("从这复制翻译后的文本")
259
- btntran.click(translate, inputs=[text], outputs = [translateResult])
260
- '''
261
- btn.click(
262
- infer,
263
- inputs=[
264
  text,
265
  sdp_ratio,
266
  noise_scale,
267
  noise_scale_w,
268
  length_scale,
269
  speaker,
270
- reference_audio,
271
- emotion,
272
- ],
273
- outputs=[audio_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
 
 
 
 
 
 
 
 
 
276
  print("推理页面已开启!")
277
- app.launch(share=True)
 
3
  from pathlib import Path
4
 
5
  import logging
6
+ import uuid
7
  import re_matching
8
 
9
  logging.getLogger("numba").setLevel(logging.WARNING)
 
16
  )
17
 
18
  logger = logging.getLogger(__name__)
19
+ import shutil
20
+ from scipy.io.wavfile import write
21
  import librosa
22
  import numpy as np
23
  import torch
 
25
  from torch.utils.data import Dataset
26
  from torch.utils.data import DataLoader, Dataset
27
  from tqdm import tqdm
 
 
28
 
29
  import gradio as gr
30
 
 
40
  from models import SynthesizerTrn
41
  from text.symbols import symbols
42
  import sys
43
+ import re
44
+
45
+ import random
46
+ import hashlib
47
+
48
+ from fugashi import Tagger
49
+ import jaconv
50
+ import unidic
51
+ import subprocess
52
 
53
+ import requests
54
+
55
+ from ebooklib import epub
56
+ import PyPDF2
57
+ from PyPDF2 import PdfReader
58
+ from bs4 import BeautifulSoup
59
+ import jieba
60
+ import romajitable
61
+
62
+ from flask import Flask, request, jsonify, render_template_string, send_file
63
+ from flask_cors import CORS
64
+ from scipy.io.wavfile import write
65
  net_g = None
66
+
67
  device = (
68
  "cuda:0"
69
  if torch.cuda.is_available()
 
73
  else "cpu"
74
  )
75
  )
76
+
77
+ #device = "cpu"
78
  BandList = {
79
  "PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
80
  "Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
 
91
  "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
92
  }
93
 
94
+ webBase = 'https://mahiruoshi-bangdream-bert-vits2.hf.space/'
95
+
96
+ port = 7860
97
+
98
+ languages = [ "Auto", "ZH", "JP"]
99
+ modelPaths = []
100
+ modes = ['pyopenjtalk-V2.3-Katakana','fugashi-V2.3-Katakana','pyopenjtalk-V2.3-Katakana-Katakana','fugashi-V2.3-Katakana-Katakana','onnx-V2.3']
101
+ sentence_modes = ['sentence','paragraph']
102
+ for dirpath, dirnames, filenames in os.walk('Data/BangDream/models/'):
103
+ for filename in filenames:
104
+ modelPaths.append(os.path.join(dirpath, filename))
105
+ hps = utils.get_hparams_from_file('Data/BangDream/config.json')
106
+
107
+ def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
108
+ """
109
+ :param Sentence: 待翻译语句
110
+ :param from_Language: 待翻译语句语言
111
+ :param to_Language: 目标语言
112
+ :return: 翻译后语句 出错时返回None
113
+
114
+ 常见语言代码:中文 zh 英语 en 日语 jp
115
+ """
116
+ appid = "20231117001883321"
117
+ key = "lMQbvZHeJveDceLof2wf"
118
+ if appid == "" or key == "":
119
+ return "请开发者在config.yml中配置app_key与secret_key"
120
+ url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
121
+ texts = Sentence.splitlines()
122
+ outTexts = []
123
+ for t in texts:
124
+ if t != "":
125
+ # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
126
+ salt = str(random.randint(1, 100000))
127
+ signString = appid + t + salt + key
128
+ hs = hashlib.md5()
129
+ hs.update(signString.encode("utf-8"))
130
+ signString = hs.hexdigest()
131
+ if from_Language == "":
132
+ from_Language = "auto"
133
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
134
+ payload = {
135
+ "q": t,
136
+ "from": from_Language,
137
+ "to": to_Language,
138
+ "appid": appid,
139
+ "salt": salt,
140
+ "sign": signString,
141
+ }
142
+ # 发送请求
143
+ try:
144
+ response = requests.post(
145
+ url=url, data=payload, headers=headers, timeout=3
146
+ )
147
+ response = response.json()
148
+ if "trans_result" in response.keys():
149
+ result = response["trans_result"][0]
150
+ if "dst" in result.keys():
151
+ dst = result["dst"]
152
+ outTexts.append(dst)
153
+ except Exception:
154
+ return Sentence
155
+ else:
156
+ outTexts.append(t)
157
+ return "\n".join(outTexts)
158
+
159
+ #文本清洗工具
160
+ def is_japanese(string):
161
+ for ch in string:
162
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
163
+ return True
164
+ return False
165
+
166
+ def is_chinese(string):
167
+ for ch in string:
168
+ if '\u4e00' <= ch <= '\u9fff':
169
+ return True
170
+ return False
171
+
172
+ def is_single_language(sentence):
173
+ # 检查句子是否为单一语言
174
+ contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
175
+ contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
176
+ contains_english = re.search(r'[a-zA-Z]', sentence) is not None
177
+ language_count = sum([contains_chinese, contains_japanese, contains_english])
178
+ return language_count == 1
179
+
180
+ def merge_scattered_parts(sentences):
181
+ """合并零散的部分到相邻的句子中,并确保单一语言性"""
182
+ merged_sentences = []
183
+ buffer_sentence = ""
184
+
185
+ for sentence in sentences:
186
+ # 检查��否是单一语言或者太短(可能是标点或单个词)
187
+ if is_single_language(sentence) and len(sentence) > 1:
188
+ # 如果缓冲区有内容,先将缓冲区的内容添加到列表
189
+ if buffer_sentence:
190
+ merged_sentences.append(buffer_sentence)
191
+ buffer_sentence = ""
192
+ merged_sentences.append(sentence)
193
+ else:
194
+ # 如果是零散的部分,将其添加到缓冲区
195
+ buffer_sentence += sentence
196
+
197
+ # 确保最后的缓冲区内容被添加
198
+ if buffer_sentence:
199
+ merged_sentences.append(buffer_sentence)
200
+
201
+ return merged_sentences
202
+
203
+ def is_only_punctuation(s):
204
+ """检查字符串是否只包含标点符号"""
205
+ # 此处列出中文、日文、英文常见标点符号
206
+ punctuation_pattern = re.compile(r'^[\s。*;,:“”()、!?《》\u3000\.,;:"\'?!()]+$')
207
+ return punctuation_pattern.match(s) is not None
208
+
209
+ def split_mixed_language(sentence):
210
+ # 分割混合语言句子
211
+ # 逐字符检查,分割不同语言部分
212
+ sub_sentences = []
213
+ current_language = None
214
+ current_part = ""
215
+
216
+ for char in sentence:
217
+ if re.match(r'[\u4e00-\u9fff]', char): # Chinese character
218
+ if current_language != 'chinese':
219
+ if current_part:
220
+ sub_sentences.append(current_part)
221
+ current_part = char
222
+ current_language = 'chinese'
223
+ else:
224
+ current_part += char
225
+ elif re.match(r'[\u3040-\u30ff\u31f0-\u31ff]', char): # Japanese character
226
+ if current_language != 'japanese':
227
+ if current_part:
228
+ sub_sentences.append(current_part)
229
+ current_part = char
230
+ current_language = 'japanese'
231
+ else:
232
+ current_part += char
233
+ elif re.match(r'[a-zA-Z]', char): # English character
234
+ if current_language != 'english':
235
+ if current_part:
236
+ sub_sentences.append(current_part)
237
+ current_part = char
238
+ current_language = 'english'
239
+ else:
240
+ current_part += char
241
+ else:
242
+ current_part += char # For punctuation and other characters
243
+
244
+ if current_part:
245
+ sub_sentences.append(current_part)
246
+
247
+ return sub_sentences
248
+
249
+ def replace_quotes(text):
250
+ # 替换中文、日文引号为英文引号
251
+ text = re.sub(r'[“”‘’『』「」()()]', '"', text)
252
+ return text
253
+
254
+ def remove_numeric_annotations(text):
255
+ # 定义用于匹配数字注释的正则表达式
256
+ # 包括 “”、【】和〔〕包裹的数字
257
+ pattern = r'“\d+”|【\d+】|〔\d+〕'
258
+ # 使用正则表达式替换掉这些注释
259
+ cleaned_text = re.sub(pattern, '', text)
260
+ return cleaned_text
261
+
262
+ def merge_adjacent_japanese(sentences):
263
+ """合并相邻且都只包含日语的句子"""
264
+ merged_sentences = []
265
+ i = 0
266
+ while i < len(sentences):
267
+ current_sentence = sentences[i]
268
+ if i + 1 < len(sentences) and is_japanese(current_sentence) and is_japanese(sentences[i + 1]):
269
+ # 当前句子和下一句都是日语,合并它们
270
+ while i + 1 < len(sentences) and is_japanese(sentences[i + 1]):
271
+ current_sentence += sentences[i + 1]
272
+ i += 1
273
+ merged_sentences.append(current_sentence)
274
+ i += 1
275
+ return merged_sentences
276
+
277
+ def extrac(text):
278
+ text = replace_quotes(remove_numeric_annotations(text)) # 替换引号
279
+ text = re.sub("<[^>]*>", "", text) # 移除 HTML 标签
280
+ # 使用换行符和标点符号进行初步分割
281
+ preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
282
+ final_sentences = []
283
+
284
+ preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
285
+
286
+ for piece in preliminary_sentences:
287
+ if is_single_language(piece):
288
+ final_sentences.append(piece)
289
+ else:
290
+ sub_sentences = split_mixed_language(piece)
291
+ final_sentences.extend(sub_sentences)
292
+
293
+ # 处理长句子,使用jieba进行分词
294
+ split_sentences = []
295
+ for sentence in final_sentences:
296
+ split_sentences.extend(split_long_sentences(sentence))
297
+
298
+ # 合并相邻的日语句子
299
+ merged_japanese_sentences = merge_adjacent_japanese(split_sentences)
300
+
301
+ # 剔除只包含标点符号的元素
302
+ clean_sentences = [s for s in merged_japanese_sentences if not is_only_punctuation(s)]
303
+
304
+ # 移除空字符串并去除多余引号
305
+ return [s.replace('"','').strip() for s in clean_sentences if s]
306
+
307
+
308
+
309
+ # 移除空字符串
310
+
311
+ def is_mixed_language(sentence):
312
+ contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
313
+ contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
314
+ contains_english = re.search(r'[a-zA-Z]', sentence) is not None
315
+ languages_count = sum([contains_chinese, contains_japanese, contains_english])
316
+ return languages_count > 1
317
+
318
+ def split_mixed_language(sentence):
319
+ # 分割混合语言句子
320
+ sub_sentences = re.split(r'(?<=[。!?\.\?!])(?=")|(?<=")(?=[\u4e00-\u9fff\u3040-\u30ff\u31f0-\u31ff]|[a-zA-Z])', sentence)
321
+ return [s.strip() for s in sub_sentences if s.strip()]
322
+
323
+ def seconds_to_ass_time(seconds):
324
+ """将秒数转换为ASS时间格式"""
325
+ hours = int(seconds / 3600)
326
+ minutes = int((seconds % 3600) / 60)
327
+ seconds = int(seconds) % 60
328
+ milliseconds = int((seconds - int(seconds)) * 1000)
329
+ return "{:01d}:{:02d}:{:02d}.{:02d}".format(hours, minutes, seconds, int(milliseconds / 10))
330
+
331
+ def extract_text_from_epub(file_path):
332
+ book = epub.read_epub(file_path)
333
+ content = []
334
+ for item in book.items:
335
+ if isinstance(item, epub.EpubHtml):
336
+ soup = BeautifulSoup(item.content, 'html.parser')
337
+ content.append(soup.get_text())
338
+ return '\n'.join(content)
339
+
340
+ def extract_text_from_pdf(file_path):
341
+ with open(file_path, 'rb') as file:
342
+ reader = PdfReader(file)
343
+ content = [page.extract_text() for page in reader.pages]
344
+ return '\n'.join(content)
345
+
346
+ def remove_annotations(text):
347
+ # 移除方括号、尖括号和中文方括号中的内容
348
+ text = re.sub(r'\[.*?\]', '', text)
349
+ text = re.sub(r'\<.*?\>', '', text)
350
+ text = re.sub(r'&#8203;``【oaicite:1】``&#8203;', '', text)
351
+ return text
352
+
353
+ def extract_text_from_file(inputFile):
354
+ file_extension = os.path.splitext(inputFile)[1].lower()
355
+ if file_extension == ".epub":
356
+ return extract_text_from_epub(inputFile)
357
+ elif file_extension == ".pdf":
358
+ return extract_text_from_pdf(inputFile)
359
+ elif file_extension == ".txt":
360
+ with open(inputFile, 'r', encoding='utf-8') as f:
361
+ return f.read()
362
+ else:
363
+ raise ValueError(f"Unsupported file format: {file_extension}")
364
+
365
+ def split_by_punctuation(sentence):
366
+ """按照中文次级标点符号分割句子"""
367
+ # 常见的中文次级分隔符号:逗号、分号等
368
+ parts = re.split(r'([,,;;])', sentence)
369
+ # 将标点符号与前面的词语合并,避免单独标点符号成为一个部分
370
+ merged_parts = []
371
+ for part in parts:
372
+ if part and not part in ',,;;':
373
+ merged_parts.append(part)
374
+ elif merged_parts:
375
+ merged_parts[-1] += part
376
+ return merged_parts
377
+
378
+ def split_long_sentences(sentence, max_length=30):
379
+ """如果中文句子太长,先按标点分割,必要时使用jieba进行分词并分割"""
380
+ if len(sentence) > max_length and is_chinese(sentence):
381
+ # 首先尝试按照次级标点符号分割
382
+ preliminary_parts = split_by_punctuation(sentence)
383
+ new_sentences = []
384
+
385
+ for part in preliminary_parts:
386
+ # 如果部分仍然太长,使用jieba进行分词
387
+ if len(part) > max_length:
388
+ words = jieba.lcut(part)
389
+ current_sentence = ""
390
+ for word in words:
391
+ if len(current_sentence) + len(word) > max_length:
392
+ new_sentences.append(current_sentence)
393
+ current_sentence = word
394
+ else:
395
+ current_sentence += word
396
+ if current_sentence:
397
+ new_sentences.append(current_sentence)
398
+ else:
399
+ new_sentences.append(part)
400
+
401
+ return new_sentences
402
+ return [sentence] # 如果句子不长或不是中文,直接返回
403
+
404
+ def extract_and_convert(text):
405
+
406
+ # 使用正则表达式找出所有英文单词
407
+ english_parts = re.findall(r'\b[A-Za-z]+\b', text) # \b为单词边界标识
408
+
409
+ # 对每个英文单词进行片假名转换
410
+ kana_parts = ['\n{}\n'.format(romajitable.to_kana(word).katakana) for word in english_parts]
411
+
412
+ # 替换原文本中的英文部分
413
+ for eng, kana in zip(english_parts, kana_parts):
414
+ text = text.replace(eng, kana, 1) # 限制每次只替换一个实例
415
+
416
+ return text
417
+ # 推理工具
418
+ def download_unidic():
419
+ try:
420
+ Tagger()
421
+ print("Tagger launch successfully.")
422
+ except Exception as e:
423
+ print("UNIDIC dictionary not found, downloading...")
424
+ subprocess.run([sys.executable, "-m", "unidic", "download"])
425
+ print("Download completed.")
426
+
427
+ def kanji_to_hiragana(text):
428
+ global tagger
429
+ output = ""
430
+
431
+ # 更新正则表达式以更准确地区分文本和标点符号
432
+ segments = re.findall(r'[一-龥ぁ-んァ-ン\w]+|[^\一-龥ぁ-んァ-ン\w\s]', text, re.UNICODE)
433
+
434
+ for segment in segments:
435
+ if re.match(r'[一-龥ぁ-んァ-ン\w]+', segment):
436
+ # 如果是单词或汉字,转换为平假名
437
+ for word in tagger(segment):
438
+ kana = word.feature.kana or word.surface
439
+ hiragana = jaconv.kata2hira(kana) # 将片假名转换为平假名
440
+ output += hiragana
441
+ else:
442
+ # 如果是标点符号,保持不变
443
+ output += segment
444
+
445
+ return output
446
+
447
  def get_net_g(model_path: str, device: str, hps):
448
  net_g = SynthesizerTrn(
449
  len(symbols),
 
456
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
457
  return net_g
458
 
459
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
460
+ style_text = None if style_text == "" else style_text
461
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
462
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
463
 
 
468
  for i in range(len(word2ph)):
469
  word2ph[i] = word2ph[i] * 2
470
  word2ph[0] += 1
471
+ bert_ori = get_bert(
472
+ norm_text, word2ph, language_str, device, style_text, style_weight
473
+ )
474
  del word2ph
475
  assert bert_ori.shape[-1] == len(phone), phone
476
 
477
  if language_str == "ZH":
478
  bert = bert_ori
479
+ ja_bert = torch.randn(1024, len(phone))
480
+ en_bert = torch.randn(1024, len(phone))
481
  elif language_str == "JP":
482
+ bert = torch.randn(1024, len(phone))
483
  ja_bert = bert_ori
484
+ en_bert = torch.randn(1024, len(phone))
485
+ elif language_str == "EN":
486
+ bert = torch.randn(1024, len(phone))
487
+ ja_bert = torch.randn(1024, len(phone))
488
+ en_bert = bert_ori
489
  else:
490
  raise ValueError("language_str should be ZH, JP or EN")
491
 
 
505
  noise_scale_w,
506
  length_scale,
507
  sid,
508
+ style_text=None,
509
+ style_weight=0.7,
510
+ language = "Auto",
511
+ mode = 'pyopenjtalk-V2.3-Katakana',
512
+ skip_start=False,
513
+ skip_end=False,
514
  ):
515
+ if style_text == None:
516
+ style_text = ""
517
+ style_weight=0,
518
+ if mode == 'fugashi-V2.3-Katakana':
519
+ text = kanji_to_hiragana(text) if is_japanese(text) else text
520
+ if language == "JP":
521
+ text = translate(text,"jp")
522
+ if language == "ZH":
523
+ text = translate(text,"zh")
524
+ if language == "Auto":
525
+ language= 'JP' if is_japanese(text) else 'ZH'
526
+ #print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{sid}:{language}:{mode}:{skip_start}:{skip_end}')
527
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
528
+ text,
529
+ language,
530
+ hps,
531
+ device,
532
+ style_text=style_text,
533
+ style_weight=style_weight,
534
  )
535
+ if skip_start:
536
+ phones = phones[3:]
537
+ tones = tones[3:]
538
+ lang_ids = lang_ids[3:]
539
+ bert = bert[:, 3:]
540
+ ja_bert = ja_bert[:, 3:]
541
+ en_bert = en_bert[:, 3:]
542
+ if skip_end:
543
+ phones = phones[:-2]
544
+ tones = tones[:-2]
545
+ lang_ids = lang_ids[:-2]
546
+ bert = bert[:, :-2]
547
+ ja_bert = ja_bert[:, :-2]
548
+ en_bert = en_bert[:, :-2]
549
  with torch.no_grad():
550
  x_tst = phones.to(device).unsqueeze(0)
551
  tones = tones.to(device).unsqueeze(0)
 
554
  ja_bert = ja_bert.to(device).unsqueeze(0)
555
  en_bert = en_bert.to(device).unsqueeze(0)
556
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
557
+ # emo = emo.to(device).unsqueeze(0)
558
  del phones
559
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
560
  audio = (
 
567
  bert,
568
  ja_bert,
569
  en_bert,
 
570
  sdp_ratio=sdp_ratio,
571
  noise_scale=noise_scale,
572
  noise_scale_w=noise_scale_w,
 
576
  .float()
577
  .numpy()
578
  )
579
+ del (
580
+ x_tst,
581
+ tones,
582
+ lang_ids,
583
+ bert,
584
+ x_tst_lengths,
585
+ speakers,
586
+ ja_bert,
587
+ en_bert,
588
+ ) # , emo
589
  if torch.cuda.is_available():
590
  torch.cuda.empty_cache()
591
+ print("Success.")
592
+ return audio
 
 
 
 
 
593
 
594
  def loadmodel(model):
595
  _ = net_g.eval()
596
  _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
597
  return "success"
598
 
599
+ def generate_audio_and_srt_for_group(
600
+ group,
601
+ outputPath,
602
+ group_index,
603
+ sampling_rate,
604
+ speaker,
605
+ sdp_ratio,
606
+ noise_scale,
607
+ noise_scale_w,
608
+ length_scale,
609
+ speakerList,
610
+ silenceTime,
611
+ language,
612
+ mode,
613
+ skip_start,
614
+ skip_end,
615
+ style_text,
616
+ style_weight,
617
+ ):
618
+ audio_fin = []
619
+ ass_entries = []
620
+ start_time = 0
621
+ #speaker = random.choice(cara_list)
622
+ ass_header = """[Script Info]
623
+ ; 我没意见
624
+ Title: Audiobook
625
+ ScriptType: v4.00+
626
+ WrapStyle: 0
627
+ PlayResX: 640
628
+ PlayResY: 360
629
+ ScaledBorderAndShadow: yes
630
+ [V4+ Styles]
631
+ Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
632
+ Style: Default,Arial,20,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,1,1,2,10,10,10,1
633
+ [Events]
634
+ Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
635
+ """
636
+
637
+ for sentence in group:
638
+
639
+ if len(sentence) > 1:
640
+ FakeSpeaker = sentence.split("|")[0]
641
+ print(FakeSpeaker)
642
+ SpeakersList = re.split('\n', speakerList)
643
+ if FakeSpeaker in list(hps.data.spk2id.keys()):
644
+ speaker = FakeSpeaker
645
+ for i in SpeakersList:
646
+ if FakeSpeaker == i.split("|")[1]:
647
+ speaker = i.split("|")[0]
648
+ if sentence != '\n':
649
+ text = (remove_annotations(sentence.split("|")[-1]).replace(" ","")+"。").replace(",。","。")
650
+ if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
651
+ #print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{speaker}:{language}:{mode}:{skip_start}:{skip_end}')
652
+ audio = infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  text,
654
  sdp_ratio,
655
  noise_scale,
656
  noise_scale_w,
657
  length_scale,
658
  speaker,
659
+ style_text,
660
+ style_weight,
661
+ language,
662
+ mode,
663
+ skip_start,
664
+ skip_end,
665
+ )
666
+ silence_frames = int(silenceTime * 44010) if is_chinese(sentence) else int(silenceTime * 44010)
667
+ silence_data = np.zeros((silence_frames,), dtype=audio.dtype)
668
+ audio_fin.append(audio)
669
+ audio_fin.append(silence_data)
670
+ duration = len(audio) / sampling_rate
671
+ print(duration)
672
+ end_time = start_time + duration + silenceTime
673
+ ass_entries.append("Dialogue: 0,{},{},".format(seconds_to_ass_time(start_time), seconds_to_ass_time(end_time)) + "Default,,0,0,0,,{}".format(sentence.replace("|",":")))
674
+ start_time = end_time
675
+
676
+ wav_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.wav')
677
+ ass_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.ass')
678
+ write(wav_filename, sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
679
+
680
+ with open(ass_filename, 'w', encoding='utf-8') as f:
681
+ f.write(ass_header + '\n'.join(ass_entries))
682
+ return (hps.data.sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
683
+
684
+ def generate_audio(
685
+ inputFile,
686
+ groupsize,
687
+ filepath,
688
+ silenceTime,
689
+ speakerList,
690
+ text,
691
+ sdp_ratio,
692
+ noise_scale,
693
+ noise_scale_w,
694
+ length_scale,
695
+ sid,
696
+ style_text=None,
697
+ style_weight=0.7,
698
+ language = "Auto",
699
+ mode = 'pyopenjtalk-V2.3-Katakana',
700
+ sentence_mode = 'sentence',
701
+ skip_start=False,
702
+ skip_end=False,
703
+ ):
704
+ if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
705
+ if sentence_mode == 'sentence':
706
+ audio = infer(
707
+ text,
708
+ sdp_ratio,
709
+ noise_scale,
710
+ noise_scale_w,
711
+ length_scale,
712
+ sid,
713
+ style_text,
714
+ style_weight,
715
+ language,
716
+ mode,
717
+ skip_start,
718
+ skip_end,
719
+ )
720
+ return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
721
+ if sentence_mode == 'paragraph':
722
+ GROUP_SIZE = groupsize
723
+ directory_path = filepath if torch.cuda.is_available() else "books"
724
+ if os.path.exists(directory_path):
725
+ shutil.rmtree(directory_path)
726
+ os.makedirs(directory_path)
727
+ if inputFile:
728
+ text = extract_text_from_file(inputFile.name)
729
+ if language == 'Auto':
730
+ sentences = extrac(extract_and_convert(text))
731
+ else:
732
+ sentences = extrac(text)
733
+ for i in range(0, len(sentences), GROUP_SIZE):
734
+ group = sentences[i:i+GROUP_SIZE]
735
+ if speakerList == "":
736
+ speakerList = "无"
737
+ result = generate_audio_and_srt_for_group(
738
+ group,
739
+ directory_path,
740
+ i//GROUP_SIZE + 1,
741
+ 44100,
742
+ sid,
743
+ sdp_ratio,
744
+ noise_scale,
745
+ noise_scale_w,
746
+ length_scale,
747
+ speakerList,
748
+ silenceTime,
749
+ language,
750
+ mode,
751
+ skip_start,
752
+ skip_end,
753
+ style_text,
754
+ style_weight,
755
  )
756
+ if not torch.cuda.is_available():
757
+ return result
758
+ return result
759
+
760
+ Flaskapp = Flask(__name__)
761
+ CORS(Flaskapp)
762
+ @Flaskapp.route('/', methods=['GET', 'POST'])
763
+
764
+ def tts():
765
+ if request.method == 'POST':
766
+ input = request.json
767
+ inputFile = None
768
+ filepath = input['filepath']
769
+ groupSize = input['groupSize']
770
+ text = input['text']
771
+ sdp_ratio = input['sdp_ratio']
772
+ noise_scale = input['noise_scale']
773
+ noise_scale_w = input['noise_scale_w']
774
+ length_scale = input['length_scale']
775
+ sid = input['speaker']
776
+ style_text = input['style_text']
777
+ style_weight = input['style_weight']
778
+ language = input['language']
779
+ mode = input['mode']
780
+ sentence_mode = input['sentence_mode']
781
+ skip_start = input['skip_start']
782
+ skip_end = input['skip_end']
783
+ speakerList = input['speakerList']
784
+ silenceTime = input['silenceTime']
785
+ samplerate, audio = generate_audio(
786
+ inputFile,
787
+ groupSize,
788
+ filepath,
789
+ silenceTime,
790
+ speakerList,
791
+ text,
792
+ sdp_ratio,
793
+ noise_scale,
794
+ noise_scale_w,
795
+ length_scale,
796
+ sid,
797
+ style_text,
798
+ style_weight,
799
+ language,
800
+ mode,
801
+ sentence_mode,
802
+ skip_start,
803
+ skip_end,
804
+ )
805
+ unique_filename = f"temp{uuid.uuid4()}.wav"
806
+ write(unique_filename, samplerate, audio)
807
+ with open(unique_filename ,'rb') as bit:
808
+ wav_bytes = bit.read()
809
+ os.remove(unique_filename)
810
+ headers = {
811
+ 'Content-Type': 'audio/wav',
812
+ 'Text': unique_filename .encode('utf-8')}
813
+ return wav_bytes, 200, headers
814
+ groupSize = request.args.get('groupSize', default = 50, type = int)
815
+ text = request.args.get('text', default = '', type = str)
816
+ sdp_ratio = request.args.get('sdp_ratio', default = 0.5, type = float)
817
+ noise_scale = request.args.get('noise_scale', default = 0.6, type = float)
818
+ noise_scale_w = request.args.get('noise_scale_w', default = 0.667, type = float)
819
+ length_scale = request.args.get('length_scale', default = 1, type = float)
820
+ sid = request.args.get('speaker', default = '八千代', type = str)
821
+ style_text = request.args.get('style_text', default = '', type = str)
822
+ style_weight = request.args.get('style_weight', default = 0.7, type = float)
823
+ language = request.args.get('language', default = 'Auto', type = str)
824
+ mode = request.args.get('mode', default = 'pyopenjtalk-V2.3-Katakana', type = str)
825
+ sentence_mode = request.args.get('sentence_mode', default = 'sentence', type = str)
826
+ skip_start = request.args.get('skip_start', default = False, type = bool)
827
+ skip_end = request.args.get('skip_end', default = False, type = bool)
828
+ speakerList = request.args.get('speakerList', default = '', type = str)
829
+ silenceTime = request.args.get('silenceTime', default = 0.1, type = float)
830
+ inputFile = None
831
+ if not sid or not text:
832
+ return render_template_string(f"""
833
+ <!DOCTYPE html>
834
+ <html>
835
+ <head>
836
+ <title>TTS API Documentation</title>
837
+ </head>
838
+ <body>
839
+ <iframe src={webBase} style="width:100%; height:100vh; border:none;"></iframe>
840
+ </body>
841
+ </html>
842
+ """)
843
+ samplerate, audio = generate_audio(
844
+ inputFile,
845
+ groupSize,
846
+ None,
847
+ silenceTime,
848
+ speakerList,
849
+ text,
850
+ sdp_ratio,
851
+ noise_scale,
852
+ noise_scale_w,
853
+ length_scale,
854
+ sid,
855
+ style_text,
856
+ style_weight,
857
+ language,
858
+ mode,
859
+ sentence_mode,
860
+ skip_start,
861
+ skip_end,
862
+ )
863
+ unique_filename = f"temp{uuid.uuid4()}.wav"
864
+ write(unique_filename, samplerate, audio)
865
+ with open(unique_filename ,'rb') as bit:
866
+ wav_bytes = bit.read()
867
+ os.remove(unique_filename)
868
+ headers = {
869
+ 'Content-Type': 'audio/wav',
870
+ 'Text': unique_filename .encode('utf-8')}
871
+ return wav_bytes, 200, headers
872
+
873
 
874
+ if __name__ == "__main__":
875
+ download_unidic()
876
+ tagger = Tagger()
877
+ net_g = get_net_g(
878
+ model_path=modelPaths[-1], device=device, hps=hps
879
+ )
880
+ speaker_ids = hps.data.spk2id
881
+ speakers = list(speaker_ids.keys())
882
+
883
  print("推理页面已开启!")
884
+ Flaskapp.run(host="0.0.0.0", port=port,debug=True)