Spaces:
Runtime error
Runtime error
TomatoCocotree
commited on
Commit
·
8056b16
1
Parent(s):
19400f8
更新server.py
Browse files
server.py
CHANGED
@@ -86,13 +86,23 @@ parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=Tru
|
|
86 |
parser.add_argument(
|
87 |
"--secure", action="store_true", help="Enforces the use of an API key"
|
88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
sd_group = parser.add_mutually_exclusive_group()
|
90 |
|
91 |
-
local_sd =
|
92 |
local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
|
93 |
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
|
94 |
|
95 |
-
remote_sd =
|
96 |
remote_sd.add_argument(
|
97 |
"--sd-remote", action="store_true", help="Use a remote backend for SD"
|
98 |
)
|
@@ -119,7 +129,7 @@ parser.add_argument(
|
|
119 |
)
|
120 |
|
121 |
args = parser.parse_args()
|
122 |
-
# [HF, Huggingface] Set port to 7860, set host to remote.
|
123 |
port = 7860
|
124 |
host = "0.0.0.0"
|
125 |
summarization_model = (
|
@@ -170,6 +180,28 @@ if not torch.cuda.is_available() and not args.cpu:
|
|
170 |
|
171 |
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
if "caption" in modules:
|
174 |
print("Initializing an image captioning model...")
|
175 |
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
@@ -189,16 +221,6 @@ if "summarize" in modules:
|
|
189 |
summarization_model, torch_dtype=torch_dtype
|
190 |
).to(device)
|
191 |
|
192 |
-
if "classify" in modules:
|
193 |
-
print("Initializing a sentiment classification pipeline...")
|
194 |
-
classification_pipe = pipeline(
|
195 |
-
"text-classification",
|
196 |
-
model=classification_model,
|
197 |
-
top_k=None,
|
198 |
-
device=device,
|
199 |
-
torch_dtype=torch_dtype,
|
200 |
-
)
|
201 |
-
|
202 |
if "sd" in modules and not sd_use_remote:
|
203 |
from diffusers import StableDiffusionPipeline
|
204 |
from diffusers import EulerAncestralDiscreteScheduler
|
@@ -251,7 +273,6 @@ if "silero-tts" in modules:
|
|
251 |
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
252 |
tts_service.generate_samples()
|
253 |
|
254 |
-
|
255 |
if "edge-tts" in modules:
|
256 |
print("Initializing Edge TTS client")
|
257 |
import tts_edge as edge
|
@@ -295,8 +316,112 @@ if "chromadb" in modules:
|
|
295 |
app = Flask(__name__)
|
296 |
CORS(app) # allow cross-domain requests
|
297 |
Compress(app) # compress responses
|
298 |
-
app.config["MAX_CONTENT_LENGTH"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
def require_module(name):
|
302 |
def wrapper(fn):
|
@@ -313,12 +438,7 @@ def require_module(name):
|
|
313 |
|
314 |
# AI stuff
|
315 |
def classify_text(text: str) -> list:
|
316 |
-
|
317 |
-
text,
|
318 |
-
truncation=True,
|
319 |
-
max_length=classification_pipe.model.config.max_position_embeddings,
|
320 |
-
)[0]
|
321 |
-
return sorted(output, key=lambda x: x["score"], reverse=True)
|
322 |
|
323 |
|
324 |
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
@@ -417,7 +537,7 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
|
|
417 |
return img_str
|
418 |
|
419 |
|
420 |
-
ignore_auth = []
|
421 |
# [HF, Huggingface] Get password instead of text file.
|
422 |
api_key = os.environ.get("password")
|
423 |
|
@@ -429,6 +549,7 @@ def is_authorize_ignored(request):
|
|
429 |
return True
|
430 |
return False
|
431 |
|
|
|
432 |
@app.before_request
|
433 |
def before_request():
|
434 |
# Request time measuring
|
@@ -532,6 +653,8 @@ def api_classify():
|
|
532 |
classification = classify_text(data["text"])
|
533 |
print("Classification output:", classification, sep="\n")
|
534 |
gc.collect()
|
|
|
|
|
535 |
return jsonify({"classification": classification})
|
536 |
|
537 |
|
@@ -540,8 +663,31 @@ def api_classify():
|
|
540 |
def api_classify_labels():
|
541 |
classification = classify_text("")
|
542 |
labels = [x["label"] for x in classification]
|
|
|
|
|
543 |
return jsonify({"labels": labels})
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
546 |
@app.route("/api/image", methods=["POST"])
|
547 |
@require_module("sd")
|
@@ -958,7 +1104,8 @@ if args.share:
|
|
958 |
cloudflare = _run_cloudflared(port, metrics_port)
|
959 |
else:
|
960 |
cloudflare = _run_cloudflared(port)
|
961 |
-
print("Running on
|
962 |
|
963 |
ignore_auth.append(tts_play_sample)
|
|
|
964 |
app.run(host=host, port=port)
|
|
|
86 |
parser.add_argument(
|
87 |
"--secure", action="store_true", help="Enforces the use of an API key"
|
88 |
)
|
89 |
+
parser.add_argument("--talkinghead-gpu", action="store_true", help="Run the talkinghead animation on the GPU (CPU is default)")
|
90 |
+
|
91 |
+
parser.add_argument("--coqui-gpu", action="store_true", help="Run the voice models on the GPU (CPU is default)")
|
92 |
+
parser.add_argument("--coqui-models", help="Install given Coqui-api TTS model at launch (comma separated list, last one will be loaded at start)")
|
93 |
+
|
94 |
+
parser.add_argument("--max-content-length", help="Set the max")
|
95 |
+
parser.add_argument("--rvc-save-file", action="store_true", help="Save the last rvc input/output audio file into data/tmp/ folder (for research)")
|
96 |
+
|
97 |
+
parser.add_argument("--stt-vosk-model-path", help="Load a custom vosk speech-to-text model")
|
98 |
+
parser.add_argument("--stt-whisper-model-path", help="Load a custom vosk speech-to-text model")
|
99 |
sd_group = parser.add_mutually_exclusive_group()
|
100 |
|
101 |
+
local_sd = parser.add_argument_group("sd-local")
|
102 |
local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
|
103 |
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
|
104 |
|
105 |
+
remote_sd = parser.add_argument_group("sd-remote")
|
106 |
remote_sd.add_argument(
|
107 |
"--sd-remote", action="store_true", help="Use a remote backend for SD"
|
108 |
)
|
|
|
129 |
)
|
130 |
|
131 |
args = parser.parse_args()
|
132 |
+
# [HF, Huggingface] Set port to 7860, set host to remote.
|
133 |
port = 7860
|
134 |
host = "0.0.0.0"
|
135 |
summarization_model = (
|
|
|
180 |
|
181 |
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
182 |
|
183 |
+
if "talkinghead" in modules:
|
184 |
+
import sys
|
185 |
+
import threading
|
186 |
+
mode = "cuda" if args.talkinghead_gpu else "cpu"
|
187 |
+
print("Initializing talkinghead pipeline in " + mode + " mode....")
|
188 |
+
talkinghead_path = os.path.abspath(os.path.join(os.getcwd(), "talkinghead"))
|
189 |
+
sys.path.append(talkinghead_path) # Add the path to the 'tha3' module to the sys.path list
|
190 |
+
|
191 |
+
try:
|
192 |
+
import talkinghead.tha3.app.app as talkinghead
|
193 |
+
from talkinghead import *
|
194 |
+
def launch_talkinghead_gui():
|
195 |
+
talkinghead.launch_gui(mode, "separable_float")
|
196 |
+
#choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
|
197 |
+
#choices='The device to use for PyTorch ("cuda" for GPU, "cpu" for CPU).'
|
198 |
+
talkinghead_thread = threading.Thread(target=launch_talkinghead_gui)
|
199 |
+
talkinghead_thread.daemon = True # Set the thread as a daemon thread
|
200 |
+
talkinghead_thread.start()
|
201 |
+
|
202 |
+
except ModuleNotFoundError:
|
203 |
+
print("Error: Could not import the 'talkinghead' module.")
|
204 |
+
|
205 |
if "caption" in modules:
|
206 |
print("Initializing an image captioning model...")
|
207 |
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
|
|
221 |
summarization_model, torch_dtype=torch_dtype
|
222 |
).to(device)
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
if "sd" in modules and not sd_use_remote:
|
225 |
from diffusers import StableDiffusionPipeline
|
226 |
from diffusers import EulerAncestralDiscreteScheduler
|
|
|
273 |
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
274 |
tts_service.generate_samples()
|
275 |
|
|
|
276 |
if "edge-tts" in modules:
|
277 |
print("Initializing Edge TTS client")
|
278 |
import tts_edge as edge
|
|
|
316 |
app = Flask(__name__)
|
317 |
CORS(app) # allow cross-domain requests
|
318 |
Compress(app) # compress responses
|
319 |
+
app.config["MAX_CONTENT_LENGTH"] = 500 * 1024 * 1024
|
320 |
+
|
321 |
+
max_content_length = (
|
322 |
+
args.max_content_length
|
323 |
+
if args.max_content_length
|
324 |
+
else None)
|
325 |
+
|
326 |
+
if max_content_length is not None:
|
327 |
+
print("Setting MAX_CONTENT_LENGTH to",max_content_length,"Mb")
|
328 |
+
app.config["MAX_CONTENT_LENGTH"] = int(max_content_length) * 1024 * 1024
|
329 |
+
|
330 |
+
if "classify" in modules:
|
331 |
+
import modules.classify.classify_module as classify_module
|
332 |
+
classify_module.init_text_emotion_classifier(classification_model, device, torch_dtype)
|
333 |
+
|
334 |
+
if "vosk-stt" in modules:
|
335 |
+
print("Initializing Vosk speech-recognition (from ST request file)")
|
336 |
+
vosk_model_path = (
|
337 |
+
args.stt_vosk_model_path
|
338 |
+
if args.stt_vosk_model_path
|
339 |
+
else None)
|
340 |
+
|
341 |
+
import modules.speech_recognition.vosk_module as vosk_module
|
342 |
+
|
343 |
+
vosk_module.model = vosk_module.load_model(file_path=vosk_model_path)
|
344 |
+
app.add_url_rule("/api/speech-recognition/vosk/process-audio", view_func=vosk_module.process_audio, methods=["POST"])
|
345 |
+
|
346 |
+
if "whisper-stt" in modules:
|
347 |
+
print("Initializing Whisper speech-recognition (from ST request file)")
|
348 |
+
whisper_model_path = (
|
349 |
+
args.stt_whisper_model_path
|
350 |
+
if args.stt_whisper_model_path
|
351 |
+
else None)
|
352 |
+
|
353 |
+
import modules.speech_recognition.whisper_module as whisper_module
|
354 |
+
|
355 |
+
whisper_module.model = whisper_module.load_model(file_path=whisper_model_path)
|
356 |
+
app.add_url_rule("/api/speech-recognition/whisper/process-audio", view_func=whisper_module.process_audio, methods=["POST"])
|
357 |
+
|
358 |
+
if "streaming-stt" in modules:
|
359 |
+
print("Initializing vosk/whisper speech-recognition (from extras server microphone)")
|
360 |
+
whisper_model_path = (
|
361 |
+
args.stt_whisper_model_path
|
362 |
+
if args.stt_whisper_model_path
|
363 |
+
else None)
|
364 |
+
|
365 |
+
import modules.speech_recognition.streaming_module as streaming_module
|
366 |
|
367 |
+
streaming_module.whisper_model, streaming_module.vosk_model = streaming_module.load_model(file_path=whisper_model_path)
|
368 |
+
app.add_url_rule("/api/speech-recognition/streaming/record-and-transcript", view_func=streaming_module.record_and_transcript, methods=["POST"])
|
369 |
+
|
370 |
+
if "rvc" in modules:
|
371 |
+
print("Initializing RVC voice conversion (from ST request file)")
|
372 |
+
print("Increasing server upload limit")
|
373 |
+
rvc_save_file = (
|
374 |
+
args.rvc_save_file
|
375 |
+
if args.rvc_save_file
|
376 |
+
else False)
|
377 |
+
|
378 |
+
if rvc_save_file:
|
379 |
+
print("RVC saving file option detected, input/output audio will be savec into data/tmp/ folder")
|
380 |
+
|
381 |
+
import sys
|
382 |
+
sys.path.insert(0,'modules/voice_conversion')
|
383 |
+
|
384 |
+
import modules.voice_conversion.rvc_module as rvc_module
|
385 |
+
rvc_module.save_file = rvc_save_file
|
386 |
+
|
387 |
+
if "classify" in modules:
|
388 |
+
rvc_module.classification_mode = True
|
389 |
+
|
390 |
+
rvc_module.fix_model_install()
|
391 |
+
app.add_url_rule("/api/voice-conversion/rvc/get-models-list", view_func=rvc_module.rvc_get_models_list, methods=["POST"])
|
392 |
+
app.add_url_rule("/api/voice-conversion/rvc/upload-models", view_func=rvc_module.rvc_upload_models, methods=["POST"])
|
393 |
+
app.add_url_rule("/api/voice-conversion/rvc/process-audio", view_func=rvc_module.rvc_process_audio, methods=["POST"])
|
394 |
+
|
395 |
+
|
396 |
+
if "coqui-tts" in modules:
|
397 |
+
mode = "GPU" if args.coqui_gpu else "CPU"
|
398 |
+
print("Initializing Coqui TTS client in " + mode + " mode")
|
399 |
+
import modules.text_to_speech.coqui.coqui_module as coqui_module
|
400 |
+
|
401 |
+
if mode == "GPU":
|
402 |
+
coqui_module.gpu_mode = True
|
403 |
+
|
404 |
+
coqui_models = (
|
405 |
+
args.coqui_models
|
406 |
+
if args.coqui_models
|
407 |
+
else None
|
408 |
+
)
|
409 |
+
|
410 |
+
if coqui_models is not None:
|
411 |
+
coqui_models = coqui_models.split(",")
|
412 |
+
for i in coqui_models:
|
413 |
+
if not coqui_module.install_model(i):
|
414 |
+
raise ValueError("Coqui model loading failed, most likely a wrong model name in --coqui-models argument, check log above to see which one")
|
415 |
+
|
416 |
+
# Coqui-api models
|
417 |
+
app.add_url_rule("/api/text-to-speech/coqui/coqui-api/check-model-state", view_func=coqui_module.coqui_check_model_state, methods=["POST"])
|
418 |
+
app.add_url_rule("/api/text-to-speech/coqui/coqui-api/install-model", view_func=coqui_module.coqui_install_model, methods=["POST"])
|
419 |
+
|
420 |
+
# Users models
|
421 |
+
app.add_url_rule("/api/text-to-speech/coqui/local/get-models", view_func=coqui_module.coqui_get_local_models, methods=["POST"])
|
422 |
+
|
423 |
+
# Handle both coqui-api/users models
|
424 |
+
app.add_url_rule("/api/text-to-speech/coqui/generate-tts", view_func=coqui_module.coqui_generate_tts, methods=["POST"])
|
425 |
|
426 |
def require_module(name):
|
427 |
def wrapper(fn):
|
|
|
438 |
|
439 |
# AI stuff
|
440 |
def classify_text(text: str) -> list:
|
441 |
+
return classify_module.classify_text_emotion(text)
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
|
444 |
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
|
|
537 |
return img_str
|
538 |
|
539 |
|
540 |
+
ignore_auth = []
|
541 |
# [HF, Huggingface] Get password instead of text file.
|
542 |
api_key = os.environ.get("password")
|
543 |
|
|
|
549 |
return True
|
550 |
return False
|
551 |
|
552 |
+
|
553 |
@app.before_request
|
554 |
def before_request():
|
555 |
# Request time measuring
|
|
|
653 |
classification = classify_text(data["text"])
|
654 |
print("Classification output:", classification, sep="\n")
|
655 |
gc.collect()
|
656 |
+
if "talkinghead" in modules: #send emotion to talkinghead
|
657 |
+
talkinghead.setEmotion(classification)
|
658 |
return jsonify({"classification": classification})
|
659 |
|
660 |
|
|
|
663 |
def api_classify_labels():
|
664 |
classification = classify_text("")
|
665 |
labels = [x["label"] for x in classification]
|
666 |
+
if "talkinghead" in modules:
|
667 |
+
labels.append('talkinghead') # Add 'talkinghead' to the labels list
|
668 |
return jsonify({"labels": labels})
|
669 |
|
670 |
+
@app.route("/api/talkinghead/load", methods=["POST"])
|
671 |
+
def live_load():
|
672 |
+
file = request.files['file']
|
673 |
+
# convert stream to bytes and pass to talkinghead_load
|
674 |
+
return talkinghead.talkinghead_load_file(file.stream)
|
675 |
+
|
676 |
+
@app.route('/api/talkinghead/unload')
|
677 |
+
def live_unload():
|
678 |
+
return talkinghead.unload()
|
679 |
+
|
680 |
+
@app.route('/api/talkinghead/start_talking')
|
681 |
+
def start_talking():
|
682 |
+
return talkinghead.start_talking()
|
683 |
+
|
684 |
+
@app.route('/api/talkinghead/stop_talking')
|
685 |
+
def stop_talking():
|
686 |
+
return talkinghead.stop_talking()
|
687 |
+
|
688 |
+
@app.route('/api/talkinghead/result_feed')
|
689 |
+
def result_feed():
|
690 |
+
return talkinghead.result_feed()
|
691 |
|
692 |
@app.route("/api/image", methods=["POST"])
|
693 |
@require_module("sd")
|
|
|
1104 |
cloudflare = _run_cloudflared(port, metrics_port)
|
1105 |
else:
|
1106 |
cloudflare = _run_cloudflared(port)
|
1107 |
+
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
|
1108 |
|
1109 |
ignore_auth.append(tts_play_sample)
|
1110 |
+
ignore_auth.append(result_feed)
|
1111 |
app.run(host=host, port=port)
|