TomatoCocotree commited on
Commit
8056b16
·
1 Parent(s): 19400f8

更新server.py

Browse files
Files changed (1) hide show
  1. server.py +170 -23
server.py CHANGED
@@ -86,13 +86,23 @@ parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=Tru
86
  parser.add_argument(
87
  "--secure", action="store_true", help="Enforces the use of an API key"
88
  )
 
 
 
 
 
 
 
 
 
 
89
  sd_group = parser.add_mutually_exclusive_group()
90
 
91
- local_sd = sd_group.add_argument_group("sd-local")
92
  local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
93
  local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
94
 
95
- remote_sd = sd_group.add_argument_group("sd-remote")
96
  remote_sd.add_argument(
97
  "--sd-remote", action="store_true", help="Use a remote backend for SD"
98
  )
@@ -119,7 +129,7 @@ parser.add_argument(
119
  )
120
 
121
  args = parser.parse_args()
122
- # [HF, Huggingface] Set port to 7860, set host to remote.
123
  port = 7860
124
  host = "0.0.0.0"
125
  summarization_model = (
@@ -170,6 +180,28 @@ if not torch.cuda.is_available() and not args.cpu:
170
 
171
  print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  if "caption" in modules:
174
  print("Initializing an image captioning model...")
175
  captioning_processor = AutoProcessor.from_pretrained(captioning_model)
@@ -189,16 +221,6 @@ if "summarize" in modules:
189
  summarization_model, torch_dtype=torch_dtype
190
  ).to(device)
191
 
192
- if "classify" in modules:
193
- print("Initializing a sentiment classification pipeline...")
194
- classification_pipe = pipeline(
195
- "text-classification",
196
- model=classification_model,
197
- top_k=None,
198
- device=device,
199
- torch_dtype=torch_dtype,
200
- )
201
-
202
  if "sd" in modules and not sd_use_remote:
203
  from diffusers import StableDiffusionPipeline
204
  from diffusers import EulerAncestralDiscreteScheduler
@@ -251,7 +273,6 @@ if "silero-tts" in modules:
251
  tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
252
  tts_service.generate_samples()
253
 
254
-
255
  if "edge-tts" in modules:
256
  print("Initializing Edge TTS client")
257
  import tts_edge as edge
@@ -295,8 +316,112 @@ if "chromadb" in modules:
295
  app = Flask(__name__)
296
  CORS(app) # allow cross-domain requests
297
  Compress(app) # compress responses
298
- app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  def require_module(name):
302
  def wrapper(fn):
@@ -313,12 +438,7 @@ def require_module(name):
313
 
314
  # AI stuff
315
  def classify_text(text: str) -> list:
316
- output = classification_pipe(
317
- text,
318
- truncation=True,
319
- max_length=classification_pipe.model.config.max_position_embeddings,
320
- )[0]
321
- return sorted(output, key=lambda x: x["score"], reverse=True)
322
 
323
 
324
  def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
@@ -417,7 +537,7 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
417
  return img_str
418
 
419
 
420
- ignore_auth = []
421
  # [HF, Huggingface] Get password instead of text file.
422
  api_key = os.environ.get("password")
423
 
@@ -429,6 +549,7 @@ def is_authorize_ignored(request):
429
  return True
430
  return False
431
 
 
432
  @app.before_request
433
  def before_request():
434
  # Request time measuring
@@ -532,6 +653,8 @@ def api_classify():
532
  classification = classify_text(data["text"])
533
  print("Classification output:", classification, sep="\n")
534
  gc.collect()
 
 
535
  return jsonify({"classification": classification})
536
 
537
 
@@ -540,8 +663,31 @@ def api_classify():
540
  def api_classify_labels():
541
  classification = classify_text("")
542
  labels = [x["label"] for x in classification]
 
 
543
  return jsonify({"labels": labels})
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
  @app.route("/api/image", methods=["POST"])
547
  @require_module("sd")
