Spaces:
Runtime error
Runtime error
zhzluke96
commited on
Commit
·
f34bda5
1
Parent(s):
b44532e
update
Browse files- modules/api/impl/openai_api.py +50 -4
- modules/normalization.py +21 -2
- modules/utils/zh_normalization/num.py +15 -6
- webui.py +3 -0
modules/api/impl/openai_api.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import HTTPException, Body
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
|
| 4 |
import io
|
|
@@ -14,7 +14,7 @@ from modules.normalization import text_normalize
|
|
| 14 |
from modules import generate_audio as generate
|
| 15 |
|
| 16 |
|
| 17 |
-
from typing import Literal
|
| 18 |
import pyrubberband as pyrb
|
| 19 |
|
| 20 |
from modules.api import utils as api_utils
|
|
@@ -106,8 +106,29 @@ async def openai_speech_api(
|
|
| 106 |
raise HTTPException(status_code=500, detail=str(e))
|
| 107 |
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
"/v1/audio/speech",
|
| 112 |
response_class=FileResponse,
|
| 113 |
description="""
|
|
@@ -122,3 +143,28 @@ openai api document:
|
|
| 122 |
> model 可填任意值
|
| 123 |
""",
|
| 124 |
)(openai_speech_api)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import File, Form, HTTPException, Body, UploadFile
|
| 2 |
from fastapi.responses import StreamingResponse
|
| 3 |
|
| 4 |
import io
|
|
|
|
| 14 |
from modules import generate_audio as generate
|
| 15 |
|
| 16 |
|
| 17 |
+
from typing import List, Literal, Optional, Union
|
| 18 |
import pyrubberband as pyrb
|
| 19 |
|
| 20 |
from modules.api import utils as api_utils
|
|
|
|
| 106 |
raise HTTPException(status_code=500, detail=str(e))
|
| 107 |
|
| 108 |
|
| 109 |
+
class TranscribeSegment(BaseModel):
|
| 110 |
+
id: int
|
| 111 |
+
seek: float
|
| 112 |
+
start: float
|
| 113 |
+
end: float
|
| 114 |
+
text: str
|
| 115 |
+
tokens: List[int]
|
| 116 |
+
temperature: float
|
| 117 |
+
avg_logprob: float
|
| 118 |
+
compression_ratio: float
|
| 119 |
+
no_speech_prob: float
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TranscriptionsVerboseResponse(BaseModel):
|
| 123 |
+
task: str
|
| 124 |
+
language: str
|
| 125 |
+
duration: float
|
| 126 |
+
text: str
|
| 127 |
+
segments: List[TranscribeSegment]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def setup(app: APIManager):
|
| 131 |
+
app.post(
|
| 132 |
"/v1/audio/speech",
|
| 133 |
response_class=FileResponse,
|
| 134 |
description="""
|
|
|
|
| 143 |
> model 可填任意值
|
| 144 |
""",
|
| 145 |
)(openai_speech_api)
|
| 146 |
+
|
| 147 |
+
@app.post(
|
| 148 |
+
"/v1/audio/transcriptions",
|
| 149 |
+
response_class=TranscriptionsVerboseResponse,
|
| 150 |
+
description="WIP",
|
| 151 |
+
)
|
| 152 |
+
async def transcribe(
|
| 153 |
+
file: UploadFile = File(...),
|
| 154 |
+
model: str = Form(...),
|
| 155 |
+
language: Optional[str] = Form(None),
|
| 156 |
+
prompt: Optional[str] = Form(None),
|
| 157 |
+
response_format: str = Form("json"),
|
| 158 |
+
temperature: float = Form(0),
|
| 159 |
+
timestamp_granularities: List[str] = Form(["segment"]),
|
| 160 |
+
):
|
| 161 |
+
# TODO: Implement transcribe
|
| 162 |
+
return {
|
| 163 |
+
"file": file.filename,
|
| 164 |
+
"model": model,
|
| 165 |
+
"language": language,
|
| 166 |
+
"prompt": prompt,
|
| 167 |
+
"response_format": response_format,
|
| 168 |
+
"temperature": temperature,
|
| 169 |
+
"timestamp_granularities": timestamp_granularities,
|
| 170 |
+
}
|
modules/normalization.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from modules.utils.zh_normalization.text_normlization import *
|
| 2 |
import emojiswitch
|
| 3 |
from modules.utils.markdown import markdown_to_text
|
|
@@ -5,12 +6,28 @@ from modules import models
|
|
| 5 |
import re
|
| 6 |
|
| 7 |
|
|
|
|
| 8 |
def is_chinese(text):
|
| 9 |
# 中文字符的 Unicode 范围是 \u4e00-\u9fff
|
| 10 |
chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
| 11 |
return bool(chinese_pattern.search(text))
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
post_normalize_pipeline = []
|
| 15 |
pre_normalize_pipeline = []
|
| 16 |
|
|
@@ -123,7 +140,7 @@ def apply_character_map(text):
|
|
| 123 |
|
| 124 |
@post_normalize()
|
| 125 |
def apply_emoji_map(text):
|
| 126 |
-
lang =
|
| 127 |
return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
|
| 128 |
|
| 129 |
|
|
@@ -144,6 +161,8 @@ def replace_unk_tokens(text):
|
|
| 144 |
"""
|
| 145 |
chat_tts = models.load_chat_tts()
|
| 146 |
if "tokenizer" not in chat_tts.pretrain_models:
|
|
|
|
|
|
|
| 147 |
return text
|
| 148 |
tokenizer = chat_tts.pretrain_models["tokenizer"]
|
| 149 |
vocab = tokenizer.get_vocab()
|
|
@@ -223,7 +242,7 @@ def sentence_normalize(sentence_text: str):
|
|
| 223 |
pattern = re.compile(r"(\[.+?\])|([^[]+)")
|
| 224 |
|
| 225 |
def normalize_part(part):
|
| 226 |
-
sentences = tx.normalize(part) if
|
| 227 |
dest_text = ""
|
| 228 |
for sentence in sentences:
|
| 229 |
sentence = apply_post_normalize(sentence)
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
from modules.utils.zh_normalization.text_normlization import *
|
| 3 |
import emojiswitch
|
| 4 |
from modules.utils.markdown import markdown_to_text
|
|
|
|
| 6 |
import re
|
| 7 |
|
| 8 |
|
| 9 |
+
@lru_cache(maxsize=64)
|
| 10 |
def is_chinese(text):
|
| 11 |
# 中文字符的 Unicode 范围是 \u4e00-\u9fff
|
| 12 |
chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
| 13 |
return bool(chinese_pattern.search(text))
|
| 14 |
|
| 15 |
|
| 16 |
+
@lru_cache(maxsize=64)
|
| 17 |
+
def is_eng(text):
|
| 18 |
+
eng_pattern = re.compile(r"[a-zA-Z]")
|
| 19 |
+
return bool(eng_pattern.search(text))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@lru_cache(maxsize=64)
|
| 23 |
+
def guess_lang(text):
|
| 24 |
+
if is_chinese(text):
|
| 25 |
+
return "zh"
|
| 26 |
+
if is_eng(text):
|
| 27 |
+
return "en"
|
| 28 |
+
return "zh"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
post_normalize_pipeline = []
|
| 32 |
pre_normalize_pipeline = []
|
| 33 |
|
|
|
|
| 140 |
|
| 141 |
@post_normalize()
|
| 142 |
def apply_emoji_map(text):
|
| 143 |
+
lang = guess_lang(text)
|
| 144 |
return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
|
| 145 |
|
| 146 |
|
|
|
|
| 161 |
"""
|
| 162 |
chat_tts = models.load_chat_tts()
|
| 163 |
if "tokenizer" not in chat_tts.pretrain_models:
|
| 164 |
+
# 这个地方只有在 huggingface spaces 中才会触发
|
| 165 |
+
# 因为 hugggingface 自动处理模型卸载加载,所以如果拿不到就算了...
|
| 166 |
return text
|
| 167 |
tokenizer = chat_tts.pretrain_models["tokenizer"]
|
| 168 |
vocab = tokenizer.get_vocab()
|
|
|
|
| 242 |
pattern = re.compile(r"(\[.+?\])|([^[]+)")
|
| 243 |
|
| 244 |
def normalize_part(part):
|
| 245 |
+
sentences = tx.normalize(part) if guess_lang(part) == "zh" else [part]
|
| 246 |
dest_text = ""
|
| 247 |
for sentence in sentences:
|
| 248 |
sentence = apply_post_normalize(sentence)
|
modules/utils/zh_normalization/num.py
CHANGED
|
@@ -144,13 +144,22 @@ def replace_number(match) -> str:
|
|
| 144 |
sign = match.group(1)
|
| 145 |
number = match.group(2)
|
| 146 |
pure_decimal = match.group(5)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
result =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return result
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
# 范围表达式
|
|
|
|
| 144 |
sign = match.group(1)
|
| 145 |
number = match.group(2)
|
| 146 |
pure_decimal = match.group(5)
|
| 147 |
+
|
| 148 |
+
# TODO 也许可以把 num2str 完全替换成 cn2an
|
| 149 |
+
import cn2an
|
| 150 |
+
text = pure_decimal if pure_decimal else f"{sign}{number}"
|
| 151 |
+
try:
|
| 152 |
+
result = cn2an.an2cn(text, "low")
|
| 153 |
+
except ValueError:
|
| 154 |
+
if pure_decimal:
|
| 155 |
+
result = num2str(pure_decimal)
|
| 156 |
+
else:
|
| 157 |
+
sign: str = "负" if sign else ""
|
| 158 |
+
number: str = num2str(number)
|
| 159 |
+
result = f"{sign}{number}"
|
| 160 |
return result
|
| 161 |
+
|
| 162 |
+
|
| 163 |
|
| 164 |
|
| 165 |
# 范围表达式
|
webui.py
CHANGED
|
@@ -45,6 +45,9 @@ from modules import refiner, config
|
|
| 45 |
from modules.utils import env, audio
|
| 46 |
from modules.SentenceSplitter import SentenceSplitter
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
torch._dynamo.config.cache_size_limit = 64
|
| 49 |
torch._dynamo.config.suppress_errors = True
|
| 50 |
torch.set_float32_matmul_precision("high")
|
|
|
|
| 45 |
from modules.utils import env, audio
|
| 46 |
from modules.SentenceSplitter import SentenceSplitter
|
| 47 |
|
| 48 |
+
# fix: If the system proxy is enabled in the Windows system, you need to skip these
|
| 49 |
+
os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
|
| 50 |
+
|
| 51 |
torch._dynamo.config.cache_size_limit = 64
|
| 52 |
torch._dynamo.config.suppress_errors = True
|
| 53 |
torch.set_float32_matmul_precision("high")
|