@@ -958,7 +1104,8 @@ if args.share:
958
  cloudflare = _run_cloudflared(port, metrics_port)
959
  else:
960
  cloudflare = _run_cloudflared(port)
961
- print("Running on", cloudflare)
962
 
963
  ignore_auth.append(tts_play_sample)
 
964
  app.run(host=host, port=port)
 
86
  parser.add_argument(
87
  "--secure", action="store_true", help="Enforces the use of an API key"
88
  )
89
+ parser.add_argument("--talkinghead-gpu", action="store_true", help="Run the talkinghead animation on the GPU (CPU is default)")
90
+
91
+ parser.add_argument("--coqui-gpu", action="store_true", help="Run the voice models on the GPU (CPU is default)")
92
+ parser.add_argument("--coqui-models", help="Install given Coqui-api TTS model at launch (comma separated list, last one will be loaded at start)")
93
+
94
+ parser.add_argument("--max-content-length", help="Set the max")
95
+ parser.add_argument("--rvc-save-file", action="store_true", help="Save the last rvc input/output audio file into data/tmp/ folder (for research)")
96
+
97
+ parser.add_argument("--stt-vosk-model-path", help="Load a custom vosk speech-to-text model")
98
+ parser.add_argument("--stt-whisper-model-path", help="Load a custom vosk speech-to-text model")
99
  sd_group = parser.add_mutually_exclusive_group()
100
 
101
+ local_sd = parser.add_argument_group("sd-local")
102
  local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
103
  local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
104
 
105
+ remote_sd = parser.add_argument_group("sd-remote")
106
  remote_sd.add_argument(
107
  "--sd-remote", action="store_true", help="Use a remote backend for SD"
108
  )
 
129
  )
130
 
131
  args = parser.parse_args()
132
+ # [HF, Huggingface] Set port to 7860, set host to remote.
133
  port = 7860
134
  host = "0.0.0.0"
135
  summarization_model = (
 
180
 
181
  print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
182
 
183
+ if "talkinghead" in modules:
184
+ import sys
185
+ import threading
186
+ mode = "cuda" if args.talkinghead_gpu else "cpu"
187
+ print("Initializing talkinghead pipeline in " + mode + " mode....")
188
+ talkinghead_path = os.path.abspath(os.path.join(os.getcwd(), "talkinghead"))
189
+ sys.path.append(talkinghead_path) # Add the path to the 'tha3' module to the sys.path list
190
+
191
+ try:
192
+ import talkinghead.tha3.app.app as talkinghead
193
+ from talkinghead import *
194
+ def launch_talkinghead_gui():
195
+ talkinghead.launch_gui(mode, "separable_float")
196
+ #choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
197
+ #choices='The device to use for PyTorch ("cuda" for GPU, "cpu" for CPU).'
198
+ talkinghead_thread = threading.Thread(target=launch_talkinghead_gui)
199
+ talkinghead_thread.daemon = True # Set the thread as a daemon thread
200
+ talkinghead_thread.start()
201
+
202
+ except ModuleNotFoundError:
203
+ print("Error: Could not import the 'talkinghead' module.")
204
+
205
  if "caption" in modules:
206
  print("Initializing an image captioning model...")
207
  captioning_processor = AutoProcessor.from_pretrained(captioning_model)
 
221
  summarization_model, torch_dtype=torch_dtype
222
  ).to(device)
223
 
 
 
 
 
 
 
 
 
 
 
224
  if "sd" in modules and not sd_use_remote:
225
  from diffusers import StableDiffusionPipeline
226
  from diffusers import EulerAncestralDiscreteScheduler
 
273
  tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
274
  tts_service.generate_samples()
275
 
 
276
  if "edge-tts" in modules:
277
  print("Initializing Edge TTS client")
278
  import tts_edge as edge
 
316
  app = Flask(__name__)
317
  CORS(app) # allow cross-domain requests
318
  Compress(app) # compress responses
319
+ app.config["MAX_CONTENT_LENGTH"] = 500 * 1024 * 1024
320
+
321
+ max_content_length = (
322
+ args.max_content_length
323
+ if args.max_content_length
324
+ else None)
325
+
326
+ if max_content_length is not None:
327
+ print("Setting MAX_CONTENT_LENGTH to",max_content_length,"Mb")
328
+ app.config["MAX_CONTENT_LENGTH"] = int(max_content_length) * 1024 * 1024
329
+
330
+ if "classify" in modules:
331
+ import modules.classify.classify_module as classify_module
332
+ classify_module.init_text_emotion_classifier(classification_model, device, torch_dtype)
333
+
334
+ if "vosk-stt" in modules:
335
+ print("Initializing Vosk speech-recognition (from ST request file)")
336
+ vosk_model_path = (
337
+ args.stt_vosk_model_path
338
+ if args.stt_vosk_model_path
339
+ else None)
340
+
341
+ import modules.speech_recognition.vosk_module as vosk_module
342
+
343
+ vosk_module.model = vosk_module.load_model(file_path=vosk_model_path)
344
+ app.add_url_rule("/api/speech-recognition/vosk/process-audio", view_func=vosk_module.process_audio, methods=["POST"])
345
+
346
+ if "whisper-stt" in modules:
347
+ print("Initializing Whisper speech-recognition (from ST request file)")
348
+ whisper_model_path = (
349
+ args.stt_whisper_model_path
350
+ if args.stt_whisper_model_path
351
+ else None)
352
+
353
+ import modules.speech_recognition.whisper_module as whisper_module
354
+
355
+ whisper_module.model = whisper_module.load_model(file_path=whisper_model_path)
356
+ app.add_url_rule("/api/speech-recognition/whisper/process-audio", view_func=whisper_module.process_audio, methods=["POST"])
357
+
358
+ if "streaming-stt" in modules:
359
+ print("Initializing vosk/whisper speech-recognition (from extras server microphone)")
360
+ whisper_model_path = (
361
+ args.stt_whisper_model_path
362
+ if args.stt_whisper_model_path
363
+ else None)
364
+
365
+ import modules.speech_recognition.streaming_module as streaming_module
366
 
367
+ streaming_module.whisper_model, streaming_module.vosk_model = streaming_module.load_model(file_path=whisper_model_path)
368
+ app.add_url_rule("/api/speech-recognition/streaming/record-and-transcript", view_func=streaming_module.record_and_transcript, methods=["POST"])
369
+
370
+ if "rvc" in modules:
371
+ print("Initializing RVC voice conversion (from ST request file)")
372
+ print("Increasing server upload limit")
373
+ rvc_save_file = (
374
+ args.rvc_save_file
375
+ if args.rvc_save_file
376
+ else False)
377
+
378
+ if rvc_save_file:
379
+ print("RVC saving file option detected, input/output audio will be savec into data/tmp/ folder")
380
+
381
+ import sys
382
+ sys.path.insert(0,'modules/voice_conversion')
383
+
384
+ import modules.voice_conversion.rvc_module as rvc_module
385
+ rvc_module.save_file = rvc_save_file
386
+
387
+ if "classify" in modules:
388
+ rvc_module.classification_mode = True
389
+
390
+ rvc_module.fix_model_install()
391
+ app.add_url_rule("/api/voice-conversion/rvc/get-models-list", view_func=rvc_module.rvc_get_models_list, methods=["POST"])
392
+ app.add_url_rule("/api/voice-conversion/rvc/upload-models", view_func=rvc_module.rvc_upload_models, methods=["POST"])
393
+ app.add_url_rule("/api/voice-conversion/rvc/process-audio", view_func=rvc_module.rvc_process_audio, methods=["POST"])
394
+
395
+
396
+ if "coqui-tts" in modules:
397
+ mode = "GPU" if args.coqui_gpu else "CPU"
398
+ print("Initializing Coqui TTS client in " + mode + " mode")
399
+ import modules.text_to_speech.coqui.coqui_module as coqui_module
400
+
401
+ if mode == "GPU":
402
+ coqui_module.gpu_mode = True
403
+
404
+ coqui_models = (
405
+ args.coqui_models
406
+ if args.coqui_models
407
+ else None
408
+ )
409
+
410
+ if coqui_models is not None:
411
+ coqui_models = coqui_models.split(",")
412
+ for i in coqui_models:
413
+ if not coqui_module.install_model(i):
414
+ raise ValueError("Coqui model loading failed, most likely a wrong model name in --coqui-models argument, check log above to see which one")
415
+
416
+ # Coqui-api models
417
+ app.add_url_rule("/api/text-to-speech/coqui/coqui-api/check-model-state", view_func=coqui_module.coqui_check_model_state, methods=["POST"])
418
+ app.add_url_rule("/api/text-to-speech/coqui/coqui-api/install-model", view_func=coqui_module.coqui_install_model, methods=["POST"])
419
+
420
+ # Users models
421
+ app.add_url_rule("/api/text-to-speech/coqui/local/get-models", view_func=coqui_module.coqui_get_local_models, methods=["POST"])
422
+
423
+ # Handle both coqui-api/users models
424
+ app.add_url_rule("/api/text-to-speech/coqui/generate-tts", view_func=coqui_module.coqui_generate_tts, methods=["POST"])
425
 
426
  def require_module(name):
427
  def wrapper(fn):
 
438
 
439
  # AI stuff
440
  def classify_text(text: str) -> list:
441
+ return classify_module.classify_text_emotion(text)
 
 
 
 
 
442
 
443
 
444
  def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
 
537
  return img_str
538
 
539
 
540
+ ignore_auth = []
541
  # [HF, Huggingface] Get password instead of text file.
542
  api_key = os.environ.get("password")
543
 
 
549
  return True
550
  return False
551
 
552
+
553
  @app.before_request
554
  def before_request():
555
  # Request time measuring
 
653
  classification = classify_text(data["text"])
654
  print("Classification output:", classification, sep="\n")
655
  gc.collect()
656
+ if "talkinghead" in modules: #send emotion to talkinghead
657
+ talkinghead.setEmotion(classification)
658
  return jsonify({"classification": classification})
659
 
660
 
 
663
  def api_classify_labels():
664
  classification = classify_text("")
665
  labels = [x["label"] for x in classification]
666
+ if "talkinghead" in modules:
667
+ labels.append('talkinghead') # Add 'talkinghead' to the labels list
668
  return jsonify({"labels": labels})
669
 
670
+ @app.route("/api/talkinghead/load", methods=["POST"])
671
+ def live_load():
672
+ file = request.files['file']
673
+ # convert stream to bytes and pass to talkinghead_load
674
+ return talkinghead.talkinghead_load_file(file.stream)
675
+
676
+ @app.route('/api/talkinghead/unload')
677
+ def live_unload():
678
+ return talkinghead.unload()
679
+
680
+ @app.route('/api/talkinghead/start_talking')
681
+ def start_talking():
682
+ return talkinghead.start_talking()
683
+
684
+ @app.route('/api/talkinghead/stop_talking')
685
+ def stop_talking():
686
+ return talkinghead.stop_talking()
687
+
688
+ @app.route('/api/talkinghead/result_feed')
689
+ def result_feed():
690
+ return talkinghead.result_feed()
691
 
692
  @app.route("/api/image", methods=["POST"])
693
  @require_module("sd")
 
1104
  cloudflare = _run_cloudflared(port, metrics_port)
1105
  else:
1106
  cloudflare = _run_cloudflared(port)
1107
+ print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
1108
 
1109
  ignore_auth.append(tts_play_sample)
1110
+ ignore_auth.append(result_feed)
1111
  app.run(host=host, port=port)